In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.tensorboard import SummaryWriter

import polars as pl
from imblearn.over_sampling import SMOTE

from sklearn.model_selection import train_test_split

from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

import random
import numpy as np

from scipy.interpolate import CubicSpline

# Set the seed for Python's built-in random module
random.seed(69)

# Set the seed for NumPy's random number generator
np.random.seed(69)

# Set the seed for PyTorch's random number generators
torch.manual_seed(69)
torch.cuda.manual_seed(69)
torch.backends.cudnn.deterministic = True

In [2]:
pl.read_parquet('/home/owner/Documents/DEV/BrainLabyrinth/data/combined.parquet')\
    # .drop(['__null_dask_index__'])


event_id,marker,time,Fp1,Fpz,Fp2,F7,F3,Fz,F4,F8,FC5,FC1,FC2,FC6,M1,T7,C3,Cz,C4,T8,M2,CP5,CP1,CP2,CP6,P7,P3,Pz,P4,P8,POz,O1,O2,AF7,AF3,AF4,AF8,F5,F1,F2,F6,FC3,FCz,FC4,C5,C1,C2,C6,CP3,CP4,P5,P1,P2,P6,PO5,PO3,PO4,PO6,FT7,FT8,TP7,TP8,PO7,PO8,Oz,__null_dask_index__
i64,str,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,i64
0,"""Stimulus/1""",9.438,18.26099,16.979025,1.769803,2.441713,2.818692,10.922015,5.462004,-24.56427,16.242172,11.220099,11.027495,7.121442,-46.429313,-7.675472,16.945733,7.783233,20.514471,-26.908206,-47.461382,-4.487977,7.371627,2.944143,7.080501,-0.390081,-3.870911,16.039242,11.825064,0.929866,-0.176483,-10.09977,5.692728,-9.288844,3.758388,4.741033,-8.940138,5.21949,10.749776,7.172626,-17.153415,6.482526,13.326968,10.335974,1.308579,7.458288,4.62902,4.661003,10.387417,6.23994,-4.072595,10.350186,21.404569,5.531532,-5.54077,-4.60115,-10.072189,5.858149,-6.076038,-3.457559,-10.564211,-3.465435,-7.45234,5.723619,-10.401716,0
0,"""Stimulus/1""",9.44,18.279552,17.167071,1.736013,2.509046,3.320807,11.62726,6.56513,-23.024749,16.535206,11.428954,11.867022,8.841448,-45.932599,-8.458003,16.839472,7.635926,21.321651,-31.699151,-46.974629,-4.697016,7.125317,3.220561,7.522331,-0.681903,-4.617159,15.836581,11.514513,-0.039646,-0.730373,-10.741345,3.741726,-9.673243,4.203654,5.521279,-8.967844,5.960125,11.30632,7.854544,-16.257055,6.743387,13.844415,11.380757,1.33881,7.527234,4.991126,7.051428,10.11894,6.555342,-4.722692,9.938277,21.260168,4.689456,-6.292506,-5.309328,-10.643972,4.269458,-6.793052,-2.42154,-10.742051,-3.052358,-7.947506,4.051073,-10.757459,1
0,"""Stimulus/1""",9.442,18.187657,17.195618,1.650717,2.619466,3.746728,12.211854,7.602871,-21.453501,16.712602,11.625626,12.706299,10.281649,-45.249379,-9.02851,16.747652,7.526668,21.978811,-34.687629,-46.088848,-4.949806,6.908232,3.4634,7.742669,-0.509808,-5.343742,15.644091,11.098233,-1.040845,-1.317047,-11.424223,1.463857,-10.139782,4.563526,6.382203,-8.837637,6.864058,11.892868,8.451628,-15.265568,6.959286,14.266943,12.201225,1.263118,7.552298,5.227928,8.960626,9.860875,6.811119,-5.368829,9.622294,21.128975,3.794791,-7.027202,-5.970655,-11.33015,2.269126,-7.381371,-1.549136,-10.951647,-2.573051,-8.422673,1.975578,-11.568488,2
0,"""Stimulus/1""",9.444,17.988738,17.064795,1.536693,2.780392,4.108368,12.665638,8.47708,-19.98386,16.771917,11.816801,13.476138,11.291504,-44.341041,-9.36457,16.66188,7.476511,22.42124,-35.46427,-44.869256,-5.247474,6.718476,3.639658,7.712424,0.094348,-6.073221,15.46495,10.593378,-1.921139,-1.925938,-12.227119,-0.95602,-10.635339,4.843329,7.244236,-8.559468,7.826785,12.502754,8.935521,-14.259979,7.13315,14.590537,12.723916,1.071259,7.535559,5.311675,10.193048,9.607037,6.978706,-6.039169,9.398759,21.022386,2.90171,-7.784325,-6.628905,-12.088001,0.015087,-7.783466,-0.885222,-11.181192,-2.05307,-8.959252,-0.346699,-12.895379,3
0,"""Stimulus/1""",9.446,17.697904,16.795464,1.420282,2.984275,4.421709,12.995026,9.112953,-18.737067,16.720522,12.012921,14.115175,11.774442,-43.205799,-9.5169,16.576344,7.49871,22.614753,-33.835611,-43.398607,-5.583868,6.552261,3.723286,7.438797,1.040869,-6.831202,15.294334,10.025392,-2.545744,-2.555724,-13.217374,-3.333748,-11.107698,5.055388,8.033596,-8.161774,8.723206,13.129667,9.293896,-13.315702,7.273473,14.822615,12.913112,0.770779,7.485243,5.235299,10.630651,9.349522,7.039159,-6.762365,9.247924,20.94341,2.056106,-8.602303,-7.331279,-12.87647,-2.320403,-7.966721,-0.441055,-11.417038,-1.496988,-9.632474,-2.74458,-14.742146,4
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
1381,"""Stimulus/A""",928.476,-24.531312,-34.040322,-39.377348,-26.813586,-50.33518,-60.351025,4.800104,-8.681692,-22.494535,-34.812206,-23.985204,-18.962696,-70.486005,-15.036631,-48.358653,-36.797774,-37.581036,-10.546239,-113.35881,-13.174225,-36.488425,-15.568561,8.284046,-25.535649,-14.563875,-30.695258,-80.903069,-73.376497,-19.759668,-16.861085,-30.762578,-27.605438,-26.903408,-11.824485,-16.397532,-31.682237,-50.844306,4.70073,-13.917587,20.762969,-43.113357,-18.824065,-0.801963,-51.926852,-28.117319,-8.311726,-33.086039,-28.030511,-34.279603,-30.71269,-43.739573,-41.626016,-23.895303,-24.468659,-51.56217,-33.35497,-24.890898,-6.698645,-0.427032,2.81708,-23.857567,-33.41483,-28.665009,275996
1381,"""Stimulus/A""",928.478,-25.037469,-34.506067,-39.484176,-26.978022,-50.013997,-59.910298,5.519787,-8.487145,-22.16917,-34.808108,-22.969049,-18.345097,-70.520443,-14.521938,-48.600807,-36.721548,-36.762816,-11.203407,-113.115348,-12.976387,-36.500322,-15.216633,8.792874,-25.335829,-14.095602,-30.442555,-80.504525,-72.604594,-19.146617,-16.718005,-30.454802,-27.996757,-26.544687,-11.544485,-16.534841,-32.275165,-50.363906,5.478738,-13.681101,21.007286,-42.716912,-17.8001,-0.705913,-51.988347,-27.745902,-7.574538,-33.135223,-27.514517,-33.821243,-30.579896,-43.33259,-40.926079,-23.535917,-24.139827,-51.252728,-32.873982,-24.360851,-6.914038,0.18291,3.346472,-23.528204,-33.089643,-28.450162,275997
1381,"""Stimulus/A""",928.48,-25.620556,-35.009315,-39.657826,-26.904605,-49.721285,-59.499187,6.209261,-8.234116,-21.743517,-34.988863,-21.930894,-17.584589,-70.335446,-13.765635,-48.915458,-36.627762,-35.817168,-11.387478,-112.422118,-12.591521,-36.460628,-14.724547,9.497388,-24.645177,-13.541097,-30.280727,-80.109167,-71.66989,-18.517414,-16.08701,-29.877011,-28.460028,-26.272782,-11.284485,-16.734728,-32.675661,-49.959505,6.218403,-13.581815,20.960643,-42.455285,-16.764513,-0.610304,-52.093277,-27.316784,-6.802836,-33.153469,-26.83351,-33.055253,-30.375729,-42.902987,-39.930319,-22.712706,-23.350402,-50.835199,-32.165625,-23.772179,-7.21466,0.826389,4.131466,-22.842623,-32.538119,-27.981675,275998
1381,"""Stimulus/A""",928.482,-26.238952,-35.513318,-39.901309,-26.594463,-49.483993,-59.151322,6.795769,-7.967241,-21.255433,-35.329352,-20.969524,-16.78674,-69.982398,-12.855594,-49.282959,-36.526949,-34.840783,-11.073768,-111.400716,-12.068991,-36.399162,-14.130256,10.335449,-23.548158,-12.943701,-30.205676,-79.746777,-70.648907,-17.896494,-14.986795,-29.04357,-28.964238,-26.107885,-11.073393,-17.012841,-32.852155,-49.660259,6.842462,-13.651779,20.64128,-42.328665,-15.820084,-0.52463,-52.230964,-26.872209,-6.109631,-33.145934,-26.047849,-32.062022,-30.112813,-42.472899,-38.711372,-21.477316,-22.149748,-50.330373,-31.281584,-23.187269,-7.57299,1.419604,5.113047,-21.838447,-31.800604,-27.246317,275999


In [3]:
# #============================================================
# # Model Architecture
# #============================================================
# class EEGDSConv(nn.Module):
#     def __init__(self, dropout=0.5):
#         super().__init__()
#         self.block = nn.Sequential(
#             nn.Conv1d(64, 64, 15, padding='same', groups=64),
#             nn.Conv1d(64, 16, 1),
#             nn.BatchNorm1d(16),
#             nn.ReLU(),
#             nn.MaxPool1d(4),
#             nn.Dropout(dropout),
#             nn.Conv1d(16, 16, 7, padding='same', groups=16),
#             nn.Conv1d(16, 8, 1),
#             nn.BatchNorm1d(8),
#             nn.ReLU(),
#             nn.AdaptiveAvgPool1d(1),
#             nn.Flatten(),
#             nn.Linear(8, 1)
#         )
    
#     def forward(self, x):
#         x = x.permute(0, 2, 1)
#         return self.block(x).squeeze(-1)  # Squeeze last dimension to match target shape
    
    


In [4]:
class EEGMobileNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv1d(63, 32, 15, padding=7),
            nn.BatchNorm1d(32),
            nn.ReLU6(),
            
            DepthwiseSeparable(32, 64),
            DepthwiseSeparable(64, 128),
            DepthwiseSeparable(128, 256),
            
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten(),
            nn.Linear(256, 1)
        )

    def forward(self, x):
        x = x.permute(0, 2, 1)
        return self.model(x).squeeze(-1)

class DepthwiseSeparable(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.depthwise = nn.Conv1d(in_ch, in_ch, 15, 
                                 padding=7, groups=in_ch)
        self.pointwise = nn.Conv1d(in_ch, out_ch, 1)
        self.bn = nn.BatchNorm1d(out_ch)
        self.act = nn.ReLU6()
    
    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return self.act(self.bn(x))

In [5]:
import numpy as np
from tslearn.neighbors import KNeighborsTimeSeries
from numba import njit

# Standalone JIT-compiled function outside the class
@njit(fastmath=True)
def fast_interpolate(original, neighbor, alpha):
    """Numba-accelerated linear interpolation"""
    return (1 - alpha) * original + alpha * neighbor

class TSMOTE:
    def __init__(self, n_neighbors=3, time_slices=10):  # Keep original 2000 timesteps
        self.n_neighbors = n_neighbors
        self.time_slices = time_slices
        self.slice_size = 200  # 2000/10=200

    def _slice_time_series(self, X):
        """Split into time slices while maintaining full series structure"""
        return X.reshape(X.shape[0], self.time_slices, self.slice_size, X.shape[2])

    def _reconstruct_full_series(self, synthetic_slices):
        """Combine synthetic slices into full-length time series"""
        return synthetic_slices.reshape(-1, self.time_slices * self.slice_size, synthetic_slices.shape[-1])

    def _generate_synthetic(self, minority_samples):
        """Generate full-length synthetic samples"""
        sliced_data = self._slice_time_series(minority_samples)  # (N, slices, 200, ch)
        syn_samples = []
        
        # Generate 1 full synthetic sample per minority sample
        for sample_idx in tqdm(range(sliced_data.shape[0])):
            synthetic_slices = []
            
            # Process each time slice
            for slice_idx in range(self.time_slices):
                knn = KNeighborsTimeSeries(n_neighbors=self.n_neighbors, metric='dtw')
                knn.fit(sliced_data[:, slice_idx, :, :])
                
                # Find neighbors for this slice
                neighbors = knn.kneighbors(sliced_data[sample_idx, slice_idx][np.newaxis], 
                                         return_distance=False)[0]
                neighbor_idx = np.random.choice(neighbors)
                
                # Interpolate within slice
                alpha = np.random.uniform(0.2, 0.8)
                synthetic_slice = (1 - alpha) * sliced_data[sample_idx, slice_idx] + \
                                 alpha * sliced_data[neighbor_idx, slice_idx]
                synthetic_slices.append(synthetic_slice)
            
            # Combine slices into full series
            full_series = np.concatenate(synthetic_slices, axis=0)
            syn_samples.append(full_series)
        
        return np.array(syn_samples)

    def fit_resample(self, X, y):
        y_int = y.astype(int)
        class_counts = np.bincount(y_int)
        minority_class = np.argmin(class_counts)
        n_needed = class_counts[1 - minority_class] - class_counts[minority_class]
        
        if n_needed <= 0:
            return X, y
        
        minority_samples = X[y_int == minority_class]
        synthetic = self._generate_synthetic(minority_samples)
        
        # Ensure matching dimensions
        assert X.shape[1:] == synthetic.shape[1:], \
            f"Dimension mismatch: Original {X.shape[1:]}, Synthetic {synthetic.shape[1:]}"
        
        return (np.concatenate([X, synthetic[:n_needed]], axis=0),
                np.concatenate([y, [minority_class] * n_needed]))

In [6]:
#============================================================
# Enhanced Dataset Class with Proper Encapsulation
#============================================================
class EEGDataset(Dataset):
    def __init__(self, source, max_length=2000):
        self.df = self._load_and_filter(source)
        self.event_ids = self.df['event_id'].unique().to_list()
        self.max_length = max_length
        # Keep time for sorting but exclude from features
        self.feature_cols = [c for c in self.df.columns 
                           if c not in {'event_id', 'marker', 'time'}]
        self._precompute_samples()
        self._class_weights = self.compute_class_weights()
    
    @property
    def class_weights(self):
        # Expose the computed weights as a property.
        return self._class_weights 

    def __len__(self):
        return len(self.event_ids)

    def __getitem__(self, idx):
        return self.samples[idx]
    
    def _precompute_samples(self):
        """Cache time-ordered samples with revolutionary discipline"""
        self.samples = []
        for event_id in self.event_ids:
            # Sort by time within each event!
            event_data = self.df.filter(pl.col("event_id") == event_id).sort("time")
            
            features = torch.tensor(
                event_data.select(self.feature_cols).to_numpy(),
                dtype=torch.float32
            )
            features = self._pad_sequence(features)
            
            label = 1.0 if event_data['marker'][0] == "Stimulus/P" else 0.0
            self.samples.append((
                torch.tensor(label, dtype=torch.float32), 
                features
            ))
    
    def augment_dataset(self, n_times=5, **kwargs):
        new_event_id = self.df["event_id"].max() + 1
        original_count = len(self.event_ids)
        
        # Store original samples for mixup
        original_samples = [s[1].numpy() for s in self.samples]
        original_labels = [s[0].item() for s in self.samples]

        new_events = []
        for idx in range(original_count):
            base_features = original_samples[idx]
            base_label = original_labels[idx]
            
            # Generate N-1 augmented versions
            for _ in range(n_times-1):
                # Apply augmentation
                aug_features = self._apply_random_augmentation(
                    base_features, 
                    mixup_samples=original_samples,
                    mixup_labels=original_labels,
                    **kwargs
                )
                
                # Create event with unique ID
                event_data = self._create_augmented_event(
                    aug_features, new_event_id, base_label
                )
                new_events.append(event_data)
                new_event_id += 1 

        # Add mixup combinations
        mixup_events = self._generate_mixup_combinations(
            original_samples, original_labels,
            n_combinations=original_count//2,
            n_times=n_times,
            **kwargs
        )
        new_events += mixup_events

        self.df = pl.concat([self.df, *new_events])
        self.event_ids = self.df['event_id'].unique().to_list()
        self._precompute_samples()
        return self

    def _apply_random_augmentation(self, features, **kwargs):
        aug_type = np.random.choice([
            lambda x: self._gaussian_noise(x, kwargs['noise_std']),
            lambda x: self._amplitude_scale(x, kwargs['scale_range']),
            lambda x: self._time_warp(x, kwargs['warp_range']),
            lambda x: self._channel_shift(x, kwargs['max_shift']),
            lambda x: self._frequency_warp(x, kwargs['freq_shift']),
            lambda x: self._time_mask(x, kwargs['mask_size']),
            lambda x: self._channel_dropout(x, kwargs['drop_prob']),
            lambda x: self._mixup(x, kwargs['mixup_samples'], 
                                kwargs['mixup_labels'], 
                                kwargs['mixup_alpha'])
        ])
        return aug_type(features)

    # === Core Augmentations ===
    def _gaussian_noise(self, features, noise_std=0.1, **kwargs):
        noise = np.random.normal(0, noise_std*np.std(features), features.shape)
        return features + noise

    def _amplitude_scale(self, features, scale_range=(0.8, 1.2), **kwargs):
        return features * np.random.uniform(*scale_range)

    def _time_warp(self, features, warp_range=(0.8, 1.2)):

        orig_length = features.shape[0]
        warp_factor = np.random.uniform(*warp_range)
        
        # People's interpolation ensuring max_length compliance
        x_original = np.linspace(0, 1, orig_length)
        x_warped = np.linspace(0, 1, self.max_length)  # Always output max_length
        
        # Proletarian cubic spline interpolation
        warped_features = np.array([
            CubicSpline(x_original, channel)(x_warped)
            for channel in features.T
        ]).T
        
        return warped_features

    def _channel_shift(self, features, max_shift=10, **kwargs):
        return np.roll(features, np.random.randint(-max_shift, max_shift), axis=0)

    def _frequency_warp(self, features, freq_shift=2, **kwargs):
        f = np.fft.fft(features, axis=0)
        shifted = np.roll(f, np.random.randint(-freq_shift, freq_shift), axis=0)
        return np.real(np.fft.ifft(shifted))

    def _time_mask(self, features, mask_size=50, **kwargs):
        start = np.random.randint(0, len(features)-mask_size)
        features[start:start+mask_size] *= np.hanning(mask_size)[:,None]
        return features

    def _channel_dropout(self, features, drop_prob=0.1, **kwargs):
        mask = np.random.rand(features.shape[1]) > drop_prob
        return features * mask

    def _mixup(self, features, all_samples, all_labels, alpha=0.4, **kwargs):
        idx = np.random.randint(0, len(all_samples))
        lam = np.random.beta(alpha, alpha)
        return lam*features + (1-lam)*all_samples[idx]

    def _generate_mixup_combinations(self, samples, labels, n_times, n_combinations=1000, alpha=0.4,**kwargs):
        mixup_events = []
        new_event_id = self.df['event_id'].max() + 1 + len(samples)*(n_times-1)
        
        for _ in range(n_combinations):
            idx1, idx2 = np.random.choice(len(samples), 2, replace=False)
            lam = np.random.beta(alpha, alpha)
            
            mixed_features = lam*samples[idx1] + (1-lam)*samples[idx2]
            mixed_label = lam*labels[idx1] + (1-lam)*labels[idx2]
            
            event_data = self._create_augmented_event(
                mixed_features,
                new_event_id,
                mixed_label
            )
            mixup_events.append(event_data)
            new_event_id += 1
            
        return mixup_events

    def _create_augmented_event(self, features, event_id, label):
        # Ensure features match max_length
        features_tensor = torch.tensor(features, dtype=torch.float32)
        padded_features = self._pad_sequence(features_tensor).numpy()

        event_data = {
            "event_id": [event_id] * self.max_length,
            "marker": ["Stimulus/P" if label else "Stimulus/A"] * self.max_length,
            "time": np.arange(self.max_length)  # This creates Int64 values by default
        }

        # Revolutionary dtype conversion map
        dtype_map = {
            pl.Float64: np.float64,
            pl.Float32: np.float32,
            pl.Int64: np.int64,
            pl.Int32: np.int32,
            pl.Utf8: str
        }

        # Convert non-feature columns to the expected types
        for key in ["event_id", "marker", "time"]:
            expected_dtype = self.df.schema[key]
            # Look up the target type based on the expected polars dtype.
            target_type = dtype_map.get(type(expected_dtype), np.float64)
            event_data[key] = np.array(event_data[key]).astype(target_type)

        # Now handle the feature columns using the same logic as before.
        for col_idx, col in enumerate(self.feature_cols):
            dtype = self.df.schema[col]
            if isinstance(dtype, pl.List):
                base_type = dtype.inner
                target_type = dtype_map.get(type(base_type), np.float64)
            else:
                target_type = dtype_map.get(type(dtype), np.float64)
            event_data[col] = padded_features[:, col_idx].astype(target_type)

        return pl.DataFrame(event_data)

    
    def compute_class_weights(self):
        """
        Compute inverse frequency weights based on the 'marker' column.
        Assumes markers are "Stimulus/A" and "Stimulus/P".
        """
        # Get unique combinations of event_id and marker.
        unique_events = self.df.select(["event_id", "marker"]).unique()
        
        # Use value_counts on the "marker" column.
        counts_df = unique_events["marker"].value_counts()

        # We'll use 'values' if it exists, otherwise 'marker'.
        d = { (row.get("values") or row.get("marker")): row["count"] 
            for row in counts_df.to_dicts() }
        
        weight_A = 1.0 / d.get("Stimulus/A", 1)
        weight_P = 1.0 / d.get("Stimulus/P", 1)
        return {"A": weight_A, "P": weight_P}
   
    def split_dataset(self, ratios=(0.7, 0.15, 0.15), seed=None):
        """
        Splits the dataset into three EEGDataset instances for train, val, and test.
        This method shuffles the event_ids and then partitions them based on the given ratios.
        """
        if seed is not None:
            np.random.seed(seed)
        
        # Copy and shuffle the event_ids
        event_ids = self.event_ids.copy()
        np.random.shuffle(event_ids)
        total = len(event_ids)
        
        n_train = int(ratios[0] * total)
        n_val   = int(ratios[1] * total)
        
        train_ids = event_ids[:n_train]
        val_ids   = event_ids[n_train:n_train+n_val]
        test_ids  = event_ids[n_train+n_val:]
        
        # Filter self.df for the selected event_ids
        train_df = self.df.filter(pl.col("event_id").is_in(train_ids))
        val_df   = self.df.filter(pl.col("event_id").is_in(val_ids))
        test_df  = self.df.filter(pl.col("event_id").is_in(test_ids))
        
        # Create new EEGDataset instances using the filtered data
        train_set = EEGDataset(train_df, self.max_length)
        val_set   = EEGDataset(val_df, self.max_length)
        test_set  = EEGDataset(test_df, self.max_length)
        
        return train_set, val_set, test_set
    
    def _load_and_filter(self, source):
        df = pl.read_parquet(source) if isinstance(source, str) else source
        df = df.filter(pl.col('marker').is_in(["Stimulus/A", "Stimulus/P"]))
        if '__null_dask_index__' in df.columns:
            df = df.drop('__null_dask_index__')
        return df

    def _pad_sequence(self, tensor):
        # Pre-allocate tensor for maximum efficiency
        padded = torch.zeros((self.max_length, tensor.size(1)), dtype=tensor.dtype)
        length = min(tensor.size(0), self.max_length)
        padded[:length] = tensor[:length]
        return padded
    
    def rebalance_by_tsmote(self):
        """TSMOTE implementation for temporal EEG data"""
        # Extract time-ordered features as 3D array (samples, timesteps, features)
        X = np.stack([features.numpy() for _, features in self.samples])
        y = np.array([label.item() for label, _ in self.samples])
        
        # Apply TSMOTE with temporal awareness
        tsmote = TSMOTE()
        X_res, y_res = tsmote.fit_resample(X, y)

        # Generate synthetic temporal events
        new_events = []
        new_event_id = self.df['event_id'].max() + 1
        time_base = np.arange(self.max_length)
        original_schema = self.df.schema

        # Create dtype conversion map
        dtype_map = {
            pl.Float64: np.float64,
            pl.Float32: np.float32,
            pl.Int64: np.int64,
            pl.Int32: np.int32,
            pl.Utf8: str,
        }

        # Process synthetic samples (original samples come first in X_res)
        for features_3d, label in zip(X_res[len(self.samples):], y_res[len(self.samples):]):
            event_data = {
                "event_id": [new_event_id] * self.max_length,
                "marker": ["Stimulus/P" if label > 0.5 else "Stimulus/A"] * self.max_length,
                "time": time_base.copy()
            }
            
            # Add features with proper temporal structure
            for col_idx, col in enumerate(self.feature_cols):
                col_data = features_3d[:, col_idx]
                schema_type = original_schema[col]
                
                # Handle data types
                if isinstance(schema_type, pl.List):
                    base_type = schema_type.inner
                    target_type = dtype_map.get(type(base_type), np.float64)
                else:
                    target_type = dtype_map.get(type(schema_type), np.float64)
                
                col_data = col_data.astype(target_type)
                
                # Maintain integer precision
                if schema_type in (pl.Int64, pl.Int32):
                    col_data = np.round(col_data).astype(int)
                
                event_data[col] = col_data

            # Create DataFrame with strict schema adherence
            event_df = pl.DataFrame(event_data).cast(original_schema)
            new_events.append(event_df)
            new_event_id += 1

        # Update dataset with synthetic temporal events
        self.df = pl.concat([self.df, *new_events])
        self.event_ids = self.df['event_id'].unique().to_list()
        self._precompute_samples()
        self._class_weights = self.compute_class_weights()
        return self


In [7]:
def collate_fn(batch):
    """
    Collate function for variable-length EEG feature sequences.

    Each sample is expected to be a tuple (label, feature), where:
    - label is a scalar tensor (or 1D tensor) representing the class/target.
    - feature is a tensor of shape (seq_len, num_channels), where seq_len may vary.

    This function stacks labels and pads features along the time dimension so that
    all sequences in the batch have the same length.
    """
    # Unzip the batch into labels and features
    labels, features = zip(*batch)
    
    labels = torch.stack(labels)
    # Optionally: labels = labels.unsqueeze(1)  # Uncomment if required by your loss function
    padded_features = pad_sequence(features, batch_first=True)
    
    return labels, padded_features


def train_model(config, train_set, train_loader, val_loader, writer):    
    # Model initialization
    model = EEGMobileNet().to(config['device'])
    #EEGDSConv(dropout=config['dropout']).to(config['device'])
    
    # Log model architecture and config
    writer.add_text("Model/Type", f"EEGDSConv with dropout={config['dropout']}")
    writer.add_text("Model/Structure", str(model))
    writer.add_text("Training Config", str(config))

    # Loss function and optimizer
    # pos_weight = torch.tensor([
    #     train_set.class_weights['A'] / train_set.class_weights['P']
    # ]).to(config['device'])
    criterion = nn.BCEWithLogitsLoss()#pos_weight=pos_weight)
    
    optimizer = optim.AdamW(
        model.parameters(),
        lr=config['lr'],
        weight_decay=config['weight_decay']
    )
    
    # Learning rate schedulers
    scheduler = ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.5,
        patience=10,
        verbose=True
    )
    
    warmup_scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer,
        lambda epoch: min(1.0, (epoch + 1) / config['warmup_epochs'])
    )
    
    # Training loop
    best_metric = float('inf')
    
    for epoch in tqdm(range(config['epochs'])):
        # Training phase
        model.train()
        train_loss = 0
        for labels, features in train_loader:
            features = features.to(config['device'])
            labels = labels.to(config['device'])
            
            optimizer.zero_grad()
            outputs = model(features)
            loss = criterion(outputs, labels)
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(
                model.parameters(), 
                config['grad_clip']
            )
            
            optimizer.step()
            train_loss += loss.item()
        
        # Validation phase
        model.eval()
        val_loss = 0
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for labels, features in val_loader:
                features = features.to(config['device'])
                labels = labels.to(config['device'])
                
                outputs = model(features)
                val_loss += criterion(outputs, labels).item()
                
                preds = torch.sigmoid(outputs)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        
        # Calculate metrics
        train_loss /= len(train_loader)
        val_loss /= len(val_loader)
        predictions = (np.array(all_preds) > 0.5).astype(int)
        
        # Get current learning rate
        current_lr = optimizer.param_groups[0]['lr']
        
        # Update schedulers
        if epoch < config['warmup_epochs']:
            warmup_scheduler.step()
        else:
            scheduler.step(val_loss)
        
        # Log metrics and learning rate
        metrics = {
            'accuracy': accuracy_score(all_labels, predictions),
            'precision': precision_score(all_labels, predictions),
            'recall': recall_score(all_labels, predictions),
            'f1': f1_score(all_labels, predictions)
        }
        
        writer.add_scalar('LR', current_lr, epoch)
        writer.add_scalar('Loss/Train', train_loss, epoch)
        writer.add_scalar('Loss/Val', val_loss, epoch)
        writer.add_scalar('Accuracy', metrics['accuracy'], epoch)
        writer.add_scalar('Precision', metrics['precision'], epoch)
        writer.add_scalar('Recall', metrics['recall'], epoch)
        writer.add_scalar('F1', metrics['f1'], epoch)
        writer.add_scalars('Metrics', metrics, epoch)
        
        # Save best model
        if val_loss < best_metric:
            best_metric = val_loss
            torch.save(model.state_dict(), f"{config['log_dir']}/best_model.pth")
    
    writer.close()
    return model

In [8]:
# Configuration
config = {
    'data_path': '/home/owner/Documents/DEV/BrainLabyrinth/data/combined.parquet',
    'split_ratios': (0.7, 0.15, 0.15),
    'batch_size': 128,          # Increased for better generalization
    'dropout': 0.6,            # Reduced from 0.6 for better information flow
    'lr': 7e-5,                # Base learning rate (sweet spot between 1e-5 and 3e-3)
    'weight_decay': 1e-5,      # Increased regularization
    'epochs': 2000,
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    'log_dir': './runs/CNN',
    'lr_scheduler': {
        'mode': 'min',
        'factor': 0.1,         # More aggressive LR reduction
        'patience': 5,         # Faster response to plateaus
        'threshold': 0.001,
        'cooldown': 3
    },
    'grad_clip': 1.0,          # Add gradient clipping
    'warmup_epochs': 10        # Linear LR warmup
}


In [9]:
# #============================================================
# # Training Pipeline
# #============================================================
# import warnings
# warnings.filterwarnings("ignore", category=FutureWarning)

# # Initialize dataset
# print("Creating full dataset...")
# full_dataset = EEGDataset(config['data_path'])

# print("Splitting the dataset...")
# # Split dataset
# train_set, val_set, test_set = full_dataset.split_dataset(
#     ratios=config['split_ratios']
# )

# del full_dataset

# len_dataset = len(train_set)
# sample = train_set[0]
# label_shape = sample[0].shape
# feature_shape = sample[1].shape

# print(f"unbalanced train dataset shape: ({len_dataset}, [labels: {label_shape}, features: {list(feature_shape)}])")

# # Balance training set
# print("Applying SMOTE to train dataset...")
# train_set.rebalance_by_tsmote()

# # print("Augmenting train dataset...")
# # train_set.augment_dataset(
# #     n_times=3,              # Nx dataset expansion
# #     noise_std=0.15,         # Moderate noise
# #     scale_range=(0.7, 1.3), # ±30% amplitude variation
# #     warp_range=(0.85, 1.15),# ±15% time warping
# #     max_shift=15,           # 150ms temporal shifts
# #     freq_shift=3,           # ±3Hz frequency shifts
# #     mask_size=75,           # 750ms masking
# #     drop_prob=0.15,         # 15% channel dropout
# #     mixup_alpha=0.2         # Mixup
# # )

# torch.save(train_set, 'train_set.pt')
# torch.save(val_set, 'val_set.pt')
# torch.save(test_set, 'test_set.pt')

In [10]:

train_set = torch.load('train_set.pt')
val_set = torch.load('val_set.pt')
test_set = torch.load('test_set.pt')


train_loader = DataLoader(
    train_set,
    batch_size=config['batch_size'],
    shuffle=True,
    num_workers=12,
    pin_memory=True,      # For GPU acceleration
    persistent_workers=True,
    collate_fn=collate_fn
)
val_loader = DataLoader(val_set, batch_size=config['batch_size'], collate_fn=collate_fn)
test_loader = DataLoader(test_set, batch_size=config['batch_size'], collate_fn=collate_fn)

len_dataset = len(train_set)
sample = train_set[0]
label_shape = sample[0].shape
feature_shape = sample[1].shape

print(f"train dataset shape: ({len_dataset}, [labels: {label_shape}, features: {list(feature_shape)}])")


# Set up logging
writer = SummaryWriter(log_dir=config['log_dir'])

# Start training
trained_model = train_model(config, train_set, train_loader, val_loader, writer)

  train_set = torch.load('train_set.pt')
  val_set = torch.load('val_set.pt')
  test_set = torch.load('test_set.pt')


train dataset shape: (2549, [labels: torch.Size([]), features: [2000, 63]])




  0%|          | 0/2000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
import torch
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

# pos_weight = torch.tensor([
#     train_set.class_weights['A'] / train_set.class_weights['P']
# ]).to(config['device'])
criterion = nn.BCEWithLogitsLoss()#pos_weight=pos_weight)

epoch = 1
# Assuming model, criterion, test_loader, device, writer, and epoch are already defined
best_model = torch.load('best_model.torch')
best_model.eval()
test_loss = 0
all_test_markers = []
all_test_predictions = []
with torch.no_grad():
    for markers, features in tqdm(test_loader):
        features = features.to(config['device'])
        markers = markers.unsqueeze(-1).to(config['device'])

        outputs = best_model(features)
        loss = criterion(outputs, markers)
        test_loss += loss.item()

        # Collect markers and predictions for metrics calculation
        all_test_markers.extend(markers.cpu().numpy().flatten())
        all_test_predictions.extend(torch.sigmoid(outputs).cpu().numpy().flatten())

test_loss /= len(test_loader)
# Calculate test metrics
test_accuracy = accuracy_score(all_test_markers, [1 if p > 0.5 else 0 for p in all_test_predictions])
test_precision = precision_score(all_test_markers, [1 if p > 0.5 else 0 for p in all_test_predictions])
test_recall = recall_score(all_test_markers, [1 if p > 0.5 else 0 for p in all_test_predictions])
test_f1 = f1_score(all_test_markers, [1 if p > 0.5 else 0 for p in all_test_predictions])
test_roc_auc = roc_auc_score(all_test_markers, all_test_predictions)

# Log test metrics to TensorBoard
writer.add_scalar('Metrics/test_accuracy', test_accuracy, epoch)
writer.add_scalar('Metrics/test_precision', test_precision, epoch)
writer.add_scalar('Metrics/test_recall', test_recall, epoch)
writer.add_scalar('Metrics/test_f1', test_f1, epoch)
writer.add_scalar('Metrics/test_roc_auc', test_roc_auc, epoch)

# Close the TensorBoard writer
writer.close()

In [None]:
print(f"""
{test_accuracy=}
{test_precision=}
{test_recall=}
{test_f1=}
{test_roc_auc=}
"""
)

In [None]:
from sklearn.metrics import f1_score
import numpy as np
best_threshold = 0.0
best_f1 = 0.0
thresholds = np.arange(0.0, 1.0, 0.01)

for threshold in tqdm(thresholds):
    binary_predictions = (all_test_predictions > threshold).astype(int)
    current_recall = f1_score(all_test_markers, binary_predictions)

    if current_recall > best_f1:
        best_f1 = current_recall
        best_threshold = threshold

print(f"{best_threshold=}")
print(f"{best_f1=}")

In [None]:
from sklearn.metrics import recall_score
import numpy as np
best_threshold = 0.1
best_recall = 0.0
thresholds = np.arange(0.1, 1.0, 0.01)

for threshold in tqdm(thresholds):
    binary_predictions = (all_test_predictions > threshold).astype(int)
    current_recall = recall_score(all_test_markers, binary_predictions)

    if current_recall > best_recall:
        best_recall = current_recall
        best_threshold = threshold

print(f"{best_threshold=}")
print(f"{best_recall=}")