In [1]:
# Set the seed for reproducibility

import torch
import numpy as np
import random

def set_seed(seed_value):
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)  # For multi-GPU
    np.random.seed(seed_value)
    random.seed(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(333)

# Model Summary

In [2]:
from torchsummary import summary
from model import model_dict
m = model_dict['simplevit']

from torcheeg.models import ArjunViT
m = ArjunViT(
    num_electrodes=14,
    chunk_size=128,
    t_patch_size=128//4,
    hid_channels=256,
    depth=6,
    heads=8,
    head_channels=128,
    mlp_channels=1024,
    embed_dropout=0.2,
    dropout=0.1
)
summary(m.cuda(), (14, 128))



----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         Rearrange-1               [-1, 4, 448]               0
            Linear-2               [-1, 4, 256]         114,944
           Dropout-3               [-1, 5, 256]               0
         LayerNorm-4               [-1, 5, 256]             512
            Linear-5              [-1, 5, 3072]         786,432
           Softmax-6              [-1, 8, 5, 5]               0
           Dropout-7              [-1, 8, 5, 5]               0
            Linear-8               [-1, 5, 256]         262,400
           Dropout-9               [-1, 5, 256]               0
        Attention-10               [-1, 5, 256]               0
          PreNorm-11               [-1, 5, 256]               0
        LayerNorm-12               [-1, 5, 256]             512
           Linear-13              [-1, 5, 1024]         263,168
             GELU-14              [-1, 

# Dataloaders

In [3]:
from dataset import *
deap_raw, seed_raw, dreamer_raw = prepare_dataset(feature_type='raw_normalized', class_type='binary', overlap_percent=75)

BATCH_SIZE = 64

# deap_raw_loader = DataLoader(CustomDataset(deap_raw), batch_size=BATCH_SIZE, shuffle=False, num_workers=8)
# seed_raw_loader = DataLoader(CustomDataset(seed_raw), batch_size=BATCH_SIZE, shuffle=False, num_workers=8)
# dreamer_raw_loader = DataLoader(CustomDataset(dreamer_raw), batch_size=BATCH_SIZE, shuffle=False, num_workers=8)

trainloaders, valloaders, testloaders = prepare_dataloaders([deap_raw, seed_raw, dreamer_raw], batch_size=BATCH_SIZE, test_ratio=0.2)
deap_raw_train, deap_raw_val, deap_raw_test = trainloaders[0], valloaders[0], testloaders[0]
seed_raw_train, seed_raw_val, seed_raw_test = trainloaders[1], valloaders[1], testloaders[1]
dreamer_raw_train, dreamer_raw_val, dreamer_raw_test = trainloaders[2], valloaders[2], testloaders[2]

[2024-04-09 20:58:30] INFO (torcheeg/MainThread) 🔍 | Detected cached processing results, reading cache from ../processed_data/deap_raw_normalized_75_percent_overlap.
[2024-04-09 20:58:30] INFO (torcheeg/MainThread) 🔍 | Detected cached processing results, reading cache from ../processed_data/seed_binary_raw_normalized_75_percent_overlap.
[2024-04-09 20:58:31] INFO (torcheeg/MainThread) 🔍 | Detected cached processing results, reading cache from ../processed_data/dreamer_raw_normalized_75_percent_overlap.


In [4]:
from dataset import *
deap_grid, seed_grid, dreamer_grid = prepare_dataset(feature_type='de_grid', class_type='binary', overlap_percent=75)

BATCH_SIZE = 64

# deap_grid_loader = DataLoader(CustomDataset(deap_grid), batch_size=BATCH_SIZE, shuffle=False, num_workers=8)
# seed_grid_loader = DataLoader(CustomDataset(seed_grid), batch_size=BATCH_SIZE, shuffle=False, num_workers=8)
# dreamer_grid_loader = DataLoader(CustomDataset(dreamer_grid), batch_size=BATCH_SIZE, shuffle=False, num_workers=8)

trainloaders, valloaders, testloaders = prepare_dataloaders([deap_grid, seed_grid, dreamer_grid], batch_size=BATCH_SIZE, test_ratio=0.2)
deap_grid_train, deap_grid_val, deap_grid_test = trainloaders[0], valloaders[0], testloaders[0]
seed_grid_train, seed_grid_val, seed_grid_test = trainloaders[1], valloaders[1], testloaders[1]
dreamer_grid_train, dreamer_grid_val, dreamer_grid_test = trainloaders[2], valloaders[2], testloaders[2]

[2024-04-09 20:58:50] INFO (torcheeg/MainThread) 🔍 | Detected cached processing results, reading cache from ../processed_data/deap_de_grid_75_percent_overlap.
[2024-04-09 20:58:50] INFO (torcheeg/MainThread) 🔍 | Detected cached processing results, reading cache from ../processed_data/seed_binary_de_grid_75_percent_overlap.
[2024-04-09 20:58:51] INFO (torcheeg/MainThread) 🔍 | Detected cached processing results, reading cache from ../processed_data/dreamer_de_grid_75_percent_overlap.


In [1]:
from dataset import *
raw_datasets = prepare_dataset(feature_type='raw_normalized', class_type='binary', overlap_percent=75)
raw_trainloaders, raw_valloaders, raw_testloaders = prepare_dataloaders(
    raw_datasets, 64, 0.2
)

grid_datasets = prepare_dataset(feature_type='de_grid', class_type='binary', overlap_percent=75)
grid_trainloaders, grid_valloaders, grid_testloaders = prepare_dataloaders(
    grid_datasets, 64, 0.2
)

[2024-04-09 21:25:07] INFO (torcheeg/MainThread) 🔍 | Detected cached processing results, reading cache from ../processed_data/deap_raw_normalized_75_percent_overlap.
[2024-04-09 21:25:07] INFO (torcheeg/MainThread) 🔍 | Detected cached processing results, reading cache from ../processed_data/seed_binary_raw_normalized_75_percent_overlap.
[2024-04-09 21:25:08] INFO (torcheeg/MainThread) 🔍 | Detected cached processing results, reading cache from ../processed_data/dreamer_raw_normalized_75_percent_overlap.
[2024-04-09 21:25:09] INFO (torcheeg/MainThread) 🔍 | Detected cached processing results, reading cache from ../processed_data/deap_de_grid_75_percent_overlap.
[2024-04-09 21:25:09] INFO (torcheeg/MainThread) 🔍 | Detected cached processing results, reading cache from ../processed_data/seed_binary_de_grid_75_percent_overlap.
[2024-04-09 21:25:10] INFO (torcheeg/MainThread) 🔍 | Detected cached processing results, reading cache from ../processed_data/dreamer_de_grid_75_percent_overlap.


In [14]:
trainloaders = list(zip(raw_trainloaders, grid_trainloaders))
trainloader = trainloaders[2]
raw_trainloader, grid_trainloader = trainloader
dataset = raw_trainloader.dataset.dataset.__dict__['dataset']
dataset_name = dataset.__class__.__name__.lower().replace('dataset', '').replace('binary', '')
dataset_name

'dreamer'

# Load Checkpoint

In [5]:
from torcheeg.models import ArjunViT
from torcheeg.trainers import ClassifierTrainer

def load_checkpoint_old(checkpoint_path, chunk_size, num_channel, patch_per_chunk, num_classes):

    # Initialize the model with parameters
    model = ArjunViT(chunk_size=chunk_size,
                     t_patch_size=chunk_size // patch_per_chunk,
                     num_electrodes=num_channel,
                     num_classes=num_classes)

    # Load the checkpoint
    trainer = ClassifierTrainer.load_from_checkpoint(checkpoint_path, model=model, num_classes=num_classes)

    return trainer

In [6]:
import torch
from torcheeg.trainers import ClassifierTrainer

def load_checkpoint(checkpoint_path):

    checkpoint = torch.load(checkpoint_path)
    hparams = checkpoint['hyper_parameters']
    trainer = ClassifierTrainer(**hparams)

    return trainer

## DEAP Checkpoint

In [9]:
# deap
checkpoint_path = '../federated_construct_2/arjunvit_binary_logs/75_percent_overlap/deap/fit/lightning_logs/version_1/checkpoints/last.ckpt'
deap_trainer = load_checkpoint(checkpoint_path)

In [6]:
# print(deap_trainer.test(deap_train, enable_checkpointing=False, logger=False))
# print(deap_trainer.test(deap_val, enable_checkpointing=False, logger=False))
# print(deap_trainer.test(deap_test, enable_checkpointing=False, logger=False))

In [7]:
# deap_trainer.test(deap_loader, enable_checkpointing=False, logger=False)

## SEED Checkpoint

In [10]:
# seed
checkpoint_path = '../federated_construct_2/arjunvit_binary_logs/75_percent_overlap/seed/fit/lightning_logs/version_1/checkpoints/last.ckpt'
seed_trainer = load_checkpoint(checkpoint_path)

/home/server-171/anaconda3/envs/eeg-gpu/lib/python3.9/site-packages/pytorch_lightning/utilities/parsing.py:198: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.


In [9]:
# print(seed_trainer.test(seed_train, enable_checkpointing=False, logger=False))
# print(seed_trainer.test(seed_val, enable_checkpointing=False, logger=False))
# print(seed_trainer.test(seed_test, enable_checkpointing=False, logger=False))

In [10]:
# seed_trainer.test(seed_loader, enable_checkpointing=False, logger=False)

## DREAMER Checkpoint

In [11]:
# dreamer
checkpoint_path = '../federated_construct_2/arjunvit_binary_logs/75_percent_overlap/dreamer/fit/lightning_logs/version_1/checkpoints/last.ckpt'
dreamer_trainer = load_checkpoint(checkpoint_path)

/home/server-171/anaconda3/envs/eeg-gpu/lib/python3.9/site-packages/pytorch_lightning/utilities/parsing.py:198: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.


In [12]:
# print(dreamer_trainer.test(dreamer_train, enable_checkpointing=False, logger=False))
# print(dreamer_trainer.test(dreamer_val, enable_checkpointing=False, logger=False))
# print(dreamer_trainer.test(dreamer_test, enable_checkpointing=False, logger=False))

In [13]:
# dreamer_trainer.test(dreamer_loader, enable_checkpointing=False, logger=False)

# Knowledge Distillation

In [31]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import numpy as np
from tqdm import tqdm

from torcheeg.models import FBCCNN
from torcheeg import transforms

from torcheeg.datasets.constants.emotion_recognition.deap import DEAP_CHANNEL_LOCATION_DICT, DEAP_CHANNEL_LIST
from torcheeg.datasets.constants.emotion_recognition.seed import SEED_CHANNEL_LOCATION_DICT, SEED_CHANNEL_LIST
from torcheeg.datasets.constants.emotion_recognition.dreamer import DREAMER_CHANNEL_LOCATION_DICT, DREAMER_CHANNEL_LIST

```
t = transforms.Compose([
    transforms.BandPowerSpectralDensity(),
    transforms.ToGrid(DEAP_CHANNEL_LOCATION_DICT),
])

t = transforms.Compose([
    transforms.BandDifferentialEntropy(),
    transforms.ToGrid(DEAP_CHANNEL_LOCATION_DICT),
])
```

## Offline KD

In [32]:
def train_offline_kd(epoch, teacher_model, student_model, data_raw, data_grid, temperature=1.0, alpha=0.5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    criterion_hard = nn.CrossEntropyLoss().to(device)
    criterion_soft = nn.KLDivLoss(reduction='batchmean').to(device)
    
    student_model.to(device)
    teacher_model.to(device)
    
    optimizer = optim.Adam(student_model.parameters(), lr=0.001)
    student_model.train()  # only update student model parameters
    
    running_loss = 0.0
    correct = 0
    total = 0

    progress_bar = tqdm(enumerate(zip(data_raw, data_grid)), total=len(data_raw), desc=f"Training Epoch {epoch}", leave=True)
    for i, (raw, grid) in progress_bar:
        # Unpacking raw and grid data
        X_raw, y_raw = raw
        X_grid, y_grid = grid
        
        X_raw, y_raw = X_raw.to(device), y_raw.to(device)
        X_grid, y_grid = X_grid.to(device), y_grid.to(device)
        
        assert torch.equal(y_raw, y_grid), "Both y must be equal"
        y = y_raw
        
        optimizer.zero_grad()

        # Forward pass teacher model
        with torch.no_grad():
            teacher_outputs = teacher_model(X_raw)

        # Forward pass student model
        student_outputs = student_model(X_grid)

        # Calculate the loss for hard label
        loss_hard = criterion_hard(student_outputs, y)

        # Calculate the loss for soft label
        loss_soft = criterion_soft(
            F.log_softmax(student_outputs / temperature, dim=1),
            F.softmax(teacher_outputs / temperature, dim=1)
        )

        # Backpropagation
        loss = alpha * loss_soft + (1 - alpha) * loss_hard
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(student_outputs.data, 1)
        total += y.size(0)
        correct += (predicted == y).sum().item()

        # Update progress bar
        current_loss = running_loss / (i + 1)
        accuracy = 100 * correct / total
        progress_bar.set_postfix(Loss=f'{current_loss:.4f}', Accuracy=f'{accuracy:.2f}%')

    average_loss = running_loss / len(data_raw)
    accuracy = 100 * correct / total
    return average_loss, accuracy

def validate_offline_kd(epoch, teacher_model, student_model, data_raw, data_grid, temperature=1.0, alpha=0.5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    criterion_hard = nn.CrossEntropyLoss().to(device)
    criterion_soft = nn.KLDivLoss(reduction='batchmean').to(device)
    
    student_model.to(device)
    teacher_model.to(device)
    
    student_model.eval()  # Set the student model to evaluation mode
    
    running_loss = 0.0
    correct = 0
    total = 0

    progress_bar = tqdm(enumerate(zip(data_raw, data_grid)), total=len(data_raw), desc=f"Validation Epoch {epoch}", leave=True)
    with torch.no_grad():  # No gradients needed
        for i, (raw, grid) in progress_bar:
            # Unpacking raw and grid data
            X_raw, y_raw = raw
            X_grid, y_grid = grid
            
            X_raw, y_raw = X_raw.to(device), y_raw.to(device)
            X_grid, y_grid = X_grid.to(device), y_grid.to(device)
            
            assert torch.equal(y_raw, y_grid), "Both y must be equal"
            y = y_raw

            # Forward pass
            teacher_outputs = teacher_model(X_raw)
            student_outputs = student_model(X_grid)

            # Loss calculation
            loss_hard = criterion_hard(student_outputs, y)
            loss_soft = criterion_soft(
                F.log_softmax(student_outputs / temperature, dim=1),
                F.softmax(teacher_outputs / temperature, dim=1)
            )

            loss = alpha * loss_soft + (1 - alpha) * loss_hard

            running_loss += loss.item()
            _, predicted = torch.max(student_outputs.data, 1)
            total += y.size(0)
            correct += (predicted == y).sum().item()

            # Update progress bar
            current_loss = running_loss / (i + 1)
            accuracy = 100 * correct / total
            progress_bar.set_postfix(Loss=f'{current_loss:.4f}', Accuracy=f'{accuracy:.2f}%')

    average_loss = running_loss / len(data_raw)
    accuracy = 100 * correct / total
    return average_loss, accuracy

def test_model(model, dataloader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    criterion = nn.CrossEntropyLoss().to(device)
    
    model.to(device)
    model.eval()  # Set the model to evaluation mode
    
    running_loss = 0.0
    correct = 0
    total = 0

    progress_bar = tqdm(enumerate(dataloader), total=len(dataloader), desc="Testing", leave=True)
    with torch.no_grad():
        for i, (X, y) in progress_bar:
            X, y = X.to(device), y.to(device)
            
            outputs = model(X)
            loss = criterion(outputs, y)

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += y.size(0)
            correct += (predicted == y).sum().item()

            # Update progress bar
            current_loss = running_loss / (i + 1)
            accuracy = 100 * correct / total
            progress_bar.set_postfix(Loss=f'{current_loss:.4f}', Accuracy=f'{accuracy:.2f}%')

    average_loss = running_loss / len(dataloader)
    accuracy = 100 * correct / total
    return average_loss, accuracy

In [41]:
data_raw_train = deap_raw_train
data_raw_val = deap_raw_val
data_raw_test = deap_raw_test

data_grid_train = deap_grid_train
data_grid_val = deap_grid_val
data_grid_test = deap_grid_test

teacher_model = deap_trainer
student_model = FBCCNN(num_classes=2, in_channels=4, grid_size=(9, 9))

# Lists to store metrics for each epoch
train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []

for epoch in range(5):
    # Training
    train_loss, train_accuracy = train_offline_kd(epoch, teacher_model, student_model, data_raw_train, data_grid_train, temperature=1.0, alpha=0.5)
    train_losses.append(train_loss)
    train_accuracies.append(train_accuracy)
    
    # Validation
    val_loss, val_accuracy = validate_offline_kd(epoch, teacher_model, student_model, data_raw_val, data_grid_val, temperature=1.0, alpha=0.5)
    val_losses.append(val_loss)
    val_accuracies.append(val_accuracy)

# Testing
test_loss_teacher, test_accuracy_teacher = test_model(teacher_model, data_raw_test)
test_loss_student, test_accuracy_student = test_model(student_model, data_grid_test)

Training Epoch 0: 100%|██████████| 768/768 [00:16<00:00, 46.25it/s, Accuracy=62.41%, Loss=0.3684]
Validation Epoch 0: 100%|██████████| 192/192 [00:03<00:00, 59.96it/s, Accuracy=68.14%, Loss=0.3427]
Training Epoch 1: 100%|██████████| 768/768 [00:17<00:00, 44.02it/s, Accuracy=75.97%, Loss=0.3148]
Validation Epoch 1: 100%|██████████| 192/192 [00:02<00:00, 66.47it/s, Accuracy=78.52%, Loss=0.3059]
Training Epoch 2: 100%|██████████| 768/768 [00:15<00:00, 49.61it/s, Accuracy=84.86%, Loss=0.2719]
Validation Epoch 2: 100%|██████████| 192/192 [00:04<00:00, 47.16it/s, Accuracy=82.19%, Loss=0.2890]
Training Epoch 3: 100%|██████████| 768/768 [00:18<00:00, 42.00it/s, Accuracy=89.81%, Loss=0.2463]
Validation Epoch 3: 100%|██████████| 192/192 [00:03<00:00, 57.56it/s, Accuracy=84.76%, Loss=0.2788]
Training Epoch 4: 100%|██████████| 768/768 [00:20<00:00, 38.29it/s, Accuracy=92.29%, Loss=0.2318]
Validation Epoch 4: 100%|██████████| 192/192 [00:04<00:00, 45.22it/s, Accuracy=85.89%, Loss=0.2732]
Testing: 1

In [42]:
print(train_accuracies)
print(val_accuracies)
print(test_accuracy_teacher)
print(test_accuracy_student)

[62.406412760416664, 75.970458984375, 84.85921223958333, 89.80712890625, 92.29329427083333]
[68.1396484375, 78.52376302083333, 82.19401041666667, 84.75748697916667, 85.888671875]
68.33984375
85.70963541666667


## Online KD

In [43]:
def train_online_kd(epoch, teacher_model, student_model, data_raw, data_grid, temperature=1.0, alpha=0.5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    criterion_hard = nn.CrossEntropyLoss().to(device)
    criterion_soft = nn.KLDivLoss(reduction='batchmean').to(device)
    
    student_model.to(device)
    teacher_model.to(device)
    
    # Separate optimizers for teacher and student models
    optimizer_teacher = optim.Adam(teacher_model.parameters(), lr=0.001)
    optimizer_student = optim.Adam(student_model.parameters(), lr=0.001)
    
    student_model.train()
    teacher_model.train()
    
    running_loss_student = 0.0
    running_loss_teacher = 0.0
    correct_student = 0
    correct_teacher = 0
    total = 0

    progress_bar = tqdm(enumerate(zip(data_raw, data_grid)), total=len(data_raw), desc=f"Training Epoch {epoch}", leave=True)
    for i, (raw, grid) in progress_bar:
        # Unpacking raw and grid data
        X_raw, y_raw = raw
        X_grid, y_grid = grid
        
        X_raw, y_raw = X_raw.to(device), y_raw.to(device)
        X_grid, y_grid = X_grid.to(device), y_grid.to(device)
        
        assert torch.equal(y_raw, y_grid), "Both y must be equal"
        y = y_raw
        
        optimizer_teacher.zero_grad()
        optimizer_student.zero_grad()

        # Forward pass for both teacher and student model
        teacher_outputs = teacher_model(X_raw)
        student_outputs = student_model(X_grid)

        # Calculate the hard loss
        loss_teacher_hard = criterion_hard(teacher_outputs, y)
        loss_student_hard = criterion_hard(student_outputs, y)

        # Calculate the soft loss
        loss_soft = criterion_soft(
            F.log_softmax(student_outputs / temperature, dim=1),
            F.softmax(teacher_outputs / temperature, dim=1)
        )

        # Backpropagation
        # For teacher
        loss_teacher = alpha * loss_soft.detach() + (1 - alpha) * loss_teacher_hard
        loss_teacher.backward(retain_graph=True)
        
        # For student
        loss_student = alpha * loss_soft + (1 - alpha) * loss_student_hard
        loss_student.backward()
        
        optimizer_teacher.step()
        optimizer_student.step()

        running_loss_student += loss_student.item()
        running_loss_teacher += loss_teacher.item()
        _, predicted_student = torch.max(student_outputs.data, 1)
        _, predicted_teacher = torch.max(teacher_outputs.data, 1)
        total += y.size(0)
        correct_student += (predicted_student == y).sum().item()
        correct_teacher += (predicted_teacher == y).sum().item()

        # Update progress bar
        accuracy_student = 100 * correct_student / total
        accuracy_teacher = 100 * correct_teacher / total
        current_loss_student = running_loss_student / (i + 1)
        current_loss_teacher = running_loss_teacher / (i + 1)
        
        progress_bar.set_postfix(Student_Loss=f'{current_loss_student:.4f}', Student_Accuracy=f'{accuracy_student:.2f}%', Teacher_Loss=f'{current_loss_teacher:.4f}', Teacher_Accuracy=f'{accuracy_teacher:.2f}%')

    average_loss_student = running_loss_student / len(data_raw)
    average_loss_teacher = running_loss_teacher / len(data_raw)
    accuracy_student = 100 * correct_student / total
    accuracy_teacher = 100 * correct_teacher / total
    return average_loss_student, average_loss_teacher, accuracy_student, accuracy_teacher

def validate_online_kd(epoch, teacher_model, student_model, data_raw, data_grid, temperature=1.0, alpha=0.5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    criterion_hard = torch.nn.CrossEntropyLoss().to(device)
    criterion_soft = torch.nn.KLDivLoss(reduction='batchmean').to(device)
    
    student_model.to(device)
    teacher_model.to(device)
    
    student_model.eval()
    teacher_model.eval()
    
    running_loss_student = 0.0
    running_loss_teacher = 0.0
    correct_student = 0
    correct_teacher = 0
    total = 0

    progress_bar = tqdm(enumerate(zip(data_raw, data_grid)), total=len(data_raw), desc=f"Validation Epoch {epoch}", leave=True)
    with torch.no_grad():
        for i, (raw, grid) in progress_bar:
            # Unpacking raw and grid data
            X_raw, y_raw = raw
            X_grid, y_grid = grid
            
            X_raw, y_raw = X_raw.to(device), y_raw.to(device)
            X_grid, y_grid = X_grid.to(device), y_grid.to(device)
            
            assert torch.equal(y_raw, y_grid), "Both y must be equal"
            y = y_raw

            # Forward pass
            teacher_outputs = teacher_model(X_raw)
            student_outputs = student_model(X_grid)

            # Loss calculation
            loss_teacher_hard = criterion_hard(teacher_outputs, y)
            loss_student_hard = criterion_hard(student_outputs, y)

            loss_soft = criterion_soft(
                F.log_softmax(student_outputs / temperature, dim=1),
                F.softmax(teacher_outputs / temperature, dim=1)
            )

            loss_student = alpha * loss_soft + (1 - alpha) * loss_student_hard
            running_loss_student += loss_student.item()

            # For teacher, we're interested in the hard loss only for monitoring
            running_loss_teacher += loss_teacher_hard.item()

            _, predicted_student = torch.max(student_outputs.data, 1)
            _, predicted_teacher = torch.max(teacher_outputs.data, 1)
            total += y.size(0)
            correct_student += (predicted_student == y).sum().item()
            correct_teacher += (predicted_teacher == y).sum().item()

            # Update progress bar with the latest losses and accuracies
            accuracy_student = 100 * correct_student / total
            accuracy_teacher = 100 * correct_teacher / total
            current_loss_student = running_loss_student / (i + 1)
            current_loss_teacher = running_loss_teacher / (i + 1)
            
            progress_bar.set_postfix(Student_Loss=f'{current_loss_student:.4f}', Student_Accuracy=f'{accuracy_student:.2f}%', Teacher_Loss=f'{current_loss_teacher:.4f}', Teacher_Accuracy=f'{accuracy_teacher:.2f}%')

    average_loss_student = running_loss_student / len(data_raw)
    average_loss_teacher = running_loss_teacher / len(data_raw)
    accuracy_student = 100 * correct_student / total
    accuracy_teacher = 100 * correct_teacher / total
    return average_loss_student, average_loss_teacher, accuracy_student, accuracy_teacher

In [44]:
data_raw_train = deap_raw_train
data_raw_val = deap_raw_val
data_raw_test = deap_raw_test

data_grid_train = deap_grid_train
data_grid_val = deap_grid_val
data_grid_test = deap_grid_test

teacher_model = deap_trainer
student_model = FBCCNN(num_classes=2, in_channels=4, grid_size=(9, 9))

# Lists to store metrics for each epoch
train_losses_student = []
train_losses_teacher = []
train_accuracies_student = []
train_accuracies_teacher = []
val_losses_student = []
val_losses_teacher = []
val_accuracies_student = []
val_accuracies_teacher = []

for epoch in range(10):
    # Training
    train_loss_student, train_loss_teacher, train_accuracy_student, train_accuracy_teacher = train_online_kd(epoch, teacher_model, student_model, data_raw_train, data_grid_train, temperature=1.0, alpha=0.5)
    train_losses_student.append(train_loss_student)
    train_losses_teacher.append(train_loss_teacher)
    train_accuracies_student.append(train_accuracy_student)
    train_accuracies_teacher.append(train_accuracy_teacher)
    
    # Validation
    val_loss_student, val_loss_teacher, val_accuracy_student, val_accuracy_teacher = validate_online_kd(epoch, teacher_model, student_model, data_raw_val, data_grid_val, temperature=1.0, alpha=0.5)
    val_losses_student.append(val_loss_student)
    val_losses_teacher.append(val_loss_teacher)
    val_accuracies_student.append(val_accuracy_student)
    val_accuracies_teacher.append(val_accuracy_teacher)

# Testing
test_loss_teacher, test_accuracy_teacher = test_model(teacher_model, data_raw_test)
test_loss_student, test_accuracy_student = test_model(student_model, data_grid_test)

Training Epoch 0: 100%|██████████| 768/768 [00:27<00:00, 28.36it/s, Student_Accuracy=56.60%, Student_Loss=0.3637, Teacher_Accuracy=71.73%, Teacher_Loss=0.3104]
Validation Epoch 0: 100%|██████████| 192/192 [00:06<00:00, 31.87it/s, Student_Accuracy=56.77%, Student_Loss=0.3712, Teacher_Accuracy=77.38%, Teacher_Loss=0.5231]
Training Epoch 1: 100%|██████████| 768/768 [00:26<00:00, 28.86it/s, Student_Accuracy=60.22%, Student_Loss=0.3688, Teacher_Accuracy=82.13%, Teacher_Loss=0.2757]
Validation Epoch 1: 100%|██████████| 192/192 [00:05<00:00, 36.45it/s, Student_Accuracy=63.47%, Student_Loss=0.3615, Teacher_Accuracy=80.38%, Teacher_Loss=0.4647]
Training Epoch 2: 100%|██████████| 768/768 [00:28<00:00, 27.31it/s, Student_Accuracy=70.18%, Student_Loss=0.3451, Teacher_Accuracy=86.49%, Teacher_Loss=0.2553]
Validation Epoch 2: 100%|██████████| 192/192 [00:04<00:00, 42.77it/s, Student_Accuracy=73.03%, Student_Loss=0.3332, Teacher_Accuracy=81.79%, Teacher_Loss=0.4250]
Training Epoch 3: 100%|██████████|

In [49]:
print("TEACHER")
print(train_accuracies_teacher)
print(val_accuracies_teacher)
print(test_accuracy_teacher)

print("STUDENT")
print(train_accuracies_student)
print(val_accuracies_student)
print(test_accuracy_student)

TEACHER
[99.19637044270833, 98.91357421875, 98.699951171875, 98.65926106770833, 98.69588216145833, 98.67146809895833, 98.663330078125, 98.681640625, 98.85660807291667, 98.92578125]
[84.11458333333333, 83.59375, 84.33430989583333, 83.92740885416667, 84.26106770833333, 83.82161458333333, 83.82975260416667, 84.1796875, 83.9599609375, 83.80533854166667]
83.11848958333333
STUDENT
[44.038899739583336, 56.170654296875, 70.60750325520833, 79.88484700520833, 85.82763671875, 89.41853841145833, 91.86197916666667, 93.26578776041667, 94.46614583333333, 95.34912109375]
[48.396809895833336, 62.605794270833336, 71.91569010416667, 76.09049479166667, 77.92154947916667, 80.28971354166667, 80.18391927083333, 82.08821614583333, 82.3486328125, 83.056640625]
83.15755208333333
