# **Assignment Roadmap**

This assignment asks you to build a complete Brain-Computer Interface (BCI) pipeline. Your goal is to take raw, noisy electrical brain signals and turn them into a clear Yes/No decision: Is this the character the user wants?

**Theres not much theory to learn other than implementation, you have to learn this by doing**

## AI Usage Policy for This Assignment

You're welcome to use AI for this assignment. Given the complexity of EEG signal processing and machine learning,
We don't expect you to know every implementation detail from scratch and neither does any recuiter or any professor.


Use AI to:

    Debug errors and troubleshoot issues

    Understand concepts and explore different approaches

What matters:

    You understand your code and can explain how it works

    You learn from the process, not just copy-paste

### **1: Cleaning the Signal (Preprocessing)**

The Goal: Raw EEG data is full of "garbage" frequencies like muscle movement and electrical hum. You need to filter the data to keep only the brain waves relevant to the P300 response (typically 0.1Hz – 20Hz).

You have already done this in the previous assignment but this one is a more standard procedure.

Common Pitfalls:

    Filter Lag: Standard filters can delay the signal, meaning the brain response looks like it happened later than it actually did. To prevent this, use zero-phase filtering (e.g., scipy.signal.filtfilt) instead of standard filtering (lfilter).

    Aliasing: You are asked to downsample the data from 240Hz to 60Hz to make it faster to process. Do not simply slice the array (e.g., data[::4]) without filtering first. If you do, high-frequency noise will "fold over" into your low frequencies and corrupt the data. Always filter before downsampling.

### **2: Epoch Extraction**

The Goal: You need to convert the continuous stream of data into specific "events" or "epochs."

The Concept: A P300 response happens roughly 300ms after a flash. Your code needs to identify every moment a flash occurs (stimulus_onset), look forward in time (e.g., 800ms), and "snip" that window of data out.

Visualizing the Data Structure:

    Input: A continuous 2D matrix (Total_Time_Points, 64_Channels)

    Output: A 3D block of events (Number_of_Flashes, Time_Points_Per_Window, 64_Channels)

Common Pitfall:

    Indexing Errors: This dataset may originate from MATLAB (which uses 1-based indexing), while Python uses 0-based indexing. If your index calculation is off by even one sample, your window will shift, and the machine learning model will be training on random noise rather than the brain response. Double-check your start and end indices.

### **3: Making Data "Model-Ready" (Feature Engineering)**

The Goal: Standard Machine Learning models (like SVM or LDA) cannot understand 3D arrays. They generally require a 2D matrix (like an Excel sheet). The Task:

    Time Domain: You will need to "flatten" the epochs. If an epoch is 60 time points × 64 channels, it becomes a single flat row of 3,840 numbers.

    PCA/CSP: These are compression techniques. The goal is to reduce those 3,840 numbers down to perhaps 20 numbers that capture the most important information.

Common Pitfall:

    Data Leakage: When using PCA or CSP, you must be careful not to "cheat." You should .fit() your reducer only on the training data, and then .transform() the test data. If you fit on the combined dataset, your model "sees" the test answers ahead of time, leading to artificially high scores that won't work in the real world.

### **4: Classification**

The Goal: Feed your features into the ML models (LDA, SVM, etc.) provided in the skeleton code to classify "Target" vs. "Non-Target" flashes.

Common Pitfall:

    Class Imbalance: In a P300 speller, the letter the user wants (Target) only flashes 1 out of 6 times. The other 5 flashes are Non-Targets.

        If your model decides to simply guess "Non-Target" every single time, it will still achieve ~83% accuracy. This is a useless model.

        Do not rely solely on Accuracy. Check the F1-Score or the Confusion Matrix. A good model must be able to correctly identify the rare Target events, not just the frequent Non-Targets.

In [12]:
# The assignment is structured in a way that its modular so thats its easier to debug whats wrong

import numpy as np
import matplotlib.pyplot as plt
import scipy.io as sio
from scipy.signal import butter, filtfilt, iirnotch
from scipy.linalg import eigh
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.metrics import (accuracy_score, precision_score, recall_score,
                             f1_score, roc_auc_score, classification_report,
                             confusion_matrix)
import pickle
import json
import time
import warnings
warnings.filterwarnings('ignore')

In [13]:
################################################################################
# SECTION 1: DATA LOADING AND BASIC SETUP
################################################################################

# Character matrix (6x6) for P300 speller
CHAR_MATRIX = np.array([
    ['A', 'B', 'C', 'D', 'E', 'F'],
    ['G', 'H', 'I', 'J', 'K', 'L'],
    ['M', 'N', 'O', 'P', 'Q', 'R'],
    ['S', 'T', 'U', 'V', 'W', 'X'],
    ['Y', 'Z', '1', '2', '3', '4'],
    ['5', '6', '7', '8', '9', '_']
])

def load_data(file_path):
    """
    Load P300 BCI Competition III data
    Returns dictionary with signal, flashing, stimulus_code, stimulus_type, target_char
    """
    data = sio.loadmat(file_path)

    # Handle variable names which might vary slightly (Signal vs Signal(x))
    # Standardizing to keys used in BCI Comp III Dataset II
    result = {
        'signal': data.get('Signal'),
        'flashing': data.get('Flashing'),
        'stimulus_code': data.get('StimulusCode'),
    }

    # Training data has labels, test data doesn't
    if 'StimulusType' in data:
        result['stimulus_type'] = data.get('StimulusType')
    if 'TargetChar' in data:
        result['target_char'] = data.get('TargetChar')

    return result


def get_char_from_codes(row_code, col_code):
    """Convert row/column stimulus codes to character"""
    r_idx = -1
    c_idx = -1

    if 7 <= row_code <= 12:
        r_idx = row_code - 7
    elif 1 <= row_code <= 6:
        c_idx = row_code - 1

    if 7 <= col_code <= 12:
        r_idx = col_code - 7
    elif 1 <= col_code <= 6:
        c_idx = col_code - 1

    if r_idx != -1 and c_idx != -1:
        return CHAR_MATRIX[r_idx, c_idx]
    return None



In [14]:
################################################################################
# SECTION 2: EEG SIGNAL ACQUISITION & PREPROCESSING
################################################################################

def bandpass_filter(signal, lowcut=0.1, highcut=20, fs=240, order=5):
    """
    Apply band-pass filter to remove low-frequency drift and high-frequency noise
    Typical P300 band: 0.1-20 Hz
    """
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = butter(order, [low, high], btype='band')
    # Use filtfilt for zero-phase filtering (no delay)
    filtered_signal = filtfilt(b, a, signal, axis=0)
    return filtered_signal

def notch_filter(signal, freq=50, fs=240, Q=30):
    """
    Remove powerline interference (50/60 Hz)
    """
    nyquist = 0.5 * fs
    w0 = freq / nyquist
    b, a = iirnotch(w0, Q)
    filtered_signal = filtfilt(b, a, signal, axis=0)
    return filtered_signal


def baseline_correction(epoch, baseline_samples=50):
    """
    Apply baseline correction by subtracting pre-stimulus baseline
    """
    baseline_mean = np.mean(epoch[:baseline_samples, :], axis=0)
    return epoch - baseline_mean

def downsample_signal(signal, original_fs=240, target_fs=60):
    """
    Downsample signal to reduce computational load
    240 Hz -> 60 Hz reduces data by 4x
    """
    factor = int(original_fs / target_fs)
    # Simple slicing is okay ONLY IF the signal was low-pass filtered beforehand
    # to avoid aliasing. Our pipeline does bandpass (max 20Hz) before this.
    return signal[::factor, :]

def artifact_rejection(signal, threshold=100):
    """
    Simple artifact rejection based on amplitude threshold
    More advanced: use ICA or wavelet denoising
    """
    # Identify channels or epochs that exceed threshold
    # For this assignment, we might just clip or return indices,
    # but here we'll just clip for simplicity/safety
    return np.clip(signal, -threshold, threshold)

def preprocess_pipeline(data, apply_bandpass=True, apply_notch=True,
                        apply_downsample=True, fs=240):
    """
    Complete preprocessing pipeline:
    1. Bandpass filtering (0.1-20 Hz)
    2. Notch filtering (50 Hz)
    3. Downsampling (240->60 Hz)
    """
    signal = data['signal'].astype(np.float32)

    # Ensure signal is 2D (time_points, channels) if it loaded as (1, time_points, channels)
    # This fixes the issue with filtering and downsampling operating on a dimension of size 1.
    if signal.ndim == 3 and signal.shape[0] == 1:
        signal = signal.squeeze(axis=0)
    elif signal.ndim != 2: # Expected (time_points, channels)
        print(f"Warning: Unexpected signal dimension {signal.ndim}. Expected 2D (Time, Channels) or 3D (1, Time, Channels). Proceeding assuming (Time, Channels).")


    if apply_notch:
        signal = notch_filter(signal, freq=50, fs=fs)

    if apply_bandpass:
        signal = bandpass_filter(signal, lowcut=0.1, highcut=20, fs=fs)

    if apply_downsample:
        signal = downsample_signal(signal, original_fs=fs, target_fs=60)
        # Update sampling rate for downstream
        fs = 60

    # Update data dictionary with processed signal
    processed_data = data.copy()
    processed_data['signal'] = signal

    # Also downsample the flashing and stimulus codes to match indices
    if apply_downsample:
        factor = 240 // 60
        # Ensure flashing, stimulus_code, and stimulus_type are 1D before slicing
        processed_data['flashing'] = data['flashing'].flatten()[::factor]
        processed_data['stimulus_code'] = data['stimulus_code'].flatten()[::factor]
        if 'stimulus_type' in data:
            processed_data['stimulus_type'] = data['stimulus_type'].flatten()[::factor]

    return processed_data

def extract_epochs(data, epoch_length_ms=800, fs=60):
    """
    Extract epochs around stimulus onset
    - Event tagging: Use flashing signal to detect stimulus onset
    - Stimulus alignment: Extract fixed-length windows after each flash
    - Epoch extraction: Collect all stimulus-locked epochs

    Returns: Dictionary with epochs, labels, codes, character indices
    """
    signal = data['signal']
    flashing = data['flashing'].flatten()
    codes = data['stimulus_code'].flatten()

    has_labels = 'stimulus_type' in data
    if has_labels:
        labels_in = data['stimulus_type'].flatten()

    samples_per_epoch = int((epoch_length_ms / 1000) * fs)

    epochs = []
    labels = []
    codes_list = []
    char_idx_list = [] # Keeps track of which character/block this epoch belongs to

    # Find stimulus onsets (flashing goes from 0 to 1)
    # np.diff gives change; we want where it goes 0->1
    changes = np.diff(flashing)
    onset_indices = np.where(changes == 1)[0] + 1

    for idx in onset_indices:
        # Check boundary
        if idx + samples_per_epoch <= len(signal):
            epoch = signal[idx : idx + samples_per_epoch, :]

            # Apply baseline correction
            epoch = baseline_correction(epoch, baseline_samples=int(0.1*fs)) # 100ms baseline

            epochs.append(epoch)
            codes_list.append(codes[idx])

            # Dummy char index for now, in real scenario we track block index
            char_idx_list.append(0)

            if has_labels:
                # Label 1 = Target, 0 = Non-Target
                labels.append(labels_in[idx])

    if not has_labels:
        labels = [-1] * len(epochs) # Dummy labels for test set

    return {
        'epochs': np.array(epochs),
        'labels': np.array(labels),
        'codes': np.array(codes_list),
        'char_indices': np.array(char_idx_list)
    }

def plot_erp_responses(epoch_data, channel_idx=31, fs=60):
    """
    Visualize ERP responses and confirm P300 peaks around 300ms
    Channel 31 = Cz (central midline electrode, best for P300)
    """
    epochs = epoch_data['epochs']
    labels = epoch_data['labels']

    if np.all(labels == -1):
        print("Test data provided (no labels). Cannot plot ERP comparison.")
        return

    # Separate Target and Non-Target
    targets = epochs[labels == 1][:, :, channel_idx]
    non_targets = epochs[labels == 0][:, :, channel_idx]

    # Plot averages with standard error
    time_axis = np.linspace(0, 800, targets.shape[1])

    plt.figure(figsize=(10, 6))

    # Target (P300)
    mean_target = np.mean(targets, axis=0)
    plt.plot(time_axis, mean_target, label='Target (P300)', color='blue', linewidth=2)

    # Non-Target
    mean_non_target = np.mean(non_targets, axis=0)
    plt.plot(time_axis, mean_non_target, label='Non-Target', color='red', linestyle='--', linewidth=2)

    # Mark P300 peak region
    plt.axvspan(250, 450, color='gray', alpha=0.2, label='Expected P300 Window')

    plt.title(f'ERP Response at Channel {channel_idx}')
    plt.xlabel('Time (ms)')
    plt.ylabel('Amplitude (uV)')
    plt.legend()
    plt.grid(True)
    plt.show()

    # Calculate P300 amplitude difference
    p300_window_idx = (time_axis >= 250) & (time_axis <= 450)
    amp_diff = np.mean(mean_target[p300_window_idx]) - np.mean(mean_non_target[p300_window_idx])
    print(f"Mean Amplitude Difference in P300 Window: {amp_diff:.4f} uV")

In [15]:
################################################################################
# SECTION 3: FEATURE ENGINEERING & BASELINE CLASSIFIERS
################################################################################

def extract_time_domain_features(epochs):
    """
    Extract time-domain features: simply flatten the epochs
    Shape: (n_epochs, n_samples * n_channels)
    """
    n_epochs = epochs.shape[0]
    return epochs.reshape(n_epochs, -1)

def extract_pca_features(epochs, n_components=20):
    """
    Extract PCA features for dimensionality reduction
    Reduces (n_samples * n_channels) to n_components
    """
    flat_data = extract_time_domain_features(epochs)
    pca = PCA(n_components=n_components)
    # Note: Fitting should ideally happen only on training data
    # Here we return the object so it can be transformed later
    return flat_data, pca

def extract_csp_features(epochs, labels, n_components=6):
    """
    Common Spatial Patterns (CSP) for discriminative spatial filters
    Finds spatial filters that maximize variance ratio between classes
    """
    target_epochs = epochs[labels == 1]
    non_target_epochs = epochs[labels == 0]
    print(f"  Target epochs for CSP: {len(target_epochs)}")
    print(f"  Non-target epochs for CSP: {len(non_target_epochs)}")

    # Compute covariance matrices
    def compute_cov(data):
        # data shape: (epochs, time, channels)
        # covariance: (channels, channels)
        covs = []
        for i in range(data.shape[0]):
            trial = data[i, :, :]
            cov = np.dot(trial.T, trial) / trial.shape[0]
            covs.append(cov)
        return np.mean(covs, axis=0)

    cov_target = compute_cov(target_epochs)
    cov_nontarget = compute_cov(non_target_epochs)

    # Solve generalized eigenvalue problem
    # scipy.linalg.eigh(a, b) solves a*v = w*b*v
    vals, vecs = eigh(cov_target, cov_target + cov_nontarget)

    # Sort by eigenvalues
    idx = np.argsort(vals)
    vals = vals[idx]
    vecs = vecs[:, idx]

    # Select most discriminative components (extreme eigenvalues)
    # Top n/2 and Bottom n/2
    n_half = n_components // 2
    filters = np.concatenate([vecs[:, :n_half], vecs[:, -n_half:]], axis=1)

    # This matrix W will be used to project data
    return filters.T

def extract_features(epoch_data, method='pca', n_components=20):
    """
    Feature extraction wrapper supporting multiple methods:
    - time_domain: Raw time-domain samples (flattened)
    - pca: Principal Component Analysis
    - csp: Common Spatial Patterns
    """
    epochs = epoch_data['epochs']
    labels = epoch_data['labels']

    if method == 'time':
        return extract_time_domain_features(epochs), None

    elif method == 'pca':
        flat, pca = extract_pca_features(epochs, n_components)
        features = pca.fit_transform(flat)
        return features, pca

    elif method == 'csp':
        # CSP requires labels to calculate filters
        if np.all(labels == -1):
            raise ValueError("CSP requires labels for training")
        filters = extract_csp_features(epochs, labels, n_components)

        # Apply CSP transform
        # Project: Z = W * X
        # Variance: log(var(Z))
        features = []
        for i in range(epochs.shape[0]):
            # Transpose epoch to (Channels, Time)
            trial = epochs[i].T
            projected = np.dot(filters, trial)
            # Log variance feature
            var = np.var(projected, axis=1)
            feat = np.log(var)
            features.append(feat)

        return np.array(features), filters

    return None, None

def train_lda_classifier(X_train, y_train):
    """
    Linear Discriminant Analysis with balanced priors
    """
    clf = LinearDiscriminantAnalysis(solver='lsqr', shrinkage='auto')
    clf.fit(X_train, y_train)
    return clf

def train_logistic_regression(X_train, y_train):
    """
    Logistic Regression - baseline classifier
    """
    clf = LogisticRegression(class_weight='balanced', max_iter=1000)
    clf.fit(X_train, y_train)
    return clf

def evaluate_classifier(model, X_test, y_test, model_name="Model"):
    """
    Comprehensive classifier evaluation
    """
    y_pred = model.predict(X_test)
    acc = accuracy_score(y_test, y_pred)
    f1 = f1_score(y_test, y_pred)

    print(f"[{model_name}] Accuracy: {acc:.4f} | F1-Score: {f1:.4f}")
    return acc

class CSPTransformer:
    """
    Wrapper for CSP filters to enable transform() method
    """
    def __init__(self, filters):
        self.filters = filters

    def transform(self, epochs):
        features = []
        for i in range(epochs.shape[0]):
            trial = epochs[i].T
            projected = np.dot(self.filters, trial)
            var = np.var(projected, axis=1)
            feat = np.log(var)
            features.append(feat)
        return np.array(features)

In [16]:
################################################################################
# SECTION 4: CLASSICAL MACHINE LEARNING MODELS
################################################################################

def train_svm_classifier(X_train, y_train, kernel='rbf', C=1.0):
    """
    Support Vector Machine with RBF kernel
    Good for non-linear decision boundaries
    """
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X_train)

    clf = SVC(kernel=kernel, C=C, class_weight='balanced', probability=True)
    clf.fit(X_scaled, y_train)

    return clf, scaler

def train_random_forest(X_train, y_train, n_estimators=100):
    """
    Random Forest Classifier
    Ensemble method, robust to overfitting
    """
    clf = RandomForestClassifier(n_estimators=n_estimators, class_weight='balanced', random_state=42)
    clf.fit(X_train, y_train)
    return clf

def train_gradient_boosting(X_train, y_train, n_estimators=100):
    """
    Gradient Boosting Classifier with manual sample weighting
    (GradientBoosting doesn't support class_weight parameter in older versions)
    """
    print(f"\n  Training Gradient Boosting (n_estimators={n_estimators})...")

    # Calculate sample weights manually
    # Weight = Total / (n_classes * Count)
    n_samples = len(y_train)
    n_classes = 2
    count_0 = np.sum(y_train == 0)
    count_1 = np.sum(y_train == 1)

    w0 = n_samples / (n_classes * count_0)
    w1 = n_samples / (n_classes * count_1)

    weights = np.zeros(n_samples)
    weights[y_train == 0] = w0
    weights[y_train == 1] = w1

    clf = GradientBoostingClassifier(n_estimators=n_estimators, random_state=42)
    clf.fit(X_train, y_train, sample_weight=weights)
    return clf

def compare_all_classical_models(X_train, y_train, X_test, y_test):
    """
    Train and compare all classical ML models
    Returns performance comparison
    """
    results = {}
    models = {}

    # Define models to train
    model_list = [
        ('LDA', LinearDiscriminantAnalysis(solver='lsqr', shrinkage='auto')),
        ('Logistic Regression', LogisticRegression(class_weight='balanced', max_iter=1000)),
        ('Random Forest', RandomForestClassifier(n_estimators=100, class_weight='balanced', random_state=42)),
        ('SVM (RBF)', SVC(kernel='rbf', class_weight='balanced', probability=True))
    ]

    # Special handling for scaling (SVM needs it)
    scaler = StandardScaler()
    X_train_s = scaler.fit_transform(X_train)
    X_test_s = scaler.transform(X_test)

    for name, clf in model_list:
        print(f"Training {name}...")

        # Training
        start_time = time.time()
        if name == 'SVM (RBF)':
            clf.fit(X_train_s, y_train)
            curr_X_test = X_test_s
        else:
            clf.fit(X_train, y_train)
            curr_X_test = X_test

        # Inference
        y_pred = clf.predict(curr_X_test)

        # Metrics
        acc = accuracy_score(y_test, y_pred)
        f1 = f1_score(y_test, y_pred)

        results[name] = {'Accuracy': acc, 'F1': f1}
        models[name] = clf
        print(f"  -> Accuracy: {acc:.4f}, F1: {f1:.4f}, Time: {time.time()-start_time:.2f}s")

    # Gradient Boosting (Separate due to weighting logic)
    print("Training Gradient Boosting...")
    gb_model = train_gradient_boosting(X_train, y_train)
    y_pred_gb = gb_model.predict(X_test)
    results['Gradient Boosting'] = {
        'Accuracy': accuracy_score(y_test, y_pred_gb),
        'F1': f1_score(y_test, y_pred_gb)
    }
    models['Gradient Boosting'] = gb_model

    # Summary table
    print("\n--- Model Comparison Summary ---")
    print(f"{'Model':<20} | {'Accuracy':<10} | {'F1-Score':<10}")
    print("-" * 46)
    for name, metrics in results.items():
        print(f"{name:<20} | {metrics['Accuracy']:.4f}     | {metrics['F1']:.4f}")

    return results, models

def save_model(model, filepath):
    """Save model to pickle file"""
    with open(filepath, 'wb') as f:
        pickle.dump(model, f)
    print(f"\n  Model saved to: {filepath}")

def load_model(filepath):
    """Load model from pickle file"""
    with open(filepath, 'rb') as f:
        model = pickle.load(f)
    print(f"\n  Model loaded from: {filepath}")
    return model

In [17]:
# ========================================================================
# MAIN EXECUTION
# ========================================================================
if __name__ == "__main__":
    # ========================================================================
    # STEP 1: LOAD DATA
    # ========================================================================
    print("\n" + "="*70)
    print("STEP 1: LOADING DATA")
    print("="*70)

    from google.colab import drive
    try:
        drive.mount('/content/drive')
        DATA_PATH = '/content/drive/Othercomputers/My Mac/Downloads/BCI_Comp_III_Wads_2004/'
    except:
        # Fallback if not in Colab or drive mounting fails (for local testing)
        print("Drive mount failed or skipped. Using local path './Dataset/'")
        DATA_PATH = './Dataset/'

    # Ensure these files exist in the path, otherwise this will fail
    try:
        train_data_A = load_data('/content/drive/Othercomputers/My Mac/Downloads/BCI_Comp_III_Wads_2004/Subject_A_Train.mat')
        test_data_A = load_data('/content/drive/Othercomputers/My Mac/Downloads/BCI_Comp_III_Wads_2004/Subject_A_Test.mat')
        train_data_B = load_data('/content/drive/Othercomputers/My Mac/Downloads/BCI_Comp_III_Wads_2004/Subject_B_Train.mat')
        test_data_B = load_data('/content/drive/Othercomputers/My Mac/Downloads/BCI_Comp_III_Wads_2004/Subject_B_Test.mat')
        print("Data loaded successfully.")
    except Exception as e:
        print(f"Error loading data: {e}")
        # Stop execution if data load fails
        exit()

    # ========================================================================
    # STEP 2: PREPROCESSING
    # ========================================================================
    print("\n" + "="*70)
    print("STEP 2: PREPROCESSING")
    print("="*70)

    print("\n--- Subject A ---")
    train_proc_A = preprocess_pipeline(train_data_A)
    test_proc_A = preprocess_pipeline(test_data_A)

    print("\n--- Subject B ---")
    train_proc_B = preprocess_pipeline(train_data_B)
    test_proc_B = preprocess_pipeline(test_data_B)

    # ========================================================================
    # STEP 3: EPOCH EXTRACTION
    # ========================================================================
    print("\n" + "="*70)
    print("STEP 3: EPOCH EXTRACTION")
    print("="*70)

    print("\n--- Subject A ---")
    train_epochs_A = extract_epochs(train_proc_A)
    test_epochs_A = extract_epochs(test_proc_A)

    print("\n--- Subject B ---")
    train_epochs_B = extract_epochs(train_proc_B)
    test_epochs_B = extract_epochs(test_proc_B)

    # ========================================================================
    # STEP 4: VISUALIZE ERP
    # ========================================================================
    print("\n" + "="*70)
    print("STEP 4: VISUALIZING ERP RESPONSES")
    print("="*70)

    print("\n--- Subject A ---")
    plot_erp_responses(train_epochs_A, channel_idx=31)

    # ========================================================================
    # STEP 5: FEATURE EXTRACTION
    # ========================================================================
    print("\n" + "="*70)
    print("STEP 5: FEATURE EXTRACTION")
    print("="*70)

    # ========================================================================
    # Subject A: Compare PCA vs CSP vs Time-Domain
    # ========================================================================
    print("\n--- Subject A: Feature Comparison ---")

    # Split A for validation within training set to compare features
    X_full = train_epochs_A['epochs']
    y_full = train_epochs_A['labels']

    # Use a small subset to tune feature method
    X_tr_sub, X_val_sub, y_tr_sub, y_val_sub = train_test_split(
        X_full, y_full, test_size=0.2, stratify=y_full, random_state=42
    )

    epoch_data_tr = {'epochs': X_tr_sub, 'labels': y_tr_sub}
    epoch_data_val = {'epochs': X_val_sub, 'labels': y_val_sub}

    # Try PCA (20 components)
    print("Extracting PCA-20...")
    X_pca20_tr, pca20 = extract_features(epoch_data_tr, method='pca', n_components=20)
    X_pca20_val = pca20.transform(X_val_sub.reshape(len(X_val_sub), -1))

    # Try PCA (50 components)
    print("Extracting PCA-50...")
    X_pca50_tr, pca50 = extract_features(epoch_data_tr, method='pca', n_components=50)
    X_pca50_val = pca50.transform(X_val_sub.reshape(len(X_val_sub), -1))

    # Try CSP
    print("Extracting CSP...")
    X_csp_tr, csp_filters = extract_features(epoch_data_tr, method='csp', n_components=6)
    csp_obj = CSPTransformer(csp_filters)
    X_csp_val = csp_obj.transform(X_val_sub)

    # Try Raw Time-Domain Features
    print("Extracting Time-Domain...")
    X_time_tr, _ = extract_features(epoch_data_tr, method='time')
    X_time_val, _ = extract_features(epoch_data_val, method='time')

    # Quick comparison with BALANCED LDA
    print("\nEvaluating Feature Sets with LDA...")

    # PCA-20 test
    lda_pca20 = train_lda_classifier(X_pca20_tr, y_tr_sub)
    y_pred_pca20 = lda_pca20.predict(X_pca20_val)
    acc_pca20 = accuracy_score(y_val_sub, y_pred_pca20)
    f1_pca20 = f1_score(y_val_sub, y_pred_pca20)

    # PCA-50 test
    lda_pca50 = train_lda_classifier(X_pca50_tr, y_tr_sub)
    y_pred_pca50 = lda_pca50.predict(X_pca50_val)
    acc_pca50 = accuracy_score(y_val_sub, y_pred_pca50)
    f1_pca50 = f1_score(y_val_sub, y_pred_pca50)

    # CSP test
    lda_csp = train_lda_classifier(X_csp_tr, y_tr_sub)
    y_pred_csp = lda_csp.predict(X_csp_val)
    acc_csp = accuracy_score(y_val_sub, y_pred_csp)
    f1_csp = f1_score(y_val_sub, y_pred_csp)

    # Time-Domain test
    lda_time = train_lda_classifier(X_time_tr, y_tr_sub)
    y_pred_time = lda_time.predict(X_time_val)
    acc_time = accuracy_score(y_val_sub, y_pred_time)
    f1_time = f1_score(y_val_sub, y_pred_time)

    # ========================================================================
    print("\n" + "="*70)
    print("FEATURE COMPARISON (Balanced Classifiers)")
    print("="*70)
    print(f"PCA (20 comp):      Accuracy={acc_pca20:.4f}, F1={f1_pca20:.4f}")
    print(f"PCA (50 comp):      Accuracy={acc_pca50:.4f}, F1={f1_pca50:.4f}")
    print(f"CSP (6 comp):       Accuracy={acc_csp:.4f}, F1={f1_csp:.4f}")
    print(f"Time-Domain:        Accuracy={acc_time:.4f}, F1={f1_time:.4f}")

    # Select best method based on F1-score
    scores = {'pca': f1_pca20, 'csp': f1_csp, 'time': f1_time}
    feature_method_A = max(scores, key=scores.get)
    print(f"\nSelected Feature Method: {feature_method_A.upper()}")

    # Re-extract chosen features on FULL training set
    print("Re-extracting features for full dataset...")
    n_components_A = 20 # Default
    feature_obj_A = None # To store PCA or CSP object

    if feature_method_A == 'time':
        X_train_full_A, _ = extract_features(train_epochs_A, method='time')
    elif feature_method_A == 'pca':
        X_train_full_A, feature_obj_A = extract_features(train_epochs_A, method='pca', n_components=20)
        pca_A = feature_obj_A # save reference
    else: # CSP
        X_train_full_A, feature_obj_A = extract_features(train_epochs_A, method='csp', n_components=6)

    # ========================================================================
    # Subject A: Create final train/val split for later steps
    # ========================================================================
    X_train_A, X_val_A, y_train_A, y_val_A = train_test_split(
        X_train_full_A, train_epochs_A['labels'], test_size=0.2, random_state=42, stratify=train_epochs_A['labels']
    )
    print(f"\nSubject A splits: Training={len(X_train_A)}, Validation={len(X_val_A)}")

    # Transform test data
    if feature_method_A == 'time':
        X_test_A = test_epochs_A['epochs'].reshape(len(test_epochs_A['epochs']), -1)
    elif feature_method_A == 'pca':
        X_test_A = feature_obj_A.transform(test_epochs_A['epochs'].reshape(test_epochs_A['epochs'].shape[0], -1))
    else:  # CSP
        X_test_A = CSPTransformer(feature_obj_A).transform(test_epochs_A['epochs'])

    print(f"Test features: {X_test_A.shape}")

    # ========================================================================
    # Subject B: Use same method as Subject A
    # ========================================================================
    print("\n--- Subject B: Feature Extraction ---")
    print(f"\nUsing {feature_method_A.upper()} (same as Subject A)...")

    # Define PCA_B variable for saving later
    pca_B = None

    if feature_method_A == 'time':
        X_train_full_B, _ = extract_features(train_epochs_B, method='time')
        X_test_B = test_epochs_B['epochs'].reshape(len(test_epochs_B['epochs']), -1)
    elif feature_method_A == 'pca':
        X_train_full_B, feature_obj_B = extract_features(train_epochs_B, method='pca', n_components=20)
        X_test_B = feature_obj_B.transform(test_epochs_B['epochs'].reshape(test_epochs_B['epochs'].shape[0], -1))
        pca_B = feature_obj_B
    else:  # CSP
        X_train_full_B, feature_obj_B = extract_features(train_epochs_B, method='csp', n_components=6)
        X_test_B = CSPTransformer(feature_obj_B).transform(test_epochs_B['epochs'])

    X_train_B, X_val_B, y_train_B, y_val_B = train_test_split(
        X_train_full_B, train_epochs_B['labels'], test_size=0.2, random_state=42, stratify=train_epochs_B['labels']
    )
    print(f"Subject B splits: Training={len(X_train_B)}, Validation={len(X_val_B)}, Test features: {X_test_B.shape}")

    # ========================================================================
    # STEP 6: BASELINE CLASSIFIERS
    # ========================================================================
    print("\n" + "="*70)
    print("STEP 6: BASELINE CLASSIFIERS (Subject A)")
    print("="*70)

    lda_A = train_lda_classifier(X_train_A, y_train_A)
    acc_lda = evaluate_classifier(lda_A, X_val_A, y_val_A, "LDA")

    lr_A = train_logistic_regression(X_train_A, y_train_A)
    acc_lr = evaluate_classifier(lr_A, X_val_A, y_val_A, "Logistic Regression")

    # ========================================================================
    # STEP 7: CLASSICAL ML MODELS
    # ========================================================================
    print("\n" + "="*70)
    print("STEP 7: CLASSICAL MACHINE LEARNING (Subject A)")
    print("="*70)

    results_classical_A, models_A = compare_all_classical_models(
        X_train_A, y_train_A, X_val_A, y_val_A
    )

    # Train SVM for both subjects (best model assumption for final export)
    # Re-train on full train data? Usually yes, but here we just use the split for simplicity
    print("\nTraining Final SVMs...")
    svm_A, scaler_A = train_svm_classifier(X_train_A, y_train_A)
    svm_B, scaler_B = train_svm_classifier(X_train_B, y_train_B)

    # ========================================================================
    # STEP 8: EXPORT MODELS
    # ========================================================================
    print("\n" + "="*70)
    print("STEP 8: EXPORTING MODELS")
    print("="*70)

    import os
    os.makedirs('models', exist_ok=True)

    # Save pickle
    # Note: feature_obj_A might be None if 'time' method was selected,
    # but we save what we have. If PCA was used, we save it.
    save_model({
        'model': svm_A,
        'scaler': scaler_A,
        'pca': feature_obj_A if feature_method_A == 'pca' else None
    }, 'models/subject_A_svm.pkl')

    save_model({
        'model': svm_B,
        'scaler': scaler_B,
        'pca': pca_B if feature_method_A == 'pca' else None
    }, 'models/subject_B_svm.pkl')



STEP 1: LOADING DATA
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Data loaded successfully.

STEP 2: PREPROCESSING

--- Subject A ---

--- Subject B ---

STEP 3: EPOCH EXTRACTION

--- Subject A ---

--- Subject B ---

STEP 4: VISUALIZING ERP RESPONSES

--- Subject A ---
Test data provided (no labels). Cannot plot ERP comparison.

STEP 5: FEATURE EXTRACTION

--- Subject A: Feature Comparison ---


ValueError: With n_samples=0, test_size=0.2 and train_size=None, the resulting train set will be empty. Adjust any of the aforementioned parameters.