In [1]:
import os
from torch.utils.data import TensorDataset,ConcatDataset,DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.preprocessing import RobustScaler
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.functional import relu
from torch.utils.data import Dataset, DataLoader

from tqdm import tqdm
from sklearn.metrics import f1_score
import numpy as np

from lib.utils import get_dataloader
from lib.utils import SleepStageClassifier
from lib.utils import ekyn_ids,snezana_mice_ids,courtney_ids
from sklearn.metrics import ConfusionMatrixDisplay,classification_report
from lib.utils import calculate_f1,plot_training_progress

batch_size = 1024

ekyn_ids = ekyn_ids[:8]
snezana_mice_ids = snezana_mice_ids[:8]
courtney_ids = courtney_ids[:8]
print(ekyn_ids,snezana_mice_ids,courtney_ids)

train_ekyn_ids,test_ekyn_ids = ekyn_ids[:-len(ekyn_ids)//4],ekyn_ids[-len(ekyn_ids)//4:]
print(len(train_ekyn_ids),len(test_ekyn_ids),train_ekyn_ids,test_ekyn_ids)
train_snezana_mice_ids,test_snezana_mice_ids = snezana_mice_ids[:-len(snezana_mice_ids)//4],snezana_mice_ids[-len(snezana_mice_ids)//4:]
print(len(train_snezana_mice_ids),len(test_snezana_mice_ids),train_snezana_mice_ids,test_snezana_mice_ids)
train_courtney_ids,test_courtney_ids = courtney_ids[:-len(courtney_ids)//4],courtney_ids[-len(courtney_ids)//4:]
print(len(train_courtney_ids),len(test_courtney_ids),train_courtney_ids,test_courtney_ids)

['A1-0', 'A1-1', 'A4-0', 'B1-0', 'B3-1', 'C1-0', 'C4-0', 'C4-1'] ['21-HET-1', '21-HET-10', '21-HET-11', '21-HET-12', '21-HET-13', '21-HET-2', '21-HET-3', '21-HET-4'] ['22-Aug-B', '22-Aug-D', '22-Aug-E', '22-Aug-F', '22-Aug-H', '22-Oct-D', '22-Oct-E', '22-Oct-G']
6 2 ['A1-0', 'A1-1', 'A4-0', 'B1-0', 'B3-1', 'C1-0'] ['C4-0', 'C4-1']
6 2 ['21-HET-1', '21-HET-10', '21-HET-11', '21-HET-12', '21-HET-13', '21-HET-2'] ['21-HET-3', '21-HET-4']
6 2 ['22-Aug-B', '22-Aug-D', '22-Aug-E', '22-Aug-F', '22-Aug-H', '22-Oct-D'] ['22-Oct-E', '22-Oct-G']


In [2]:
# Define the CNN model with two layers and multi-level classifiers
class BaselineCNN(nn.Module):
    def __init__(self, input_channels=1, num_classes=5):
        super(BaselineCNN, self).__init__()
        
        # First conv layer
        self.conv1 = nn.Conv1d(in_channels=input_channels, out_channels=16, kernel_size=5, stride=1, padding=2)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)  # Downsample by 2
        
        # Second conv layer
        self.conv2 = nn.Conv1d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)  # Downsample by 2 again
        
        # Third conv layer
        self.conv3 = nn.Conv1d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2)
        self.relu3 = nn.ReLU()
        self.pool3 = nn.MaxPool1d(kernel_size=2, stride=2)  # Downsample by 2
    
        self.head3 = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.Linear(64, num_classes)
        )
    def forward(self, x):
        # First layer
        x1 = self.pool1(self.relu1(self.conv1(x)))
        
        # Second layer
        x2 = self.pool2(self.relu2(self.conv2(x1)))
        
        # Third layer
        x3 = self.pool3(self.relu3(self.conv3(x2)))
        out3 = self.head3(x3)  # Third prediction
        
        return out3
class MultiLevelCNN(nn.Module):
    def __init__(self, input_channels=1, num_classes=5):
        super(MultiLevelCNN, self).__init__()
        
        # First conv layer
        self.conv1 = nn.Conv1d(in_channels=input_channels, out_channels=16, kernel_size=5, stride=1, padding=2)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)  # Downsample by 2
        
        # Second conv layer
        self.conv2 = nn.Conv1d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)  # Downsample by 2 again
        
        # Third conv layer
        self.conv3 = nn.Conv1d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2)
        self.relu3 = nn.ReLU()
        self.pool3 = nn.MaxPool1d(kernel_size=2, stride=2)  # Downsample by 2
        
        # Classifier heads for each latent space
        self.head1 = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),  # Global average pooling to 1 value per channel
            nn.Flatten(),
            nn.Linear(16, num_classes)
        )
        
        self.head2 = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.Linear(32, num_classes)
        )
        
        self.head3 = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.Linear(64, num_classes)
        )
    
    def forward(self, x):
        # First layer
        x1 = self.pool1(self.relu1(self.conv1(x)))
        out1 = self.head1(x1)  # First prediction
        
        # Second layer
        x2 = self.pool2(self.relu2(self.conv2(x1)))
        out2 = self.head2(x2)  # Second prediction
        
        # Third layer
        x3 = self.pool3(self.relu3(self.conv3(x2)))
        out3 = self.head3(x3)  # Third prediction
        
        return out1, out2, out3  # Return predictions from all levels
    
# model = MultiLevelCNN(input_channels=1, num_classes=3)
model = BaselineCNN(input_channels=1, num_classes=3)
trainloader = get_dataloader(train_ekyn_ids[:1],snezana_mice_ids=None,courtney_ids=None,batch_size=batch_size,shuffle=True,downsample=False)
testloader = get_dataloader(train_ekyn_ids[1:2],snezana_mice_ids=None,courtney_ids=None,batch_size=batch_size,shuffle=True,downsample=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
optimizer = torch.optim.AdamW(model.parameters(),lr=3e-4)
criterion = nn.CrossEntropyLoss()

In [3]:
validation_frequency_epochs = 1
best_dev_loss = torch.inf
best_dev_loss_epoch = 0
best_dev_f1 = 0
best_dev_f1_epoch = 0
ma_window_size = 10

model.to(device)
model.train()

trainlossi = []
testlossi = []
train_f1s = []
test_f1s = []
progressbar = tqdm(range(1000))
for epoch in progressbar:
    epoch_train_f1s = []  # Collect F1s for each batch in this epoch
    epoch_train_losses = []  # Collect losses for each batch in this epoch
    
    for Xi, yi in trainloader:
        Xi, yi = Xi.to(device), yi.to(device)
        outputs = model(Xi)
        optimizer.zero_grad()
        if isinstance(outputs,torch.Tensor):
            outputs = [outputs]
        losses = [criterion(output,yi) for output in outputs]
        loss = sum(losses)
        loss.backward()
        optimizer.step()

        # Calculate and store F1 for this batch
        batch_f1 = calculate_f1(outputs[-1], yi)
        epoch_train_f1s.append(batch_f1)
        epoch_train_losses.append(loss.item())
    
    # Add average loss and F1 for this epoch
    trainlossi.extend(epoch_train_losses)
    train_f1s.append(np.mean(epoch_train_f1s))
    
    if epoch % validation_frequency_epochs == 0:
        model.eval()
        all_test_preds = []
        all_test_labels = []
        test_losses = []

        with torch.no_grad():
            for Xi, yi in testloader:
                Xi, yi = Xi.to(device), yi.to(device)
                outputs = model(Xi)
                if isinstance(outputs,torch.Tensor):
                    outputs = [outputs]
                losses = [criterion(output,yi) for output in outputs]
                loss = sum(losses)
                test_losses.append(loss.item())
                all_test_preds.append(outputs[-1].argmax(dim=1))
                all_test_labels.append(yi.argmax(dim=1))
            all_test_preds = torch.cat(all_test_preds).cpu()
            all_test_labels = torch.cat(all_test_labels).cpu()
            avg_test_loss = np.mean(test_losses)
            test_f1 = f1_score(all_test_labels, all_test_preds, average='macro')

            testlossi.append(avg_test_loss)
            test_f1s.append(test_f1)

            # Track best model by loss
            if avg_test_loss < best_dev_loss:
                best_dev_loss = avg_test_loss
                best_dev_loss_epoch = epoch
                torch.save(model.state_dict(), '../models/best_model_by_loss.pt')
                
            # Track best model by F1
            if test_f1 > best_dev_f1:
                best_dev_f1 = test_f1
                best_dev_f1_epoch = epoch
                torch.save(model.state_dict(), '../models/best_model_by_f1.pt')

            progressbar.set_description(
                f"Epoch {epoch}: Train Loss: {np.mean(epoch_train_losses):.4f}, Test Loss: {avg_test_loss:.4f} | "
                f"Train F1: {train_f1s[-1]:.4f}, Test F1: {test_f1:.4f} | "
                f"Best Test Loss: {best_dev_loss:.4f} (Ep {best_dev_loss_epoch}), "
                f"Best Test F1: {best_dev_f1:.4f} (Ep {best_dev_f1_epoch})"
            )
    # Call the updated plotting function with both loss and F1 data
    plot_training_progress(
        trainlossi,
        testlossi,
        train_f1s,
        test_f1s,
        ma_window_size,
        '../models/training_metrics.jpg'
    )

Epoch 999: Train Loss: 0.2201, Test Loss: 0.2719 | Train F1: 0.8874, Test F1: 0.7509 | Best Test Loss: 0.2571 (Ep 790), Best Test F1: 0.7771 (Ep 977): 100%|██████████| 1000/1000 [36:52<00:00,  2.21s/it]


In [4]:
model

BaselineCNN(
  (conv1): Conv1d(1, 16, kernel_size=(5,), stride=(1,), padding=(2,))
  (relu1): ReLU()
  (pool1): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv1d(16, 32, kernel_size=(5,), stride=(1,), padding=(2,))
  (relu2): ReLU()
  (pool2): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv1d(32, 64, kernel_size=(5,), stride=(1,), padding=(2,))
  (relu3): ReLU()
  (pool3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (head3): Sequential(
    (0): AdaptiveAvgPool1d(output_size=1)
    (1): Flatten(start_dim=1, end_dim=-1)
    (2): Linear(in_features=64, out_features=3, bias=True)
  )
)