In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.tensorboard import SummaryWriter

import polars as pl
from imblearn.over_sampling import SMOTE

from sklearn.model_selection import train_test_split

from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

import random
import numpy as np

from scipy.interpolate import CubicSpline

# Set the seed for Python's built-in random module
random.seed(69)

# Set the seed for NumPy's random number generator
np.random.seed(69)

# Set the seed for PyTorch's random number generators
torch.manual_seed(69)
torch.cuda.manual_seed(69)
torch.backends.cudnn.deterministic = True

In [2]:
pl.read_parquet('/home/owner/Documents/DEV/BrainLabyrinth/data/combined.parquet')\
    # ['marker'].unique()


event_id,orig_marker,time,Fp1,Fpz,Fp2,F7,F3,Fz,F4,F8,FC5,FC1,FC2,FC6,M1,T7,C3,Cz,C4,T8,M2,CP5,CP1,CP2,CP6,P7,P3,Pz,P4,P8,POz,O1,O2,AF7,AF3,AF4,AF8,F5,F1,F2,F6,FC3,FCz,FC4,C5,C1,C2,C6,CP3,CP4,P5,P1,P2,P6,PO5,PO3,PO4,PO6,FT7,FT8,TP7,TP8,PO7,PO8,Oz,marker,prev_marker
i64,str,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,str,str
0,"""Stimulus/A""",16.458,17.170373,13.906436,21.251943,0.04914,6.744807,-11.924581,-0.963751,0.169659,5.941923,10.144309,-5.712161,0.856205,-13.592378,5.55773,5.277309,4.191759,11.859419,-5.135765,-22.437368,-15.228711,-12.038605,-0.75959,-1.466066,-21.842841,13.373856,-9.149643,0.330315,4.335297,-5.321761,16.798603,12.262668,30.010249,16.320185,7.076928,21.956538,11.408192,-0.31748,2.813686,6.930768,4.998787,0.411273,4.038051,10.464517,5.762318,1.730282,-0.464392,-5.006052,-8.962569,13.681179,-9.582485,-17.106364,13.691397,12.681116,11.514871,25.205208,14.079291,11.501145,1.636257,5.387385,6.44589,14.401535,13.353615,8.559186,"""Right""","""Left"""
0,"""Stimulus/A""",16.46,18.195809,14.819681,21.679231,1.756323,7.864059,-11.064782,-0.080928,0.735955,6.969142,10.664451,-4.989593,1.752117,-12.19725,7.671375,5.210584,4.463182,12.858593,-2.887153,-20.226918,-15.193608,-11.675535,-0.408616,-0.500061,-21.86162,12.080171,-9.111754,0.608741,5.29489,-6.487817,14.594936,11.257213,30.903666,16.686429,7.859665,21.840979,13.346964,0.369015,3.588863,7.620327,5.759861,1.229387,4.697295,10.818292,6.213349,2.547987,0.577626,-4.975261,-8.451353,12.374722,-9.802303,-17.080294,13.566208,10.967429,9.950173,24.596992,13.09233,12.683138,3.16226,6.085818,8.821207,12.746989,12.518362,5.916264,"""Right""","""Left"""
0,"""Stimulus/A""",16.462,19.208807,15.713978,22.097658,3.589434,8.999892,-10.254515,0.710523,1.346892,7.978441,11.145332,-4.357971,2.793821,-10.682037,10.741887,5.124484,4.718824,13.820216,-1.052706,-18.077094,-15.224916,-11.403306,-0.094469,0.433875,-22.097145,10.625836,-9.156111,0.810305,6.464841,-7.659126,12.052806,10.762715,31.842351,17.043632,8.56405,22.021347,15.409027,1.007149,4.277586,8.126715,6.484071,2.037813,5.304125,11.000511,6.634963,3.363777,1.715886,-4.994935,-7.956395,10.896297,-10.095891,-17.036856,13.747517,8.995882,8.208103,24.074887,12.511693,13.800554,4.745254,6.843331,11.128809,10.785815,12.08669,3.017775,"""Right""","""Left"""
0,"""Stimulus/A""",16.464,20.14623,16.545905,22.489029,5.424029,10.097336,-9.551693,1.345182,1.98236,8.906685,11.551465,-3.870977,3.916049,-9.132715,14.534485,5.035828,4.941408,14.682348,0.284907,-16.081929,-15.290949,-11.246743,0.171751,1.3134,-22.523215,9.146861,-9.263138,0.941635,7.810634,-8.737122,9.412096,10.861548,32.818187,17.386197,9.129862,22.509357,17.445482,1.540144,4.819262,8.413686,7.13795,2.780287,5.815162,11.045499,6.992499,4.122729,2.89388,-5.057835,-7.497601,9.387137,-10.433307,-16.954144,14.257214,6.958701,6.45094,23.693334,12.429921,14.822983,6.309243,7.631156,13.27648,8.708844,12.141982,0.15928,"""Right""","""Left"""
0,"""Stimulus/A""",16.466,20.946355,17.271059,22.827998,7.140907,11.096112,-9.006329,1.771387,2.620716,9.695722,11.849742,-3.56821,5.041098,-7.624744,18.712668,4.959592,5.110745,15.390717,1.107496,-14.316858,-15.355903,-11.212635,0.38525,2.12328,-23.090999,7.784682,-9.401254,1.018174,9.261804,-9.633229,6.924365,11.55411,33.814531,17.705811,9.512561,23.27303,19.314331,1.917181,5.166991,8.46995,7.688362,3.401579,6.191096,11.003568,7.254085,4.772919,4.046689,-5.150528,-7.09336,7.995443,-10.775587,-16.813751,15.07551,5.057434,4.844601,23.494136,12.871215,15.72843,7.780457,8.422234,15.192707,6.723113,12.701932,-2.368745,"""Right""","""Left"""
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
2771,"""Stimulus/A""",956.836,-18.873559,-24.227101,-13.486943,-0.403397,-11.98354,-16.557347,-17.518873,-14.536315,-6.105563,-10.793705,-4.853747,-7.456738,10.414136,-2.954412,-5.026366,5.687796,-2.821944,3.946126,4.366952,-0.86041,-2.708516,3.370458,7.267404,2.82558,2.815112,-1.849568,9.221112,7.698055,2.700318,3.765467,4.554788,6.173636,-7.883982,-18.15556,-1.624224,-10.617139,-18.317587,-17.6747,-16.561863,-4.705041,-4.589931,1.813413,-4.839258,-4.689196,0.565952,-2.825098,-4.87889,4.368522,-1.258749,4.727739,7.158111,5.224539,2.009805,2.215446,9.938263,10.332691,0.287628,-6.543679,5.453187,0.753634,3.273086,10.787134,5.02536,"""Left""","""Right"""
2771,"""Stimulus/A""",956.838,-18.704103,-24.969565,-14.585048,-1.047141,-11.943701,-16.45773,-17.959561,-15.664664,-7.082881,-11.007107,-5.073525,-8.186207,8.59455,-4.468591,-5.374392,5.653907,-2.543513,2.916711,1.86508,-2.226537,-2.766756,3.701948,7.189953,0.664609,2.01553,-1.72148,9.666301,7.382395,2.65323,2.507406,4.086716,5.953678,-7.897042,-18.654489,-2.907517,-10.716993,-18.135785,-18.222283,-17.620759,-5.103536,-4.95171,1.615306,-5.864539,-4.802275,0.860669,-3.010969,-5.317478,4.972083,-2.620815,4.514949,7.682226,5.731373,0.435993,0.715036,10.155613,10.351784,-0.944356,-7.749362,3.527842,-0.227262,1.671753,10.779771,4.207392,"""Left""","""Right"""
2771,"""Stimulus/A""",956.84,-18.061667,-25.302962,-15.126122,-1.543149,-11.638041,-16.168372,-18.234343,-16.689475,-7.844465,-11.080084,-5.22795,-8.864504,6.983161,-5.717354,-5.553369,5.590639,-2.268625,1.711394,-0.604882,-3.276609,-2.707251,3.97878,6.988798,-1.132993,1.500356,-1.638895,10.013091,6.84536,2.529954,1.345095,3.459218,5.952675,-7.653053,-18.907577,-3.972871,-10.577918,-17.735399,-18.616241,-18.516636,-5.395328,-5.275623,1.45044,-6.564501,-4.839488,1.135505,-3.189045,-5.473588,5.565337,-3.532208,4.463409,8.272403,6.121259,-0.84287,-0.486312,10.264764,10.23355,-2.014752,-9.047787,1.966788,-1.368454,0.333598,10.611117,3.436481,"""Left""","""Right"""
2771,"""Stimulus/A""",956.842,-17.023192,-25.220666,-15.079919,-1.852834,-11.10314,-15.743369,-18.338312,-17.531294,-8.309149,-11.006394,-5.293303,-9.428791,5.715629,-6.592615,-5.547356,5.501853,-2.008694,0.434555,-2.88224,-3.923841,-2.530573,4.190703,6.70415,-2.430261,1.293626,-1.613443,10.266805,6.137341,2.336787,0.363123,2.739203,6.169374,-7.181344,-18.903561,-4.734289,-10.219783,-17.170485,-18.840451,-19.18934,-5.546836,-5.528552,1.357458,-6.879141,-4.784823,1.373481,-3.355288,-5.336791,6.114987,-3.940727,4.575404,8.896693,6.400383,-1.731706,-1.306744,10.309981,10.048173,-2.819302,-10.318089,0.896298,-2.548241,-0.638066,10.358187,2.779875,"""Left""","""Right"""


In [3]:
# #============================================================
# # Model Architecture
# #============================================================
# class EEGDSConv(nn.Module):
#     def __init__(self, dropout=0.5):
#         super().__init__()
#         self.block = nn.Sequential(
#             nn.Conv1d(64, 64, 15, padding='same', groups=64),
#             nn.Conv1d(64, 16, 1),
#             nn.BatchNorm1d(16),
#             nn.ReLU(),
#             nn.MaxPool1d(4),
#             nn.Dropout(dropout),
#             nn.Conv1d(16, 16, 7, padding='same', groups=16),
#             nn.Conv1d(16, 8, 1),
#             nn.BatchNorm1d(8),
#             nn.ReLU(),
#             nn.AdaptiveAvgPool1d(1),
#             nn.Flatten(),
#             nn.Linear(8, 1)
#         )
    
#     def forward(self, x):
#         x = x.permute(0, 2, 1)
#         return self.block(x).squeeze(-1)  # Squeeze last dimension to match target shape
    
    


In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
class EEGMobileNet(nn.Module):
    def __init__(self, in_channels=64, num_classes=1, dropout=0.5):
        super().__init__()
        self.model = nn.Sequential(
            # Initial Conv
            nn.Conv1d(in_channels, 32, kernel_size=15, stride=2, padding=7),
            nn.BatchNorm1d(32),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),  # ← Insert dropout here

            # Depthwise
            nn.Conv1d(32, 32, kernel_size=3, stride=1, padding=1, groups=32),
            nn.BatchNorm1d(32),
            nn.ReLU(inplace=True),
            # Pointwise
            nn.Conv1d(32, 64, kernel_size=1),
            nn.BatchNorm1d(64),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),  # ← Insert dropout here

            # Another Depthwise Separable block
            nn.Conv1d(64, 64, kernel_size=3, stride=2, padding=1, groups=64),
            nn.BatchNorm1d(64),
            nn.ReLU(inplace=True),
            nn.Conv1d(64, 128, kernel_size=1),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),  # ← Insert dropout here

            # Global Average Pool
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = x.transpose(1, 2)  # your original transpose
        return self.model(x).squeeze(1)


In [5]:
import numpy as np
from tslearn.neighbors import KNeighborsTimeSeries
from numba import njit

# Standalone JIT-compiled function outside the class
@njit(fastmath=True)
def fast_interpolate(original, neighbor, alpha):
    """Numba-accelerated linear interpolation"""
    return (1 - alpha) * original + alpha * neighbor

class TSMOTE:
    def __init__(self, n_neighbors=3, time_slices=10):  # Keep original 2000 timesteps
        self.n_neighbors = n_neighbors
        self.time_slices = time_slices
        self.slice_size = 200  # 2000/10=200

    def _slice_time_series(self, X):
        """Split into time slices while maintaining full series structure"""
        return X.reshape(X.shape[0], self.time_slices, self.slice_size, X.shape[2])

    def _reconstruct_full_series(self, synthetic_slices):
        """Combine synthetic slices into full-length time series"""
        return synthetic_slices.reshape(-1, self.time_slices * self.slice_size, synthetic_slices.shape[-1])

    def _generate_synthetic(self, minority_samples):
        """Generate full-length synthetic samples"""
        sliced_data = self._slice_time_series(minority_samples)  # (N, slices, 200, ch)
        syn_samples = []
        
        # Generate 1 full synthetic sample per minority sample
        for sample_idx in tqdm(range(sliced_data.shape[0])):
            synthetic_slices = []
            
            # Process each time slice
            for slice_idx in range(self.time_slices):
                knn = KNeighborsTimeSeries(n_neighbors=self.n_neighbors, metric='dtw')
                knn.fit(sliced_data[:, slice_idx, :, :])
                
                # Find neighbors for this slice
                neighbors = knn.kneighbors(sliced_data[sample_idx, slice_idx][np.newaxis], 
                                         return_distance=False)[0]
                neighbor_idx = np.random.choice(neighbors)
                
                # Interpolate within slice
                alpha = np.random.uniform(0.2, 0.8)
                synthetic_slice = (1 - alpha) * sliced_data[sample_idx, slice_idx] + \
                                 alpha * sliced_data[neighbor_idx, slice_idx]
                synthetic_slices.append(synthetic_slice)
            
            # Combine slices into full series
            full_series = np.concatenate(synthetic_slices, axis=0)
            syn_samples.append(full_series)
        
        return np.array(syn_samples)

    def fit_resample(self, X, y):
        y_int = y.astype(int)
        class_counts = np.bincount(y_int)
        minority_class = np.argmin(class_counts)
        n_needed = class_counts[1 - minority_class] - class_counts[minority_class]
        
        if n_needed <= 0:
            return X, y
        
        minority_samples = X[y_int == minority_class]
        synthetic = self._generate_synthetic(minority_samples)
        
        # Ensure matching dimensions
        assert X.shape[1:] == synthetic.shape[1:], \
            f"Dimension mismatch: Original {X.shape[1:]}, Synthetic {synthetic.shape[1:]}"
        
        return (np.concatenate([X, synthetic[:n_needed]], axis=0),
                np.concatenate([y, [minority_class] * n_needed]))

In [None]:
#============================================================
# Enhanced Dataset Class with Proper Encapsulation
#============================================================
class EEGDataset(Dataset):
    def __init__(self, source, max_length=2000):
        self.df = pl.read_parquet(source) if isinstance(source, str) else source
        if 'orig_marker' in self.df.columns:
            self.df = self.df.drop('orig_marker')
        
        # label_map = {"Left": 0, "Right": 1}  
        # self.df = self.df.with_columns([
        #     pl.col("marker").replace(label_map).cast(pl.Int32).alias("marker"),
        #     pl.col("prev_marker").replace(label_map).cast(pl.Int32).alias("prev_marker")
        # ])
        
        self.df = self.df.with_columns([
            pl.col("marker")
            .cast(pl.Utf8)
            .str.replace_all("Left", "0")      # replace exact string "Left" with "0"
            .str.replace_all("Right", "1")     # replace exact string "Right" with "1"
            .cast(pl.Int32)                      # now cast the string "0"/"1" -> int
            .alias("marker"),
            
            pl.col("prev_marker")
            .cast(pl.Utf8)
            .str.replace_all("Left", "0")
            .str.replace_all("Right", "1")
            .cast(pl.Int32)
            .alias("prev_marker"),
        ])
        
        self.event_ids = self.df['event_id'].unique().to_list()
        self.max_length = max_length
        # Keep time for sorting but exclude from features
        self.feature_cols = [c for c in self.df.columns 
                           if c not in {'event_id', 'marker', 'time'}]
        
        print("Precomputing samples...")
        self._precompute_samples()
        print("Computing class weights...")
        self._class_weights = self.compute_class_weights()
    
    @property
    def class_weights(self):
        # Expose the computed weights as a property.
        return self._class_weights 

    def __len__(self):
        return len(self.event_ids)

    def __getitem__(self, idx):
        return self.samples[idx]
    
    def _precompute_samples(self):
        self.samples = []
        for event_id in tqdm(self.event_ids, desc='precomputing_samples'):
            # Sort by time within each event!
            event_data = self.df.filter(pl.col("event_id") == event_id).sort("time")
            features = torch.tensor(
                event_data.select(self.feature_cols).to_numpy(),
                dtype=torch.float32
            )
            features = self._pad_sequence(features)
            
            label = event_data['marker'][0]
            self.samples.append((
                torch.tensor(label, dtype=torch.float32), 
                features
            ))
    
    def compute_class_weights(self):
        """
        Compute inverse frequency weights based on the 'marker' column.
        Assumes markers are "Stimulus/A" and "Stimulus/P".
        """
        # Get unique combinations of event_id and marker.
        unique_events = self.df.select(["event_id", "marker"]).unique()
        
        # Use value_counts on the "marker" column.
        counts_df = unique_events["marker"].value_counts()

        # We'll use 'values' if it exists, otherwise 'marker'.
        d = { (row.get("values") or row.get("marker")): row["count"] 
            for row in counts_df.to_dicts() }
        
        weight_L = 1.0 / d.get(0, 1)
        weight_R = 1.0 / d.get(1, 1)
        return {"Left": weight_L, "Right": weight_R}
   
    def split_dataset(self, ratios=(0.7, 0.15, 0.15), seed=None):
        """
        Splits the dataset into three EEGDataset instances for train, val, and test.
        This method shuffles the event_ids and then partitions them based on the given ratios.
        """
        if seed is not None:
            np.random.seed(seed)
        
        # Copy and shuffle the event_ids
        event_ids = self.event_ids.copy()
        np.random.shuffle(event_ids)
        total = len(event_ids)
        
        n_train = int(ratios[0] * total)
        n_val   = int(ratios[1] * total)
        
        train_ids = event_ids[:n_train]
        val_ids   = event_ids[n_train:n_train+n_val]
        test_ids  = event_ids[n_train+n_val:]
        
        # Filter self.df for the selected event_ids
        train_df = self.df.filter(pl.col("event_id").is_in(train_ids))
        val_df   = self.df.filter(pl.col("event_id").is_in(val_ids))
        test_df  = self.df.filter(pl.col("event_id").is_in(test_ids))
        
        # Create new EEGDataset instances using the filtered data
        train_set = EEGDataset(train_df, self.max_length)
        val_set   = EEGDataset(val_df, self.max_length)
        test_set  = EEGDataset(test_df, self.max_length)
        
        return train_set, val_set, test_set

    def _pad_sequence(self, tensor):
        # Pre-allocate tensor for maximum efficiency
        padded = torch.zeros((self.max_length, tensor.size(1)), dtype=tensor.dtype)
        length = min(tensor.size(0), self.max_length)
        padded[:length] = tensor[:length]
        return padded
    
    def rebalance_by_tsmote(self):
        """TSMOTE implementation for temporal EEG data"""
        # Extract time-ordered features as 3D array (samples, timesteps, features)
        X = np.stack([features.numpy() for _, features in self.samples])
        y = np.array([label.item() for label, _ in self.samples])
        
        # Apply TSMOTE with temporal awareness
        tsmote = TSMOTE()
        X_res, y_res = tsmote.fit_resample(X, y)

        # Generate synthetic temporal events
        new_events = []
        new_event_id = self.df['event_id'].max() + 1
        time_base = np.arange(self.max_length)
        original_schema = self.df.schema

        # Create dtype conversion map
        dtype_map = {
            pl.Float64: np.float64,
            pl.Float32: np.float32,
            pl.Int64: np.int64,
            pl.Int32: np.int32,
            pl.Utf8: str,
        }

        # Process synthetic samples (original samples come first in X_res)
        for features_3d, label in zip(X_res[len(self.samples):], y_res[len(self.samples):]):
            event_data = {
                "event_id": [new_event_id] * self.max_length,
                "marker": [label] * self.max_length,
                "time": time_base.copy()
            }
            
            # Add features with proper temporal structure
            for col_idx, col in enumerate(self.feature_cols):
                col_data = features_3d[:, col_idx]
                schema_type = original_schema[col]
                
                # Handle data types
                if isinstance(schema_type, pl.List):
                    base_type = schema_type.inner
                    target_type = dtype_map.get(type(base_type), np.float64)
                else:
                    target_type = dtype_map.get(type(schema_type), np.float64)
                
                col_data = col_data.astype(target_type)
                
                # Maintain integer precision
                if schema_type in (pl.Int64, pl.Int32):
                    col_data = np.round(col_data).astype(int)
                
                event_data[col] = col_data

            # Create DataFrame with strict schema adherence
            event_df = pl.DataFrame(event_data).cast(original_schema)
            new_events.append(event_df)
            new_event_id += 1

        # Update dataset with synthetic temporal events
        self.df = pl.concat([self.df, *new_events])
        self.event_ids = self.df['event_id'].unique().to_list()
        self._precompute_samples()
        self._class_weights = self.compute_class_weights()
        return self


In [7]:
def collate_fn(batch):
    """
    Collate function for variable-length EEG feature sequences.

    Each sample is expected to be a tuple (label, feature), where:
    - label is a scalar tensor (or 1D tensor) representing the class/target.
    - feature is a tensor of shape (seq_len, num_channels), where seq_len may vary.

    This function stacks labels and pads features along the time dimension so that
    all sequences in the batch have the same length.
    """
    # Unzip the batch into labels and features
    labels, features = zip(*batch)
    
    labels = torch.stack(labels)
    # Optionally: labels = labels.unsqueeze(1)  # Uncomment if required by your loss function
    padded_features = pad_sequence(features, batch_first=True)
    
    return labels, padded_features


def train_model(config, train_set, train_loader, val_loader, writer):    
    # Model initialization
    model = EEGMobileNet(
        in_channels=64,  # Or from config if you wish
        num_classes=1,
        dropout=config['dropout']  # Use the dropout from config
    ).to(config['device'])
    #EEGDSConv(dropout=config['dropout']).to(config['device'])
    
    # Log model architecture and config
    writer.add_text("Model/Type", f"EEGDSConv with dropout={config['dropout']}")
    writer.add_text("Model/Structure", str(model))
    writer.add_text("Training Config", str(config))

    # Loss function and optimizer
    pos_weight = torch.tensor([
        train_set.class_weights['Left'] / train_set.class_weights['Right']
    ]).to(config['device'])
    criterion = torch.nn.BCEWithLogitsLoss(weight=pos_weight)
    
    optimizer = optim.AdamW(
        model.parameters(),
        lr=config['lr'],
        weight_decay=config['weight_decay']
    )
    
    # Learning rate schedulers
    scheduler = ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.5,
        patience=10,
        verbose=True
    )
    
    warmup_scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer,
        lambda epoch: min(1.0, (epoch + 1) / config['warmup_epochs'])
    )
    
    # Training loop
    best_metric = float('inf')
    
    for epoch in tqdm(range(config['epochs'])):
        # Training phase
        model.train()
        train_loss = 0
        for labels, features in train_loader:
            features = features.to(config['device']).float()
            labels = labels.to(config['device']).float()
            
            optimizer.zero_grad()
            outputs = model(features)
            loss = criterion(outputs, labels)
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(
                model.parameters(), 
                config['grad_clip']
            )
            
            optimizer.step()
            train_loss += loss.item()
        
        # Validation phase
        model.eval()
        val_loss = 0
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for labels, features in val_loader:
                features = features.to(config['device']).float()
                labels = labels.to(config['device']).float()
                
                outputs = model(features)
                val_loss += criterion(outputs, labels).item()
                
                preds = torch.sigmoid(outputs)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        
        # Calculate metrics
        train_loss /= len(train_loader)
        val_loss /= len(val_loader)
        predictions = (np.array(all_preds) > 0.5).astype(int)
        
        # Get current learning rate
        current_lr = optimizer.param_groups[0]['lr']
        
        # Update schedulers
        if epoch < config['warmup_epochs']:
            warmup_scheduler.step()
        else:
            scheduler.step(val_loss)
        
        # Log metrics and learning rate
        metrics = {
            'accuracy': accuracy_score(all_labels, predictions),
            'precision': precision_score(all_labels, predictions),
            'recall': recall_score(all_labels, predictions),
            'f1': f1_score(all_labels, predictions)
        }
        
        writer.add_scalar('LR', current_lr, epoch)
        writer.add_scalar('Loss/Train', train_loss, epoch)
        writer.add_scalar('Loss/Val', val_loss, epoch)
        writer.add_scalar('Accuracy', metrics['accuracy'], epoch)
        writer.add_scalar('Precision', metrics['precision'], epoch)
        writer.add_scalar('Recall', metrics['recall'], epoch)
        writer.add_scalar('F1', metrics['f1'], epoch)
        writer.add_scalars('Metrics', metrics, epoch)
        
        # Save best model
        if val_loss < best_metric:
            best_metric = val_loss
            torch.save(model.state_dict(), f"{config['log_dir']}/best_model.pth")
    
    writer.close()
    return model

In [8]:
# Configuration
config = {
    'data_path': '/home/owner/Documents/DEV/BrainLabyrinth/data/combined.parquet',
    'split_ratios': (0.7, 0.15, 0.15),
    'batch_size': 32,          # Increased for better generalization
    'dropout': 0.3,            # Reduced from 0.6 for better information flow
    'lr': 1e-4,                # Base learning rate (sweet spot between 1e-5 and 3e-3)
    'weight_decay': 0,      # Increased regularization
    'epochs': 200,
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    'log_dir': './runs/CNN',
    'lr_scheduler': {
        'mode': 'min',
        'factor': 0.1,         # More aggressive LR reduction
        'patience': 5,         # Faster response to plateaus
        'threshold': 0.001,
        'cooldown': 3
    },
    'grad_clip': 1.0,          # Add gradient clipping
    'warmup_epochs': 10        # Linear LR warmup
}


In [9]:
#============================================================
# Training Pipeline
#============================================================
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

# Initialize dataset
print("Creating full dataset...")
full_dataset = EEGDataset(config['data_path'])

print("Splitting the dataset...")
# Split dataset
train_set, val_set, test_set = full_dataset.split_dataset(
    ratios=config['split_ratios']
)

del full_dataset

len_dataset = len(train_set)
sample = train_set[0]
label_shape = sample[0].shape
feature_shape = sample[1].shape

print(f"unbalanced train dataset shape: ({len_dataset}, [labels: {label_shape}, features: {list(feature_shape)}])")

# Balance training set
print("Applying SMOTE to train dataset...")
train_set.rebalance_by_tsmote()

len_dataset = len(train_set)
sample = train_set[0]
label_shape = sample[0].shape
feature_shape = sample[1].shape

print(f"balanced train dataset shape: ({len_dataset}, [labels: {label_shape}, features: {list(feature_shape)}])")

# print("Augmenting train dataset...")
# train_set.augment_dataset(
#     n_times=3,              # Nx dataset expansion
#     noise_std=0.15,         # Moderate noise
#     scale_range=(0.7, 1.3), # ±30% amplitude variation
#     warp_range=(0.85, 1.15),# ±15% time warping
#     max_shift=15,           # 150ms temporal shifts
#     freq_shift=3,           # ±3Hz frequency shifts
#     mask_size=75,           # 750ms masking
#     drop_prob=0.15,         # 15% channel dropout
#     mixup_alpha=0.2         # Mixup
# )

torch.save(train_set, 'train_set.pt')
torch.save(val_set, 'val_set.pt')
torch.save(test_set, 'test_set.pt')

Creating full dataset...
Precomputing samples...


precomputing_samples:   0%|          | 0/2772 [00:00<?, ?it/s]

Computing class weights...
Splitting the dataset...
Precomputing samples...


precomputing_samples:   0%|          | 0/1940 [00:00<?, ?it/s]

Computing class weights...
Precomputing samples...


precomputing_samples:   0%|          | 0/415 [00:00<?, ?it/s]

Computing class weights...
Precomputing samples...


precomputing_samples:   0%|          | 0/417 [00:00<?, ?it/s]

Computing class weights...
unbalanced train dataset shape: (1940, [labels: torch.Size([]), features: [2000, 64]])
Applying SMOTE to train dataset...


  0%|          | 0/952 [00:00<?, ?it/s]

InvalidOperationError: conversion from `str` to `i32` failed in column 'marker' for 2000 out of 2000 values: ["Stimulus/A", "Stimulus/A", … "Stimulus/A"]

In [None]:

train_set = torch.load('train_set.pt', weights_only=False)
val_set = torch.load('val_set.pt', weights_only=False)
test_set = torch.load('test_set.pt', weights_only=False)


train_loader = DataLoader(
    train_set,
    batch_size=config['batch_size'],
    shuffle=True,
    num_workers=12,
    pin_memory=True,      # For GPU acceleration
    persistent_workers=True,
    collate_fn=collate_fn
)
val_loader = DataLoader(val_set, batch_size=config['batch_size'], collate_fn=collate_fn)
test_loader = DataLoader(test_set, batch_size=config['batch_size'], collate_fn=collate_fn)

len_dataset = len(train_set)
sample = train_set[0]
label_shape = sample[0].shape
feature_shape = sample[1].shape

print(f"train dataset shape: ({len_dataset}, [labels: {label_shape}, features: {list(feature_shape)}])")


# Set up logging
writer = SummaryWriter(log_dir=config['log_dir'])

train dataset shape: (1940, [labels: torch.Size([]), features: [2000, 64]])


In [None]:
# Start training
trained_model = train_model(config, train_set, train_loader, val_loader, writer)



  0%|          | 0/200 [00:00<?, ?it/s]

In [None]:
import torch
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

# pos_weight = torch.tensor([
#     train_set.class_weights['A'] / train_set.class_weights['P']
# ]).to(config['device'])
criterion = nn.BCEWithLogitsLoss()#pos_weight=pos_weight)

epoch = 1
# Assuming model, criterion, test_loader, device, writer, and epoch are already defined
# Instantiate your model
best_model = EEGMobileNet()  # Adjust parameters as needed

# Load the state dictionary
state_dict = torch.load(f"{config['log_dir']}/best_model.pth", map_location=config['device'])
best_model.load_state_dict(state_dict)

# Move model to the correct device
best_model = best_model.to(config['device'])

# Set model to evaluation mode
best_model.eval()

test_loss = 0
all_test_markers = []
all_test_predictions = []
with torch.no_grad():
    for markers, features in tqdm(test_loader):
        features = features.to(config['device'])
        markers = markers.to(config['device'])

        outputs = best_model(features)
        loss = criterion(outputs, markers)
        test_loss += loss.item()

        # Collect markers and predictions for metrics calculation
        all_test_markers.extend(markers.cpu().numpy().flatten())
        all_test_predictions.extend(torch.sigmoid(outputs).cpu().numpy().flatten())

test_loss /= len(test_loader)
# Calculate test metrics
test_accuracy = accuracy_score(all_test_markers, [1 if p > 0.5 else 0 for p in all_test_predictions])
test_precision = precision_score(all_test_markers, [1 if p > 0.5 else 0 for p in all_test_predictions])
test_recall = recall_score(all_test_markers, [1 if p > 0.5 else 0 for p in all_test_predictions])
test_f1 = f1_score(all_test_markers, [1 if p > 0.5 else 0 for p in all_test_predictions])
test_roc_auc = roc_auc_score(all_test_markers, all_test_predictions)

# Log test metrics to TensorBoard
writer.add_scalar('Metrics/test_accuracy', test_accuracy, epoch)
writer.add_scalar('Metrics/test_precision', test_precision, epoch)
writer.add_scalar('Metrics/test_recall', test_recall, epoch)
writer.add_scalar('Metrics/test_f1', test_f1, epoch)
writer.add_scalar('Metrics/test_roc_auc', test_roc_auc, epoch)

# Close the TensorBoard writer
writer.close()

  0%|          | 0/14 [00:00<?, ?it/s]

In [None]:
print(f"""
{test_accuracy=}
{test_precision=}
{test_recall=}
{test_f1=}
{test_roc_auc=}
"""
)


test_accuracy=0.5563549160671463
test_precision=0.5362903225806451
test_recall=0.6551724137931034
test_f1=0.5898004434589801
test_roc_auc=np.float64(0.6026656231296901)



In [None]:
from sklearn.metrics import f1_score
import numpy as np
best_threshold = 0.0
best_f1 = 0.0
thresholds = np.arange(0.1, 1.0, 0.01)

for threshold in tqdm(thresholds):
    binary_predictions = (all_test_predictions > threshold).astype(int)
    current_recall = f1_score(all_test_markers, binary_predictions)

    if current_recall > best_f1:
        best_f1 = current_recall
        best_threshold = threshold

print(f"{best_threshold=}")
print(f"{best_f1=}")

  0%|          | 0/90 [00:00<?, ?it/s]

best_threshold=np.float64(0.34999999999999987)
best_f1=0.6568265682656826


In [None]:
from sklearn.metrics import recall_score
import numpy as np
best_threshold = 0.1
best_recall = 0.0
thresholds = np.arange(0.1, 1.0, 0.01)

for threshold in tqdm(thresholds):
    binary_predictions = (all_test_predictions > threshold).astype(int)
    current_recall = recall_score(all_test_markers, binary_predictions)

    if current_recall > best_recall:
        best_recall = current_recall
        best_threshold = threshold

print(f"{best_threshold=}")
print(f"{best_recall=}")

  0%|          | 0/90 [00:00<?, ?it/s]

best_threshold=np.float64(0.1)
best_recall=0.9852216748768473
