In [1]:
import numpy as np
import pandas as pd
import json
import os
import mne
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import time # To measure time

from scipy.fft import fft
from scipy.signal import detrend, butter, filtfilt
import pywt


from scipy.signal import resample
from collections import Counter


from skimage.transform import resize
from skimage import img_as_float, img_as_ubyte

# Import

In [2]:
# Load subject labels
with open("model-data/Labels_epochs.json", "r") as f:
    subject_labels = json.load(f)

def load_data(directory):
    """
    Loads and returns augmented EEG data (X) and corresponding labels (y) from the specified directory.
    """
    X, y = [], []
    for file in os.listdir(directory):
        if file.endswith(".npy"):
            # Load the features (EEG data)
            X_data = np.load(os.path.join(directory, file))
            # Load corresponding label (subject ID matches file naming convention)
            subject_id = file.split("_")[0]
            label = subject_labels.get(subject_id, None)
            if label is not None:
                X.append(X_data)
                y.append(label)

    return np.array(X), np.array(y)


# Preprocessing

## model 1

In [11]:
def calculate_td_psd_features(epoch_data, fs, power_lambda=0.1, epsilon=1e-9):
    """
    Calculates the 7 TD-PSD features for a single EEG epoch (single channel).
    Based on equations in Amini et al., 2021.

    Args:
        epoch_data (np.ndarray): 1D numpy array for a single channel epoch.
        fs (float): Sampling frequency of the epoch data.
        power_lambda (float): Lambda for power transform normalization.
        epsilon (float): Small value to prevent log(0) or division by zero.

    Returns:
        np.ndarray: Array containing the 7 log-transformed TD-PSD features.
                    Returns NaNs if calculation fails.
    """
    n_samples = len(epoch_data)
    if n_samples == 0:
        return np.full(7, np.nan)

    # Detrend the signal (optional but often good practice)
    signal = detrend(epoch_data)

    # 1. Calculate Power Spectrum and Moments
    try:
        # FFT
        X = fft(signal)
        # Power Spectrum (One-sided, ignoring DC for moments perhaps?)
        # Frequencies for moments k: corresponds to frequency bins
        freqs = np.fft.fftfreq(n_samples, 1/fs)
        # Power spectrum P[k] = |X[k]|^2 / N
        P = np.abs(X)**2 / n_samples

        # Calculate moments m0, m2, m4
        # m_n = sum(f^n * P(f)) df - approximated by sum(k^n * P[k])
        # We use the magnitude of frequencies for k, ignore negative freqs?
        # Let's use Hjorth parameters definition based on time-domain variance
        # m0 = variance(signal) = total power (approx)
        m0_bar = np.sum(signal**2) / n_samples # Variance = mean square if mean is zero
        if m0_bar < epsilon: m0_bar = epsilon

        # m2 = variance of first derivative (activity)
        delta_x = np.diff(signal, n=1) * fs # Scale by fs? Hjorth doesn't explicitly scale by fs
        m2_bar = np.sum(delta_x**2) / (n_samples -1) # Use n_samples-1?
        if m2_bar < epsilon: m2_bar = epsilon


        # m4 = variance of second derivative (mobility)
        delta2_x = np.diff(signal, n=2) * (fs**2) # Scale by fs^2?
        m4_bar = np.sum(delta2_x**2) / (n_samples -2)
        if m4_bar < epsilon: m4_bar = epsilon


        # Apply power transform (Box-Cox with lambda=0 is log, this is slightly different)
        m0 = (m0_bar**power_lambda - 1) / power_lambda if power_lambda != 0 else np.log(m0_bar)
        m2 = (m2_bar**power_lambda - 1) / power_lambda if power_lambda != 0 else np.log(m2_bar)
        m4 = (m4_bar**power_lambda - 1) / power_lambda if power_lambda != 0 else np.log(m4_bar)

        # Ensure moments are positive after transform for log
        m0 = max(m0, epsilon)
        m2 = max(m2, epsilon)
        m4 = max(m4, epsilon)


    except Exception as e:
        print(f"Error calculating moments: {e}")
        return np.full(7, np.nan)

    features = np.zeros(7)

    # 2. Calculate Features f1, f2, f3
    try:
        features[0] = np.log(m0) # f1 = log(m0)
        # Check for valid subtractions
        if m0 <= m2: m0 = m2 + epsilon
        if m0 <= m4: m0 = m4 + epsilon
        features[1] = np.log(m0 - m2) # f2 = log(m0 - m2)
        features[2] = np.log(m0 - m4) # f3 = log(m0 - m4)

    except Exception as e:
         print(f"Error calculating f1, f2, f3: {e}")
         features[:3] = np.nan


    # 3. Calculate Feature f4 (Sparseness)
    try:
        denominator_sqrt = np.sqrt(max(m0 - m2, epsilon)) * np.sqrt(max(m0 - m4, epsilon))
        if denominator_sqrt < epsilon: denominator_sqrt = epsilon
        features[3] = np.log(m0 / denominator_sqrt) # f4 = log(S) = log(m0 / sqrt((m0-m2)(m0-m4)))
    except Exception as e:
         print(f"Error calculating f4 (Sparseness): {e}")
         features[3] = np.nan

    # 4. Calculate Feature f5 (Irregularity Factor - IF)
    # IF = (m4/m2) / (m2/m0) based on Hjorth parameters 'complexity'
    # Paper formula: sqrt(m4/m2) / sqrt(m2/m0) => m0*m4 / m2^2
    try:
        if m2 < epsilon: m2 = epsilon
        if_val = (m0 * m4) / (m2**2)
        features[4] = np.log(max(if_val, epsilon)) # f5 = log(IF)
    except Exception as e:
        print(f"Error calculating f5 (IF): {e}")
        features[4] = np.nan

    # 5. Calculate Feature f6 (Covariance - COV)
    # COV = std_dev / mean
    try:
        mean_val = np.mean(signal)
        std_dev_val = np.std(signal)
        if abs(mean_val) < epsilon: mean_val = np.sign(mean_val) * epsilon if mean_val != 0 else epsilon
        cov_val = std_dev_val / mean_val
        features[5] = np.log(max(abs(cov_val), epsilon)) # Log of magnitude? Paper isn't explicit if COV can be negative. Let's take abs.
    except Exception as e:
        print(f"Error calculating f6 (COV): {e}")
        features[5] = np.nan


    # 6. Calculate Feature f7 (Teager Energy Operator - TEO)
    try:
        # TEO(x[j]) = x[j]^2 - x[j-1]x[j+1]
        # Need to handle boundaries (pad or slice)
        teo_vals = signal[1:-1]**2 - signal[:-2] * signal[2:]
        sum_teo = np.sum(teo_vals)
        features[6] = np.log(max(abs(sum_teo), epsilon)) # Log of magnitude? Sum can be negative. Paper isn't explicit. Taking abs.
    except Exception as e:
        print(f"Error calculating f7 (TEO): {e}")
        features[6] = np.nan

    return features


def preprocess_amini(eeg_data, fs, target_fs=256, target_duration_sec=180, epoch_len=256):
    """
    Preprocesses EEG data according to Amini et al. (2021).
    """
    # 1. Select time window
    n_channels, n_timesteps = eeg_data.shape
    start_sample = int(60 * fs)
    end_sample = start_sample + int(target_duration_sec * fs)
    eeg_segment = eeg_data[:, start_sample:end_sample]

    # 2. Downsample
    if fs != target_fs:
        num_samples_resampled = int(eeg_segment.shape[1] * (target_fs / fs))
        eeg_resampled = resample(eeg_segment, num_samples_resampled, axis=1)
    else:
        eeg_resampled = eeg_segment

    # 3. Segment into epochs and calculate features
    n_channels_res, n_timesteps_res = eeg_resampled.shape
    num_epochs = n_timesteps_res // epoch_len
    all_channel_features = []

    for i_ch in range(n_channels_res):
        channel_data = eeg_resampled[i_ch, :]
        epoch_features_list = []
        for i_epoch in range(num_epochs):
            start = i_epoch * epoch_len
            end = start + epoch_len
            epoch = channel_data[start:end]
            features = calculate_td_psd_features(epoch, fs)
            if not np.isnan(features).all():
                epoch_features_list.append(features)

        # Average features across valid epochs for the channel
        avg_features = np.nanmean(np.array(epoch_features_list), axis=0) if epoch_features_list else np.full(7, np.nan)
        all_channel_features.append(avg_features)

    return np.array(all_channel_features)  # shape: (n_channels, 7)

## Model 2

In [8]:
def get_cwt_scales(fs, f_min=1, f_max=100, num_scales=128, wavelet='morl'):
    """Helper to get CWT scales corresponding to a frequency range."""
    wname = wavelet
    central_freq = pywt.central_frequency(wname)
    # Formula: scale = central_frequency * sampling_period / desired_frequency
    sampling_period = 1.0 / fs
    scales = central_freq * sampling_period / (np.logspace(np.log10(f_min), np.log10(f_max), num_scales))[::-1]
    # Frequencies corresponding to these scales (for verification)
    # frequencies = pywt.scale2frequency(wname, scales) / sampling_period
    # print(f"Generated {len(scales)} scales for freqs approx {frequencies.min()}-{frequencies.max()} Hz")
    return scales


def preprocess_acharya(eeg_data, fs, epoch_sec=5, target_size=(224, 224), wavelet='morl', scales=None):
    """
    Preprocesses raw EEG data according to Acharya et al. (2025).
    Segments data, performs CWT for each channel in each epoch, averages CWT magnitudes
    across channels, resizes to target image size, and converts to 3-channel image.

    Args:
        eeg_data (np.ndarray): Raw EEG data (n_channels, n_timesteps).
        fs (float): Original sampling frequency.
        epoch_sec (int): Duration of each epoch in seconds (default: 5s).
        target_size (tuple): Target image size (height, width) (default: (224, 224)).
        wavelet (str): Wavelet to use for CWT (default: 'morl').
        scales (np.ndarray, optional): Scales to use for CWT. If None, calculated for 1-100 Hz.

    Returns:
        list[np.ndarray]: A list of processed images (one for each epoch).
                         Each image is a numpy array of shape (height, width, 3).
                         Returns empty list if error or no full epochs.
    """
    n_channels, n_timesteps = eeg_data.shape
    samples_per_epoch = int(epoch_sec * fs)
    num_epochs = n_timesteps // samples_per_epoch

    if num_epochs == 0:
        print(f"Error: Data duration ({n_timesteps/fs:.2f}s) is less than epoch duration ({epoch_sec}s). Cannot create epochs.")
        return []

    if scales is None:
        scales = get_cwt_scales(fs, f_min=1, f_max=100, num_scales=128, wavelet=wavelet) # Example scale selection

    processed_images = []

    for i_epoch in range(num_epochs):
        start = i_epoch * samples_per_epoch
        end = start + samples_per_epoch
        epoch_data = eeg_data[:, start:end] # Shape: (n_channels, samples_per_epoch)

        all_channel_coeffs = []
        valid_channel_count = 0
        for i_ch in range(n_channels):
            try:
                # Perform CWT for the channel
                coeffs, _ = pywt.cwt(epoch_data[i_ch, :], scales, wavelet)
                # Take the magnitude (absolute value) of coefficients
                all_channel_coeffs.append(np.abs(coeffs)) # Shape: (num_scales, samples_per_epoch)
                valid_channel_count += 1
            except Exception as e:
                print(f"Warning: CWT failed for channel {i_ch}, epoch {i_epoch}: {e}")
                # Optionally append NaNs or zeros if needed, here we just skip

        if valid_channel_count == 0:
            print(f"Warning: CWT failed for all channels in epoch {i_epoch}. Skipping epoch.")
            continue

        # Average the CWT magnitudes across valid channels
        # Shape: (num_scales, samples_per_epoch)
        avg_coeffs_mag = np.mean(np.array(all_channel_coeffs), axis=0)

        # Normalize the averaged magnitudes (e.g., to 0-1 range for image representation)
        min_val = np.min(avg_coeffs_mag)
        max_val = np.max(avg_coeffs_mag)
        if max_val > min_val:
            normalized_coeffs = (avg_coeffs_mag - min_val) / (max_val - min_val)
        else:
            normalized_coeffs = np.zeros_like(avg_coeffs_mag) # Handle case of flat input

        # Convert to float image format [0, 1]
        image_gray = img_as_float(normalized_coeffs)

        # Resize the grayscale image to the target size
        try:
            # Anti-aliasing is recommended for downsampling
            image_resized_gray = resize(image_gray, target_size, anti_aliasing=True)
        except Exception as e:
            print(f"Error resizing image for epoch {i_epoch}: {e}")
            continue

        # Convert grayscale to 3-channel image (e.g., by repeating the channel)
        image_rgb = np.stack([image_resized_gray]*3, axis=-1) # Shape: (H, W, 3)

        # Convert back to uint8 if necessary for some libraries, but float is often fine for PyTorch
        # image_rgb_uint8 = img_as_ubyte(image_rgb)

        processed_images.append(image_rgb) # Add the processed image for this epoch

    return processed_images

## Model 3

In [7]:

def preprocess_eegnet_minimal(eeg_data, fs, lowcut=1.0, highcut=40.0, order=5):
    """
    Applies minimal preprocessing suitable for models like EEGNet:
    Bandpass filtering and channel-wise standardization.

    Args:
        eeg_data (np.ndarray): Raw EEG data (n_channels, n_timesteps).
        fs (float): Original sampling frequency.
        lowcut (float): Lower cutoff frequency for bandpass filter (Hz).
        highcut (float): Upper cutoff frequency for bandpass filter (Hz).
        order (int): Order of the Butterworth filter.

    Returns:
        np.ndarray: Preprocessed EEG data (n_channels, n_timesteps).
    """
    n_channels, n_timesteps = eeg_data.shape
    processed_data = np.zeros_like(eeg_data)

    # 1. Bandpass Filter Design
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    # Ensure frequency bounds are valid
    if low <= 0 or high >= 1:
         print(f"Warning: Filter frequencies ({lowcut}Hz, {highcut}Hz) are invalid for Nyquist freq {nyq}Hz. Adjusting or skipping filter.")
         # Option: Skip filtering or adjust bounds
         b, a = None, None # Indicate filter skip
    else:
        try:
            b, a = butter(order, [low, high], btype='band')
        except ValueError as e:
            print(f"Warning: Could not design Butterworth filter (order={order}, freqs=[{low}, {high}]). Skipping filter. Error: {e}")
            b, a = None, None


    # 2. Apply Filter and Standardize Channel by Channel
    for i_ch in range(n_channels):
        channel_data = eeg_data[i_ch, :]

        # Apply filtering if filter design was successful
        if b is not None and a is not None:
             try:
                 filtered_data = filtfilt(b, a, channel_data)
             except Exception as e:
                 print(f"Warning: Filtering failed for channel {i_ch}. Using original data for this channel. Error: {e}")
                 filtered_data = channel_data # Use original if filtering fails
        else:
             filtered_data = channel_data # Use original if filter wasn't designed

        # Standardize (z-score normalization)
        mean = np.mean(filtered_data)
        std = np.std(filtered_data)
        if std > 1e-9: # Avoid division by zero
            processed_data[i_ch, :] = (filtered_data - mean) / std
        else:
            processed_data[i_ch, :] = filtered_data - mean # Only center if std is zero

    return processed_data

# Models

## Model 1： Amini et al., 2021

In [12]:
import torch.nn as nn
import torch.nn.functional as F

class Amini_Adapted_CNN(nn.Module):
    def __init__(self, n_channels, n_timesteps, num_classes, dropout_rate=0.5):
        super(Amini_Adapted_CNN, self).__init__()
        self.n_channels = n_channels
        self.n_timesteps = n_timesteps
        self.num_classes = num_classes

        # Convolutional Layer(s)
        self.conv1_out_channels = 16
        self.conv1_kernel_size = 64
        self.conv1_stride = 16
        self.conv1 = nn.Conv1d(in_channels=n_channels,
                               out_channels=self.conv1_out_channels,
                               kernel_size=self.conv1_kernel_size,
                               stride=self.conv1_stride)
        self.bn1 = nn.BatchNorm1d(self.conv1_out_channels)

        # Fully Connected Layers
        conv1_out_timesteps = (n_timesteps - self.conv1_kernel_size) // self.conv1_stride + 1
        self.fc1_input_features = self.conv1_out_channels * conv1_out_timesteps
        self.fc1_hidden_units = 128
        self.fc2_hidden_units = 64

        self.fc1 = nn.Linear(self.fc1_input_features, self.fc1_hidden_units)
        self.fc2 = nn.Linear(self.fc1_hidden_units, self.fc2_hidden_units)
        self.fc3 = nn.Linear(self.fc2_hidden_units, num_classes)

        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))  # (batch_size, conv1_out_channels, conv1_out_timesteps)
        x = x.view(x.size(0), -1)  # Flatten the output
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        return x


## Model 2: charya et al., 2025

In [7]:
class ConvNeXt1DBlock(nn.Module):
    """
    1D adaptation of the ConvNeXt block described in Acharya et al. (2025).
    Uses Depthwise and Pointwise convolutions adapted for 1D.
    r: expansion ratio for inverted bottleneck (typically 4)
    """
    def __init__(self, dim, drop_p=0.):
        super().__init__()
        # Depthwise convolution (applied independently to each channel)
        self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim)
        self.norm = nn.LayerNorm(dim, eps=1e-6) # LayerNorm applied on the channel dimension
        # Pointwise convolutions (expand and contract channels)
        self.pwconv1 = nn.Linear(dim, 4 * dim) # Equivalent to 1x1 Conv for channel expansion
        self.act = nn.GELU()
        self.pwconv2 = nn.Linear(4 * dim, dim) # Equivalent to 1x1 Conv for channel contraction
        self.drop_p = drop_p
        if self.drop_p > 0.0 :
             # LayerScale and DropPath would typically be here in original ConvNeXt,
             # Simplified for this adaptation. Using basic dropout.
             self.dropout = nn.Dropout(drop_p)


    def forward(self, x):
        # x shape: (batch_size, channels, timesteps)
        input = x
        x = self.dwconv(x)

        # LayerNorm needs input (batch_size, seq_len, features)
        # Here, seq_len=timesteps, features=channels. So, permute.
        x = x.permute(0, 2, 1) # (batch_size, timesteps, channels)
        x = self.norm(x)

        # Pointwise Convs (Linear layers operate on the last dimension - channels)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)

        # Permute back
        x = x.permute(0, 2, 1) # (batch_size, channels, timesteps)

        # Add dropout if applicable
        if self.drop_p > 0.0:
             x = self.dropout(x)

        # Residual Connection
        x = input + x
        return x

class EEGConvNeXt_1D(nn.Module):
    """
    1D adaptation of the EEGConvNeXt model from Acharya et al. (2025).
    Processes raw EEG (n_channels, n_timesteps) directly using 1D ConvNeXt blocks.
    This differs significantly from the original paper which uses 2D CWT images.
    """
    def __init__(self, n_channels, n_timesteps, num_classes,
                 depths=[1, 1, 2, 1], dims=[96, 192, 384, 768], # Based on Table 3 (R={1,1,2,1}, F starts at 96)
                 dropout_rate=0.5):
        super().__init__()

        self.n_timesteps = n_timesteps

        # --- Stem ---
        # Patchify: Use Conv1d to embed channels and reduce timesteps
        # Kernel size and stride control the initial downsampling
        stem_kernel_size = 4
        stem_stride = 4
        self.stem = nn.Sequential(
            nn.Conv1d(n_channels, dims[0], kernel_size=stem_kernel_size, stride=stem_stride),
            # LayerNorm operates on the channel dimension, needs permutation
            PermuteLayerNorm(dims[0])
        )
        current_timesteps = (n_timesteps - stem_kernel_size) // stem_stride + 1

        # --- Main Stages ---
        self.stages = nn.ModuleList()
        for i in range(4): # 4 stages
            # Downsampling before stages 2, 3, 4
            if i > 0:
                downsample_layer = nn.Sequential(
                    # LayerNorm before downsampling
                    PermuteLayerNorm(dims[i-1]),
                    # Conv1d for downsampling (stride=2) and increasing channels
                    nn.Conv1d(dims[i-1], dims[i], kernel_size=2, stride=2)
                )
                self.stages.append(downsample_layer)
                current_timesteps = (current_timesteps - 2) // 2 + 1 # Update timesteps after downsampling

            # ConvNeXt1D Blocks for the current stage
            stage_blocks = nn.Sequential(
                *[ConvNeXt1DBlock(dim=dims[i], drop_p=dropout_rate) for _ in range(depths[i])]
            )
            self.stages.append(stage_blocks)


        # --- Output Head ---
        self.norm_out = nn.LayerNorm(dims[-1], eps=1e-6) # Final LayerNorm
        # Global Average Pooling equivalent for 1D: average over the time dimension
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(dims[-1], num_classes)


    def forward(self, x):
        # x shape: (batch_size, n_channels, n_timesteps)
        x = self.stem(x)
        # x shape: (batch_size, dims[0], current_timesteps after stem)

        for stage_module in self.stages:
            x = stage_module(x)
            # Shape changes after downsampling layers within the loop

        # Output head processing
        # LayerNorm needs (batch_size, seq_len, features)
        x = x.permute(0, 2, 1) # (batch_size, timesteps, channels=dims[-1])
        x = self.norm_out(x)
        x = x.permute(0, 2, 1) # (batch_size, channels=dims[-1], timesteps)

        x = self.avgpool(x) # (batch_size, dims[-1], 1)
        x = torch.flatten(x, 1) # (batch_size, dims[-1])
        x = self.head(x)      # (batch_size, num_classes)

        return x

# Helper module for LayerNorm after Conv1d
class PermuteLayerNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.norm = nn.LayerNorm(dim, eps=eps)

    def forward(self, x):
        # x shape: (batch_size, channels, timesteps)
        x = x.permute(0, 2, 1) # (batch_size, timesteps, channels)
        x = self.norm(x)
        x = x.permute(0, 2, 1) # (batch_size, channels, timesteps)
        return x

## Model 3: Rakhmatulin et al., 2024

In [None]:
class EEGNet(nn.Module):
    """
    Implementation of the EEGNet architecture described in Rakhmatulin et al. (2024)
    and originally proposed by Lawhern et al. (2018).
    Processes raw EEG (n_channels, n_timesteps).
    """
    def __init__(self, n_channels, n_timesteps, num_classes,
                 F1=8, D=2, F2=16, kernel_length=64, dropout_rate=0.5):
        super(EEGNet, self).__init__()
        self.n_channels = n_channels
        self.n_timesteps = n_timesteps
        self.num_classes = num_classes
        self.F1 = F1
        self.D = D
        self.F2 = F2 # Original paper uses F2 = F1 * D
        self.kernel_length = kernel_length
        self.dropout_rate = dropout_rate

        # Block 1: Temporal Convolution + Depthwise Spatial Convolution
        # Temporal Conv: kernel (1, kernel_length), output F1 feature maps
        # Input shape: (batch_size, 1, n_channels, n_timesteps) - Add a dummy channel dim
        self.conv1 = nn.Conv2d(1, self.F1, (1, self.kernel_length), padding=(0, self.kernel_length // 2), bias=False)
        self.bn1 = nn.BatchNorm2d(self.F1)
        # Depthwise Conv: kernel (n_channels, 1), output F1*D feature maps
        # Applied to each F1 map spatially (across channels)
        self.depthwise_conv = nn.Conv2d(self.F1, self.F1 * self.D, (self.n_channels, 1), groups=self.F1, bias=False)
        self.bn2 = nn.BatchNorm2d(self.F1 * self.D)
        # Pooling
        self.pool1 = nn.AvgPool2d((1, 4)) # Downsample time dimension

        # Block 2: Separable Convolution
        # SeparableConv = Depthwise Conv + Pointwise Conv
        # Input shape: (batch_size, F1*D, 1, n_timesteps//4)
        separable_kernel_length = 16 # Example kernel size for separable conv
        self.separable_conv = nn.Sequential(
            # Depthwise part: applies one filter per input channel (F1*D)
             nn.Conv2d(self.F1 * self.D, self.F1 * self.D, (1, separable_kernel_length),
                       padding=(0, separable_kernel_length // 2), groups=self.F1 * self.D, bias=False),
            # Pointwise part: 1x1 conv to mix channels and change depth to F2
             nn.Conv2d(self.F1 * self.D, self.F2, (1, 1), bias=False)
        )
        self.bn3 = nn.BatchNorm2d(self.F2)
        # Pooling
        self.pool2 = nn.AvgPool2d((1, 8)) # Further downsample time dimension

        # Calculate Flatten layer input size
        # After pool1: T' = T // 4
        # After pool2: T'' = T' // 8 = T // 32
        # Input to FC: F2 * 1 * (T // 32)
        self.flatten_size = self.F2 * (n_timesteps // 32)

        # Fully Connected Layer for Classification
        self.fc_out = nn.Linear(self.flatten_size, self.num_classes)

        self.dropout = nn.Dropout(self.dropout_rate)

    def forward(self, x):
        # x shape: (batch_size, n_channels, n_timesteps)

        # Add dummy channel dimension for Conv2D layers
        x = x.unsqueeze(1)
        # x shape: (batch_size, 1, n_channels, n_timesteps)

        # Block 1
        x = self.conv1(x)
        x = self.bn1(x)
        # Depthwise Conv expects (batch_size, F1, n_channels, time)
        x = self.depthwise_conv(x)
        x = F.elu(self.bn2(x)) # Activation after BN2
        x = self.pool1(x)
        x = self.dropout(x)
        # x shape: (batch_size, F1*D, 1, n_timesteps//4)

        # Block 2
        x = self.separable_conv(x)
        x = F.elu(self.bn3(x)) # Activation after BN3
        x = self.pool2(x)
        x = self.dropout(x)
        # x shape: (batch_size, F2, 1, n_timesteps//32)

        # Flatten for FC layer
        x = x.view(x.size(0), -1) # Or x = torch.flatten(x, 1)
        # x shape: (batch_size, flatten_size)

        # Classification layer
        x = self.fc_out(x)
        # x shape: (batch_size, num_classes)
        # Softmax is applied in the loss function

        return x

# Training

In [9]:
# --- Modified Generic Training Function with Validation ---
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device):
    """
    Generic function to train and validate a PyTorch model.

    Args:
        model (nn.Module): The PyTorch model to train.
        train_loader (DataLoader): DataLoader for the training data.
        val_loader (DataLoader or None): DataLoader for the validation data. If None, validation is skipped.
        criterion (nn.Module): The loss function (e.g., nn.CrossEntropyLoss).
        optimizer (Optimizer): The optimizer (e.g., optim.Adam).
        num_epochs (int): Number of epochs to train for.
        device (torch.device): The device to train on (CPU or CUDA).

    Returns:
        None: Prints training and validation progress information directly.
    """
    model.to(device) # Move model to the designated device
    total_train_steps = len(train_loader)
    if val_loader:
        total_val_steps = len(val_loader)
    start_time = time.time()

    print(f"\n--- Training {model.__class__.__name__} ---")

    for epoch in range(num_epochs):
        # --- Training Phase ---
        model.train() # Set the model to training mode
        epoch_train_loss = 0.0
        train_correct_predictions = 0
        train_total_samples = 0

        for i, (inputs, labels) in enumerate(train_loader):
            # Move data to the designated device
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Accumulate training statistics
            epoch_train_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            train_total_samples += labels.size(0)
            train_correct_predictions += (predicted == labels).sum().item()

        # Calculate average training loss and accuracy for the epoch
        avg_epoch_train_loss = epoch_train_loss / total_train_steps
        epoch_train_accuracy = 100 * train_correct_predictions / train_total_samples

        # --- Validation Phase ---
        if val_loader is not None:
            model.eval() # Set the model to evaluation mode
            epoch_val_loss = 0.0
            val_correct_predictions = 0
            val_total_samples = 0

            with torch.no_grad(): # Disable gradient calculations during validation
                for val_inputs, val_labels in val_loader:
                    # Move data to the designated device
                    val_inputs = val_inputs.to(device)
                    val_labels = val_labels.to(device)

                    # Forward pass
                    val_outputs = model(val_inputs)
                    val_loss_batch = criterion(val_outputs, val_labels)

                    # Accumulate validation statistics
                    epoch_val_loss += val_loss_batch.item()
                    _, val_predicted = torch.max(val_outputs.data, 1)
                    val_total_samples += val_labels.size(0)
                    val_correct_predictions += (val_predicted == val_labels).sum().item()

            # Calculate average validation loss and accuracy for the epoch
            avg_epoch_val_loss = epoch_val_loss / total_val_steps
            epoch_val_accuracy = 100 * val_correct_predictions / val_total_samples

            # Print combined epoch results
            print(f'Epoch [{epoch+1}/{num_epochs}], '
                  f'Train Loss: {avg_epoch_train_loss:.4f}, Train Acc: {epoch_train_accuracy:.2f}%, '
                  f'Val Loss: {avg_epoch_val_loss:.4f}, Val Acc: {epoch_val_accuracy:.2f}%')
        else:
            # Print only training results if no validation loader is provided
            print(f'Epoch [{epoch+1}/{num_epochs}], '
                  f'Train Loss: {avg_epoch_train_loss:.4f}, Train Acc: {epoch_train_accuracy:.2f}%')

        # Note: model is already set back to train() mode at the start of the next epoch loop iteration

    end_time = time.time()
    print(f"Finished Training {model.__class__.__name__}. Total time: {end_time - start_time:.2f} seconds")
    # --- Consider saving the best model based on validation performance ---
    # (Logic for tracking best val_accuracy/lowest val_loss and saving model state_dict would go here)

# Running

## Parameters

## Import the data

## Preprocessing

In [None]:
train_data, train_labels = load_data("model-data/epochs-numpy/train")
test_data, test_labels = load_data("model-data/epochs-numpy/test")

# Check shapes of the loaded data
print(f"Training data shape: {train_data.shape}")
print(f"Test data shape: {test_data.shape}")




Train files: ['sub-058_epochs.npy', 'sub-073_X_shift.npy', 'sub-065_X_scale.npy', 'sub-003_X_shift.npy', 'sub-015_X_scale.npy', 'sub-043_X_noise.npy', 'sub-009_X_scale.npy', 'sub-079_X_scale.npy', 'sub-084_X_scale.npy', 'sub-031_X_shift.npy', 'sub-073_epochs.npy', 'sub-044_epochs.npy', 'sub-026_epochs.npy', 'sub-057_X_scale.npy', 'sub-001_X_noise.npy', 'sub-037_epochs.npy', 'sub-024_X_shift.npy', 'sub-055_epochs.npy', 'sub-054_X_shift.npy', 'sub-062_epochs.npy', 'sub-014_X_noise.npy', 'sub-087_X_shift.npy', 'sub-066_X_shift.npy', 'sub-026_X_noise.npy', 'sub-070_X_scale.npy', 'sub-056_X_noise.npy', 'sub-081_epochs.npy', 'sub-048_X_shift.npy', 'sub-078_X_noise.npy', 'sub-085_X_noise.npy', 'sub-069_X_scale.npy', 'sub-082_X_shift.npy', 'sub-011_X_noise.npy', 'sub-047_X_scale.npy', 'sub-051_X_shift.npy', 'sub-061_X_noise.npy', 'sub-021_X_shift.npy', 'sub-050_epochs.npy', 'sub-080_X_noise.npy', 'sub-084_epochs.npy', 'sub-005_X_scale.npy', 'sub-019_epochs.npy', 'sub-013_X_shift.npy', 'sub-023

In [23]:
# Convert data to PyTorch tensors
train_data_tensor = torch.tensor(train_data, dtype=torch.float32)
test_data_tensor = torch.tensor(test_data, dtype=torch.float32)

# Convert labels to PyTorch tensors
train_labels_tensor = torch.tensor(train_labels, dtype=torch.long)
test_labels_tensor = torch.tensor(test_labels, dtype=torch.long)


# Create DataLoader for training and testing
train_dataset = TensorDataset(train_data_tensor, train_labels_tensor)
test_dataset = TensorDataset(test_data_tensor, test_labels_tensor)

batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)



ValueError: num_samples should be a positive integer value, but got num_samples=0

## Model 1

In [14]:
# Model initialization
n_channels = 19  # Fixed for your data
n_timesteps = train_data.shape[2]  # Number of time points in each epoch
num_classes = len(np.unique(train_labels))  # Assuming labels are categorical, e.g., F, A, C

model = Amini_Adapted_CNN(n_channels=n_channels, n_timesteps=n_timesteps, num_classes=num_classes, dropout_rate=0.5)

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)


NameError: name 'train_data' is not defined

## Model 2

In [29]:

# --- Train Model 2: EEGConvNeXt_1D ---
if len(train_loader) > 0:
    try:
        print("\nInitializing Model 2: EEGConvNeXt_1D...")
        model2 = EEGConvNeXt_1D(n_channels=actual_n_channels,
                                n_timesteps=actual_n_timesteps,
                                num_classes=num_classes_actual).to(device)
        criterion2 = nn.CrossEntropyLoss()
        optimizer2 = optim.Adam(model2.parameters(), lr=learning_rate)
        train_model(model2, train_loader, val_loader, criterion2, optimizer2, num_epochs, device)
        # Add evaluation calls here if needed
    except NameError:
        print("Error: EEGConvNeXt_1D or PermuteLayerNorm class not defined.")
    except Exception as e:
        print(f"An error occurred during Model 2 training: {e}")
else:
    print("Skipping Model 2 training: Train loader is empty.")



Initializing Model 2: EEGConvNeXt_1D...

--- Training EEGConvNeXt_1D ---
Epoch [1/100], Train Loss: 1.2604, Train Acc: 40.91%, Val Loss: 1.1359, Val Acc: 43.17%
Epoch [2/100], Train Loss: 1.0956, Train Acc: 41.27%, Val Loss: 1.1550, Val Acc: 34.16%
Epoch [3/100], Train Loss: 1.0957, Train Acc: 39.62%, Val Loss: 1.0769, Val Acc: 43.79%


KeyboardInterrupt: 

## Model 3

In [30]:

# --- Train Model 3: EEGNet ---
if len(train_loader) > 0:
    try:
        print("\nInitializing Model 3: EEGNet...")
        if actual_n_timesteps // 32 <= 0:
                print(f"Warning: n_timesteps ({actual_n_timesteps}) might be too small for EEGNet pooling. Skipping EEGNet.")
        else:
            # EEGNet parameters
            F1=8; D=2; F2=F1*D
            kernel_length = min(64, actual_n_timesteps // 4) # Adjust kernel based on actual timesteps, e.g. fs/4
            dropout_rate=0.25

            model3 = EEGNet(n_channels=actual_n_channels,
                            n_timesteps=actual_n_timesteps,
                            num_classes=num_classes_actual,
                            F1=F1, D=D, F2=F2,
                            kernel_length=kernel_length,
                            dropout_rate=dropout_rate).to(device)
            criterion3 = nn.CrossEntropyLoss()
            optimizer3 = optim.Adam(model3.parameters(), lr=learning_rate)
            train_model(model3, train_loader, val_loader, criterion3, optimizer3, num_epochs, device)
            # Add evaluation calls here if needed
    except NameError:
        print("Error: EEGNet class not defined.")
    except Exception as e:
        print(f"An error occurred during Model 3 training: {e}")
else:
    print("Skipping Model 3 training: Train loader is empty.")


Initializing Model 3: EEGNet...

--- Training EEGNet ---
Epoch [1/100], Train Loss: 1.0656, Train Acc: 44.59%, Val Loss: 0.9740, Val Acc: 58.54%
Epoch [2/100], Train Loss: 0.8774, Train Acc: 60.64%, Val Loss: 0.7874, Val Acc: 64.91%
Epoch [3/100], Train Loss: 0.7665, Train Acc: 65.56%, Val Loss: 0.6953, Val Acc: 68.94%
Epoch [4/100], Train Loss: 0.6856, Train Acc: 69.60%, Val Loss: 0.6231, Val Acc: 71.43%
Epoch [5/100], Train Loss: 0.6241, Train Acc: 73.07%, Val Loss: 0.5845, Val Acc: 74.22%
Epoch [6/100], Train Loss: 0.5807, Train Acc: 75.45%, Val Loss: 0.5478, Val Acc: 74.38%
Epoch [7/100], Train Loss: 0.5234, Train Acc: 78.30%, Val Loss: 0.5306, Val Acc: 78.73%
Epoch [8/100], Train Loss: 0.4840, Train Acc: 80.27%, Val Loss: 0.4608, Val Acc: 79.97%
Epoch [9/100], Train Loss: 0.4642, Train Acc: 81.93%, Val Loss: 0.4112, Val Acc: 83.23%
Epoch [10/100], Train Loss: 0.4013, Train Acc: 83.22%, Val Loss: 0.3705, Val Acc: 86.18%
Epoch [11/100], Train Loss: 0.3884, Train Acc: 84.41%, Val Lo

KeyboardInterrupt: 

# Evaluation

In [31]:
def evaluate_model(model, data_loader, device):
    """
    Evaluates the model's accuracy on the provided data loader.

    Args:
        model (nn.Module): The trained PyTorch model to evaluate.
        data_loader (DataLoader): DataLoader for the dataset to evaluate on (e.g., test_loader or val_loader).
        device (torch.device): The device to run evaluation on (CPU or CUDA).

    Returns:
        float: The accuracy of the model on the dataset (in percentage).
               Returns 0.0 if the data_loader is empty or None.
    """
    if data_loader is None or len(data_loader) == 0:
        print("Warning: Evaluation data loader is empty or None. Returning 0.0 accuracy.")
        return 0.0

    model.to(device) # Ensure model is on the correct device
    model.eval()     # Set the model to evaluation mode (disables dropout, uses running means/vars for BatchNorm)

    correct_predictions = 0
    total_samples = 0

    with torch.no_grad(): # Disable gradient calculations - crucial for evaluation efficiency and correctness
        for inputs, labels in data_loader:
            # Move data to the designated device
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(inputs)

            # Get predictions from the maximum value
            _, predicted = torch.max(outputs.data, 1) # Get the index of the max log-probability/logit

            # Update total samples and correct predictions count
            total_samples += labels.size(0)
            correct_predictions += (predicted == labels).sum().item()

    # Calculate final accuracy
    accuracy = 100 * correct_predictions / total_samples
    return accuracy


In [34]:
evaluate_model(model3, test_loader, device)

95.3416149068323