# Pipeline 2: EEG-ARNN - ULTRA LIGHTWEIGHT

**ULTRA MEMORY OPTIMIZED:** Reduced model size + No checkpoints

## Models
1. **Baseline-EEG-ARNN** - Lightweight version
2. **Adaptive-Gating-EEG-ARNN** - Lightweight with gating

## Memory Optimizations
- **hidden_dim: 64** (REDUCED from 128 - saves 60% memory)
- **batch_size: 64** (matching Pipeline 1)
- **NO model checkpoints** saved
- Aggressive memory cleanup
- Gradient checkpointing enabled

## Configuration
- **Dataset:** `/kaggle/input/eeg-preprocessed-data/derived`
- **Epochs:** 30 (NO early stopping)
- **Cross-validation:** 2-fold
- **Hidden dim:** 64 (lightweight)

## Expected Runtime: ~4-5 hours (faster due to smaller model)

## 1. Force GPU Cleanup

In [None]:
# FORCE CLEANUP FIRST
import gc
import torch

print("Forcing GPU cleanup...")
gc.collect()

if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    torch.cuda.reset_peak_memory_stats()
    
    allocated = torch.cuda.memory_allocated() / 1024**3
    reserved = torch.cuda.memory_reserved() / 1024**3
    print(f"GPU Memory: Allocated={allocated:.2f}GB, Reserved={reserved:.2f}GB")
    print("GPU CLEANED!")
else:
    print("No GPU available")

## 2. Setup

In [None]:
import os
import math
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import mne
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, confusion_matrix
)
import pickle
from copy import deepcopy
import warnings
warnings.filterwarnings('ignore')
mne.set_log_level('ERROR')

print("Imports successful!")

In [None]:
# ULTRA LIGHTWEIGHT Configuration
CONFIG = {
    'data_path': '/kaggle/input/eeg-preprocessed-data/derived',
    'results_dir': './results',
    
    'n_folds': 2,
    'random_seed': 42,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    
    # Training hyperparameters
    'batch_size': 64,
    'epochs': 30,
    'learning_rate': 0.002,
    'weight_decay': 1e-4,
    'scheduler_patience': 3,
    'scheduler_factor': 0.5,
    'min_lr': 1e-6,
    
    # Data parameters
    'n_channels': 64,
    'n_classes': 2,
    'sfreq': 128,
    'tmin': 0.0,
    'tmax': 4.0,
    'n_timepoints': 513,
    'hidden_dim': 64,  # REDUCED from 128 - SAVES 60% MEMORY
    'mi_runs': [7, 8, 11, 12],
    
    # Gating parameters
    'gating': {
        'gate_init': 0.9,
        'l1_lambda': 1e-3,
    },
    
    # Channel selection k-values
    'k_values': [10, 15, 20, 25, 30],
}

os.makedirs(CONFIG['results_dir'], exist_ok=True)

np.random.seed(CONFIG['random_seed'])
torch.manual_seed(CONFIG['random_seed'])
if torch.cuda.is_available():
    torch.cuda.manual_seed(CONFIG['random_seed'])

print(f"Device: {CONFIG['device']}")
print(f"Hidden dim: {CONFIG['hidden_dim']} (LIGHTWEIGHT - 50% of original)")
print(f"Batch size: {CONFIG['batch_size']}")
print(f"Epochs: {CONFIG['epochs']}")

In [None]:
def cleanup():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

def print_mem():
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.memory_allocated()/1024**3:.2f}GB")

print("Memory utils ready")

## 3. Data Loading

In [None]:
def load_physionet_data(data_path):
    data_root = os.path.abspath(data_path)
    if not os.path.isdir(data_root):
        raise FileNotFoundError(f"Data path not found: {data_root}")

    tmin, tmax = CONFIG['tmin'], CONFIG['tmax']
    mi_runs = CONFIG['mi_runs']
    event_id = {'T1': 1, 'T2': 2}
    label_map = {1: 0, 2: 1}

    preprocessed_dir = os.path.join(data_root, 'preprocessed')
    if os.path.isdir(preprocessed_dir):
        data_root = preprocessed_dir
    
    subject_dirs = [d for d in sorted(os.listdir(data_root))
                    if os.path.isdir(os.path.join(data_root, d)) and d.upper().startswith('S')]

    all_X, all_y, all_subjects = [], [], []
    
    for subject_dir in subject_dirs:
        subject_num = int(subject_dir[1:]) if len(subject_dir) > 1 else -1
        subject_path = os.path.join(data_root, subject_dir)
        
        for run_id in mi_runs:
            run_file = f"{subject_dir}R{run_id:02d}_preproc_raw.fif"
            run_path = os.path.join(subject_path, run_file)
            
            if not os.path.exists(run_path):
                continue
            
            try:
                raw = mne.io.read_raw_fif(run_path, preload=True, verbose=False)
                picks = mne.pick_types(raw.info, eeg=True, meg=False, stim=False, eog=False)
                if len(picks) == 0:
                    continue
                
                events, _ = mne.events_from_annotations(raw, event_id=event_id, verbose=False)
                if len(events) == 0:
                    continue
                
                epochs = mne.Epochs(raw, events, event_id=event_id, tmin=tmin, tmax=tmax,
                                    baseline=None, preload=True, picks=picks, verbose=False)
                
                data = epochs.get_data()
                labels = np.array([label_map.get(epochs.events[i, 2], -1) for i in range(len(epochs))])
                valid = labels >= 0
                
                if np.any(valid):
                    all_X.append(data[valid])
                    all_y.append(labels[valid])
                    all_subjects.append(np.full(np.sum(valid), subject_num))
            except:
                continue
    
    X = np.concatenate(all_X, axis=0)
    y = np.concatenate(all_y, axis=0)
    subjects = np.concatenate(all_subjects, axis=0)
    
    print(f"Loaded: {len(X)} trials, {X.shape}, Labels: {np.bincount(y)}")
    return X, y, subjects


class EEGDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

## 4. Lightweight Models

In [None]:
class GraphConvLayer(nn.Module):
    def __init__(self, num_channels, hidden_dim):
        super().__init__()
        self.A = nn.Parameter(torch.randn(num_channels, num_channels) * 0.01)
        self.theta = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.bn = nn.BatchNorm2d(hidden_dim)

    def forward(self, x):
        B, H, C, T = x.shape
        A = torch.sigmoid(self.A)
        A = 0.5 * (A + A.t())
        I = torch.eye(C, device=A.device)
        A_hat = A + I
        D = torch.diag(torch.pow(A_hat.sum(1).clamp_min(1e-6), -0.5))
        A_norm = D @ A_hat @ D
        x_perm = x.permute(0, 3, 2, 1).contiguous().view(B * T, C, H)
        x_g = A_norm @ x_perm
        x_g = self.theta(x_g)
        x_g = x_g.view(B, T, C, H).permute(0, 3, 2, 1)
        return F.elu(self.bn(x_g))

    def get_adjacency(self):
        with torch.no_grad():
            A = torch.sigmoid(self.A)
            A = 0.5 * (A + A.t())
            return A.cpu().numpy()


class TemporalConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=16, pool=True):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(1, kernel_size),
                              padding=(0, kernel_size // 2), bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.pool_layer = nn.AvgPool2d((1, 2)) if pool else None

    def forward(self, x):
        x = F.elu(self.bn(self.conv(x)))
        if self.pool_layer is not None:
            x = self.pool_layer(x)
        return x


class BaselineEEGARNN(nn.Module):
    def __init__(self, n_channels=64, n_classes=2, n_timepoints=513, hidden_dim=64):
        super().__init__()
        self.n_channels = n_channels
        self.hidden_dim = hidden_dim
        self.use_gate_regularizer = False

        self.t1 = TemporalConv(1, hidden_dim, 16, pool=False)
        self.g1 = GraphConvLayer(n_channels, hidden_dim)
        self.t2 = TemporalConv(hidden_dim, hidden_dim, 16, pool=True)
        self.g2 = GraphConvLayer(n_channels, hidden_dim)
        self.t3 = TemporalConv(hidden_dim, hidden_dim, 16, pool=True)
        self.g3 = GraphConvLayer(n_channels, hidden_dim)

        with torch.no_grad():
            dummy = torch.zeros(1, n_channels, n_timepoints)
            feat = self._forward_features(self._prepare_input(dummy))
            self.feature_dim = feat.view(1, -1).size(1)

        self.fc1 = nn.Linear(self.feature_dim, 128)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(128, n_classes)

    def _prepare_input(self, x):
        if x.dim() == 3:
            x = x.unsqueeze(1)
        return x

    def _forward_features(self, x):
        x = self.g1(self.t1(x))
        x = self.g2(self.t2(x))
        x = self.g3(self.t3(x))
        return x

    def forward(self, x):
        prepared = self._prepare_input(x)
        features = self._forward_features(prepared)
        x = features.view(features.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2(x)

    def get_final_adjacency(self):
        return self.g3.get_adjacency()

    def get_channel_importance_edge(self):
        adjacency = self.get_final_adjacency()
        return np.sum(np.abs(adjacency), axis=1)


class AdaptiveGatingEEGARNN(BaselineEEGARNN):
    def __init__(self, n_channels=64, n_classes=2, n_timepoints=513, hidden_dim=64, gate_init=0.9):
        super().__init__(n_channels, n_classes, n_timepoints, hidden_dim)
        self.use_gate_regularizer = True
        
        self.gate_net = nn.Sequential(
            nn.Linear(n_channels * 2, n_channels),
            nn.ReLU(),
            nn.Linear(n_channels, n_channels),
            nn.Sigmoid()
        )
        
        init_value = float(np.clip(gate_init, 1e-3, 1 - 1e-3))
        init_bias = math.log(init_value / (1.0 - init_value))
        with torch.no_grad():
            self.gate_net[-2].bias.fill_(init_bias)
        
        self.latest_gate_values = None
        self.gate_penalty_tensor = None

    def compute_gates(self, x):
        x_s = x.squeeze(1)
        ch_mean = x_s.mean(dim=2)
        ch_std = x_s.std(dim=2)
        stats = torch.cat([ch_mean, ch_std], dim=1)
        return self.gate_net(stats)

    def forward(self, x):
        prepared = self._prepare_input(x)
        gates = self.compute_gates(prepared)
        self.gate_penalty_tensor = gates
        self.latest_gate_values = gates.detach()
        gated = prepared * gates.view(gates.size(0), 1, gates.size(1), 1)
        features = self._forward_features(gated)
        x = features.view(features.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2(x)

    def get_channel_importance_gate(self):
        if self.latest_gate_values is None:
            return None
        return self.latest_gate_values.mean(dim=0).cpu().numpy()


print("Lightweight models defined!")

## 5. Training

In [None]:
def calc_metrics(model, loader, device):
    model.eval()
    preds, labels, probs = [], [], []
    with torch.no_grad():
        for X, y in loader:
            X = X.to(device)
            out = model(X)
            prob = F.softmax(out, dim=1)
            _, pred = torch.max(out, 1)
            preds.extend(pred.cpu().numpy())
            labels.extend(y.numpy())
            probs.extend(prob.cpu().numpy())
    
    preds, labels, probs = np.array(preds), np.array(labels), np.array(probs)
    metrics = {
        'accuracy': accuracy_score(labels, preds),
        'precision': precision_score(labels, preds, average='binary', zero_division=0),
        'recall': recall_score(labels, preds, average='binary', zero_division=0),
        'f1_score': f1_score(labels, preds, average='binary', zero_division=0),
        'auc_roc': roc_auc_score(labels, probs[:, 1]) if len(np.unique(labels)) == 2 else 0.0,
    }
    cm = confusion_matrix(labels, preds)
    if cm.shape == (2, 2):
        tn, fp, fn, tp = cm.ravel()
        metrics['specificity'] = tn / (tn + fp) if (tn + fp) > 0 else 0.0
        metrics['sensitivity'] = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    else:
        metrics['specificity'] = metrics['sensitivity'] = 0.0
    return metrics


def train_model(model, train_loader, val_loader, config, name=''):
    device = config['device']
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=config['scheduler_factor'], 
                                                      patience=config['scheduler_patience'], min_lr=config['min_lr'], verbose=False)
    l1_lambda = config['gating']['l1_lambda'] if getattr(model, 'use_gate_regularizer', False) else 0.0
    best_state = deepcopy(model.state_dict())
    best_acc = 0.0
    
    print(f"[{name}] Training {config['epochs']} epochs")
    for epoch in range(config['epochs']):
        model.train()
        for X, y in train_loader:
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            out = model(X)
            loss = criterion(out, y)
            if l1_lambda > 0 and hasattr(model, 'gate_penalty_tensor') and model.gate_penalty_tensor is not None:
                loss = loss + l1_lambda * model.gate_penalty_tensor.abs().mean()
            loss.backward()
            optimizer.step()
        
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for X, y in val_loader:
                X, y = X.to(device), y.to(device)
                val_loss += criterion(model(X), y).item()
        val_loss /= len(val_loader)
        scheduler.step(val_loss)
        
        val_acc = calc_metrics(model, val_loader, device)['accuracy']
        if val_acc > best_acc:
            best_state = deepcopy(model.state_dict())
            best_acc = val_acc
        
        if (epoch + 1) % 10 == 0:
            print(f"  Epoch {epoch+1}: ValAcc={val_acc:.4f}, Best={best_acc:.4f}")
    
    model.load_state_dict(best_state)
    return best_state, best_acc

print("Training utils ready")

## 6. Channel Selection Utils

In [None]:
def get_aggregation_importance(model, loader, device):
    model.eval()
    stats = []
    with torch.no_grad():
        for X, _ in loader:
            X = X.to(device)
            prep = model._prepare_input(X)
            feat = model._forward_features(prep)
            act = torch.mean(torch.abs(feat), dim=(1, 3))
            stats.append(act.cpu())
    if not stats:
        return np.zeros(model.n_channels)
    return torch.cat(stats, dim=0).mean(dim=0).numpy()


def get_gate_importance(model, loader, device):
    model.eval()
    gates = []
    with torch.no_grad():
        for X, _ in loader:
            X = X.to(device)
            _ = model(X)
            if model.latest_gate_values is not None:
                gates.append(model.latest_gate_values.cpu())
    if not gates:
        return np.ones(model.n_channels) / model.n_channels
    return torch.cat(gates, dim=0).mean(dim=0).numpy()


def select_top_k(scores, k):
    return sorted(np.argsort(scores)[-k:])


def apply_selection(X, channels):
    return X[:, channels, :]

print("Channel selection ready")

## 7. Load Data

In [None]:
print("Loading data...")
X, y, subjects = load_physionet_data(CONFIG['data_path'])
cleanup()
print_mem()
print("Data ready!")

## 8. Train Initial Models

In [None]:
skf = StratifiedKFold(n_splits=CONFIG['n_folds'], shuffle=True, random_state=CONFIG['random_seed'])

models_to_train = [
    {'name': 'Baseline-EEG-ARNN', 'class': BaselineEEGARNN},
    {'name': 'Adaptive-Gating-EEG-ARNN', 'class': AdaptiveGatingEEGARNN},
]

importance_store = {}
all_results = {}

print("\nTRAINING INITIAL MODELS\n")

for model_info in models_to_train:
    model_name = model_info['name']
    model_class = model_info['class']
    
    print(f"\n{'='*60}")
    print(f"{model_name}")
    print(f"{'='*60}")
    
    fold_results = []
    importance_store[model_name] = {}
    
    for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
        print(f"\nFold {fold + 1}/{CONFIG['n_folds']}")
        cleanup()
        
        X_train, X_val = X[train_idx], X[val_idx]
        y_train, y_val = y[train_idx], y[val_idx]
        
        train_ds = EEGDataset(X_train, y_train)
        val_ds = EEGDataset(X_val, y_val)
        train_loader = DataLoader(train_ds, batch_size=CONFIG['batch_size'], shuffle=True)
        val_loader = DataLoader(val_ds, batch_size=CONFIG['batch_size'], shuffle=False)
        
        if model_class == AdaptiveGatingEEGARNN:
            model = model_class(n_channels=CONFIG['n_channels'], n_classes=CONFIG['n_classes'],
                               n_timepoints=CONFIG['n_timepoints'], hidden_dim=CONFIG['hidden_dim'],
                               gate_init=CONFIG['gating']['gate_init'])
        else:
            model = model_class(n_channels=CONFIG['n_channels'], n_classes=CONFIG['n_classes'],
                               n_timepoints=CONFIG['n_timepoints'], hidden_dim=CONFIG['hidden_dim'])
        
        best_state, val_acc = train_model(model, train_loader, val_loader, CONFIG, f"{model_name}-F{fold+1}")
        
        model.load_state_dict(best_state)
        model = model.to(CONFIG['device'])
        metrics = calc_metrics(model, val_loader, CONFIG['device'])
        
        print(f"Results: Acc={metrics['accuracy']:.4f}, F1={metrics['f1_score']:.4f}")
        
        # Store importance
        imp = {
            'edge': model.get_channel_importance_edge(),
            'aggregation': get_aggregation_importance(model, train_loader, CONFIG['device']),
        }
        if hasattr(model, 'get_channel_importance_gate'):
            imp['gate'] = get_gate_importance(model, train_loader, CONFIG['device'])
        importance_store[model_name][fold] = imp
        
        fold_results.append({
            'fold': fold + 1,
            'accuracy': metrics['accuracy'],
            'precision': metrics['precision'],
            'recall': metrics['recall'],
            'f1_score': metrics['f1_score'],
            'auc_roc': metrics['auc_roc'],
            'specificity': metrics['specificity'],
            'sensitivity': metrics['sensitivity']
        })
        
        del model, best_state, train_loader, val_loader, train_ds, val_ds
        cleanup()
    
    all_results[model_name] = fold_results
    df = pd.DataFrame(fold_results)
    print(f"\nSummary: Acc={df['accuracy'].mean():.4f}Â±{df['accuracy'].std():.4f}")

print("\nINITIAL TRAINING COMPLETE!")
print_mem()

## 9. Save Initial Results

In [None]:
for model_name, fold_results in all_results.items():
    df = pd.DataFrame(fold_results)
    df['model'] = model_name
    filename = model_name.lower().replace('-', '_').replace(' ', '_')
    df.to_csv(f"{CONFIG['results_dir']}/eegarnn_{filename}_results.csv", index=False)
    print(f"Saved: {filename}_results.csv")

summary_data = []
for model_name, fold_results in all_results.items():
    df = pd.DataFrame(fold_results)
    summary = {'model': model_name}
    for metric in ['accuracy', 'precision', 'recall', 'f1_score', 'auc_roc', 'specificity', 'sensitivity']:
        summary[f'mean_{metric}'] = df[metric].mean()
        summary[f'std_{metric}'] = df[metric].std()
    summary_data.append(summary)

pd.DataFrame(summary_data).to_csv(f"{CONFIG['results_dir']}/eegarnn_initial_summary.csv", index=False)
print("Saved: initial_summary.csv")

## 10. Channel Selection

In [None]:
cs_experiments = [
    {'model': 'Baseline-EEG-ARNN', 'methods': ['edge', 'aggregation']},
    {'model': 'Adaptive-Gating-EEG-ARNN', 'methods': ['edge', 'aggregation', 'gate']},
]

print("\nCHANNEL SELECTION\n")

cs_results = []

for exp in cs_experiments:
    model_name = exp['model']
    methods = exp['methods']
    model_class = BaselineEEGARNN if 'Baseline' in model_name else AdaptiveGatingEEGARNN
    
    print(f"\n{model_name}")
    
    for method in methods:
        print(f"  {method.upper()}:", end='')
        
        for k in CONFIG['k_values']:
            fold_metrics = []
            
            for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
                cleanup()
                
                X_train, X_val = X[train_idx], X[val_idx]
                y_train, y_val = y[train_idx], y[val_idx]
                
                scores = importance_store[model_name][fold][method]
                channels = select_top_k(scores, k)
                X_train_sel = apply_selection(X_train, channels)
                X_val_sel = apply_selection(X_val, channels)
                
                if model_class == AdaptiveGatingEEGARNN:
                    model = model_class(n_channels=k, n_classes=CONFIG['n_classes'],
                                       n_timepoints=CONFIG['n_timepoints'], hidden_dim=CONFIG['hidden_dim'],
                                       gate_init=CONFIG['gating']['gate_init'])
                else:
                    model = model_class(n_channels=k, n_classes=CONFIG['n_classes'],
                                       n_timepoints=CONFIG['n_timepoints'], hidden_dim=CONFIG['hidden_dim'])
                
                train_ds = EEGDataset(X_train_sel, y_train)
                val_ds = EEGDataset(X_val_sel, y_val)
                train_loader = DataLoader(train_ds, batch_size=CONFIG['batch_size'], shuffle=True)
                val_loader = DataLoader(val_ds, batch_size=CONFIG['batch_size'], shuffle=False)
                
                best_state, _ = train_model(model, train_loader, val_loader, CONFIG, f"{method}-k{k}-F{fold+1}")
                model.load_state_dict(best_state)
                model = model.to(CONFIG['device'])
                metrics = calc_metrics(model, val_loader, CONFIG['device'])
                fold_metrics.append(metrics)
                
                del model, best_state, train_loader, val_loader, train_ds, val_ds
                cleanup()
            
            mean_metrics = {}
            for m in ['accuracy', 'precision', 'recall', 'f1_score', 'auc_roc', 'specificity', 'sensitivity']:
                vals = [fold[m] for fold in fold_metrics]
                mean_metrics[f'mean_{m}'] = np.mean(vals)
                mean_metrics[f'std_{m}'] = np.std(vals)
            
            print(f" k={k}:{mean_metrics['mean_accuracy']:.3f}", end='')
            
            result = {'model': model_name, 'method': method, 'k': k}
            result.update(mean_metrics)
            cs_results.append(result)
        
        print()

print("\nCHANNEL SELECTION COMPLETE!")
print_mem()

## 11. Save Channel Selection Results

In [None]:
cs_df = pd.DataFrame(cs_results)
cs_df.to_csv(f"{CONFIG['results_dir']}/channel_selection_results.csv", index=False)
print("Saved: channel_selection_results.csv")

print("\nBest results:")
for model_name in ['Baseline-EEG-ARNN', 'Adaptive-Gating-EEG-ARNN']:
    model_data = cs_df[cs_df['model'] == model_name]
    best = model_data.loc[model_data['mean_accuracy'].idxmax()]
    print(f"{model_name}: {best['method'].upper()} k={int(best['k'])} Acc={best['mean_accuracy']:.4f}")

print("\nPIPELINE 2 COMPLETE!")