In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import torch.optim as optim
from scipy.stats import pearsonr
import torch.nn.functional as F

In [3]:
train_data = torch.load('train.pt')
val_data = torch.load('val.pt')
test_data = torch.load('test.pt')

In [4]:
X_train, y_train = train_data[:][:1], train_data[:][1:] 
X_val, y_val = val_data[:][:1], val_data[:][1:] 
X_test, y_test = test_data[:][:1], test_data[:][1:] 

In [5]:
def prepare_dataset(data):
    eeg_list = []
    stim_list = []
    for eeg, stim in data:
        eeg_list.append(eeg.float())           # (320, 64)
        stim_list.append(stim.float())         # (320,)
    eeg_tensor = torch.stack(eeg_list)         # (N, 320, 64)
    stim_tensor = torch.stack(stim_list)       # (N, 320)
    return eeg_tensor, stim_tensor

In [6]:
X_train, y_train = prepare_dataset(train_data)
X_val, y_val = prepare_dataset(val_data)
X_test, y_test = prepare_dataset(test_data)

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [8]:
hyperparams = {
    'eeg_channels': 64,
    'seq_len': 320,
    'conv_channels': [64, 128, 256, 512, 512],  # per encoder layer
    'kernel_size': 5,
    'stride': 1,
    'pool_size': 2,
    'upsample_mode': 'nearest',
    'dropout': 0.3,
}

In [None]:
class MLAEncoder(nn.Module):
    def __init__(self, in_channels, conv_channels, kernel_size, pool_size, dropout):
        super().__init__()
        layers = []
        curr_channels = in_channels
        for out_channels in conv_channels:
            layers.append(nn.Conv1d(curr_channels, out_channels, kernel_size, padding=kernel_size//2))
            layers.append(nn.BatchNorm1d(out_channels))
            layers.append(nn.ReLU())
            layers.append(nn.MaxPool1d(pool_size))
            layers.append(nn.Dropout(dropout))
            curr_channels = out_channels
        self.encoder = nn.Sequential(*layers)

    def forward(self, x):
        return self.encoder(x)

class MLADecoder(nn.Module):
    def __init__(self, conv_channels, kernel_size, upsample_mode, dropout, output_length):
        super().__init__()
        layers = []
        channels = list(reversed(conv_channels))
        for i in range(len(channels) - 1):
            layers.append(nn.Upsample(scale_factor=2, mode=upsample_mode))
            layers.append(nn.Conv1d(channels[i], channels[i+1], kernel_size, padding=kernel_size//2))
            layers.append(nn.BatchNorm1d(channels[i+1]))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout))
        self.decoder = nn.Sequential(*layers)
        self.final_conv = nn.Conv1d(channels[-1], 1, kernel_size=1)
        self.output_length = output_length

    def forward(self, x):
        x = self.decoder(x)
        x = self.final_conv(x)
        return x.squeeze(1)[:, :self.output_length]  # Output shape: (B, T)

class MLAcodec(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.encoder = MLAEncoder(
            in_channels=config['eeg_channels'],
            conv_channels=config['conv_channels'],
            kernel_size=config['kernel_size'],
            pool_size=config['pool_size'],
            dropout=config['dropout']
        )
        self.decoder = MLADecoder(
            conv_channels=config['conv_channels'],
            kernel_size=config['kernel_size'],
            upsample_mode=config['upsample_mode'],
            dropout=config['dropout'],
            output_length=config['seq_len']
        )

    def forward(self, x):
        x = x.permute(0, 2, 1)  # (B, T, C) → (B, C, T)
        z = self.encoder(x)
        out = self.decoder(z)
        return out

In [None]:
model = MLAcodec(hyperparams)

In [None]:

batch_size = 32  # Adjust as needed for memory

train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)
test_dataset = TensorDataset(X_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

In [None]:
def pearson_loss(pred, target):
    pred = pred - pred.mean(dim=1, keepdim=True)
    target = target - target.mean(dim=1, keepdim=True)
    numerator = (pred * target).sum(dim=1)
    denominator = torch.sqrt((pred ** 2).sum(dim=1) * (target ** 2).sum(dim=1) + 1e-8)
    loss = 1 - numerator / denominator
    return loss.mean()

In [None]:
config = {
    'eeg_channels': 64,
    'seq_len': 320,
    'conv_channels': [64, 128, 256, 256],
    'kernel_size': 5,
    'pool_size': 2,
    'upsample_mode': 'nearest',
    'dropout': 0.3
}

model = MLAcodec(config).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
epochs = 20

for epoch in range(epochs):
    model.train()
    train_loss = 0
    for X_batch, y_batch in train_loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        optimizer.zero_grad()
        preds = model(X_batch)
        loss = pearson_loss(preds, y_batch)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    model.eval()
    val_loss = 0
    with torch.no_grad():
        for X_batch, y_batch in val_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            preds = model(X_batch)
            loss = pearson_loss(preds, y_batch)
            val_loss += loss.item()

    print(f"Epoch {epoch+1}/{epochs} | Train Loss: {train_loss/len(train_loader):.4f} | Val Loss: {val_loss/len(val_loader):.4f}")

In [None]:
def evaluate_model(model, data_loader, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for x, y in data_loader:
            x, y = x.to(device), y.to(device)
            y_pred = model(x)

            loss = criterion(y_pred, y)
            total_loss += loss.item()

            all_preds.append(y_pred.cpu())
            all_targets.append(y.cpu())

    # Concatenate all batches
    preds = torch.cat(all_preds, dim=0).numpy().flatten()
    targets = torch.cat(all_targets, dim=0).numpy().flatten()

    # Pearson correlation
    pearson_corr = pearsonr(preds, targets)[0]

    # Cosine similarity
    cosine_sim = F.cosine_similarity(
        torch.tensor(preds), torch.tensor(targets), dim=0
    ).item()

    return total_loss / len(data_loader), pearson_corr, cosine_sim

In [None]:
# Final Test Evaluation
model.eval()
with torch.no_grad():
    test_loss = sum(criterion(model(x.to(device)), y.to(device)).item() for x, y in test_loader) / len(test_loader)
print(f"\n✅ Final Test Loss: {test_loss:.4f}")

In [None]:
val_loss, val_pearson, val_cosine = evaluate_model(model, val_loader, device)
print(
    f"val data: Epoch {epochs:02d} | Val Loss: {val_loss:.4f} | Pearson: {val_pearson:.4f} | Cosine: {val_cosine:.4f}"
)

val_loss, val_pearson, val_cosine = evaluate_model(model, test_loader, device)
print(
    f"test data: Epoch {epochs:02d} | test Loss: {val_loss:.4f} | Pearson: {val_pearson:.4f} | Cosine: {val_cosine:.4f}"
)

In [None]:
import matplotlib.pyplot as plt

x_sample, y_true = X_test[0:1], y_test[0:1]

model.eval()
with torch.no_grad():
    x_sample = x_sample.to(device)           # Move to GPU if available
    y_pred = model(x_sample).squeeze().cpu()  # Move prediction back to CPU for plotting

plt.figure(figsize=(10, 4))
plt.plot(y_true.squeeze().numpy(), label='True')
plt.plot(y_pred.numpy(), label='Predicted')
plt.legend()
plt.title("Stimulus Prediction from EEG")
plt.show()
