In [None]:
__author__ = "Himon Thakur"
__credits__ = ["Himon Thakur"]
__license__ = "Apache 2.0"
__version__ = "1.0.1"
__maintainer__ = "Himon Thakur"
__email__ = "hthakur@uccs.edu"
__status__ = "Prototype"

In [None]:
import mne
import numpy as np
import pandas as pd
from mne.io import read_raw_eeglab
from mne.preprocessing import ICA
from autoreject import AutoReject
import os
from mne.decoding import CSP
from scipy.fft import fft
import matplotlib.pyplot as plt
from mne.decoding import Scaler

import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
from sklearn.model_selection import LeaveOneGroupOut
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score, f1_score
import copy
from tqdm import tqdm
import json
from datetime import datetime
import glob

In [None]:
# Configuration
root_dir = 'dataset'
output_dir = 'processed_data'
subjects = [f"{i:02d}" for i in range(1, 30)]  # List of subjects to process
sessions = ['S1', 'S2', 'S3']  
montage = 'standard_1005'  
task_type = 'nback'  

In [None]:
def load_raw_data(subject, session, task_files=None):
    sub_dir = os.path.join(root_dir, f'sub-{subject}', f'ses-{session}', 'eeg')
    raw_list = []
    
    if task_files is None:
        if task_type == 'nback':
            task_files = ['zeroBACK', 'oneBACK', 'twoBACK']
        elif task_type == 'matb':
            task_files = ['MATBdiff']
        else:
            raise ValueError(f"Unknown task type: {task_type}")
    
    print(f"Loading data for subject {subject}, session {session}")
    
    for task_file in task_files:
        fpath = os.path.join(sub_dir, f"{task_file}.set")
        if not os.path.exists(fpath):
            print(f"Missing file: {fpath}")
            continue
            
        try:
            print(f"Loading {task_file}.set")
            raw = read_raw_eeglab(fpath, preload=True)
            raw.set_montage(montage)
            
            raw.annotations.append(
                onset=0,
                duration=raw.times[-1],
                description=f"FILE_{task_file.upper()}"
            )
            
            raw_list.append(raw)
        except Exception as e:
            print(f"Error loading {task_file}: {str(e)}")
            continue
    
    if not raw_list:
        raise FileNotFoundError(f"No valid EEG files found for sub-{subject}, ses-{session}")
    
    print(f"Concatenating {len(raw_list)} files")
    raw_concat = mne.concatenate_raws(raw_list)
    
    return raw_concat

In [None]:
def preprocess_raw_data(raw):
    print("Preprocessing raw data")
    
    raw_clean = raw.copy()
    
    # Apply filters
    print("Applying filters")
    raw_clean.filter(1.0, 40.0, fir_design='firwin')
    raw_clean.notch_filter(np.arange(50, 125, 50))
    
    # Apply ICA for artifact removal
    print("Applying ICA for artifact removal")
    ica = ICA(n_components=20, random_state=97)
    ica.fit(raw_clean)
    
    # Find components related to eye movements/blinks
    eog_indices = []
    
    # Look for EOG channels
    eog_channels = [ch for ch in raw_clean.ch_names if 'EOG' in ch.upper()]
    
    if eog_channels:
        print(f"Using EOG channels: {eog_channels}")
        for ch_name in eog_channels:
            indices, scores = ica.find_bads_eog(raw_clean, ch_name=ch_name)
            eog_indices.extend(indices)
    else:
        frontals = ['Fp1', 'Fp2', 'F7', 'F8']
        frontal_chs = [ch for ch in frontals if ch in raw_clean.ch_names]
        
        if frontal_chs:
            print(f"No EOG channels found. Using frontal channels: {frontal_chs}")
            for ch in frontal_chs:
                eog_indices.extend(ica.find_bads_eog(raw_clean, ch_name=ch)[0])
    
    if not eog_indices and ica.n_components_ > 0:
        eog_indices = [0]
        print("No EOG components found automatically. Excluding first component.")
    
    eog_indices = list(set(eog_indices))
    
    print(f"Excluding ICA components: {eog_indices}")
    ica.exclude = eog_indices
    ica.apply(raw_clean)
    
    print("Setting average reference")
    raw_clean.set_eeg_reference('average')
    
    return raw_clean

In [None]:
def create_epochs(raw, tmin=-0.2, tmax=1.0, baseline=(None, 0)):
    print("Creating epochs")
    
    # Extract events from annotations
    events, event_id = mne.events_from_annotations(raw)
    print(f"Found event types: {event_id}")
    
    if task_type == 'nback':
        nback_files = [desc for desc in raw.annotations.description if 'FILE_' in desc]
        print(f"Found n-back files: {nback_files}")
        
        task_event_id = {}
        
        for event_name, event_code in event_id.items():
            event_lower = str(event_name).lower()
            
            # Look for trial markers
            if ('zeroback' in event_lower or '0back' in event_lower) and ('trial' in event_lower or 'onset' in event_lower):
                task_event_id['0back'] = event_code
            elif ('oneback' in event_lower or '1back' in event_lower) and ('trial' in event_lower or 'onset' in event_lower):
                task_event_id['1back'] = event_code
            elif ('twoback' in event_lower or '2back' in event_lower) and ('trial' in event_lower or 'onset' in event_lower):
                task_event_id['2back'] = event_code
            
            # Check for more specific trial markers based on the triggerlist.txt
            if event_name in ['6021', '6022', '6023']:  # ZEROBACK Trial markers
                task_event_id['0back'] = event_code
            elif event_name in ['6121', '6122', '6123']:  # ONEBACK Trial markers
                task_event_id['1back'] = event_code
            elif event_name in ['6221', '6222', '6223']:  # TWOBACK Trial markers
                task_event_id['2back'] = event_code
        
        if not task_event_id:
            print("No specific task events found. Using trial onset markers.")
            for event_name, event_code in event_id.items():
                try:
                    event_str = str(event_name)
                    
                    if event_str.isdigit():
                        code = int(event_str)
                        # ZEROBACK Trial onset markers
                        if code in [6021, 6022, 6023]:
                            task_event_id['0back'] = event_code
                        # ONEBACK Trial onset markers
                        elif code in [6121, 6122, 6123]:
                            task_event_id['1back'] = event_code
                        # TWOBACK Trial onset markers
                        elif code in [6221, 6222, 6223]:
                            task_event_id['2back'] = event_code
                except:
                    continue
        
        if not task_event_id:
            print("Still no task events found. Filtering out non-trial events.")
            task_event_id = {name: code for name, code in event_id.items() 
                            if str(name) not in ['boundary', 'FILE_ZEROBACK', 'FILE_ONEBACK', 'FILE_TWOBACK']}
        
        print(f"Using event dictionary: {task_event_id}")
    else:
        task_event_id = {name: code for name, code in event_id.items() 
                        if 'boundary' not in str(name) and 'FILE_' not in str(name)}
    
    # Create epochs - handle repeated events
    epochs = mne.Epochs(
        raw, events, event_id=task_event_id,
        tmin=tmin, tmax=tmax, baseline=baseline,
        preload=True, on_missing='warn',
        event_repeated='drop'  # Handle repeated events by dropping them
    )
    
    print(f"Created {len(epochs)} epochs")
    
    # Apply AutoReject to clean epochs
    print("Applying AutoReject for automatic artifact rejection")
    ar = AutoReject(n_interpolate=[1, 2, 4], random_state=42)
    epochs_clean, reject_log = ar.fit_transform(epochs, return_log=True)
    
    print(f"Kept {len(epochs_clean)}/{len(epochs)} epochs after rejection")
    
    return epochs_clean

In [None]:
def extract_features(epochs):
    print("Extracting features")
    
    # Initialize feature dictionary
    features = {
        'label': epochs.events[:, -1],
        'csp': np.full((len(epochs), 4), np.nan),
        'psd': [],
        'fft': []
    }

    # 1. CSP Features
    print("Extracting CSP features")
    unique_labels = np.unique(features['label'])
    
    if len(unique_labels) >= 2:
        try:
            epochs_csp = epochs.copy().pick_types(eeg=True)
            
            sel = np.concatenate([
                np.where(epochs_csp.events[:, -1] == unique_labels[0])[0],
                np.where(epochs_csp.events[:, -1] == unique_labels[-1])[0]
            ])
            
            valid_sel = [idx for idx in sel if idx < len(epochs_csp)]
            
            if len(valid_sel) > 5:
                # Apply CSP
                csp = CSP(n_components=4, reg=None, log=True)
                X_csp = csp.fit_transform(epochs_csp[valid_sel].get_data(), 
                                         epochs_csp[valid_sel].events[:, -1])
                
                # Store CSP features
                for i, idx in enumerate(valid_sel):
                    features['csp'][idx] = X_csp[i]
                print(f"Extracted CSP features from {len(valid_sel)} epochs")
        except Exception as e:
            print(f"Error computing CSP features: {str(e)}")
    
    # 2. PSD Features
    print("Extracting PSD features")
    # Define frequency bands (Hz)
    freq_bands = [(4, 8), (8, 13), (13, 30), (30, 40)]  # theta, alpha, beta, gamma
    
    # Get all epochs data
    X = epochs.get_data()  # shape: (n_epochs, n_channels, n_times)
    
    # Process each epoch
    for epoch_idx in range(X.shape[0]):
        print(f"Processing epoch {epoch_idx+1}/{X.shape[0]}", end='\r')
        epoch = X[epoch_idx]
        
        # Manual Z-score normalization for each channel
        epoch_norm = np.zeros_like(epoch)
        for ch_idx in range(epoch.shape[0]):
            ch_data = epoch[ch_idx]
            ch_mean = np.mean(ch_data)
            ch_std = np.std(ch_data)
            # Avoid division by zero
            if ch_std > 0:
                epoch_norm[ch_idx] = (ch_data - ch_mean) / ch_std
            else:
                epoch_norm[ch_idx] = ch_data - ch_mean
        
        # Calculate PSD using Welch's method
        psd, freqs = mne.time_frequency.psd_array_welch(
            epoch_norm, sfreq=epochs.info['sfreq'],
            fmin=1, fmax=40, n_fft=256, n_overlap=128
        )
        
        # Calculate band power for each frequency band
        band_powers = []
        for fmin, fmax in freq_bands:
            # Find frequencies within the band
            freq_mask = (freqs >= fmin) & (freqs <= fmax)
            
            band_power = np.log10(psd[:, freq_mask].mean(axis=1) + 1e-10)
            band_powers.append(band_power)
        
        features['psd'].append(band_powers)

    # 3. FFT Features
    print("\nExtracting FFT features")
    for epoch_idx, epoch in enumerate(X):
        print(f"Processing epoch {epoch_idx+1}/{X.shape[0]}", end='\r')
        # Apply Hanning window
        window = np.hanning(epoch.shape[1])
        fft_vals = np.abs(fft(epoch * window[np.newaxis, :], axis=1))
        
        # Calculate frequency axis
        freqs = np.fft.fftfreq(epoch.shape[1], 1/epochs.info['sfreq'])
        mask = (freqs >= 0) & (freqs <= 40)  # Keep only 0-40 Hz
        
        # For each channel, extract: peak frequency, peak amplitude, mean amplitude
        channel_features = []
        for ch_idx in range(epoch.shape[0]):
            ch_fft = fft_vals[ch_idx, mask]
            ch_freqs = freqs[mask]
            
            # Find peak frequency
            peak_idx = np.argmax(ch_fft)
            peak_freq = ch_freqs[peak_idx]
            peak_amp = ch_fft[peak_idx]
            mean_amp = np.mean(ch_fft)
            
            channel_features.append([peak_freq, peak_amp, mean_amp])
        
        features['fft'].append(channel_features)
    
    print("\nConverting features to DataFrame")
    
    # Convert features to a DataFrame
    feature_dict = {
        'label': features['label']
    }
    
    # Add CSP features
    for i in range(4):
        feature_dict[f'csp_{i}'] = features['csp'][:, i]
    
    # Add PSD features - one column for each channel and frequency band
    psd_array = np.array(features['psd'])
    for band_idx, (fmin, fmax) in enumerate(freq_bands):
        for ch in range(psd_array.shape[2]):  # For each channel
            feature_dict[f'psd_{fmin}_{fmax}_ch{ch}'] = psd_array[:, band_idx, ch]
    
    # Add FFT features - one column for each channel and feature type
    fft_array = np.array(features['fft'])
    for ch in range(fft_array.shape[1]):  # For each channel
        feature_dict[f'fft_peak_freq_ch{ch}'] = fft_array[:, ch, 0]
        feature_dict[f'fft_peak_amp_ch{ch}'] = fft_array[:, ch, 1]
        feature_dict[f'fft_mean_amp_ch{ch}'] = fft_array[:, ch, 2]
    
    # Create DataFrame
    df = pd.DataFrame(feature_dict)
    
    # Map event codes to class labels
    label_mapping = {}
    for name, code in epochs.event_id.items():
        label_mapping[code] = name
    
    df['class'] = df['label'].map(label_mapping)
    
    return df

In [None]:
def validate_features(df):
    print("Validating features")
    
    numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
    
    # Check for NaN values in numeric columns
    nan_cols = [col for col in numeric_cols if df[col].isna().any()]
    if nan_cols:
        print(f"Found NaN values in columns: {nan_cols}")
        
        # For CSP features, fill with zeros
        csp_cols = [col for col in nan_cols if 'csp' in col]
        if csp_cols:
            print(f"Filling NaN values in CSP columns with zeros")
            df[csp_cols] = df[csp_cols].fillna(0)
        
        # For other features, use column median
        other_cols = [col for col in nan_cols if col not in csp_cols]
        if other_cols:
            print(f"Filling NaN values in other columns with column median")
            for col in other_cols:
                df[col] = df[col].fillna(df[col].median())
    
    # Check for infinite values in numeric columns
    for col in numeric_cols:
        # Check for inf values
        mask = np.isinf(df[col].values)
        if mask.any():
            print(f"Found infinite values in column: {col}")
            # Get finite values to calculate replacement values
            finite_values = df[col][~mask]
            
            if len(finite_values) > 0:
                max_val = finite_values.max()
                min_val = finite_values.min()
                
                # Replace inf with max and -inf with min
                df.loc[df[col] == np.inf, col] = max_val
                df.loc[df[col] == -np.inf, col] = min_val
            else:
                # If no finite values, replace with 0
                df[col] = df[col].replace([np.inf, -np.inf], 0)
    
    # Check class balance
    class_counts = df['class'].value_counts()
    print(f"Class distribution:\n{class_counts}")
    
    if len(class_counts) < 2:
        print("Warning: Only one class found in the data")
    
    return df

In [None]:
def plot_feature_distributions(df, output_path):
    print(f"Plotting feature distributions to {output_path}")
    
    # Ensure output directory exists
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    # Plot PSD features
    plt.figure(figsize=(12, 8))
    psd_cols = [col for col in df.columns if 'psd' in col][:10]  # First 10 PSD features
    
    for i, col in enumerate(psd_cols):
        plt.subplot(2, 5, i+1)
        for class_name in df['class'].unique():
            plt.hist(df[df['class'] == class_name][col], bins=20, alpha=0.5, label=class_name)
        plt.title(col, fontsize=8)
        plt.legend(fontsize=6)
    
    plt.tight_layout()
    plt.savefig(output_path)
    plt.close()

In [None]:
def process_subject_session(subject, session):
    print(f"\n=== Processing subject {subject}, session {session} ===")
    
    # Create output directories
    subject_output_dir = os.path.join(output_dir, f"sub-{subject}", f"ses-{session}")
    os.makedirs(subject_output_dir, exist_ok=True)
    
    try:
        # 1. Load raw data
        raw = load_raw_data(subject, session)
        
        # 2. Preprocess
        raw_clean = preprocess_raw_data(raw)
        
        # Optionally save preprocessed raw data
        raw_clean_file = os.path.join(subject_output_dir, "raw_clean.fif")
        raw_clean.save(raw_clean_file, overwrite=True)
        print(f"Saved preprocessed raw data to {raw_clean_file}")
        
        # 3. Create epochs
        epochs = create_epochs(raw_clean)
        
        # Save epochs
        epochs_file = os.path.join(subject_output_dir, "epochs.fif")
        epochs.save(epochs_file, overwrite=True)
        print(f"Saved epochs to {epochs_file}")
        
        # 4. Extract features
        features_df = extract_features(epochs)
        
        # 5. Validate features
        features_df = validate_features(features_df)
        
        # Save features
        features_dir = os.path.join(subject_output_dir, "features")
        os.makedirs(features_dir, exist_ok=True)
        
        features_file = os.path.join(features_dir, "features.csv")
        features_df.to_csv(features_file, index=False)
        print(f"Saved features to {features_file}")
        
        # 6. Plot feature distributions
        plot_path = os.path.join(features_dir, "feature_distributions.png")
        plot_feature_distributions(features_df, plot_path)
        
        return epochs, features_df
        
    except Exception as e:
        print(f"Error processing subject {subject}, session {session}: {str(e)}")
        import traceback
        traceback.print_exc()
        return None, None


In [None]:
print(f"Processing {len(subjects)} subjects, {len(sessions)} sessions")

# Create main output directory
os.makedirs(output_dir, exist_ok=True)

# Process each subject and session
results = {}

for subject in subjects:
    results[subject] = {}
    
    for session in sessions:
        print(f"\nProcessing subject {subject}, session {session}")
        
        epochs, features_df = process_subject_session(subject, session)
        
        if epochs is not None and features_df is not None:
            results[subject][session] = {
                "epochs_count": len(epochs),
                "features_count": len(features_df),
                "class_counts": features_df['class'].value_counts().to_dict()
            }

In [None]:
summary_file = os.path.join(output_dir, "processing_summary.txt")
with open(summary_file, 'w') as f:
    f.write("COG-BCI Processing Summary\n")
    f.write("========================\n\n")
    
    for subject, sessions_data in results.items():
        f.write(f"Subject {subject}:\n")
        
        for session, data in sessions_data.items():
            f.write(f"  Session {session}:\n")
            f.write(f"    Epochs: {data['epochs_count']}\n")
            f.write(f"    Features: {data['features_count']}\n")
            f.write(f"    Class distribution: {data['class_counts']}\n")
        
        f.write("\n")

print(f"Processing complete. Summary saved to {summary_file}")

In [None]:
# Transformer model for EEG classification
class TransformerEncoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
        super(TransformerEncoderBlock, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(ff_dim, embed_dim),
            nn.Dropout(dropout)
        )
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        # Multi-head self-attention (seq_len, batch, features)
        x_tmp = x.permute(1, 0, 2)  # (seq_len, batch, features)
        attn_out, _ = self.attention(x_tmp, x_tmp, x_tmp)
        attn_out = attn_out.permute(1, 0, 2)  # (batch, seq_len, features)
        
        # Add & Norm
        out1 = self.norm1(x + self.dropout(attn_out))
        
        # Feed Forward
        ff_out = self.ff(out1)
        
        # Add & Norm
        out2 = self.norm2(out1 + ff_out)
        
        return out2

In [None]:
class EEGTransformer(nn.Module):
    def __init__(self, input_dim, num_classes, seq_len=10, embed_dim=64, 
                 num_heads=4, ff_dim=128, num_transformer_blocks=2, dropout=0.2):
        super(EEGTransformer, self).__init__()
        
        self.input_dim = input_dim
        self.seq_len = seq_len
        self.embed_dim = embed_dim
        
        if input_dim < seq_len:
            self.features_per_seq = 1
            self.seq_len = input_dim
        else:
            self.features_per_seq = input_dim // seq_len
            if input_dim % seq_len != 0:
                self.features_per_seq += 1
                
        print(f"Creating model with input_dim={input_dim}, seq_len={self.seq_len}, "
              f"features_per_seq={self.features_per_seq}, embed_dim={embed_dim}")
        
        self.input_projection = nn.Linear(self.features_per_seq, embed_dim)
        
        # Positional embedding
        self.pos_embedding = nn.Parameter(torch.randn(1, self.seq_len, embed_dim))
        
        # Transformer blocks
        self.transformer_blocks = nn.ModuleList([
            TransformerEncoderBlock(embed_dim, num_heads, ff_dim, dropout)
            for _ in range(num_transformer_blocks)
        ])
        
        # Global average pooling
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(ff_dim, num_classes)
        )
        
    def forward(self, x):
        batch_size = x.shape[0]
        
        if x.shape[1] != self.input_dim:
            print(f"Warning: Input dimension mismatch. Expected {self.input_dim}, got {x.shape[1]}")
            self.input_dim = x.shape[1]
            
            if self.input_dim < self.seq_len:
                self.seq_len = self.input_dim
                self.features_per_seq = 1
            else:
                self.features_per_seq = self.input_dim // self.seq_len
                if self.input_dim % self.seq_len != 0:
                    self.features_per_seq += 1
            
            new_projection = nn.Linear(self.features_per_seq, self.embed_dim).to(x.device)
            if hasattr(self, 'input_projection'):
                if self.input_projection.weight.shape[1] == new_projection.weight.shape[1]:
                    with torch.no_grad():
                        new_projection.weight.copy_(self.input_projection.weight)
                        new_projection.bias.copy_(self.input_projection.bias)
            self.input_projection = new_projection
            
            if hasattr(self, 'pos_embedding') and self.pos_embedding.shape[1] != self.seq_len:
                new_pos_embedding = nn.Parameter(torch.randn(1, self.seq_len, self.embed_dim).to(x.device))
                if self.pos_embedding.shape[1] < self.seq_len:
                    with torch.no_grad():
                        new_pos_embedding[0, :self.pos_embedding.shape[1], :] = self.pos_embedding
                else:
                    with torch.no_grad():
                        new_pos_embedding = nn.Parameter(self.pos_embedding[:, :self.seq_len, :])
                self.pos_embedding = new_pos_embedding
        
        if self.seq_len == 1:
            x = x.unsqueeze(1)  # (batch, 1, features)
        else:
            if x.shape[1] < self.seq_len * self.features_per_seq:
                padding = self.seq_len * self.features_per_seq - x.shape[1]
                x = torch.cat([x, torch.zeros(batch_size, padding, device=x.device)], dim=1)
            x = x.reshape(batch_size, self.seq_len, self.features_per_seq)
        
        x = self.input_projection(x)  # (batch, seq_len, embed_dim)
        
        # Add positional embeddings
        x = x + self.pos_embedding
        
        # Apply transformer blocks
        for block in self.transformer_blocks:
            x = block(x)
        
        # Global pooling across sequence dimension
        x = x.permute(0, 2, 1)  # (batch, embed_dim, seq_len)
        x = self.global_pool(x).squeeze(-1)  # (batch, embed_dim)
        
        # Classification
        x = self.classifier(x)
        
        return x

In [None]:
class EEGDataset(Dataset):
    """
    Dataset for EEG features
    """
    def __init__(self, X, y=None, transform=None):
        self.X = torch.FloatTensor(X)
        self.y = torch.LongTensor(y) if y is not None else None
        self.transform = transform
        
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        if self.y is not None:
            return self.X[idx], self.y[idx]
        else:
            return self.X[idx]

In [None]:
def load_features(feature_dir, subjects=None):
    all_features = []
    all_labels = []
    all_subjects = []
    all_sessions = []
    feature_files = []
    
    # Find all feature CSV files
    if subjects is None:
        feature_files = glob.glob(f"{feature_dir}/sub-*/ses-*/features/features.csv")
    else:
        for subject in subjects:
            subj_files = glob.glob(f"{feature_dir}/sub-{subject}/ses-*/features/features.csv")
            feature_files.extend(subj_files)
    
    print(f"Found {len(feature_files)} feature files")
    
    feature_dimensions = []
    
    for file_path in feature_files:
        try:
            df = pd.read_csv(file_path)
            if 'class' not in df.columns or df.shape[0] == 0:
                continue
                
            n_features = df.shape[1] - 2
            feature_dimensions.append(n_features)
        except Exception as e:
            print(f"Error scanning {file_path}: {str(e)}")
    
    if not feature_dimensions:
        print("No valid feature files found")
        return None
        
    feature_dim_counts = pd.Series(feature_dimensions).value_counts()
    most_common_dim = feature_dim_counts.index[0]
    
    print(f"Feature dimensions across files: {dict(feature_dim_counts)}")
    print(f"Using most common dimension: {most_common_dim} features")
    
    for file_path in feature_files:
        try:
            # Extract subject and session from path
            path_parts = file_path.split(os.sep)
            subject_idx = [i for i, p in enumerate(path_parts) if p.startswith('sub-')][0]
            session_idx = [i for i, p in enumerate(path_parts) if p.startswith('ses-')][0]
            
            subject = path_parts[subject_idx].split('-')[1]
            session = path_parts[session_idx].split('-')[1]
            
            # Load features
            df = pd.read_csv(file_path)
            
            if 'class' not in df.columns or df.shape[0] == 0:
                print(f"Skipping {file_path}: No class column or empty data")
                continue
                
            print(f"Loading {file_path}: {df.shape[0]} samples, classes: {df['class'].unique()}")
            
            # Extract features and labels
            X = df.drop(['label', 'class'], axis=1).values
            y = df['class'].values
            
            # Check if feature dimension matches the most common dimension
            if X.shape[1] != most_common_dim:
                print(f"  Warning: Feature dimension mismatch in {file_path}: {X.shape[1]} vs {most_common_dim}")
                
                if X.shape[1] > most_common_dim:
                    # Too many features - select first most_common_dim features
                    print(f"  Selecting first {most_common_dim} features")
                    X = X[:, :most_common_dim]
                else:
                    # Too few features - pad with zeros
                    print(f"  Padding with {most_common_dim - X.shape[1]} zeros")
                    padding = np.zeros((X.shape[0], most_common_dim - X.shape[1]))
                    X = np.hstack((X, padding))
            
            # Store data
            all_features.append(X)
            all_labels.append(y)
            all_subjects.extend([subject] * len(y))
            all_sessions.extend([session] * len(y))
            
        except Exception as e:
            print(f"Error loading {file_path}: {str(e)}")
    
    # Convert to arrays
    if all_features:
        features_data = {
            'X': np.vstack(all_features),
            'y': np.concatenate(all_labels),
            'subjects': np.array(all_subjects),
            'sessions': np.array(all_sessions)
        }
        
        print(f"Loaded {features_data['X'].shape[0]} samples with {features_data['X'].shape[1]} features")
        
        return features_data
    else:
        print("No valid features found")
        return None

In [None]:
def prepare_data_for_model(features_data):
    # Extract features and labels
    X = features_data['X']
    y_str = features_data['y']
    subjects = features_data['subjects']
    
    # Standardize features
    print("Standardizing features")
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)
    
    # Encode string labels to integers
    print("Encoding labels")
    label_encoder = LabelEncoder()
    y = label_encoder.fit_transform(y_str)
    
    print(f"Classes: {label_encoder.classes_}")
    
    return X_scaled, y, subjects, label_encoder

In [None]:
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, 
               num_epochs=100, device='cuda', patience=15, verbose=True):
    
    # Initialize history dictionary
    history = {
        'train_loss': [],
        'val_loss': [],
        'train_acc': [],
        'val_acc': []
    }
    
    # Initialize early stopping variables
    best_val_loss = float('inf')
    best_model_wts = copy.deepcopy(model.state_dict())
    no_improve = 0
    
    # Training loop
    for epoch in range(num_epochs):
        if verbose:
            print(f'Epoch {epoch+1}/{num_epochs}')
            print('-' * 10)
        
        # Training phase
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Zero gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()
            
            # Update statistics
            train_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            train_correct += (predicted == labels).sum().item()
            train_total += labels.size(0)
        
        # Calculate training metrics
        epoch_train_loss = train_loss / train_total
        epoch_train_acc = train_correct / train_total
        history['train_loss'].append(epoch_train_loss)
        history['train_acc'].append(epoch_train_acc)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                
                # Forward pass
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                
                # Update statistics
                val_loss += loss.item() * inputs.size(0)
                _, predicted = torch.max(outputs, 1)
                val_correct += (predicted == labels).sum().item()
                val_total += labels.size(0)
        
        # Calculate validation metrics
        epoch_val_loss = val_loss / val_total
        epoch_val_acc = val_correct / val_total
        history['val_loss'].append(epoch_val_loss)
        history['val_acc'].append(epoch_val_acc)
        
        if verbose:
            print(f'Train loss: {epoch_train_loss:.4f}, acc: {epoch_train_acc:.4f}')
            print(f'Val loss: {epoch_val_loss:.4f}, acc: {epoch_val_acc:.4f}')
        
        # Update learning rate
        scheduler.step(epoch_val_loss)
        
        # Early stopping
        if epoch_val_loss < best_val_loss:
            best_val_loss = epoch_val_loss
            best_model_wts = copy.deepcopy(model.state_dict())
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= patience:
                if verbose:
                    print(f'Early stopping at epoch {epoch+1}')
                break
    
    # Load best model weights
    model.load_state_dict(best_model_wts)
    
    return model, history

In [None]:
def evaluate_model(model, test_loader, device):
    model.eval()
    y_true = []
    y_pred = []
    
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Forward pass
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            
            # Collect results
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(predicted.cpu().numpy())
    
    # Calculate metrics
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    
    accuracy = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred, average='weighted')
    conf_matrix = confusion_matrix(y_true, y_pred)
    
    metrics = {
        'accuracy': accuracy,
        'f1_score': f1,
        'confusion_matrix': conf_matrix,
        'y_true': y_true,
        'y_pred': y_pred
    }
    
    return metrics

In [None]:
def plot_confusion_matrix(cm, class_names, title='Confusion Matrix', cmap=plt.cm.Blues):
    plt.figure(figsize=(10, 8))
    
    # Determine format based on whether the matrix is normalized
    fmt = '.2f' if np.any(np.issubdtype(cm.dtype, np.floating)) else 'd'
    
    sns.heatmap(cm, annot=True, fmt=fmt, cmap=cmap, xticklabels=class_names, yticklabels=class_names)
    plt.title(title)
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    return plt.gcf()

In [None]:
def plot_learning_curves(history):
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train')
    plt.plot(history['val_loss'], label='Validation')
    plt.title('Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(history['train_acc'], label='Train')
    plt.plot(history['val_acc'], label='Validation')
    plt.title('Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    
    plt.tight_layout()
    return plt.gcf()

In [None]:
def train_with_subject_transfer(X, y, subjects, label_encoder, output_dir, 
                               device='cuda', batch_size=32, num_epochs=100, patience=15):
    # Initialize leave-one-subject-out cross-validation
    logo = LeaveOneGroupOut()
    
    # Initialize results dictionary
    results = {
        'accuracy': [],
        'f1_score': [],
        'subject': [],
        'confusion_matrix': [],
        'history': []
    }
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Create log file
    log_file = os.path.join(output_dir, 'training_log.txt')
    with open(log_file, 'w') as f:
        f.write(f"Training started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write(f"Features shape: {X.shape}\n")
        f.write(f"Classes: {label_encoder.classes_}\n")
        f.write(f"Subjects: {np.unique(subjects)}\n")
        f.write("-" * 50 + "\n\n")
    
    # Find unique subject IDs to iterate over, rather than using LOGO
    # This ensures we explicitly control which subject is held out in each fold
    unique_subjects = np.unique(subjects)
    
    # Run cross-validation
    for fold, test_subject in enumerate(unique_subjects):
        print(f"\n=== Fold {fold+1}: Testing on subject {test_subject} ===")
        
        # Manually create train/test split
        test_idx = np.where(subjects == test_subject)[0]
        train_idx = np.where(subjects != test_subject)[0]
        
        # Add to log
        with open(log_file, 'a') as f:
            f.write(f"Fold {fold+1}: Testing on subject {test_subject}\n")
        
        # Create fold directory
        fold_dir = os.path.join(output_dir, f"fold_{fold+1}_subject_{test_subject}")
        os.makedirs(fold_dir, exist_ok=True)
        
        # Split data
        X_train, X_test = X[train_idx], X[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]
        
        # Check if we have enough samples
        if len(X_train) < 10 or len(X_test) < 5:
            print(f"Warning: Not enough samples for subject {test_subject}. Skipping.")
            with open(log_file, 'a') as f:
                f.write(f"  Warning: Not enough samples for subject {test_subject}. Skipping.\n\n")
            continue
            
        # Log input dimensions
        with open(log_file, 'a') as f:
            f.write(f"  Input dimensions - X_train: {X_train.shape}, X_test: {X_test.shape}\n")
        
        # Split training data into train and validation
        np.random.seed(42)
        val_size = min(int(0.2 * len(X_train)), 50)  # Cap validation size
        val_idx = np.random.choice(len(X_train), val_size, replace=False)
        train_mask = np.ones(len(X_train), dtype=bool)
        train_mask[val_idx] = False
        
        X_val = X_train[val_idx]
        y_val = y_train[val_idx]
        X_train_final = X_train[train_mask]
        y_train_final = y_train[train_mask]
        
        # Log dataset splits
        with open(log_file, 'a') as f:
            f.write(f"  Training samples: {len(X_train_final)}\n")
            f.write(f"  Validation samples: {len(X_val)}\n")
            f.write(f"  Test samples: {len(X_test)}\n")
        
        # Create data loaders
        train_dataset = EEGDataset(X_train_final, y_train_final)
        val_dataset = EEGDataset(X_val, y_val)
        test_dataset = EEGDataset(X_test, y_test)
        
        # Adjust batch size if needed
        actual_batch_size = min(batch_size, len(X_train_final), len(X_val), len(X_test))
        if actual_batch_size < batch_size:
            print(f"Warning: Reducing batch size to {actual_batch_size} due to small dataset")
        
        train_loader = DataLoader(train_dataset, batch_size=actual_batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=actual_batch_size, shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=actual_batch_size, shuffle=False)
        
        # Create model
        num_classes = len(np.unique(y))
        input_dim = X_train.shape[1]
        
        # Define sequence length based on input dimension
        seq_len = min(10, max(1, input_dim // 64))
        
        model = EEGTransformer(
            input_dim=input_dim, 
            num_classes=num_classes,
            seq_len=seq_len,
            embed_dim=64,
            num_heads=4,
            ff_dim=128,
            num_transformer_blocks=2,
            dropout=0.2
        ).to(device)
        
        # Log model details
        with open(log_file, 'a') as f:
            f.write(f"  Model input_dim: {input_dim}\n")
            f.write(f"  Model seq_len: {seq_len}\n")
            f.write(f"  Model num_classes: {num_classes}\n")
        
        # Set up loss function, optimizer and scheduler
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=5, min_lr=1e-5, verbose=True
        )
        
        try:
            # Train model
            model, history = train_model(
                model=model,
                train_loader=train_loader,
                val_loader=val_loader,
                criterion=criterion,
                optimizer=optimizer,
                scheduler=scheduler,
                num_epochs=num_epochs,
                device=device,
                patience=patience,
                verbose=True
            )
            
            # Save model
            model_path = os.path.join(fold_dir, "model.pt")
            torch.save(model.state_dict(), model_path)
            
            # Plot and save learning curves
            fig = plot_learning_curves(history)
            fig.savefig(os.path.join(fold_dir, "learning_curves.png"))
            plt.close(fig)
            
            # Save training history
            with open(os.path.join(fold_dir, "history.json"), 'w') as f:
                # Convert numpy arrays to lists
                history_json = {k: [float(val) for val in v] for k, v in history.items()}
                json.dump(history_json, f, indent=4)
            
            # Evaluate model
            metrics = evaluate_model(model, test_loader, device)
            
            # Log metrics
            with open(log_file, 'a') as f:
                f.write(f"  Test accuracy: {metrics['accuracy']:.4f}\n")
                f.write(f"  Test F1 score: {metrics['f1_score']:.4f}\n\n")
            
            # Plot and save confusion matrix
            class_names = label_encoder.classes_
            fig = plot_confusion_matrix(metrics['confusion_matrix'], class_names)
            fig.savefig(os.path.join(fold_dir, "confusion_matrix.png"))
            plt.close(fig)
            
            # Save metrics
            with open(os.path.join(fold_dir, "metrics.json"), 'w') as f:
                metrics_json = {
                    'accuracy': float(metrics['accuracy']),
                    'f1_score': float(metrics['f1_score']),
                    'confusion_matrix': metrics['confusion_matrix'].tolist(),
                    'class_names': class_names.tolist()
                }
                json.dump(metrics_json, f, indent=4)
            
            # Store results
            results['accuracy'].append(metrics['accuracy'])
            results['f1_score'].append(metrics['f1_score'])
            results['subject'].append(test_subject)
            results['confusion_matrix'].append(metrics['confusion_matrix'])
            results['history'].append(history)
            
        except Exception as e:
            print(f"Error in fold {fold+1} (subject {test_subject}): {str(e)}")
            import traceback
            traceback.print_exc()
            
            with open(log_file, 'a') as f:
                f.write(f"  Error: {str(e)}\n\n")
            
            # Continue with next fold
            continue
    
    # Check if we have any results
    if len(results['accuracy']) == 0:
        print("No successful folds completed. Check the logs for errors.")
        return results
    
    # Calculate average results
    mean_accuracy = np.mean(results['accuracy'])
    std_accuracy = np.std(results['accuracy'])
    mean_f1 = np.mean(results['f1_score'])
    std_f1 = np.std(results['f1_score'])
    
    # Log overall results
    with open(log_file, 'a') as f:
        f.write(f"\n=== Overall Results ===\n")
        f.write(f"Mean accuracy: {mean_accuracy:.4f} ± {std_accuracy:.4f}\n")
        f.write(f"Mean F1 score: {mean_f1:.4f} ± {std_f1:.4f}\n\n")
        
        f.write("Per-subject results:\n")
        for i, subj in enumerate(results['subject']):
            f.write(f"  Subject {subj}: "
                   f"Accuracy={results['accuracy'][i]:.4f}, "
                   f"F1={results['f1_score'][i]:.4f}\n")
    
    # Save overall results
    with open(os.path.join(output_dir, "overall_results.json"), 'w') as f:
        overall_results = {
            'mean_accuracy': float(mean_accuracy),
            'std_accuracy': float(std_accuracy),
            'mean_f1': float(mean_f1),
            'std_f1': float(std_f1),
            'per_subject': [
                {
                    'subject': str(subj),
                    'accuracy': float(acc),
                    'f1_score': float(f1)
                }
                for subj, acc, f1 in zip(results['subject'], results['accuracy'], results['f1_score'])
            ]
        }
        json.dump(overall_results, f, indent=4)
    
    # Plot and save overall results
    plt.figure(figsize=(10, 6))
    
    # Bar plot of accuracy by subject
    plt.bar(range(len(results['subject'])), results['accuracy'])
    plt.axhline(y=mean_accuracy, color='r', linestyle='--', label=f'Mean accuracy: {mean_accuracy:.4f}')
    
    plt.title('Accuracy by Subject (Leave-One-Subject-Out CV)')
    plt.xlabel('Subject')
    plt.ylabel('Accuracy')
    plt.ylim(0, 1)
    plt.xticks(range(len(results['subject'])), results['subject'])
    plt.legend()
    plt.tight_layout()
    
    plt.savefig(os.path.join(output_dir, "overall_accuracy.png"))
    plt.close()
    
    
    avg_conf_matrix = np.zeros_like(results['confusion_matrix'][0], dtype=float)
    for cm in results['confusion_matrix']:
        # Normalize each confusion matrix by row (true label)
        row_sums = cm.sum(axis=1, keepdims=True)
        if np.all(row_sums > 0):  # Avoid division by zero
            cm_norm = cm / row_sums
            avg_conf_matrix += cm_norm
    
    avg_conf_matrix /= len(results['confusion_matrix'])
    
    
    fig = plot_confusion_matrix(avg_conf_matrix, class_names, title='Average Normalized Confusion Matrix')
    fig.savefig(os.path.join(output_dir, "average_confusion_matrix.png"))
    plt.close(fig)
    
    print(f"\n=== Overall Results ===")
    print(f"Mean accuracy: {mean_accuracy:.4f} ± {std_accuracy:.4f}")
    print(f"Mean F1 score: {mean_f1:.4f} ± {std_f1:.4f}")
    
    return results

In [None]:
def main(feature_dir, output_dir, subjects=None):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    
    os.makedirs(output_dir, exist_ok=True)
    
    
    print("Loading features")
    features_data = load_features(feature_dir, subjects)
    
    if features_data is None or len(features_data['X']) == 0:
        print("No features found. Exiting.")
        return
    
    
    print("Preparing data for model")
    X, y, subjects, label_encoder = prepare_data_for_model(features_data)
    
    
    with open(os.path.join(output_dir, "label_encoder.json"), 'w') as f:
        json.dump({
            'classes': label_encoder.classes_.tolist(),
            'transform': label_encoder.transform(label_encoder.classes_).tolist()
        }, f, indent=4)
    
    
    print("Training with subject transfer")
    results = train_with_subject_transfer(
        X=X,
        y=y,
        subjects=subjects,
        label_encoder=label_encoder,
        output_dir=output_dir,
        device=device,
        batch_size=32,
        num_epochs=100,
        patience=15
    )
    
    print("Training complete!")
    print(f"Results saved to {output_dir}")

In [None]:
feature_dir = "processed_data"
output_dir = "transformer_results"
subjects = None  # Process all subjects


main(feature_dir, output_dir, subjects)