# Pipeline 1: Baseline Methods Evaluation

This notebook trains and evaluates 5 baseline methods on PhysioNet Motor Imagery:
1. **FBCSP** - Filter Bank Common Spatial Patterns
2. **CNN-SAE** - CNN with Spatial Attention
3. **EEGNet** - Compact temporal convolutional network
4. **ACS-SE-CNN** - Adaptive Channel Selection SE-CNN
5. **G-CARM** - Graph Channel Active Reasoning Module

**Plus: Retention Analysis**
- Tests how baseline methods (EEGNet) perform with reduced channels
- Uses variance-based channel selection (simple baseline)
- k-values: [10, 15, 20, 25, 30, 35]
- Provides comparison baseline for Pipeline 2 (adaptive gating)

**Training Configuration:**
- Epochs: 30 (with early stopping patience=5)
- Learning rate: 0.002
- Batch size: 64
- 3-fold cross-validation

**Expected Runtime:** ~7-8 hours on Kaggle GPU
- Baseline training: ~3-4 hours
- Retention analysis: ~3.5 hours (6 k-values Ã— 3 folds)

**Outputs:**
- `models/baseline_*.pt` - Trained model checkpoints
- `results/baseline_methods_results.csv` - Complete metrics for all models
- `results/baseline_methods_summary.csv` - Summary statistics
- `results/baseline_retention_analysis.csv` - Retention curve data (for comparison with Pipeline 2)

## 1. Setup and Configuration

In [None]:
import os
import gc
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import mne
from sklearn.model_selection import StratifiedKFold
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
import pickle
from copy import deepcopy
import warnings
warnings.filterwarnings('ignore')
mne.set_log_level('ERROR')

print("All imports successful!")

In [None]:
# Configuration
CONFIG = {
    'data_path': '/kaggle/input/eeg-preprocessed-data/derived',
    'models_dir': './models',
    'results_dir': './results',
    
    'n_folds': 3,
    'random_seed': 42,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    
    # Training hyperparameters
    'batch_size': 64,
    'epochs': 30,
    'learning_rate': 0.002,
    'weight_decay': 1e-4,
    'patience': 5,
    'scheduler_patience': 2,
    'scheduler_factor': 0.5,
    'use_early_stopping': True,
    'min_lr': 1e-6,
    
    # Data parameters
    'n_channels': 64,
    'n_classes': 2,
    'sfreq': 128,
    'tmin': 0.0,
    'tmax': 4.0,
    'n_timepoints': 513,
    'mi_runs': [7, 8, 11, 12],
    
    # FBCSP parameters
    'fbcsp_bands': [(4, 8), (8, 12), (12, 16), (16, 20), (20, 24), (24, 28), (28, 32), (32, 36), (36, 40)],
    'fbcsp_n_components': 4,
}

os.makedirs(CONFIG['models_dir'], exist_ok=True)
os.makedirs(CONFIG['results_dir'], exist_ok=True)

np.random.seed(CONFIG['random_seed'])
torch.manual_seed(CONFIG['random_seed'])
if torch.cuda.is_available():
    torch.cuda.manual_seed(CONFIG['random_seed'])

print(f"Device: {CONFIG['device']}")
print(f"Epochs: {CONFIG['epochs']}")
print(f"Learning rate: {CONFIG['learning_rate']}")
print(f"Early stopping patience: {CONFIG['patience']}")

# Estimate runtime
n_models = 5
total_runs = n_models * CONFIG['n_folds']
print(f"\nEstimated training runs: {total_runs}")
print(f"Estimated runtime (~12 min/run): {total_runs * 12 / 60:.1f} hours")

## 2. Data Loading

In [None]:
def load_physionet_data(data_path):
    """Load preprocessed PhysioNet data."""
    data_root = os.path.abspath(data_path)
    if not os.path.isdir(data_root):
        raise FileNotFoundError(f"Data path not found: {data_root}")

    tmin = CONFIG['tmin']
    tmax = CONFIG['tmax']
    mi_runs = CONFIG['mi_runs']
    event_id = {'T1': 1, 'T2': 2}
    label_map = {1: 0, 2: 1}

    preprocessed_dir = os.path.join(data_root, 'preprocessed')
    if os.path.isdir(preprocessed_dir):
        data_root = preprocessed_dir
    
    subject_dirs = [d for d in sorted(os.listdir(data_root))
                    if os.path.isdir(os.path.join(data_root, d)) and d.upper().startswith('S')]

    all_X, all_y, all_subjects = [], [], []
    
    print(f"Loading data from {len(subject_dirs)} subjects...")
    
    for subject_dir in subject_dirs:
        subject_num = int(subject_dir[1:]) if len(subject_dir) > 1 else -1
        subject_path = os.path.join(data_root, subject_dir)
        
        for run_id in mi_runs:
            run_file = f"{subject_dir}R{run_id:02d}_preproc_raw.fif"
            run_path = os.path.join(subject_path, run_file)
            
            if not os.path.exists(run_path):
                continue
            
            try:
                raw = mne.io.read_raw_fif(run_path, preload=True, verbose=False)
                picks = mne.pick_types(raw.info, eeg=True, meg=False, stim=False, eog=False)
                if len(picks) == 0:
                    continue
                
                events, _ = mne.events_from_annotations(raw, event_id=event_id)
                if len(events) == 0:
                    continue
                
                epochs = mne.Epochs(raw, events, event_id=event_id, tmin=tmin, tmax=tmax,
                                    baseline=None, preload=True, picks=picks, verbose=False)
                
                data = epochs.get_data()
                labels = np.array([label_map.get(epochs.events[i, 2], -1) for i in range(len(epochs))])
                valid = labels >= 0
                
                if np.any(valid):
                    all_X.append(data[valid])
                    all_y.append(labels[valid])
                    all_subjects.append(np.full(np.sum(valid), subject_num))
            except Exception as e:
                continue
    
    X = np.concatenate(all_X, axis=0)
    y = np.concatenate(all_y, axis=0)
    subjects = np.concatenate(all_subjects, axis=0)
    
    print(f"Loaded {len(X)} trials from {len(np.unique(subjects))} subjects")
    print(f"Data shape: {X.shape}")
    print(f"Labels: {np.bincount(y)}")
    
    return X, y, subjects


class EEGDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

## 3. Model Architectures

In [None]:
# FBCSP
class FBCSP:
    def __init__(self, freq_bands, n_components=4, sfreq=128):
        self.freq_bands = freq_bands
        self.n_components = n_components
        self.sfreq = sfreq
        self.csp_list = []
        self.classifier = None
    
    def fit(self, X, y):
        from mne.decoding import CSP
        all_features = []
        
        for low, high in self.freq_bands:
            X_filtered = self._bandpass_filter(X, low, high)
            csp = CSP(n_components=self.n_components, reg=None, log=True, norm_trace=False)
            features = csp.fit_transform(X_filtered, y)
            self.csp_list.append(csp)
            all_features.append(features)
        
        all_features = np.concatenate(all_features, axis=1)
        self.classifier = LinearDiscriminantAnalysis()
        self.classifier.fit(all_features, y)
        return self
    
    def predict(self, X):
        all_features = []
        for idx, (low, high) in enumerate(self.freq_bands):
            X_filtered = self._bandpass_filter(X, low, high)
            features = self.csp_list[idx].transform(X_filtered)
            all_features.append(features)
        all_features = np.concatenate(all_features, axis=1)
        return self.classifier.predict(all_features)
    
    def score(self, X, y):
        return np.mean(self.predict(X) == y)
    
    def _bandpass_filter(self, X, low, high):
        from scipy.signal import butter, filtfilt
        nyq = self.sfreq / 2
        b, a = butter(4, [low / nyq, high / nyq], btype='band')
        X_filtered = np.zeros_like(X)
        for i in range(X.shape[0]):
            for j in range(X.shape[1]):
                X_filtered[i, j, :] = filtfilt(b, a, X[i, j, :])
        return X_filtered

In [None]:
# CNN-SAE
class SpatialAttention(nn.Module):
    def __init__(self, n_channels):
        super().__init__()
        self.attention = nn.Sequential(
            nn.Linear(n_channels, n_channels // 4),
            nn.ReLU(),
            nn.Linear(n_channels // 4, n_channels),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        pooled = torch.mean(x, dim=2)
        weights = self.attention(pooled)
        return x * weights.unsqueeze(2)


class CNNSAE(nn.Module):
    def __init__(self, n_channels=64, n_classes=2, n_timepoints=513):
        super().__init__()
        self.spatial_attention = SpatialAttention(n_channels)
        self.conv1 = nn.Conv1d(n_channels, 64, kernel_size=5, padding=2)
        self.bn1 = nn.BatchNorm1d(64)
        self.pool1 = nn.MaxPool1d(2)
        self.conv2 = nn.Conv1d(64, 128, kernel_size=5, padding=2)
        self.bn2 = nn.BatchNorm1d(128)
        self.pool2 = nn.MaxPool1d(2)
        self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm1d(256)
        self.pool3 = nn.MaxPool1d(2)
        self.dropout = nn.Dropout(0.5)
        
        with torch.no_grad():
            test_input = torch.zeros(1, n_channels, n_timepoints)
            test_output = self._forward_features(test_input)
            flattened_size = test_output.view(1, -1).size(1)
        
        self.fc1 = nn.Linear(flattened_size, 256)
        self.fc2 = nn.Linear(256, n_classes)
    
    def _forward_features(self, x):
        x = self.spatial_attention(x)
        x = self.pool1(F.relu(self.bn1(self.conv1(x))))
        x = self.pool2(F.relu(self.bn2(self.conv2(x))))
        x = self.pool3(F.relu(self.bn3(self.conv3(x))))
        return x
    
    def forward(self, x):
        x = self._forward_features(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2(x)

In [None]:
# EEGNet
class EEGNet(nn.Module):
    def __init__(self, n_channels=64, n_classes=2, n_timepoints=513, F1=8, D=2, F2=16):
        super().__init__()
        self.conv1 = nn.Conv2d(1, F1, (1, 64), padding=(0, 32), bias=False)
        self.bn1 = nn.BatchNorm2d(F1)
        self.conv2 = nn.Conv2d(F1, F1 * D, (n_channels, 1), groups=F1, bias=False)
        self.bn2 = nn.BatchNorm2d(F1 * D)
        self.pool1 = nn.AvgPool2d((1, 4))
        self.dropout1 = nn.Dropout(0.5)
        self.conv3 = nn.Conv2d(F1 * D, F2, (1, 16), padding=(0, 8), bias=False)
        self.bn3 = nn.BatchNorm2d(F2)
        self.pool2 = nn.AvgPool2d((1, 8))
        self.dropout2 = nn.Dropout(0.5)
        
        with torch.no_grad():
            test_input = torch.zeros(1, 1, n_channels, n_timepoints)
            test_output = self._forward_features(test_input)
            flattened_size = test_output.view(1, -1).size(1)
        
        self.fc = nn.Linear(flattened_size, n_classes)

    def _forward_features(self, x):
        x = self.bn1(self.conv1(x))
        x = self.dropout1(self.pool1(F.elu(self.bn2(self.conv2(x)))))
        x = self.dropout2(self.pool2(F.elu(self.bn3(self.conv3(x)))))
        return x

    def forward(self, x):
        x = x.unsqueeze(1)
        x = self._forward_features(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

In [None]:
# ACS-SE-CNN
class SEBlock(nn.Module):
    def __init__(self, channels, reduction=4):
        super().__init__()
        self.fc1 = nn.Linear(channels, max(1, channels // reduction))
        self.fc2 = nn.Linear(max(1, channels // reduction), channels)
    
    def forward(self, x):
        squeeze = torch.mean(x, dim=2)
        excitation = F.relu(self.fc1(squeeze))
        excitation = torch.sigmoid(self.fc2(excitation))
        return x * excitation.unsqueeze(2)


class ACSECNN(nn.Module):
    def __init__(self, n_channels=64, n_classes=2, n_timepoints=513):
        super().__init__()
        self.channel_attention = nn.Sequential(
            nn.Linear(n_timepoints, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
        self.se1 = SEBlock(n_channels)
        self.se2 = SEBlock(128)
        self.se3 = SEBlock(256)
        self.conv1 = nn.Conv1d(n_channels, 128, kernel_size=5, padding=2)
        self.bn1 = nn.BatchNorm1d(128)
        self.pool1 = nn.MaxPool1d(2)
        self.conv2 = nn.Conv1d(128, 256, kernel_size=5, padding=2)
        self.bn2 = nn.BatchNorm1d(256)
        self.pool2 = nn.MaxPool1d(2)
        self.conv3 = nn.Conv1d(256, 512, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm1d(512)
        self.pool3 = nn.MaxPool1d(2)
        self.dropout = nn.Dropout(0.5)
        
        with torch.no_grad():
            test_input = torch.zeros(1, n_channels, n_timepoints)
            test_output = self._forward_features(test_input)
            flattened_size = test_output.view(1, -1).size(1)
        
        self.fc1 = nn.Linear(flattened_size, 256)
        self.fc2 = nn.Linear(256, n_classes)
    
    def _forward_features(self, x):
        channel_weights = []
        for i in range(x.size(1)):
            w = self.channel_attention(x[:, i, :])
            channel_weights.append(w)
        channel_weights = torch.cat(channel_weights, dim=1)
        x = x * channel_weights.unsqueeze(2)
        x = self.se1(x)
        x = self.pool1(F.relu(self.bn1(self.conv1(x))))
        x = self.se2(x)
        x = self.pool2(F.relu(self.bn2(self.conv2(x))))
        x = self.se3(x)
        x = self.pool3(F.relu(self.bn3(self.conv3(x))))
        return x
    
    def forward(self, x):
        x = self._forward_features(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2(x)

In [None]:
# G-CARM
class CARMBlock(nn.Module):
    def __init__(self, n_channels):
        super().__init__()
        self.A = nn.Parameter(torch.randn(n_channels, n_channels) * 0.01)
    
    def forward(self, x):
        A_norm = torch.softmax(self.A, dim=1)
        x_reshaped = x.permute(0, 2, 1)
        x_graph = torch.matmul(x_reshaped, A_norm.t())
        return x_graph.permute(0, 2, 1)


class GCARM(nn.Module):
    def __init__(self, n_channels=64, n_classes=2, n_timepoints=513):
        super().__init__()
        self.carm1 = CARMBlock(n_channels)
        self.carm2 = CARMBlock(n_channels)
        self.conv1 = nn.Conv1d(n_channels, 128, kernel_size=5, padding=2)
        self.bn1 = nn.BatchNorm1d(128)
        self.pool1 = nn.MaxPool1d(2)
        self.conv2 = nn.Conv1d(128, 256, kernel_size=5, padding=2)
        self.bn2 = nn.BatchNorm1d(256)
        self.pool2 = nn.MaxPool1d(2)
        self.conv3 = nn.Conv1d(256, 512, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm1d(512)
        self.pool3 = nn.MaxPool1d(2)
        self.dropout = nn.Dropout(0.5)
        
        with torch.no_grad():
            test_input = torch.zeros(1, n_channels, n_timepoints)
            test_output = self._forward_features(test_input)
            flattened_size = test_output.view(1, -1).size(1)
        
        self.fc1 = nn.Linear(flattened_size, 256)
        self.fc2 = nn.Linear(256, n_classes)
    
    def _forward_features(self, x):
        x = self.carm1(x)
        x = self.carm2(x)
        x = self.pool1(F.relu(self.bn1(self.conv1(x))))
        x = self.pool2(F.relu(self.bn2(self.conv2(x))))
        x = self.pool3(F.relu(self.bn3(self.conv3(x))))
        return x
    
    def forward(self, x):
        x = self._forward_features(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2(x)

## 4. Training Utilities

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss, correct, total = 0.0, 0, 0
    
    for X_batch, y_batch in dataloader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        optimizer.zero_grad()
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += y_batch.size(0)
        correct += (predicted == y_batch).sum().item()
    
    return total_loss / max(1, len(dataloader)), correct / max(1, total)


def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss, correct, total = 0.0, 0, 0
    
    with torch.no_grad():
        for X_batch, y_batch in dataloader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += y_batch.size(0)
            correct += (predicted == y_batch).sum().item()
    
    return total_loss / max(1, len(dataloader)), correct / max(1, total)


def train_pytorch_model(model, train_loader, val_loader, config, model_name=''):
    device = config['device']
    model = model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'], 
                          weight_decay=config['weight_decay'])
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=config['scheduler_factor'], 
        patience=config['scheduler_patience'], min_lr=config['min_lr'], verbose=False
    )
    
    best_state = deepcopy(model.state_dict())
    best_val_acc = 0.0
    best_val_loss = float('inf')
    patience_counter = 0
    
    for epoch in range(config['epochs']):
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc = evaluate(model, val_loader, criterion, device)
        scheduler.step(val_loss)
        
        improved = val_acc > best_val_acc or (val_acc == best_val_acc and val_loss < best_val_loss)
        if improved:
            best_state = deepcopy(model.state_dict())
            best_val_acc = val_acc
            best_val_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1
        
        if epoch % 5 == 0 or improved:
            print(f"[{model_name}] Epoch {epoch+1}/{config['epochs']} - "
                  f"Train: {train_acc:.4f} | Val: {val_acc:.4f} | Best: {best_val_acc:.4f}")
        
        if config['use_early_stopping'] and patience_counter >= config['patience']:
            print(f"Early stopping at epoch {epoch+1}")
            break
    
    model.load_state_dict(best_state)
    return best_state, best_val_acc

## 5. Load Data

In [None]:
print("Loading PhysioNet data...")
X, y, subjects = load_physionet_data(CONFIG['data_path'])
print(f"\nData ready for training!")

## 6. Train All Baseline Models

In [None]:
# Define models
models_to_train = [
    {'name': 'FBCSP', 'type': 'sklearn'},
    {'name': 'CNN-SAE', 'type': 'pytorch'},
    {'name': 'EEGNet', 'type': 'pytorch'},
    {'name': 'ACS-SE-CNN', 'type': 'pytorch'},
    {'name': 'G-CARM', 'type': 'pytorch'},
]

all_results = []
skf = StratifiedKFold(n_splits=CONFIG['n_folds'], shuffle=True, random_state=CONFIG['random_seed'])

print(f"\n{'='*60}")
print("TRAINING BASELINE METHODS")
print(f"{'='*60}\n")

In [None]:
# Training loop
for model_info in models_to_train:
    model_name = model_info['name']
    model_type = model_info['type']
    
    print(f"\n{'='*60}")
    print(f"Training {model_name}")
    print(f"{'='*60}\n")
    
    fold_results = []
    
    for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
        print(f"\nFold {fold + 1}/{CONFIG['n_folds']}")
        
        X_train, X_val = X[train_idx], X[val_idx]
        y_train, y_val = y[train_idx], y[val_idx]
        
        if model_type == 'sklearn':
            model = FBCSP(freq_bands=CONFIG['fbcsp_bands'],
                          n_components=CONFIG['fbcsp_n_components'],
                          sfreq=CONFIG['sfreq'])
            model.fit(X_train, y_train)
            val_acc = model.score(X_val, y_val)
            
            model_path = os.path.join(CONFIG['models_dir'], f"baseline_{model_name}_fold{fold+1}.pkl")
            with open(model_path, 'wb') as f:
                pickle.dump(model, f)
        else:
            train_dataset = EEGDataset(X_train, y_train)
            val_dataset = EEGDataset(X_val, y_val)
            train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True)
            val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], shuffle=False)
            
            base_kwargs = {
                'n_channels': CONFIG['n_channels'],
                'n_classes': CONFIG['n_classes'],
                'n_timepoints': CONFIG['n_timepoints'],
            }
            
            if model_name == 'CNN-SAE':
                model = CNNSAE(**base_kwargs)
            elif model_name == 'EEGNet':
                model = EEGNet(**base_kwargs)
            elif model_name == 'ACS-SE-CNN':
                model = ACSECNN(**base_kwargs)
            elif model_name == 'G-CARM':
                model = GCARM(**base_kwargs)
            
            best_state, val_acc = train_pytorch_model(model, train_loader, val_loader, CONFIG, model_name)
            
            model_path = os.path.join(CONFIG['models_dir'], f"baseline_{model_name}_fold{fold+1}.pt")
            torch.save(best_state, model_path)
            
            del model
            torch.cuda.empty_cache()
            gc.collect()
        
        fold_results.append({'fold': fold + 1, 'accuracy': val_acc})
        print(f"Fold {fold + 1} Accuracy: {val_acc:.4f}")
    
    fold_accs = [r['accuracy'] for r in fold_results]
    mean_acc = np.mean(fold_accs)
    std_acc = np.std(fold_accs)
    
    print(f"\n{model_name} Summary:")
    print(f"Mean Accuracy: {mean_acc:.4f} +/- {std_acc:.4f}")
    
    all_results.append({
        'model': model_name,
        'mean_accuracy': mean_acc,
        'std_accuracy': std_acc,
        'fold_results': fold_results
    })

print(f"\n{'='*60}")
print("ALL BASELINE MODELS TRAINED!")
print(f"{'='*60}")

## 7. Save Results

## 8. Retention Analysis

Test how baseline methods perform with reduced channels using variance-based selection.
This provides a fair comparison baseline for the adaptive gating approach in Pipeline 2.

In [None]:
# Variance-based channel importance
def get_channel_importance_variance(X_train):
    """Compute channel importance based on temporal variance."""
    # X_train: (n_trials, n_channels, n_timepoints)
    channel_variance = np.var(X_train, axis=(0, 2))  # variance across trials and time
    return channel_variance


def select_top_k_channels(importance_scores, k):
    """Select top k channels based on importance scores."""
    top_k_indices = np.argsort(importance_scores)[-k:]
    return sorted(top_k_indices)


def apply_channel_selection(X, selected_channels):
    """Apply channel selection to data."""
    return X[:, selected_channels, :]

In [None]:
# Retention analysis configuration
RETENTION_K_VALUES = [10, 15, 20, 25, 30, 35]
RETENTION_MODEL = 'EEGNet'  # Use EEGNet as baseline representative

print(f"Running retention analysis for {RETENTION_MODEL} with k-values: {RETENTION_K_VALUES}")
print(f"Estimated time: {len(RETENTION_K_VALUES) * CONFIG['n_folds'] * 12 / 60:.1f} hours
")

In [None]:
# Run retention analysis
retention_results = []

for k in RETENTION_K_VALUES:
    print(f"
Testing with k={k} channels:", end=' ')
    fold_accuracies = []
    
    for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
        X_train, X_val = X[train_idx], X[val_idx]
        y_train, y_val = y[train_idx], y[val_idx]
        
        # Compute channel importance on training data
        importance_scores = get_channel_importance_variance(X_train)
        selected_channels = select_top_k_channels(importance_scores, k)
        
        # Apply channel selection
        X_train_selected = apply_channel_selection(X_train, selected_channels)
        X_val_selected = apply_channel_selection(X_val, selected_channels)
        
        # Train model with selected channels
        train_dataset = EEGDataset(X_train_selected, y_train)
        val_dataset = EEGDataset(X_val_selected, y_val)
        train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], shuffle=False)
        
        # Modify config for reduced channels
        temp_config = CONFIG.copy()
        temp_config['n_channels'] = k
        
        model = EEGNet(n_channels=k, n_classes=CONFIG['n_classes'], n_timepoints=CONFIG['n_timepoints'])
        best_state, val_acc = train_pytorch_model(model, train_loader, val_loader, temp_config, f"Retention-k{k}")
        
        fold_accuracies.append(val_acc)
        
        del model
        torch.cuda.empty_cache()
        gc.collect()
    
    mean_acc = np.mean(fold_accuracies)
    std_acc = np.std(fold_accuracies)
    print(f"{mean_acc:.4f} +/- {std_acc:.4f}")
    
    retention_results.append({
        'model': RETENTION_MODEL,
        'k': k,
        'mean_accuracy': mean_acc,
        'std_accuracy': std_acc,
        'fold_accuracies': fold_accuracies
    })

In [None]:
# Save retention results
retention_df = pd.DataFrame(retention_results)
retention_df.to_csv(os.path.join(CONFIG['results_dir'], 'baseline_retention_analysis.csv'), index=False)

print("
Retention Analysis Results:")
print(retention_df[['k', 'mean_accuracy', 'std_accuracy']])

print("
Results saved to: results/baseline_retention_analysis.csv")

In [None]:
# Prepare detailed results
detailed_results = []
for result in all_results:
    for fold_result in result['fold_results']:
        detailed_results.append({
            'model': result['model'],
            'fold': fold_result['fold'],
            'accuracy': fold_result['accuracy']
        })

detailed_df = pd.DataFrame(detailed_results)
detailed_df.to_csv(os.path.join(CONFIG['results_dir'], 'baseline_methods_results.csv'), index=False)

# Prepare summary
summary_df = pd.DataFrame(all_results)[['model', 'mean_accuracy', 'std_accuracy']]
summary_df = summary_df.sort_values('mean_accuracy', ascending=False).reset_index(drop=True)
summary_df['rank'] = range(1, len(summary_df) + 1)
summary_df = summary_df[['rank', 'model', 'mean_accuracy', 'std_accuracy']]
summary_df.to_csv(os.path.join(CONFIG['results_dir'], 'baseline_methods_summary.csv'), index=False)

print("\nResults saved to:")
print("  - results/baseline_methods_results.csv")
print("  - results/baseline_methods_summary.csv")

print("\nBaseline Methods Summary:")
print(summary_df.to_string(index=False))

print(f"\n{'='*60}")
print("PIPELINE 1 COMPLETE!")
print(f"{'='*60}")