In [None]:
import torch
import os
from torch.utils.data import TensorDataset,ConcatDataset,DataLoader
import matplotlib.pyplot as plt
from torch import nn
from torch.nn.functional import relu
from tqdm import tqdm
from sklearn.preprocessing import RobustScaler
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

device = 'cuda'
conditions = ['PF','Vehicle']
path_to_pt_ekyn = f'../pt_ekyn'
path_to_pt_snezana_mice = f'../pt_snezana_mice'

ekyn_ids = sorted(set([recording_filename.split('_')[0] for recording_filename in os.listdir(path_to_pt_ekyn)]))
snezana_mice_ids = sorted(set([recording_filename.split('.')[0] for recording_filename in os.listdir(path_to_pt_snezana_mice)]))
print(len(ekyn_ids),ekyn_ids)
print(len(snezana_mice_ids),snezana_mice_ids)

def load_ekyn(id,condition):
    X,y = torch.load(f'{path_to_pt_ekyn}/{id}_{condition}.pt',weights_only=False)
    return X,y
def load_snezana_mice(id):
    X,y = torch.load(f'{path_to_pt_snezana_mice}/{id}.pt',weights_only=False)
    return X,y

class SimpleNorm(nn.Module):
    def __init__(self,eps):
        super().__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.tensor(1.0))
        self.shift = nn.Parameter(torch.tensor(0.0))
    def forward(self,x):
        mean = x.flatten().mean()
        std = x.flatten().std()
        x = (x - mean) / (std + self.eps)
        return x * self.scale + self.shift
    
class CNNSleepStager(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.norm = SimpleNorm(1e-5)
        self.c1 = nn.Conv1d(in_channels=1,out_channels=4,kernel_size=7,padding='same')
        self.c2 = nn.Conv1d(in_channels=4,out_channels=8,kernel_size=5,padding='same')
        self.c3 = nn.Conv1d(in_channels=8,out_channels=16,kernel_size=3,padding='same')
        self.c4 = nn.Conv1d(in_channels=16,out_channels=32,kernel_size=3,padding='same')

        self.mp = nn.MaxPool1d(kernel_size=2)
        self.gap = nn.AdaptiveAvgPool1d(1)

        self.fc1 = nn.Linear(in_features=32,out_features=16)
        self.classifier = nn.Linear(in_features=16,out_features=3)
    def forward(self,x):
        x = self.norm(x)

        x = self.c1(x)
        x = relu(x)
        x = self.mp(x)

        x = self.c2(x)
        x = relu(x)
        x = self.mp(x)

        x = self.c3(x)
        x = relu(x)
        x = self.mp(x)

        x = self.c4(x)
        x = relu(x)
        x = self.mp(x)

        x = self.gap(x)
        x = x.squeeze()

        x = self.fc1(x)
        x = self.classifier(x)
        return x
    def get_latent_space(self,x):
        x = self.norm(x)

        x = self.c1(x)
        x = relu(x)
        x = self.mp(x)

        x = self.c2(x)
        x = relu(x)
        x = self.mp(x)

        x = self.c3(x)
        x = relu(x)
        x = self.mp(x)

        x = self.c4(x)
        x = relu(x)
        x = self.mp(x)

        x = self.gap(x)
        x = x.squeeze()
        return x
    
model = CNNSleepStager()
optimizer = torch.optim.AdamW(model.parameters(),lr=3e-3)
criterion = nn.CrossEntropyLoss()
model.to(device)

class EEGDataset(torch.utils.data.Dataset):
    def __init__(self, id, condition):
        self.X,self.y= load_ekyn(id=id,condition=condition)
        self.X = self.X.unsqueeze(1)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx],self.y[idx]

def compute_mmd(source, target, kernel_sigma=1.0):
    """
    Compute MMD between source and target latent features using a Gaussian kernel.
    Args:
        source: Tensor of shape (batch_size, feature_dim), e.g., (32, 64)
        target: Tensor of shape (batch_size, feature_dim), e.g., (32, 64)
        kernel_sigma: Bandwidth of the Gaussian kernel
    Returns:
        MMD loss (scalar)
    """
    # Number of samples
    n_source = source.size(0)
    n_target = target.size(0)

    # Compute pairwise distances
    xx = torch.cdist(source, source, p=2) ** 2  # Source-Source distances
    yy = torch.cdist(target, target, p=2) ** 2  # Target-Target distances
    xy = torch.cdist(source, target, p=2) ** 2  # Source-Target distances

    # Gaussian kernel: exp(-distance^2 / sigma^2)
    scale = 2 * (kernel_sigma ** 2)
    k_xx = torch.exp(-xx / scale)
    k_yy = torch.exp(-yy / scale)
    k_xy = torch.exp(-xy / scale)

    # MMD: mean of kernel terms
    mmd = k_xx.mean() + k_yy.mean() - 2 * k_xy.mean()
    return mmd

traindataset = ConcatDataset([EEGDataset(id='A1-0',condition='PF')])
testdataset  = ConcatDataset([EEGDataset(id='A1-1',condition='PF')])

trainloader = DataLoader(traindataset, batch_size=512, shuffle=True)
testloader = DataLoader(testdataset, batch_size=512, shuffle=True)

Xi,yi = next(iter(trainloader))
Xi.shape,yi.shape

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNNSleepStager()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss()

print(device)

In [None]:
model.to(device)

# Hyperparameters
lambda_mmd = 4  # Weight for MMD loss
kernel_sigma = 1.0  # Gaussian kernel bandwidth
num_epochs = 200
lossi = []

# Training loop
for epoch in range(num_epochs):
    model.train()
    trainiter = iter(trainloader)
    testiter = iter(testloader)

    total_loss = 0
    num_batches = min(len(trainloader), len(testloader))

    for _ in tqdm(range(num_batches), desc=f"Epoch {epoch+1}/{num_epochs}"):
        # Get batches
        try:
            source_x, source_y = next(trainiter)
            target_x, _ = next(testiter)  # Ignore target labels for now
        except StopIteration:
            break

        source_x, source_y = source_x.to(device), source_y.to(device)
        target_x = target_x.to(device)

        # Forward pass
        optimizer.zero_grad()
        source_logits = model(source_x)  # Predictions for classification
        # source_latent = model.get_latent_space(source_x)  # Latent features
        # target_latent = model.get_latent_space(target_x)  # Latent features

        # Losses
        ce_loss = criterion(source_logits, source_y)  # Cross-entropy on source
        # mmd_loss = compute_mmd(source_latent, target_latent, kernel_sigma) * lambda_mmd  # MMD
        # loss = ce_loss + mmd_lossv
        loss = ce_loss
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        lossi.append(loss.item())
    
    avg_loss = total_loss / num_batches
    # print(f"Epoch {epoch+1}, Avg Loss: {avg_loss:.4f}, CE: {ce_loss.item():.4f}, MMD: {mmd_loss.item():.4f}")

    plt.figure()
    plt.plot(lossi)
    plt.savefig('loss.jpg')
    plt.close()

In [None]:
from sklearn.manifold import TSNE
import pandas as pd
tsne = TSNE()
Xi,yi = next(iter(trainloader))
Xi,yi = Xi.to(device),yi.to(device)

Xi_tsne = tsne.fit_transform(model.get_latent_space(Xi).detach().cpu())
df = pd.DataFrame(torch.hstack([torch.from_numpy(Xi_tsne),yi.detach().cpu().argmax(dim=1,keepdim=True)]))
sns.scatterplot(data=df,x=0,y=1,hue=2)

In [None]:
from sklearn.manifold import TSNE
import pandas as pd
tsne = TSNE()
Xi,yi = next(iter(testloader))
Xi,yi = Xi.to(device),yi.to(device)

Xi_tsne = tsne.fit_transform(model.get_latent_space(Xi).detach().cpu())
df = pd.DataFrame(torch.hstack([torch.from_numpy(Xi_tsne),yi.detach().cpu().argmax(dim=1,keepdim=True)]))
sns.scatterplot(data=df,x=0,y=1,hue=2)

In [None]:
from sklearn.manifold import TSNE
import pandas as pd
tsne = TSNE()
Xi,yi = next(iter(trainloader))
Xi,yi = Xi.to(device),yi.to(device)
Xi_tsne = tsne.fit_transform(model.get_latent_space(Xi).detach().cpu())
df = pd.DataFrame(torch.hstack([torch.from_numpy(Xi_tsne),yi.detach().cpu().argmax(dim=1,keepdim=True)]))

plt.figure(figsize=(20,20))
sns.scatterplot(data=df,x=0,y=1,hue=2,facecolors='none',marker='o')

Xi,yi = next(iter(testloader))
Xi,yi = Xi.to(device),yi.to(device)
Xi_tsne = tsne.fit_transform(model.get_latent_space(Xi).detach().cpu())
df = pd.DataFrame(torch.hstack([torch.from_numpy(Xi_tsne),yi.detach().cpu().argmax(dim=1,keepdim=True)]))
sns.scatterplot(data=df,x=0,y=1,hue=2,marker='x',s=50,linewidths=2)

In [None]:
from sklearn.metrics import ConfusionMatrixDisplay,classification_report

y = torch.vstack([torch.vstack([model(Xi.to(device)).softmax(dim=1).argmax(dim=1).detach().cpu(),yi.argmax(dim=1).detach().cpu()]).T for Xi,yi in trainloader])
y_pred = y[:,0]
y_true = y[:,1]
print(classification_report(y_true=y_true,y_pred=y_pred))
ConfusionMatrixDisplay.from_predictions(y_true,y_pred)


y = torch.vstack([torch.vstack([model(Xi.to(device)).softmax(dim=1).argmax(dim=1).detach().cpu(),yi.argmax(dim=1).detach().cpu()]).T for Xi,yi in testloader])
y_pred = y[:,0]
y_true = y[:,1]
print(classification_report(y_true=y_true,y_pred=y_pred))
ConfusionMatrixDisplay.from_predictions(y_true,y_pred)