In [1]:
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"  # Must be first!

import torch
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import (
    ReduceLROnPlateau,
    LambdaLR
)

import polars as pl
import copy


from sklearn.metrics import accuracy_score
import random
import numpy as np

import optuna

###################
from model import EEGMobileNet
from dataset import EEGDataset
from utils import collate_fn
###################

# Set seeds and deterministic flags

torch.use_deterministic_algorithms(True)  # Enable full determinism
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [2]:
pl.read_parquet('/home/owner/Documents/DEV/BrainLabyrinth/data/combined.parquet')\
    .columns


['event_id',
 'orig_marker',
 'time',
 'Fp1',
 'Fpz',
 'Fp2',
 'F7',
 'F3',
 'Fz',
 'F4',
 'F8',
 'FC5',
 'FC1',
 'FC2',
 'FC6',
 'M1',
 'T7',
 'C3',
 'Cz',
 'C4',
 'T8',
 'M2',
 'CP5',
 'CP1',
 'CP2',
 'CP6',
 'P7',
 'P3',
 'Pz',
 'P4',
 'P8',
 'POz',
 'O1',
 'O2',
 'AF7',
 'AF3',
 'AF4',
 'AF8',
 'F5',
 'F1',
 'F2',
 'F6',
 'FC3',
 'FCz',
 'FC4',
 'C5',
 'C1',
 'C2',
 'C6',
 'CP3',
 'CP4',
 'P5',
 'P1',
 'P2',
 'P6',
 'PO5',
 'PO3',
 'PO4',
 'PO6',
 'FT7',
 'FT8',
 'TP7',
 'TP8',
 'PO7',
 'PO8',
 'Oz',
 'marker',
 'prev_marker']

In [3]:
def train_model(config, train_set, train_loader, val_loader):
    # -------------------- MODEL --------------------
    model = EEGMobileNet(
        in_channels=64,
        num_classes=1,
        dropout=config['dropout']
    ).to(config['device'])
    
    # ------------------ LOSS FUNCTION ------------------
    pos_weight = torch.tensor([
        train_set.class_weights['Left'] / train_set.class_weights['Right']
    ]).to(config['device'])
    criterion = torch.nn.BCEWithLogitsLoss(weight=pos_weight)
    
    # ------------------- OPTIMIZER ---------------------
    lr = config.get('lr', 1e-3)
    weight_decay = config.get('weight_decay', 1e-2)
    
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    
    # ------------------- SCHEDULER ---------------------
    scheduler_config = config.get('scheduler', {})
    
    scheduler = ReduceLROnPlateau(
        optimizer,
        mode=scheduler_config.get('mode', 'min'),
        factor=scheduler_config.get('factor', 0.1),
        patience=scheduler_config.get('patience', 10),
        threshold=scheduler_config.get('threshold', 0.0001),
        cooldown=scheduler_config.get('cooldown', 0),
        min_lr=scheduler_config.get('min_lr', 0),
    )
    
    # ------------------- WARMUP SCHEDULER ---------------
    warmup_epochs = config.get('warmup_epochs', 0)
    if warmup_epochs > 0:
        warmup_scheduler = LambdaLR(
            optimizer,
            lambda epoch: min(1.0, (epoch + 1) / warmup_epochs)
        )
    else:
        warmup_scheduler = None
    
    # -------------------- TRAINING LOOP --------------------
    best_metric = -float("inf")
    best_model_weights = None  # will hold a copy of state_dict()
    
    for epoch in range(config['epochs']):
        # ---------- TRAIN ----------
        model.train()
        train_loss = 0.0
        
        for labels, features in train_loader:
            features = features.to(config['device']).float()
            labels = labels.to(config['device']).float()
            
            optimizer.zero_grad()
            outputs = model(features)
            loss = criterion(outputs, labels)
            loss.backward()
            
            # Gradient clipping (if specified)
            if config.get('grad_clip') is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), config['grad_clip'])
            
            optimizer.step()
            train_loss += loss.item()
        
        train_loss /= len(train_loader)
        
        # ---------- VALIDATION ----------
        model.eval()
        val_loss = 0.0
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for labels, features in val_loader:
                features = features.to(config['device']).float()
                labels = labels.to(config['device']).float()
                
                outputs = model(features)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                
                preds = torch.sigmoid(outputs)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        
        val_loss /= len(val_loader)
        predictions = (np.array(all_preds) > 0.5).astype(int)
        
        # ---------- METRICS ----------
        accuracy = accuracy_score(all_labels, predictions)
        
        # ---------- SCHEDULER UPDATE ----------        
        if warmup_scheduler is not None and epoch < warmup_epochs:
            warmup_scheduler.step()
        else:
            if scheduler is not None:
                scheduler.step(val_loss)
        

        
        # ---------- SAVE BEST MODEL ----------
        if accuracy > best_metric:
            best_metric = accuracy
            best_model_weights = copy.deepcopy(model.state_dict())
    
    # After the loop, restore the best weights
    best_model = EEGMobileNet(
        in_channels=64,
        num_classes=1,
        dropout=config['dropout']
    ).to(config['device'])

    best_model.load_state_dict(best_model_weights)

    # Now best_model is truly the best epoch’s model
    return best_model

In [None]:
config_data = {
    'batch_size': 32
}

train_set = torch.load('train_set_smol.pt', weights_only=False)
val_set = torch.load('val_set.pt', weights_only=False)
test_set = torch.load('test_set.pt', weights_only=False)

generator = torch.Generator().manual_seed(69)  # Set seed
initial_state = generator.get_state()
train_loader = DataLoader(
    train_set,
    batch_size=config_data['batch_size'],
    shuffle=True,
    generator=generator,  # Add this line
    num_workers=0,
    collate_fn=collate_fn
)
val_loader = DataLoader(val_set, batch_size=config_data['batch_size'], collate_fn=collate_fn)
test_loader = DataLoader(test_set, batch_size=config_data['batch_size'], collate_fn=collate_fn)

In [None]:
def objective_adamw_plateau(trial, generator, initial_state, train_set, train_loader, val_loader, device):
    random.seed(69)
    np.random.seed(69)
    torch.manual_seed(69)
    torch.cuda.manual_seed(69)
    generator.set_state(initial_state)

    lr = trial.suggest_float("lr", 1e-5, 1e-3, log=True)
    weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-4, log=True)
    dropout = trial.suggest_float("dropout", 0.0, 0.5)
    
    # Scheduler params
    factor = trial.suggest_float("factor", 0.1, 0.8)
    patience = trial.suggest_int("patience", 2, 25)
    cooldown = trial.suggest_int("cooldown", 10, 25)
    
    config = {
        "device": device,
        "dropout": dropout,
        "epochs": 300,
        "log_dir": "./runs/OptunaTest",
        "warmup_epochs": 0,
        "grad_clip": None,
        "lr": lr,
        "weight_decay": weight_decay,
        
        # Scheduler config
        "scheduler": {
            "mode": "min",
            "factor": factor,
            "patience": patience,
            "threshold": 0.0001,
            "cooldown": cooldown,
            "min_lr": 1e-8,
        }
    }
    
    best_model = train_model(config, train_set, train_loader, val_loader)

    # Move model to the correct device
    best_model = best_model.to(config['device'])

    # Set model to evaluation mode
    best_model.eval()

    all_test_markers = []
    all_test_predictions = []
    with torch.no_grad():
        for markers, features in test_loader:
            features = features.to(config['device'])
            markers = markers.to(config['device'])

            outputs = best_model(features)
            # Collect markers and predictions for metrics calculation
            all_test_markers.extend(markers.cpu().numpy().flatten())
            all_test_predictions.extend(torch.sigmoid(outputs).cpu().numpy().flatten())

    test_accuracy = accuracy_score(all_test_markers, [1 if p > 0.5 else 0 for p in all_test_predictions])
    
    return test_accuracy


In [None]:
# Example: Tuning AdamW + CyclicLR
optuna.logging.set_verbosity(optuna.logging.WARNING)
study = optuna.create_study(direction="maximize")
study.optimize(
    lambda trial: objective_adamw_plateau(
        trial, generator, initial_state, train_set, train_loader, val_loader, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    ),
    n_trials=10,
    show_progress_bar=True
)

print("Best Trial:", study.best_trial.value)
print("Best Hyperparams:", study.best_trial.params)


  0%|          | 0/10 [00:00<?, ?it/s]

[W 2025-02-27 21:18:58,897] Trial 0 failed with parameters: {'lr': 9.151674276545755e-05, 'weight_decay': 1.8100353374705348e-05, 'dropout': 0.12700141396574932, 'factor': 0.11322237165676294, 'patience': 18, 'cooldown': 14} because of the following error: KeyboardInterrupt().
Traceback (most recent call last):
  File "/home/owner/Documents/DEV/BrainLabyrinth/.venv/lib/python3.10/site-packages/optuna/study/_optimize.py", line 197, in _run_trial
    value_or_values = func(trial)
  File "/tmp/ipykernel_620734/4072054517.py", line 5, in <lambda>
    lambda trial: objective_adamw_plateau(
  File "/tmp/ipykernel_620734/3898594924.py", line 38, in objective_adamw_plateau
    best_model = train_model(config, train_set, train_loader, val_loader)
  File "/tmp/ipykernel_620734/767373109.py", line 60, in train_model
    loss.backward()
  File "/home/owner/Documents/DEV/BrainLabyrinth/.venv/lib/python3.10/site-packages/torch/_tensor.py", line 581, in backward
    torch.autograd.backward(
  File "/

KeyboardInterrupt: 

In [None]:
import pickle 
from datetime import datetime
with open(f'study_{datetime.now().isoformat()}.pkl', 'wb') as f:
    pickle.dump(study.best_trial.params, f)

In [None]:
study.best_trial.params

{'lr': 9.765820104617875e-05,
 'weight_decay': 1.8290404005473213e-05,
 'dropout': 0.35994387648460596,
 'factor': 0.19990106143116088,
 'patience': 7,
 'cooldown': 16}