## Data loading

In [4]:
import h5py
import numpy as np
from pathlib import Path

In [None]:
colab = True

if colab == True:
    from google.colab import drive
    drive.mount('/content/drive')
    !unzip '/content/drive/My Drive/Colab Notebooks/data.zip' -d '/content/data'
    data_dir = Path("/content/data")
else:
    data_dir = Path("./data/")
assert data_dir.is_dir()
intra_dir = data_dir / "Intra"
cross_dir = data_dir / "Cross"
intra_train_glob = list((intra_dir / "train").glob("*.h5"))
intra_test_glob = list((intra_dir / "test").glob("*.h5"))
len(intra_train_glob)

In [6]:
def load_labels(path: Path) -> np.ndarray:
    *task, subject_identifier, chunk = path.stem.split("_")
    if "rest" in task:
        y = 0
    elif 'math' in task:
        y = 1
    elif 'working' in task:
        y = 2
    elif 'motor' in task:
        y = 3
    else:
        assert False, 'unknown task'
    return np.array([y, int(subject_identifier), int(chunk)])

In [7]:
def load_h5(path: Path) -> np.ndarray:
    with h5py.File(path) as f:
        keys = f.keys()
        assert len(keys) == 1, f"Only one key per file, right? {intra_train_glob[0]}"
        matrix = f.get(next(iter(keys)))[()]
    return matrix


intra_train_X = np.stack(list(map(load_h5, intra_train_glob)))
intra_train_labels = np.stack(list(map(load_labels, intra_train_glob)))
intra_train_X.shape, intra_train_labels.shape

((32, 248, 35624), (32, 3))

In [8]:
intra_test_X = np.stack(list(map(load_h5, intra_test_glob)))
intra_test_labels = np.stack(list(map(load_labels, intra_test_glob)))
intra_test_X.shape, intra_test_labels.shape

((8, 248, 35624), (8, 3))

## Data preprocessing

In [26]:
import numpy as np
import torch 

def downsample(data, old_freq, new_freq):
    # Calculate the downsampling factor
    downsample_factor = int(np.round(old_freq / new_freq))
    # Ensure that timesteps are divisible by the downsampling factor
    data = data[:,:,:data.shape[2]//downsample_factor*downsample_factor]
    # Reshape
    reshaped_data = data.reshape(data.shape[0], data.shape[1], -1, downsample_factor)
    # Take the mean along the last axis
    downsampled_data = reshaped_data.mean(axis=-1)
    return downsampled_data

def z_score_normalize(data):
    # Convert to PyTorch tensor
    data_tensor = torch.tensor(data, dtype=torch.float32)
    # Calculate mean and std along the timesteps
    mean = torch.mean(data_tensor, dim=2, keepdim=True)
    std = torch.std(data_tensor, dim=2, keepdim=True)
    # Perform z-score norm
    normalized_data = (data_tensor - mean) / std
    return normalized_data

intra_train_X_downsampled = downsample(intra_train_X, 2034, 125)
intra_train_X_norm = z_score_normalize(intra_train_X_downsampled)

intra_test_X_downsampled = downsample(intra_test_X, 2034, 125)
intra_test_X_norm = z_score_normalize(intra_test_X_downsampled)

print(intra_train_X_norm.shape)

## VAR-CNN Architecture
implemented from: https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6609925/

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam

#Define based on the input data shape
n_classes = 4
input_channels = 248
input_height = 2226  
input_width = 1  
k = 10  # Number of kernels
l = 5  # Kernel height

# Define the neural network module
class VectorAutoregressiveCNN(nn.Module):
    def __init__(self, input_channels, k, l, n_classes):
        super(VectorAutoregressiveCNN, self).__init__()
        #2D Conv
        self.conv = nn.Conv2d(input_channels, k, (l, k)) 
        # Max Pooling
        self.pool = nn.MaxPool2d((2, 2), stride=(2, 2))
        # Calculate output shape after conv and pool
        conv_output_height = (input_height - l + 1) // 2
        conv_output_width = (input_width - k + 1) // 2
        ninputs = k * conv_output_height * conv_output_width
        #Fully Connected Layer
        self.fc = nn.Linear(ninputs, n_classes)
        self.l1_penalty = 3e-4

    def forward(self, x):
        x = self.conv(x)
        x = F.relu(x)
        x = self.pool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return F.log_softmax(x, dim=1)

    def l1_regularization(self):
        l1_norm = sum(p.abs().sum() for p in self.parameters())
        return self.l1_penalty * l1_norm

# Instantiate the model
model = VectorAutoregressiveCNN(input_channels, k, l, n_classes)

# Define the optimizer and loss function
optimizer = Adam(model.parameters(), lr=3e-4)
loss_function = nn.CrossEntropyLoss()

## Training

In [None]:
import numpy as np
from sklearn.model_selection import KFold
from torch.utils.data import DataLoader, Subset
from torch.utils.data.dataset import TensorDataset

criterion = nn.CrossEntropyLoss()
batch_size = 8
num_epochs = 1

dataset = TensorDataset(intra_train_X_norm, intra_train_labels)

# 5fold cross-validation
kf = KFold(n_splits=5, shuffle=True, random_state=42)

# Variables for early stopping
early_stopping_patience = 100
best_loss = np.inf
patience_counter = 0

# Training and validation loop
for fold, (train_index, val_index) in enumerate(kf.split(intra_train_X_norm, intra_train_labels)):
    print(f"Fold {fold + 1}")

    # Create data loaders for current fold
    train_data = Subset(dataset, train_index)
    val_data = Subset(dataset, val_index)
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)

    # Reset the early stopping patience
    patience_counter = 0
    best_loss = np.inf

    for epoch in range(num_epochs):

        train_loss = 0.0
        correct_predictions = 0
        total_predictions = 0

        # Training 
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            optimizer.zero_grad()
            output = model(data)
            loss = loss_function(output, target) + model.l1_regularization()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

            # Curr Accuracy
            _, predicted = torch.max(output.data, 1)
            total_predictions += target.size(0)
            correct_predictions += (predicted == target).sum().item() 
    
        # Average loss and accuracy over the epoch
        train_loss /= len(train_loader.dataset)
        train_accuracy = correct_predictions / total_predictions 

        print(f"Epoch {epoch}: Training loss: {train_loss}, Training accuracy: {train_accuracy}")          


        # Validation 
        model.eval()
        val_loss = 0.0
        correct_predictions = 0
        total_predictions = 0
        with torch.no_grad():
            for data, target in val_loader:
                output = model(data)
                val_loss += loss_function(output, target).item()  # Sum up batch loss

        val_loss /= len(val_loader.dataset)  # Get the average loss
        val_accuracy = correct_predictions / total_predictions

        print(f"Epoch {epoch}: Validation loss: {val_loss}, Validation accuracy: {val_accuracy}")

        # Early stopping logic
        if val_loss < best_loss:
            best_loss = val_loss
            patience_counter = 0
            #save the model 
            checkpoint = {
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'fold': fold,
                'epoch': epoch,
                'best_loss': best_loss}
            torch.save(model.state_dict(), f'cnn_checkpoint.pt')
            print(f"Checkpoint saved for fold {fold} at epoch {epoch}")
        else:
            patience_counter += 1
            if patience_counter >= early_stopping_patience:
                print("Early stopping triggered.")
                break



#### Further Training (optional)

Load the model checkpoint

In [None]:
model = VectorAutoregressiveCNN(input_channels, k, l, n_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

# Load the model and optimizer state_dict
checkpoint = torch.load('cnn_checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()  # For inference
model.train()  # For further training