In [20]:
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"  # Must be first!

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import (
    ReduceLROnPlateau,
    LambdaLR,
    CosineAnnealingWarmRestarts
)

from transformers import AdamW, get_cosine_schedule_with_warmup


from torch.utils.tensorboard import SummaryWriter

import polars as pl
from tqdm.notebook import tqdm

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import random
import numpy as np

import optuna

###################
from utils import collate_fn
###################


import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from dataset import EEGPTDataset  # Your existing dataset class
from model import EEGPTWrapper

# Set seeds and deterministic flags
random.seed(69)
np.random.seed(69)
torch.manual_seed(69)
torch.cuda.manual_seed(69)
torch.use_deterministic_algorithms(True)  # Enable full determinism
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [21]:
import pandas as pd
columns = pd.read_parquet('/home/owner/Documents/DEV/BrainLabyrinth/data/combined_GPT.parquet').columns.str.upper()
required = pd.Series([
    'FP1', 'FPZ', 'FP2', 'AF3', 'AF4', 
    'F7', 'F5', 'F3', 'F1', 'FZ', 
    'F2', 'F4', 'F6', 'F8', 'FT7', 
    'FC5', 'FC3', 'FC1', 'FCZ', 'FC2', 
    'FC4', 'FC6', 'FT8', 'T7', 'C5', 
    'C3', 'C1', 'CZ', 'C2', 'C4', 
    'C6', 'T8', 'TP7', 'CP5', 'CP3', 
    'CP1', 'CPZ', 'CP2', 'CP4', 'CP6', 
    'TP8', 'P7', 'P5', 'P3', 'P1', 
    'PZ', 'P2', 'P4', 'P6', 'P8', 
    'PO7', 'PO3', 'POZ',  'PO4', 'PO8', 
    'O1', 'OZ', 'O2' 
])

print(required[~required.isin(columns)])

Series([], dtype: object)


In [22]:
def train_model(config, train_set, train_loader, val_loader, writer):
    # -------------------- MODEL --------------------
    model = EEGPTWrapper(
        pretrained_path="checkpoint/eegpt_mcae_58chs_4s_large4E.ckpt",
        channel_list=[
            'FP1', 'FPZ', 'FP2', 'AF3', 'AF4', 
            'F7', 'F5', 'F3', 'F1', 'FZ', 
            'F2', 'F4', 'F6', 'F8', 'FT7', 
            'FC5', 'FC3', 'FC1', 'FCZ', 'FC2', 
            'FC4', 'FC6', 'FT8', 'T7', 'C5', 
            'C3', 'C1', 'CZ', 'C2', 'C4', 
            'C6', 'T8', 'TP7', 'CP5', 'CP3', 
            'CP1', 'CPZ', 'CP2', 'CP4', 'CP6', 
            'TP8', 'P7', 'P5', 'P3', 'P1', 
            'PZ', 'P2', 'P4', 'P6', 'P8', 
            'PO7', 'PO3', 'POZ',  'PO4', 'PO8', 
            'O1', 'OZ', 'O2' 
        ],
        num_classes=1
    ).to(config['device'])

    
    # Log model architecture and config
    writer.add_text("Model/Structure", str(model))
    writer.add_text("Training Config", str(config))
    
    # ------------------ LOSS FUNCTION ------------------
    pos_weight = torch.tensor([
        train_set.class_weights['Left'] / train_set.class_weights['Right']
    ]).to(config['device'])
    criterion = torch.nn.BCEWithLogitsLoss(weight=pos_weight)
    
    # ------------------- OPTIMIZER ---------------------
    lr = config.get('lr', 1e-3)
    weight_decay = config.get('weight_decay', 1e-2)    

    optimizer = AdamW(model.eegpt.parameters(), lr=lr, weight_decay=weight_decay)
    
    # ------------------- SCHEDULER ---------------------
    
    total_steps = config['epochs'] * len(train_loader)
    warmup_epochs = config.get('warmup_epochs', 0)
    num_cycles = config.get('num_cycles', 0.5)
    scheduler = get_cosine_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_epochs, num_training_steps=total_steps, num_cycles=num_cycles
    )
    
    # # ------------------- WARMUP SCHEDULER ---------------
    # warmup_epochs = config.get('warmup_epochs', 0)
    # if warmup_epochs > 0:
    #     warmup_scheduler = LambdaLR(
    #         optimizer,
    #         lambda epoch: min(1.0, (epoch + 1) / warmup_epochs)
    #     )
    # else:
    #     warmup_scheduler = None
    
    # -------------------- TRAINING LOOP --------------------
    best_metric = -float('inf')
    
    for epoch in tqdm(range(config['epochs']), desc="Training"):
        # ---------- TRAIN ----------
        model.train()
        train_loss = 0.0
        
        for labels, features in train_loader:
            features = features.to(config['device']).float()
            labels = labels.to(config['device']).float()
            
            optimizer.zero_grad()
            outputs = model(features)
            loss = criterion(outputs, labels)
            loss.backward()
            
            # Gradient clipping (if specified)
            if config.get('grad_clip') is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), config['grad_clip'])
            
            optimizer.step()
            scheduler.step()
            train_loss += loss.item()
        
        train_loss /= len(train_loader)
        
        # ---------- VALIDATION ----------
        model.eval()
        val_loss = 0.0
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for labels, features in val_loader:
                features = features.to(config['device']).float()
                labels = labels.to(config['device']).float()
                
                outputs = model(features)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                
                preds = torch.sigmoid(outputs)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        
        val_loss /= len(val_loader)
        predictions = (np.array(all_preds) > 0.5).astype(int)
        
        # ---------- METRICS ----------
        accuracy = accuracy_score(all_labels, predictions)
        precision = precision_score(all_labels, predictions)
        recall = recall_score(all_labels, predictions)
        f1 = f1_score(all_labels, predictions)
        
        # ---------- SCHEDULER UPDATE ----------
        current_lr = optimizer.param_groups[0]['lr']
        
        # if warmup_scheduler is not None and epoch < warmup_epochs:
        #     warmup_scheduler.step()
        # else:
        #     if scheduler is not None:
        #         scheduler.step(val_loss)
        
        # ---------- LOGGING ----------
        writer.add_scalar('LR', current_lr, epoch)
        writer.add_scalar('Loss/Train', train_loss, epoch)
        writer.add_scalar('Loss/Val', val_loss, epoch)
        writer.add_scalar('Accuracy', accuracy, epoch)
        writer.add_scalar('Precision', precision, epoch)
        writer.add_scalar('Recall', recall, epoch)
        writer.add_scalar('F1', f1, epoch)
        
        # You can also combine them in a single dictionary
        metrics = {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1
        }
        writer.add_scalars('Metrics', metrics, epoch)
        
        # ---------- SAVE BEST MODEL ----------
        if accuracy > best_metric:
            best_metric = accuracy
            torch.save(model.state_dict(), f"{config['log_dir']}/best_model.pth")
    
    writer.close()
    return model


In [23]:
config_study = {
    'lr': 3e-4,
    'weight_decay': 0, #2.215232012031863e-05,
 }

config = {
    'data_path': '/home/owner/Documents/DEV/BrainLabyrinth/data/combined_GPT.parquet',
    'pretrained_path': 'checkpoint/eegpt_mcae_58chs_4s_large4E.ckpt',
    'split_ratios': (0.7, 0.15, 0.15),
    'batch_size': 64,
    'epochs': 600,
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    'log_dir': './runs/EEGPT',
    'lr': config_study['lr'],
    'weight_decay': config_study['weight_decay'],
    'grad_clip': 5,
    'warmup_epochs': 100,
    'num_cycles': 1.5,
}


In [24]:
#============================================================
# Training Pipeline
#============================================================
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

# Initialize dataset
print("Creating full dataset...")
full_dataset = EEGPTDataset(config['data_path'], max_length=1024)

print("Splitting the dataset...")
# Split dataset
train_set, val_set, test_set = full_dataset.split_dataset(
    ratios=config['split_ratios']
)

del full_dataset

len_dataset = len(train_set)
sample = train_set[0]
label_shape = sample[0].shape
feature_shape = sample[1].shape

print(f"unbalanced train dataset shape: ({len_dataset}, [labels: {label_shape}, features: {list(feature_shape)}])")

torch.save(train_set, 'train_set_smol_GPT.pt')
torch.save(val_set, 'val_set_GPT.pt')
torch.save(test_set, 'test_set_GPT.pt')

Creating full dataset...
Precomputing samples...


Precomputing Samples:   0%|          | 0/2772 [00:00<?, ?it/s]

Computing class weights...
Splitting the dataset...
Precomputing samples...


Precomputing Samples:   0%|          | 0/1940 [00:00<?, ?it/s]

Computing class weights...
Precomputing samples...


Precomputing Samples:   0%|          | 0/415 [00:00<?, ?it/s]

Computing class weights...
Precomputing samples...


Precomputing Samples:   0%|          | 0/417 [00:00<?, ?it/s]

Computing class weights...
unbalanced train dataset shape: (1940, [labels: torch.Size([]), features: [1024, 58]])


In [None]:
train_set = torch.load('train_set_smol_GPT.pt', weights_only=False)
val_set = torch.load('val_set_GPT.pt', weights_only=False)
test_set = torch.load('test_set_GPT.pt', weights_only=False)


generator = torch.Generator().manual_seed(69)  # Set seed
train_loader = DataLoader(
    train_set,
    batch_size=config['batch_size'],
    shuffle=True,
    generator=generator,  # Add this line
    num_workers=0,
    pin_memory=True,
    # persistent_workers=True,
    # collate_fn=collate_fn
)
val_loader = DataLoader(val_set, batch_size=config['batch_size'], collate_fn=collate_fn)
test_loader = DataLoader(test_set, batch_size=config['batch_size'], collate_fn=collate_fn)

len_dataset = len(train_set)
sample = train_set[0]
label_shape = sample[0].shape
feature_shape = sample[1].shape

print(f"train dataset shape: ({len_dataset}, [labels: {label_shape}, features: {list(feature_shape)}])")


# Set up logging
writer = SummaryWriter(log_dir=config['log_dir'])

train dataset shape: (1940, [labels: torch.Size([]), features: [1024, 58]])


In [26]:
# Start training
trained_model = train_model(config, train_set, train_loader, val_loader, writer)

Training:   0%|          | 0/600 [00:00<?, ?it/s]

In [None]:
import torch
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

best_model = EEGPTWrapper(
    pretrained_path="checkpoint/eegpt_mcae_58chs_4s_large4E.ckpt",
    channel_list=[
            "FP2", "FPZ", "FP1", "AF4", "AF3", "F7", "F5", "F3", "F6", "F1",
            "FZ", "F2", "F4", "F8", "FT7", "FC5", "FC3", "FC6", "FC1", "FCZ",
            "FC2", "FC4", "FT8", "T7", "C5", "C3", "C6", "C1", "CZ", "C2",
            "C4", "T8", "TP7", "CP5", "CP3", "CP6", "CP1", "CPZ", "CP2", "CP4",
            "TP8", "P7", "P5", "P3", "P6", "P1", "PZ", "P2", "P4", "P8",
            "O1", "PO7", "PO3", "O2", "OZ", "PO4", "PO8", "POZ"
        ],
    num_classes=1
)  # Adjust parameters as needed

# Load the state dictionary
state_dict = torch.load(f"{config['log_dir']}/best_model.pth", map_location=config['device'])
best_model.load_state_dict(state_dict)

# Move model to the correct device
best_model = best_model.to(config['device'])

# Set model to evaluation mode
best_model.eval()

test_loss = 0
all_test_markers = []
all_test_predictions = []
with torch.no_grad():
    for markers, features in tqdm(test_loader):
        features = features.to(config['device'])
        markers = markers.to(config['device'])

        outputs = best_model(features)

        # Collect markers and predictions for metrics calculation
        all_test_markers.extend(markers.cpu().numpy().flatten())
        all_test_predictions.extend(torch.sigmoid(outputs).cpu().numpy().flatten())

# Calculate test metrics
test_accuracy = accuracy_score(all_test_markers, [1 if p > 0.5 else 0 for p in all_test_predictions])
test_precision = precision_score(all_test_markers, [1 if p > 0.5 else 0 for p in all_test_predictions])
test_recall = recall_score(all_test_markers, [1 if p > 0.5 else 0 for p in all_test_predictions])
test_f1 = f1_score(all_test_markers, [1 if p > 0.5 else 0 for p in all_test_predictions])
test_roc_auc = roc_auc_score(all_test_markers, all_test_predictions)

# Log test metrics to TensorBoard
writer.add_scalar('Metrics/test_accuracy', test_accuracy, 1)
writer.add_scalar('Metrics/test_precision', test_precision, 1)
writer.add_scalar('Metrics/test_recall', test_recall, 1)
writer.add_scalar('Metrics/test_f1', test_f1, 1)
writer.add_scalar('Metrics/test_roc_auc', test_roc_auc, 1)

# Close the TensorBoard writer
writer.close()

KeyError: 'state_dict'

In [None]:
print(f"""
{test_accuracy=}
{test_precision=}
{test_recall=}
{test_f1=}
{test_roc_auc=}
"""
)


test_accuracy=0.4844124700239808
test_precision=0.4823529411764706
test_recall=0.8078817733990148
test_f1=0.6040515653775322
test_roc_auc=np.float64(0.4886400257815018)



In [None]:
from sklearn.metrics import f1_score
import numpy as np
best_threshold = 0.0
best_f1 = 0.0
thresholds = np.arange(0.1, 1.0, 0.01)

for threshold in tqdm(thresholds):
    binary_predictions = (all_test_predictions > threshold).astype(int)
    current_accuracy = f1_score(all_test_markers, binary_predictions)

    if current_accuracy > best_f1:
        best_f1 = current_accuracy
        best_threshold = threshold
        test_accuracy = accuracy_score(all_test_markers, [1 if p > 0.5 else 0 for p in all_test_predictions])

print(f"{best_threshold=}")
print(f"{best_f1=}")

  0%|          | 0/90 [00:00<?, ?it/s]

best_threshold=np.float64(0.1)
best_f1=0.6548387096774193


In [None]:
from sklearn.metrics import accuracy_score
import numpy as np
best_threshold = 0.1
best_accuracy = 0.0
thresholds = np.arange(0.005, 1.0, 0.005)

for threshold in tqdm(thresholds):
    binary_predictions = (all_test_predictions > threshold).astype(int)
    current_accuracy = accuracy_score(all_test_markers, binary_predictions)

    if current_accuracy > best_accuracy:
        best_accuracy = current_accuracy
        best_threshold = threshold

print(f"{best_threshold=}")
print(f"{best_accuracy=}")

  0%|          | 0/199 [00:00<?, ?it/s]

best_threshold=np.float64(0.53)
best_accuracy=0.5251798561151079


In [None]:
from sklearn.metrics import accuracy_score
import numpy as np
best_threshold = 0.1
best_accuracy = 0.0
thresholds = np.arange(0.005, 1.0, 0.005)



for threshold in tqdm(thresholds):
    binary_predictions = (all_test_predictions > threshold).astype(int)
    current_accuracy = accuracy_score(all_test_markers, binary_predictions)

    if current_accuracy > best_accuracy:
        best_accuracy = current_accuracy
        best_threshold = threshold
        precision = precision_score(all_test_markers, [1 if p > 0.5 else 0 for p in all_test_predictions])
        recall = recall_score(all_test_markers, [1 if p > 0.5 else 0 for p in all_test_predictions])
        f1 = f1_score(all_test_markers, [1 if p > 0.5 else 0 for p in all_test_predictions])
        roc_auc = roc_auc_score(all_test_markers, all_test_predictions)

print(f"{best_threshold=}")
print(f"""
      {best_accuracy=}
    {precision=}
    {recall=}
    {f1=}
    {roc_auc=}
""")

  0%|          | 0/199 [00:00<?, ?it/s]

best_threshold=np.float64(0.53)

      best_accuracy=0.5251798561151079
    precision=0.4823529411764706
    recall=0.8078817733990148
    f1=0.6040515653775322
    roc_auc=np.float64(0.4886400257815018)

