In [1]:
import torch
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
import torch.nn as nn
from tqdm import tqdm
from sklearn.cluster import KMeans
from torch.utils.data import DataLoader, TensorDataset
from scipy.stats import pearsonr
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

In [2]:
# Load data
train_data = torch.load('train.pt')
val_data = torch.load('val.pt')
test_data = torch.load('test.pt')

In [3]:
# Prepare dataset
def prepare_dataset(data):
    eeg_list = []
    stim_list = []
    for eeg, stim in data:
        eeg_list.append(eeg.float())           # (320, 64)
        stim_list.append(stim.float())         # (320,)
    eeg_tensor = torch.stack(eeg_list)         # (N, 320, 64)
    stim_tensor = torch.stack(stim_list)       # (N, 320)
    return eeg_tensor, stim_tensor

X_train, y_train = prepare_dataset(train_data)
X_val, y_val = prepare_dataset(val_data)
X_test, y_test = prepare_dataset(test_data)

In [4]:
# Pearson Correlation Loss
class PearsonCorrelationLoss(nn.Module):
    def __init__(self):
        super(PearsonCorrelationLoss, self).__init__()

    def forward(self, pred, target):
        pred = pred - pred.mean(dim=1, keepdim=True)
        target = target - target.mean(dim=1, keepdim=True)

        numerator = (pred * target).sum(dim=1)
        denominator = torch.sqrt((pred ** 2).sum(dim=1) * (target ** 2).sum(dim=1) + 1e-8)

        correlation = numerator / denominator
        return -correlation.mean()

In [5]:
# Model without skip connection
class CNNLSTMNoSkip(nn.Module):
    def __init__(self, in_channels=64, lstm_hidden_dim=64):
        super(CNNLSTMNoSkip, self).__init__()

        self.cnn = nn.Sequential(
            nn.Conv1d(in_channels, 128, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Conv1d(128, 64, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.BatchNorm1d(64)
        )

        self.lstm1 = nn.LSTM(input_size=64, hidden_size=lstm_hidden_dim, batch_first=True, bidirectional=True)
        self.lstm2 = nn.LSTM(input_size=2 * lstm_hidden_dim, hidden_size=lstm_hidden_dim, batch_first=True, bidirectional=True)
        self.regressor = nn.Linear(2 * lstm_hidden_dim, 1)

    def forward(self, x):
        x = x.permute(0, 2, 1)          # (B, 64, 320)
        x = self.cnn(x)                 # (B, 64, 320)
        x = x.permute(0, 2, 1)          # (B, 320, 64)
        x, _ = self.lstm1(x)            # (B, 320, 2*H)
        x, _ = self.lstm2(x)            # (B, 320, 2*H)
        x = self.regressor(x)          # (B, 320, 1)
        x = x.squeeze(-1)              # (B, 320)
        return x

NameError: name 'CNNLSTMWithSkip' is not defined