In [20]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
from torch.optim.lr_scheduler import (
    ReduceLROnPlateau,
    CosineAnnealingLR,
    CyclicLR,
    OneCycleLR,
    LambdaLR
)
from torch.utils.tensorboard import SummaryWriter

import polars as pl
from imblearn.over_sampling import SMOTE

from sklearn.model_selection import train_test_split

from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

import random
import numpy as np

from scipy.interpolate import CubicSpline

import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"  # Fix for CuBLAS determinism

# Set seeds and deterministic flags
random.seed(69)
np.random.seed(69)
torch.manual_seed(69)
torch.cuda.manual_seed(69)
torch.use_deterministic_algorithms(True)  # Enable full determinism
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [21]:
pl.read_parquet('/home/owner/Documents/DEV/BrainLabyrinth/data/combined.parquet')\
    .columns


['event_id',
 'orig_marker',
 'time',
 'Fp1',
 'Fpz',
 'Fp2',
 'F7',
 'F3',
 'Fz',
 'F4',
 'F8',
 'FC5',
 'FC1',
 'FC2',
 'FC6',
 'M1',
 'T7',
 'C3',
 'Cz',
 'C4',
 'T8',
 'M2',
 'CP5',
 'CP1',
 'CP2',
 'CP6',
 'P7',
 'P3',
 'Pz',
 'P4',
 'P8',
 'POz',
 'O1',
 'O2',
 'AF7',
 'AF3',
 'AF4',
 'AF8',
 'F5',
 'F1',
 'F2',
 'F6',
 'FC3',
 'FCz',
 'FC4',
 'C5',
 'C1',
 'C2',
 'C6',
 'CP3',
 'CP4',
 'P5',
 'P1',
 'P2',
 'P6',
 'PO5',
 'PO3',
 'PO4',
 'PO6',
 'FT7',
 'FT8',
 'TP7',
 'TP8',
 'PO7',
 'PO8',
 'Oz',
 'marker',
 'prev_marker']

In [22]:
# #============================================================
# # Model Architecture
# #============================================================
# class EEGDSConv(nn.Module):
#     def __init__(self, dropout=0.5):
#         super().__init__()
#         self.block = nn.Sequential(
#             nn.Conv1d(64, 64, 15, padding='same', groups=64),
#             nn.Conv1d(64, 16, 1),
#             nn.BatchNorm1d(16),
#             nn.ReLU(),
#             nn.MaxPool1d(4),
#             nn.Dropout(dropout),
#             nn.Conv1d(16, 16, 7, padding='same', groups=16),
#             nn.Conv1d(16, 8, 1),
#             nn.BatchNorm1d(8),
#             nn.ReLU(),
#             nn.AdaptiveAvgPool1d(1),
#             nn.Flatten(),
#             nn.Linear(8, 1)
#         )
    
#     def forward(self, x):
#         x = x.permute(0, 2, 1)
#         return self.block(x).squeeze(-1)  # Squeeze last dimension to match target shape
    
    


In [23]:
import torch
import torch.nn as nn
import torch.nn.functional as F
class EEGMobileNet(nn.Module):
    def __init__(self, in_channels=64, num_classes=1, dropout=0.5):
        super().__init__()
        self.model = nn.Sequential(
            # Initial Conv
            nn.Conv1d(in_channels, 32, kernel_size=15, stride=2, padding=7),
            nn.BatchNorm1d(32),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),  # ← Insert dropout here

            # Depthwise
            nn.Conv1d(32, 32, kernel_size=3, stride=1, padding=1, groups=32),
            nn.BatchNorm1d(32),
            nn.ReLU(inplace=True),
            # Pointwise
            nn.Conv1d(32, 64, kernel_size=1),
            nn.BatchNorm1d(64),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),  # ← Insert dropout here

            # Another Depthwise Separable block
            nn.Conv1d(64, 64, kernel_size=3, stride=2, padding=1, groups=64),
            nn.BatchNorm1d(64),
            nn.ReLU(inplace=True),
            nn.Conv1d(64, 128, kernel_size=1),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),  # ← Insert dropout here

            # Global Average Pool
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = x.transpose(1, 2)  # your original transpose
        return self.model(x).squeeze(1)


In [24]:
import numpy as np
from tslearn.neighbors import KNeighborsTimeSeries
from numba import njit

@njit(fastmath=True)
def fast_interpolate(original, neighbor, alpha):
    """Numba-accelerated linear interpolation for numeric columns."""
    return (1 - alpha) * original + alpha * neighbor

class TSMOTE:
    def __init__(self, 
                 n_neighbors=3, 
                 time_slices=10, 
                 bool_cols=None):
        """
        :param n_neighbors: Number of neighbors for KNN
        :param time_slices: Number of slices to split each time series
        :param bool_cols:   List (or array) of indices for boolean columns
        """
        self.n_neighbors = n_neighbors
        self.time_slices = time_slices
        self.slice_size = None  # will be set after seeing data
        self.bool_cols = bool_cols if bool_cols is not None else []
        # numeric_cols will be determined at fit-time, after we see total channels.

    def _slice_time_series(self, X):
        """Split into time slices: (N, 2000, ch) -> (N, time_slices, slice_size, ch)."""
        return X.reshape(X.shape[0], self.time_slices, self.slice_size, X.shape[2])

    def _generate_synthetic(self, minority_samples, bool_probs):
        """
        Generate full-length synthetic samples.
        :param minority_samples: Array of shape (N_minority, 2000, ch)
        :param bool_probs:       Dict mapping boolean column index -> probability of 1
        """
        # slice_size was computed earlier in fit_resample.
        sliced_data = self._slice_time_series(minority_samples)  # shape (N, slices, slice_size, ch)
        syn_samples = []

        # We'll figure out numeric_cols from total channels
        all_cols = list(range(minority_samples.shape[2]))
        numeric_cols = [c for c in all_cols if c not in self.bool_cols]

        for sample_idx in tqdm(range(sliced_data.shape[0]), desc="Generating synthetic"):
            synthetic_slices = []

            # For each time slice
            for slice_idx in range(self.time_slices):
                # Split data into included (numeric) columns vs. excluded (boolean) columns
                slice_incl = sliced_data[:, slice_idx, :, :][:, :, numeric_cols]  # (N, slice_size, #numeric)
                slice_excl = sliced_data[:, slice_idx, :, :][:, :, self.bool_cols] # (N, slice_size, #bool)

                # Fit KNN on included (numeric) data only
                knn = KNeighborsTimeSeries(n_neighbors=self.n_neighbors, metric='dtw')
                knn.fit(slice_incl)  # each entry is shape (slice_size, #numeric)

                # The sample's numeric slice
                original_slice_incl = slice_incl[sample_idx]  # shape (slice_size, #numeric)

                # Find neighbors for this numeric slice
                neighbors = knn.kneighbors(original_slice_incl[np.newaxis], 
                                           return_distance=False)[0]
                neighbor_idx = np.random.choice(neighbors)

                neighbor_slice_incl = slice_incl[neighbor_idx]  # shape (slice_size, #numeric)

                # Interpolate for numeric columns
                alpha = np.random.uniform(0.2, 0.8)
                # Using fast_interpolate or direct calculation:
                synthetic_slice_incl = fast_interpolate(original_slice_incl, 
                                                        neighbor_slice_incl, 
                                                        alpha)

                # For boolean columns: sample from distribution
                # We'll create an array of shape (slice_size, #bool)
                # For each boolean column index b, pick 0/1 based on bool_probs[b].
                # If you want different logic (like "choose original or neighbor"?),
                # you can adapt here.
                n_bool_cols = len(self.bool_cols)
                synthetic_slice_excl = np.zeros((self.slice_size, n_bool_cols), dtype=np.float32)

                for col_idx_in_boolarray, bcol in enumerate(self.bool_cols):
                    p = bool_probs[bcol]  # Probability of 1 for that bool column
                    # Sample 0/1 for each time step in the slice
                    synthetic_slice_excl[:, col_idx_in_boolarray] = \
                        np.random.binomial(n=1, p=p, size=self.slice_size)
                
                # Combine numeric + boolean columns back in correct order
                # We have numeric_cols in synthetic_slice_incl
                # We have bool_cols in synthetic_slice_excl
                # We need to re-insert them into shape: (slice_size, total_channels)
                synthetic_slice = np.zeros((self.slice_size, len(all_cols)), dtype=np.float32)

                # Place numeric columns
                synthetic_slice[:, numeric_cols] = synthetic_slice_incl
                # Place boolean columns
                synthetic_slice[:, self.bool_cols] = synthetic_slice_excl

                synthetic_slices.append(synthetic_slice)

            # Concatenate slices into a full time series (2000, ch)
            full_series = np.concatenate(synthetic_slices, axis=0)
            syn_samples.append(full_series)

        return np.array(syn_samples)

    def fit_resample(self, X, y):
        """
        Perform TSMOTE oversampling.
        :param X: shape (N, 2000, ch)
        :param y: shape (N,)
        """
        y_int = y.astype(int)
        class_counts = np.bincount(y_int)
        minority_class = np.argmin(class_counts)
        majority_class = 1 - minority_class

        n_needed = class_counts[majority_class] - class_counts[minority_class]
        if n_needed <= 0:
            return X, y  # no oversampling needed

        # Suppose X has shape (N, 2000, ch). We'll assume 2000 is consistent with time_slices * slice_size.
        # We'll deduce slice_size
        self.slice_size = X.shape[1] // self.time_slices  # e.g. 2000/10=200

        # Get only minority samples
        minority_samples = X[y_int == minority_class]

        # ----- Compute distribution of booleans in the minority data ------
        # For each bool column b, compute fraction of 1s across the entire minority set
        bool_probs = {}
        if len(self.bool_cols) > 0:
            # shape is (N_minority, 2000, ch)
            # We'll flatten across time for each column to get overall fraction
            for bcol in self.bool_cols:
                col_values = minority_samples[:, :, bcol].flatten()
                p = col_values.mean()  # fraction of 1's
                bool_probs[bcol] = p
        # ------------------------------------------------------------------

        synthetic = self._generate_synthetic(minority_samples, bool_probs)

        # Ensure matching dimensions
        assert X.shape[1:] == synthetic.shape[1:], \
            f"Dimension mismatch: Original {X.shape[1:]}, Synthetic {synthetic.shape[1:]}"

        # Use only as many synthetic as needed
        synthetic = synthetic[:n_needed]

        # Concatenate
        X_resampled = np.concatenate([X, synthetic], axis=0)
        y_resampled = np.concatenate([y, [minority_class] * len(synthetic)], axis=0)
        return X_resampled, y_resampled


In [25]:
#============================================================
# Enhanced Dataset Class with Proper Encapsulation
#============================================================
class EEGDataset(Dataset):
    def __init__(self, source, max_length=2000):
        self.df = pl.read_parquet(source) if isinstance(source, str) else source
        if 'orig_marker' in self.df.columns:
            self.df = self.df.drop('orig_marker')
        
        # label_map = {"Left": 0, "Right": 1}  
        # self.df = self.df.with_columns([
        #     pl.col("marker").replace(label_map).cast(pl.Int32).alias("marker"),
        #     pl.col("prev_marker").replace(label_map).cast(pl.Int32).alias("prev_marker")
        # ])
        
        self.df = self.df.with_columns([
            pl.col("marker")
            .cast(pl.Utf8)
            .str.replace_all("Left", "0")      # replace exact string "Left" with "0"
            .str.replace_all("Right", "1")     # replace exact string "Right" with "1"
            .cast(pl.Int32)                      # now cast the string "0"/"1" -> int
            .alias("marker"),
            
            pl.col("prev_marker")
            .cast(pl.Utf8)
            .str.replace_all("Left", "0")
            .str.replace_all("Right", "1")
            .cast(pl.Int32)
            .alias("prev_marker"),
        ])
        
        self.event_ids = self.df['event_id'].unique().to_list()
        self.max_length = max_length
        # Keep time for sorting but exclude from features
        self.feature_cols = [c for c in self.df.columns 
                           if c not in {'event_id', 'marker', 'time'}]
        
        print("Precomputing samples...")
        self._precompute_samples()
        print("Computing class weights...")
        self._class_weights = self.compute_class_weights()
    
    @property
    def class_weights(self):
        # Expose the computed weights as a property.
        return self._class_weights 

    def __len__(self):
        return len(self.event_ids)

    def __getitem__(self, idx):
        return self.samples[idx]
    
    def _precompute_samples(self):
        self.samples = []
        for event_id in tqdm(self.event_ids, desc='precomputing_samples'):
            # Sort by time within each event!
            event_data = self.df.filter(pl.col("event_id") == event_id).sort("time")
            features = torch.tensor(
                event_data.select(self.feature_cols).to_numpy(),
                dtype=torch.float32
            )
            features = self._pad_sequence(features)
            
            label = event_data['marker'][0]
            self.samples.append((
                torch.tensor(label, dtype=torch.float32), 
                features
            ))
    
    def compute_class_weights(self):
        """
        Compute inverse frequency weights based on the 'marker' column.
        Assumes markers are "Stimulus/A" and "Stimulus/P".
        """
        # Get unique combinations of event_id and marker.
        unique_events = self.df.select(["event_id", "marker"]).unique()
        
        # Use value_counts on the "marker" column.
        counts_df = unique_events["marker"].value_counts()

        # We'll use 'values' if it exists, otherwise 'marker'.
        d = { (row.get("values") or row.get("marker")): row["count"] 
            for row in counts_df.to_dicts() }
        
        weight_L = 1.0 / d.get(0, 1)
        weight_R = 1.0 / d.get(1, 1)
        return {"Left": weight_L, "Right": weight_R}
   
    def split_dataset(self, ratios=(0.7, 0.15, 0.15), seed=None):
        """
        Splits the dataset into three EEGDataset instances for train, val, and test.
        This method shuffles the event_ids and then partitions them based on the given ratios.
        """
        if seed is not None:
            np.random.seed(seed)
        
        # Copy and shuffle the event_ids
        event_ids = self.event_ids.copy()
        np.random.shuffle(event_ids)
        total = len(event_ids)
        
        n_train = int(ratios[0] * total)
        n_val   = int(ratios[1] * total)
        
        train_ids = event_ids[:n_train]
        val_ids   = event_ids[n_train:n_train+n_val]
        test_ids  = event_ids[n_train+n_val:]
        
        # Filter self.df for the selected event_ids
        train_df = self.df.filter(pl.col("event_id").is_in(train_ids))
        val_df   = self.df.filter(pl.col("event_id").is_in(val_ids))
        test_df  = self.df.filter(pl.col("event_id").is_in(test_ids))
        
        # Create new EEGDataset instances using the filtered data
        train_set = EEGDataset(train_df, self.max_length)
        val_set   = EEGDataset(val_df, self.max_length)
        test_set  = EEGDataset(test_df, self.max_length)
        
        return train_set, val_set, test_set

    def _pad_sequence(self, tensor):
        # Pre-allocate tensor for maximum efficiency
        padded = torch.zeros((self.max_length, tensor.size(1)), dtype=tensor.dtype)
        length = min(tensor.size(0), self.max_length)
        padded[:length] = tensor[:length]
        return padded
    
    def rebalance_by_tsmote(self):
        """TSMOTE implementation for temporal EEG data"""
        # Extract time-ordered features as 3D array (samples, timesteps, features)
        X = np.stack([features.numpy() for _, features in self.samples])
        y = np.array([label.item() for label, _ in self.samples])
        
        # Apply TSMOTE with temporal awareness
        
        
        # Find the index of 'prev_marker' in the feature columns
        prev_marker_idx = self.feature_cols.index('prev_marker')
        
        # Apply TSMOTE with the correct boolean column index
        tsmote = TSMOTE(bool_cols=[prev_marker_idx])
        X_res, y_res = tsmote.fit_resample(X, y)

        # Generate synthetic temporal events
        new_events = []
        new_event_id = self.df['event_id'].max() + 1
        time_base = np.arange(self.max_length)
        original_schema = self.df.schema

        # Create dtype conversion map
        dtype_map = {
            pl.Float64: np.float64,
            pl.Float32: np.float32,
            pl.Int64: np.int64,
            pl.Int32: np.int32,
            pl.Utf8: str,
        }

        # Process synthetic samples (original samples come first in X_res)
        for features_3d, label in zip(X_res[len(self.samples):], y_res[len(self.samples):]):
            event_data = {}

            # Ensure columns are added in the original DataFrame's order
            for col in self.df.columns:
                if col == 'event_id':
                    event_data[col] = [new_event_id] * self.max_length
                elif col == 'marker':
                    event_data[col] = [int(label)] * self.max_length  # Ensure label is integer
                elif col == 'time':
                    event_data[col] = time_base.copy().astype(np.int32)  # Match original time type
                else:
                    # Feature columns (excluding event_id, marker, time)
                    if col not in self.feature_cols:
                        continue  # Shouldn't happen as feature_cols covers all else
                    col_idx = self.feature_cols.index(col)
                    col_data = features_3d[:, col_idx]
                    schema_type = original_schema[col]

                    # Handle data types
                    if isinstance(schema_type, pl.List):
                        base_type = schema_type.inner
                        target_type = dtype_map.get(type(base_type), np.float64)
                    else:
                        target_type = dtype_map.get(type(schema_type), np.float64)
                    
                    col_data = col_data.astype(target_type)
                    
                    # Maintain integer precision for Int columns (e.g., prev_marker)
                    if schema_type in (pl.Int64, pl.Int32):
                        col_data = np.round(col_data).astype(int)
                    
                    event_data[col] = col_data

            # Create DataFrame with strict schema adherence
            event_df = pl.DataFrame(event_data).cast(original_schema)
            new_events.append(event_df)
            new_event_id += 1

        # Update dataset with synthetic temporal events
        self.df = pl.concat([self.df, *new_events])
        self.event_ids = self.df['event_id'].unique().to_list()
        self._precompute_samples()
        self._class_weights = self.compute_class_weights()
        return self


In [26]:


def create_optimizer(model, optimizer_config):
    """
    Create and return an optimizer based on optimizer_config.
    """
    opt_type = optimizer_config.get('type', 'AdamW')
    lr = optimizer_config.get('lr', 1e-3)
    weight_decay = optimizer_config.get('weight_decay', 1e-2)
    
    if opt_type == 'AdamW':
        return optim.AdamW(
            model.parameters(), 
            lr=lr, 
            weight_decay=weight_decay,
        )
    elif opt_type == 'SGD':
        momentum = optimizer_config.get('momentum', 0.9)
        return optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
    elif opt_type == 'RMSProp':
        alpha = optimizer_config.get('alpha', 0.99)
        return optim.RMSprop(model.parameters(), lr=lr, alpha=alpha, weight_decay=weight_decay)
    # Add more optimizers if needed
    else:
        raise ValueError(f"Unsupported optimizer type: {opt_type}")

def create_scheduler(optimizer, scheduler_config, total_epochs=None, steps_per_epoch=None):
    """
    Create and return a scheduler based on scheduler_config.
    Return (scheduler, requires_val_loss),
    where requires_val_loss indicates if the scheduler needs validation loss (e.g. ReduceLROnPlateau).
    """
    if not scheduler_config or scheduler_config.get('type') is None:
        # No scheduler used
        return None, False

    sched_type = scheduler_config['type']

    if sched_type == 'ReduceLROnPlateau':
        scheduler = ReduceLROnPlateau(
            optimizer,
            mode=scheduler_config.get('mode', 'min'),
            factor=scheduler_config.get('factor', 0.1),
            patience=scheduler_config.get('patience', 10),
            threshold=scheduler_config.get('threshold', 0.0001),
            cooldown=scheduler_config.get('cooldown', 0),
            min_lr=scheduler_config.get('min_lr', 0),
            verbose=scheduler_config.get('verbose', False)
        )
        return scheduler, True

    elif sched_type == 'CosineAnnealingLR':
        T_max = scheduler_config.get('T_max', 10)
        eta_min = scheduler_config.get('eta_min', 0)
        scheduler = CosineAnnealingLR(optimizer, T_max=T_max, eta_min=eta_min)
        return scheduler, False

    elif sched_type == 'CyclicLR':
        base_lr = scheduler_config.get('base_lr', 1e-3)
        max_lr = scheduler_config.get('max_lr', 1e-2)
        step_size_up = scheduler_config.get('step_size_up', 2000)
        mode = scheduler_config.get('mode', 'triangular')
        scheduler = CyclicLR(
            optimizer,
            base_lr=base_lr,
            max_lr=max_lr,
            step_size_up=step_size_up,
            mode=mode
        )
        return scheduler, False

    elif sched_type == 'OneCycleLR':
        # OneCycleLR requires total_steps or (epochs * steps_per_epoch).
        # If not specified in config, try to compute from total_epochs and steps_per_epoch.
        max_lr = scheduler_config.get('max_lr', 1e-2)
        if 'total_steps' in scheduler_config:
            total_steps = scheduler_config['total_steps']
        else:
            if total_epochs is None or steps_per_epoch is None:
                raise ValueError(
                    "OneCycleLR requires either 'total_steps' in config "
                    "or (total_epochs and steps_per_epoch) arguments."
                )
            total_steps = total_epochs * steps_per_epoch

        scheduler = OneCycleLR(
            optimizer,
            max_lr=max_lr,
            total_steps=total_steps,
            pct_start=scheduler_config.get('pct_start', 0.3),
            anneal_strategy=scheduler_config.get('anneal_strategy', 'cos'),
            cycle_momentum=scheduler_config.get('cycle_momentum', True),
            base_momentum=scheduler_config.get('base_momentum', 0.85),
            max_momentum=scheduler_config.get('max_momentum', 0.95),
            div_factor=scheduler_config.get('div_factor', 25.0),
            final_div_factor=scheduler_config.get('final_div_factor', 1e4)
        )
        return scheduler, False

    else:
        raise ValueError(f"Unsupported scheduler type: {sched_type}")


In [27]:
def collate_fn(batch):
    """
    Collate function for variable-length EEG feature sequences.

    Each sample is expected to be a tuple (label, feature), where:
    - label is a scalar tensor (or 1D tensor) representing the class/target.
    - feature is a tensor of shape (seq_len, num_channels), where seq_len may vary.

    This function stacks labels and pads features along the time dimension so that
    all sequences in the batch have the same length.
    """
    # Unzip the batch into labels and features
    labels, features = zip(*batch)
    
    labels = torch.stack(labels)
    padded_features = pad_sequence(features, batch_first=True)
    
    return labels, padded_features


def train_model(config, train_set, train_loader, val_loader, writer):
    # -------------------- MODEL --------------------
    model = EEGMobileNet(
        in_channels=64,
        num_classes=1,
        dropout=config['dropout']
    ).to(config['device'])
    
    # Log model architecture and config
    writer.add_text("Model/Type", f"EEGMobileNet with dropout={config['dropout']}")
    writer.add_text("Model/Structure", str(model))
    writer.add_text("Training Config", str(config))
    
    # ------------------ LOSS FUNCTION ------------------
    pos_weight = torch.tensor([
        train_set.class_weights['Left'] / train_set.class_weights['Right']
    ]).to(config['device'])
    criterion = torch.nn.BCEWithLogitsLoss(weight=pos_weight)
    
    # ------------------- OPTIMIZER ---------------------
    optimizer_config = config['optimizer']
    # Inject global lr & weight_decay into optimizer_config
    optimizer_config['lr'] = config['lr']
    optimizer_config['weight_decay'] = config['weight_decay']
    
    optimizer = create_optimizer(model, optimizer_config)
    
    # ------------------- SCHEDULER ---------------------
    steps_per_epoch = len(train_loader)
    scheduler_config = config.get('scheduler', {})
    scheduler_config['factor'] = config['factor']
    scheduler_config['patience'] = config['patience']
    scheduler_config['cooldown'] = config['cooldown']
    
    scheduler, requires_val_loss = create_scheduler(
        optimizer,
        scheduler_config,
        total_epochs=config['epochs'],
        steps_per_epoch=steps_per_epoch
    )
    
    # ------------------- WARMUP SCHEDULER ---------------
    warmup_epochs = config.get('warmup_epochs', 0)
    if warmup_epochs > 0:
        warmup_scheduler = LambdaLR(
            optimizer,
            lambda epoch: min(1.0, (epoch + 1) / warmup_epochs)
        )
    else:
        warmup_scheduler = None
    
    # -------------------- TRAINING LOOP --------------------
    best_metric = -float('inf')
    
    for epoch in tqdm(range(config['epochs']), desc="Training"):
        # ---------- TRAIN ----------
        model.train()
        train_loss = 0.0
        
        for labels, features in train_loader:
            features = features.to(config['device']).float()
            labels = labels.to(config['device']).float()
            
            optimizer.zero_grad()
            outputs = model(features)
            loss = criterion(outputs, labels)
            loss.backward()
            
            # Gradient clipping (if specified)
            if config.get('grad_clip') is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), config['grad_clip'])
            
            optimizer.step()
            train_loss += loss.item()
        
        train_loss /= len(train_loader)
        
        # ---------- VALIDATION ----------
        model.eval()
        val_loss = 0.0
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for labels, features in val_loader:
                features = features.to(config['device']).float()
                labels = labels.to(config['device']).float()
                
                outputs = model(features)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                
                preds = torch.sigmoid(outputs)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        
        val_loss /= len(val_loader)
        predictions = (np.array(all_preds) > 0.5).astype(int)
        
        # ---------- METRICS ----------
        accuracy = accuracy_score(all_labels, predictions)
        precision = precision_score(all_labels, predictions)
        recall = recall_score(all_labels, predictions)
        f1 = f1_score(all_labels, predictions)
        
        # ---------- SCHEDULER UPDATE ----------
        current_lr = optimizer.param_groups[0]['lr']
        
        if warmup_scheduler is not None and epoch < warmup_epochs:
            warmup_scheduler.step()
        else:
            if scheduler is not None:
                if requires_val_loss:
                    # e.g. ReduceLROnPlateau
                    scheduler.step(val_loss)
                else:
                    scheduler.step()
        
        # ---------- LOGGING ----------
        writer.add_scalar('LR', current_lr, epoch)
        writer.add_scalar('Loss/Train', train_loss, epoch)
        writer.add_scalar('Loss/Val', val_loss, epoch)
        writer.add_scalar('Accuracy', accuracy, epoch)
        writer.add_scalar('Precision', precision, epoch)
        writer.add_scalar('Recall', recall, epoch)
        writer.add_scalar('F1', f1, epoch)
        
        # You can also combine them in a single dictionary
        metrics = {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1
        }
        writer.add_scalars('Metrics', metrics, epoch)
        
        # ---------- SAVE BEST MODEL ----------
        if accuracy > best_metric:
            best_metric = accuracy
            torch.save(model.state_dict(), f"{config['log_dir']}/best_model.pth")
    
    writer.close()
    return model


In [28]:
# ---------- OPTIMIZERS CONFIGS ----------

adamw_config = {
    'type': 'AdamW',
    'weight_decay': 1e-4
}


sgd_config = {
    'type': 'SGD',
    'momentum': 0.9,
    'weight_decay': 1e-4
}

rmsprop_config = {
    'type': 'RMSProp',
    'alpha': 0.99,
    'weight_decay': 1e-5
}

# ---------- SCHEDULERS CONFIGS ----------

reduce_on_plateau_config = {
    'type': 'ReduceLROnPlateau',
    'mode': 'min',
    'factor': 0.1,
    'patience': 10,
    'threshold': 0.0001,
    'cooldown': 10,
    'min_lr': 1e-8
}

cosine_annealing_config = {
    'type': 'CosineAnnealingLR',
    'T_max': 10,       # Number of iterations (e.g., epochs) to restart from max LR
    'eta_min': 0       # Minimum learning rate
}

cycliclr_config = {
    'type': 'CyclicLR',
    'base_lr': 1e-4,     # Lower learning rate bound
    'max_lr': 1e-3,      # Upper learning rate bound
    'step_size_up': 2000,# Number of training iterations (batches) in the increasing half of a cycle
    'mode': 'triangular' # 'triangular', 'triangular2', or 'exp_range'
}

onecyclelr_config = {
    'type': 'OneCycleLR',
    'max_lr': 1e-3,
    'pct_start': 0.3,
    'anneal_strategy': 'cos',   # 'cos' or 'linear'
    'cycle_momentum': True,
    'base_momentum': 0.85,
    'max_momentum': 0.95,
    'div_factor': 25.0,
    'final_div_factor': 1e4
    # 'total_steps':  ...  # Provide explicitly OR use (epochs * steps_per_epoch) if left out
}


In [29]:
import pickle
with open('study.pkl', 'rb') as f:
    config_study = pickle.load(f)

In [30]:
config_study

{'lr': 0.00014782765350118218,
 'weight_decay': 4.39923324740438e-06,
 'dropout': 0.44201278804845473,
 'factor': 0.7470511809208503,
 'patience': 5,
 'cooldown': 17}

In [None]:
config = {
    'data_path': '/home/owner/Documents/DEV/BrainLabyrinth/data/combined.parquet',
    'split_ratios': (0.7, 0.15, 0.15),
    'batch_size': 32,
    'dropout': config_study['dropout'],
    'epochs': 300,
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    'log_dir': './runs/CNN',

    # <<< Global LR and Weight Decay here >>>
    'lr': config_study['lr'],
    'weight_decay': config_study['weight_decay'],
    'factor': config_study['factor'],
    'patience': config_study['patience'],
    'cooldown': config_study['cooldown'],
    
    # Optimizer config (without lr/weight_decay)
    'optimizer': adamw_config,

    # Scheduler config
    'scheduler': reduce_on_plateau_config,

    'warmup_epochs': 10,
    'grad_clip': None
}



# Replace the hardcoded scheduler config with Optuna parameters
config['scheduler'] = {
    'type': 'ReduceLROnPlateau',
    'mode': 'min',
    'factor': config_study['factor'],      # From Optuna
    'patience': config_study['patience'],  # From Optuna
    'cooldown': config_study['cooldown'],  # From Optuna
    'min_lr': 1e-8,
    'threshold': 0.0001,
}




In [32]:
# #============================================================
# # Training Pipeline
# #============================================================
# import warnings
# warnings.filterwarnings("ignore", category=FutureWarning)

# # Initialize dataset
# print("Creating full dataset...")
# full_dataset = EEGDataset(config['data_path'])

# print("Splitting the dataset...")
# # Split dataset
# train_set, val_set, test_set = full_dataset.split_dataset(
#     ratios=config['split_ratios']
# )

# del full_dataset

# len_dataset = len(train_set)
# sample = train_set[0]
# label_shape = sample[0].shape
# feature_shape = sample[1].shape

# print(f"unbalanced train dataset shape: ({len_dataset}, [labels: {label_shape}, features: {list(feature_shape)}])")

# # # Balance training set
# # print("Applying SMOTE to train dataset...")
# # train_set.rebalance_by_tsmote()

# len_dataset = len(train_set)
# sample = train_set[0]
# label_shape = sample[0].shape
# feature_shape = sample[1].shape

# print(f"balanced train dataset shape: ({len_dataset}, [labels: {label_shape}, features: {list(feature_shape)}])")

# torch.save(train_set, 'train_set.pt')
# torch.save(val_set, 'val_set.pt')
# torch.save(test_set, 'test_set.pt')

In [33]:

train_set = torch.load('train_set_smol.pt', weights_only=False)
val_set = torch.load('val_set.pt', weights_only=False)
test_set = torch.load('test_set.pt', weights_only=False)


generator = torch.Generator().manual_seed(69)  # Set seed
train_loader = DataLoader(
    train_set,
    batch_size=config['batch_size'],
    shuffle=True,
    generator=generator,  # Add this line
    num_workers=12,
    pin_memory=True,
    persistent_workers=True,
    collate_fn=collate_fn
)
val_loader = DataLoader(val_set, batch_size=config['batch_size'], collate_fn=collate_fn)
test_loader = DataLoader(test_set, batch_size=config['batch_size'], collate_fn=collate_fn)

len_dataset = len(train_set)
sample = train_set[0]
label_shape = sample[0].shape
feature_shape = sample[1].shape

print(f"train dataset shape: ({len_dataset}, [labels: {label_shape}, features: {list(feature_shape)}])")


# Set up logging
writer = SummaryWriter(log_dir=config['log_dir'])

KeyboardInterrupt: 

In [None]:
# Start training
trained_model = train_model(config, train_set, train_loader, val_loader, writer)



Training:   0%|          | 0/300 [00:00<?, ?it/s]

In [None]:
import torch
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

pos_weight = torch.tensor([
    train_set.class_weights['Left'] / train_set.class_weights['Right']
]).to(config['device'])
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

epoch = 1
# Assuming model, criterion, test_loader, device, writer, and epoch are already defined
# Instantiate your model
best_model = EEGMobileNet()  # Adjust parameters as needed

# Load the state dictionary
state_dict = torch.load(f"{config['log_dir']}/best_model.pth", map_location=config['device'])
best_model.load_state_dict(state_dict)

# Move model to the correct device
best_model = best_model.to(config['device'])

# Set model to evaluation mode
best_model.eval()

test_loss = 0
all_test_markers = []
all_test_predictions = []
with torch.no_grad():
    for markers, features in tqdm(test_loader):
        features = features.to(config['device'])
        markers = markers.to(config['device'])

        outputs = best_model(features)
        loss = criterion(outputs, markers)
        test_loss += loss.item()

        # Collect markers and predictions for metrics calculation
        all_test_markers.extend(markers.cpu().numpy().flatten())
        all_test_predictions.extend(torch.sigmoid(outputs).cpu().numpy().flatten())

test_loss /= len(test_loader)
# Calculate test metrics
test_accuracy = accuracy_score(all_test_markers, [1 if p > 0.5 else 0 for p in all_test_predictions])
test_precision = precision_score(all_test_markers, [1 if p > 0.5 else 0 for p in all_test_predictions])
test_recall = recall_score(all_test_markers, [1 if p > 0.5 else 0 for p in all_test_predictions])
test_f1 = f1_score(all_test_markers, [1 if p > 0.5 else 0 for p in all_test_predictions])
test_roc_auc = roc_auc_score(all_test_markers, all_test_predictions)

# Log test metrics to TensorBoard
writer.add_scalar('Metrics/test_accuracy', test_accuracy, epoch)
writer.add_scalar('Metrics/test_precision', test_precision, epoch)
writer.add_scalar('Metrics/test_recall', test_recall, epoch)
writer.add_scalar('Metrics/test_f1', test_f1, epoch)
writer.add_scalar('Metrics/test_roc_auc', test_roc_auc, epoch)

# Close the TensorBoard writer
writer.close()

  state_dict = torch.load(f"{config['log_dir']}/best_model.pth", map_location=config['device'])


  0%|          | 0/14 [00:00<?, ?it/s]

In [None]:
print(f"""
{test_accuracy=}
{test_precision=}
{test_recall=}
{test_f1=}
{test_roc_auc=}
"""
)


test_accuracy=0.5323741007194245
test_precision=0.5176991150442478
test_recall=0.5763546798029556
test_f1=0.5454545454545454
test_roc_auc=np.float64(0.5681368261129782)



In [None]:
from sklearn.metrics import f1_score
import numpy as np
best_threshold = 0.0
best_f1 = 0.0
thresholds = np.arange(0.1, 1.0, 0.01)

for threshold in tqdm(thresholds):
    binary_predictions = (all_test_predictions > threshold).astype(int)
    current_recall = f1_score(all_test_markers, binary_predictions)

    if current_recall > best_f1:
        best_f1 = current_recall
        best_threshold = threshold

print(f"{best_threshold=}")
print(f"{best_f1=}")

  0%|          | 0/90 [00:00<?, ?it/s]

best_threshold=np.float64(0.2599999999999999)
best_f1=0.657439446366782


In [None]:
from sklearn.metrics import recall_score
import numpy as np
best_threshold = 0.1
best_recall = 0.0
thresholds = np.arange(0.1, 1.0, 0.01)

for threshold in tqdm(thresholds):
    binary_predictions = (all_test_predictions > threshold).astype(int)
    current_recall = recall_score(all_test_markers, binary_predictions)

    if current_recall > best_recall:
        best_recall = current_recall
        best_threshold = threshold

print(f"{best_threshold=}")
print(f"{best_recall=}")

  0%|          | 0/90 [00:00<?, ?it/s]

best_threshold=np.float64(0.1)
best_recall=0.9852216748768473
