# 🧠 EEG Classification Pipeline – Step-by-Step Guide

## 🧠 Step 1: Setup and Dataset Loading
- Set logging level for MNE to avoid cluttered output.
- Set dataset path (`tuab_path`) pointing to TUAB `.edf` folder.
- Load train and eval datasets from the TUAB abnormal dataset using `TUHAbnormal`.

---

## 🏷 Step 2: Assign Labels (Normal = 0, Abnormal = 1)
- Iterate over each dataset in train and eval sets.
- Check if the path contains `"normal"` to assign `target = 0`, otherwise `target = 1`.

---

## 🎧 Step 3: Select Common EEG Channels
- Define a list of 21 standard EEG channels.
- Filter each dataset to keep only those common channels.

---

## 💥 Step 4: Handle Artifacts in Normal Train Samples
- Split train dataset into `train_normal` and `train_abnormal`.
- For `train_normal`:
  - Use `annotate_muscle_zscore()` to detect muscle artifacts.
  - Remove those artifacts from the annotations.

---

## 🧹 Step 5: Apply Preprocessing
- Define preprocessors:
  - Bandpass filtering (0.5–40 Hz)
  - Rescale signal amplitude (Volts → µV)
- Load EEG into memory for both `train_normal` and `train_abnormal`.
- Apply preprocessing to each dataset.
- Also preprocess `eval_dataset`.

---

## 🧠 Step 6: Combine and Window the Data
- Recombine cleaned `train_normal` and `train_abnormal`.
- Use `create_fixed_length_windows` with:
  - Window size: 1000 samples
  - No overlap
- Apply to both train and eval datasets.

---

## 🏷 Step 7: Extract Labels and Pad Windows
- Extract `target` labels from training windows.
- Pad windows to have the same number of channels using PyTorch.
- Convert data to tensors: `X`, `y`.

---

## ⚖️ Step 8: Handle Class Imbalance
- Flatten `X`, apply random oversampling to balance class distribution.
- Reshape oversampled data to `[samples, channels, time]`.

---

## 🧪 Step 9: Prepare Eval Dataset
- Extract and pad windows from eval dataset similarly.
- Save `X_eval.pt` and `y_eval.pt` for later testing.

---

## 📁 Step 10: Save Final Balanced Train Data
- Save the oversampled and padded tensors: `X_resampled.pt`, `y_resampled.pt`.

---

## 📐 Step 11: Feature Extraction
- Load selected feature names from text file (e.g., `"ch_2_time_mean"`).
- Determine feature types to extract:
  - Time-domain stats
  - Power Spectral Density (PSD) bands
  - Hjorth parameters
  - Wavelet energies
  - Catch22 features
- For each window in `X_resampled`:
  - Loop through each channel and compute only selected features.
- Stack features into final tensor `X_feat`.
- Save as `X_feat.pt`.

---

## 🧼 Step 12: Preprocess Features
- Load `X_feat` and `y_resampled`.
- Apply `RobustScaler` to standardize features.
- Save the scaler as `scaler.pkl` for future inference.

---

## 🔀 Step 13: Split for Training and Validation
- Use `train_test_split` with 80/20 stratified split.
- Create PyTorch `DataLoaders` for training and validation sets.

---

## ⚖️ Step 14: Compute Class Weights
- Compute balanced class weights for `CrossEntropyLoss`.

---

## 🧠 Step 15: Define the MLP Model
- Define `MLPOnly` class:
  - Layers: Input → 256 → 128 → 64 → 2 (output)
  - Each hidden layer: BatchNorm → ReLU → Dropout
- Initialize:
  - Model
  - Loss: `CrossEntropyLoss` with computed weights
  - Optimizer: `Adam`
  - Scheduler: `ReduceLROnPlateau`

---

## 🏋️ Step 16: Train the MLP
- Loop through epochs:
  - Train on batches with `.backward()` and `optimizer.step()`
  - Validate after each epoch:
    - Compute accuracy, precision, recall, F1
    - Use scheduler to adjust learning rate
    - Save model if it achieves best F1

---

## 💾 Step 17: Save the Best Model
- Save `state_dict` of the best model as `mlp_best_model.pth`.

---

## ✅ Final Outputs
- `X_eval.pt`, `y_eval.pt` – evaluation data
- `X_resampled_final.pt`, `y_resampled_final.pt` – balanced training data
- `X_feat.pt` – handcrafted features
- `scaler.pkl` – saved feature scaler
- `mlp_best_model.pth` – trained MLP model


## Preprocess for both MLP + Deep4Net

In [None]:
import os
import mne
import torch
import numpy as np
from collections import Counter
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.utils.class_weight import compute_class_weight
from imblearn.over_sampling import RandomOverSampler
from braindecode import EEGClassifier
from braindecode.datasets.tuh import TUHAbnormal
from braindecode.preprocessing import Preprocessor, preprocess, create_fixed_length_windows
from braindecode.datasets.base import BaseConcatDataset
from mne.preprocessing import annotate_muscle_zscore
from tqdm import tqdm

# Suppress verbose MNE logs
os.environ["MNE_LOGGING_LEVEL"] = "CRITICAL"
mne.set_log_level('CRITICAL')

# Base path to TUAB dataset
tuab_path = r"C:\Users\obass\Desktop\linux_shared\v3.0.1\edf"

# Load training and evaluation datasets
train_dataset = TUHAbnormal(path=os.path.join(tuab_path, "train"))
eval_dataset = TUHAbnormal(path=os.path.join(tuab_path, "eval"))

# Assign binary target labels based on file path
for ds in train_dataset.datasets:
    ds.description["target"] = 0 if "train\\normal" in str(ds.raw.filenames[0]).lower() else 1
for ds in eval_dataset.datasets:
    ds.description["target"] = 0 if "eval\\normal" in str(ds.raw.filenames[0]).lower() else 1

# Define a fixed set of EEG channels to keep
fixed_channels = [
    'EEG FP1-REF', 'EEG FP2-REF', 'EEG F3-REF', 'EEG F4-REF',
    'EEG C3-REF', 'EEG C4-REF', 'EEG P3-REF', 'EEG P4-REF',
    'EEG O1-REF', 'EEG O2-REF', 'EEG F7-REF', 'EEG F8-REF',
    'EEG T3-REF', 'EEG T4-REF', 'EEG T5-REF', 'EEG T6-REF',
    'EEG A1-REF', 'EEG A2-REF', 'EEG FZ-REF', 'EEG CZ-REF', 'EEG PZ-REF'
]
for ds in train_dataset.datasets + eval_dataset.datasets:
    ds.raw.pick_channels([ch for ch in fixed_channels if ch in ds.raw.ch_names])

# Split train dataset into normal and abnormal groups
train_normal = [ds for ds in train_dataset.datasets if ds.description['target'] == 0]
train_abnormal = [ds for ds in train_dataset.datasets if ds.description['target'] == 1]

# Remove muscle artifacts from normal training data only
for ds in tqdm(train_normal, desc="Removing muscle artifacts"):
    ds.raw.load_data()
    annotate_muscle_zscore(ds.raw, threshold=4.0, filter_freq=(30, 90))
    ds.raw.set_annotations(ds.raw.annotations.delete(
        np.where(ds.raw.annotations.description == 'BAD_MUSCLE')[0]
    ))

# Define common preprocessing steps
preprocessors = [
    Preprocessor('filter', l_freq=0.5, h_freq=40.),
    Preprocessor(lambda x: x * 1e6, picks='eeg')  # Convert to µV
]

# Apply preprocessing to normal training data
for ds in tqdm(train_normal, desc="Preprocessing train_normal"):
    ds.raw.load_data()
    for p in preprocessors:
        p.apply(ds.raw)

# Apply preprocessing to abnormal training data
for ds in tqdm(train_abnormal, desc="Preprocessing train_abnormal"):
    ds.raw.load_data()
    for p in preprocessors:
        p.apply(ds.raw)

# Apply preprocessing to evaluation dataset
for ds in tqdm(eval_dataset.datasets, desc="Preprocessing eval"):
    ds.raw.load_data()
    for p in preprocessors:
        p.apply(ds.raw)

# Combine processed normal and abnormal training sets
train_dataset.datasets = train_normal + train_abnormal

# Create sliding windows from training and evaluation datasets
train_windows = create_fixed_length_windows(
    train_dataset,
    start_offset_samples=0,
    window_size_samples=1000,
    window_stride_samples=1000,
    drop_last_window=True,
    preload=True
)
eval_windows = create_fixed_length_windows(
    eval_dataset,
    start_offset_samples=0,
    window_size_samples=1000,
    window_stride_samples=1000,
    drop_last_window=True,
    preload=True
)

# Extract class labels from training windows
y_train = [ds.description.get("target", -1) for ds in train_windows.datasets]
print("Label distribution before balancing:", Counter(y_train))

# Pad all windowed EEG data to have consistent channel size
max_chans = max(ds[0].shape[0] for ds in train_windows)
X = torch.stack([
    torch.nn.functional.pad(torch.tensor(ds[0]), (0, 0, 0, max_chans - ds[0].shape[0]))
    for ds in train_windows
])
y = torch.tensor([ds[1] for ds in train_windows])

# Use random oversampling to balance class distribution
X_np = X.numpy().reshape(X.shape[0], -1)
ros = RandomOverSampler()
X_resampled_np, y_resampled_np = ros.fit_resample(X_np, y.numpy())
X_resampled = torch.tensor(X_resampled_np).reshape(-1, max_chans, X.shape[2])
y_resampled = torch.tensor(y_resampled_np)

# Extract and pad evaluation data
X_eval = []
y_eval = []
for windows_ds in tqdm(eval_windows.datasets, desc="Extracting eval windows"):
    for i in range(len(windows_ds)):
        window_data, label, _ = windows_ds[i]
        padded_data = torch.nn.functional.pad(
            torch.tensor(window_data, dtype=torch.float32),
            (0, 0, 0, max_chans - window_data.shape[0])
        )
        X_eval.append(padded_data)
        y_eval.append(label)

# Convert to tensors
X_eval = torch.stack(X_eval)
y_eval = torch.tensor(y_eval)

# Save processed datasets
torch.save(X_eval, "D:/Models_Data/X_eval.pt")
torch.save(y_eval, "D:/Models_Data/y_eval.pt")
torch.save(X_resampled, "D:/Models_Data/X_resampled_final(artifact from normal,no tmax).pt") 
torch.save(y_resampled, "D:/Models_Data/y_resampled_final.pt")


💥 Removing muscle artifacts: 100%|████████████████████████████████████████████████| 1371/1371 [33:15<00:00,  1.46s/it]
  warn('Preprocessing choices with lambda functions cannot be saved.')
⚙️ Preprocessing train_normal: 100%|███████████████████████████████████████████████| 1371/1371 [12:29<00:00,  1.83it/s]
⚙️ Preprocessing train_abnormal: 100%|█████████████████████████████████████████████| 1346/1346 [09:09<00:00,  2.45it/s]
⚙️ Preprocessing eval safely: 100%|██████████████████████████████████████████████████| 276/276 [01:45<00:00,  2.62it/s]


✔️ Label distribution before balancing: Counter({0: 1371, 1: 1346})


📥 Extracting eval windows: 100%|████████████████████████████████████████████████████| 276/276 [03:11<00:00,  1.44it/s]


### MLP Features Extraction

In [None]:
import numpy as np
from scipy.stats import skew, kurtosis
from scipy.signal import welch
import pywt
from pycatch22 import catch22_all
import torch
from tqdm import tqdm

# Load the names of the features to be computed
with open("selected_feature_names.txt", "r") as f:
    selected_feature_names = set(line.strip() for line in f.readlines())

# Determine which feature types are required
need_time = any("_time_" in f for f in selected_feature_names)
need_psd = any("_psd_" in f for f in selected_feature_names)
need_hjorth = any("_hjorth_" in f for f in selected_feature_names)
need_wavelet = any("_wavelet_" in f for f in selected_feature_names)
need_catch22 = any("_catch22_" in f for f in selected_feature_names)


def extract_features(window):
    features = []
    for ch_idx, ch in enumerate(window):
        ch = np.nan_to_num(ch, nan=0.0, posinf=0.0, neginf=0.0)
        if np.all(ch == ch[0]) or np.std(ch) < 1e-6:
            features.extend([0.0 for name in selected_feature_names if name.startswith(f"ch_{ch_idx}_")])
            continue

        # Time-domain features
        if need_time:
            if f"ch_{ch_idx}_time_mean" in selected_feature_names:
                features.append(ch.mean())
            if f"ch_{ch_idx}_time_std" in selected_feature_names:
                features.append(ch.std())
            if f"ch_{ch_idx}_time_max" in selected_feature_names:
                features.append(ch.max())
            if f"ch_{ch_idx}_time_min" in selected_feature_names:
                features.append(ch.min())
            if f"ch_{ch_idx}_time_skew" in selected_feature_names:
                features.append(skew(ch))
            if f"ch_{ch_idx}_time_kurtosis" in selected_feature_names:
                features.append(kurtosis(ch))

        # Power Spectral Density features
        if need_psd:
            freqs, psd = welch(ch, fs=100, nperseg=256)
            if f"ch_{ch_idx}_psd_delta" in selected_feature_names:
                features.append(psd[(freqs >= 0.5) & (freqs < 4)].mean())
            if f"ch_{ch_idx}_psd_theta" in selected_feature_names:
                features.append(psd[(freqs >= 4) & (freqs < 8)].mean())
            if f"ch_{ch_idx}_psd_alpha" in selected_feature_names:
                features.append(psd[(freqs >= 8) & (freqs < 13)].mean())
            if f"ch_{ch_idx}_psd_beta" in selected_feature_names:
                features.append(psd[(freqs >= 13) & (freqs < 30)].mean())
            if f"ch_{ch_idx}_psd_gamma" in selected_feature_names:
                features.append(psd[(freqs >= 30) & (freqs < 40)].mean())

        # Hjorth parameters
        if need_hjorth:
            d1, d2 = np.diff(ch), np.diff(np.diff(ch))
            if f"ch_{ch_idx}_hjorth_var" in selected_feature_names:
                features.append(np.var(ch))
            if f"ch_{ch_idx}_hjorth_mob" in selected_feature_names:
                features.append(np.std(d1) / (np.std(ch) + 1e-8))
            if f"ch_{ch_idx}_hjorth_comp" in selected_feature_names:
                features.append(np.std(d2) / (np.std(d1) + 1e-8))

        # Wavelet-based features
        if need_wavelet:
            coeffs = pywt.wavedec(ch, 'db4', level=3)
            for i, c in enumerate(coeffs):
                key = f"ch_{ch_idx}_wavelet_cD{i}"
                if key in selected_feature_names:
                    features.append(np.sqrt(np.sum(c ** 2)))

        # Catch22 features
        if need_catch22:
            c22 = catch22_all(ch)["values"]
            for i, val in enumerate(c22):
                key = f"ch_{ch_idx}_catch22_{i}"
                if key in selected_feature_names:
                    features.append(np.nan_to_num(val, nan=0.0, posinf=0.0, neginf=0.0))

    return np.array(features, dtype=np.float32)


# Load preprocessed data from disk
X_resampled = torch.load("D:/Models_Data/X_resampled_final(artifact from normal,no tmax).pt")
y_resampled = torch.load("D:/Models_Data/y_resampled_final.pt")

X_feat = []

# Extract features from each signal window
for signal in tqdm(X_resampled, desc="Extracting features"):
    signal_np = signal.numpy()
    feats = extract_features(signal_np)
    X_feat.append(torch.tensor(feats, dtype=torch.float32))

X_feat = torch.stack(X_feat)

# Save the extracted features
torch.save(X_feat, "D:/Models_Data/X_feat.pt")

print("Saved: X_feat.pt, y_resampled.pt, X_resampled.pt, max_chans.pt")


🔍 Extracting features: 100%|███████████████████████████████████████████████| 952714/952714 [48:33:39<00:00,  5.45it/s]


✅ Saved: X_feat.pt, y_resampled.pt, X_resampled.pt, max_chans.pt


## MLP Model Training

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torch.optim import Adam
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.preprocessing import RobustScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
import joblib

# Load feature and label tensors
X_feat = torch.load("D:/Models_Data/X_feat.pt")
y_resampled = torch.load("D:/Models_Data/y_resampled_final.pt")

# Align lengths if needed
X_feat = X_feat[:len(y_resampled)]
y_resampled = y_resampled.long()

# Normalize features
scaler = RobustScaler()
X_feat_scaled_np = scaler.fit_transform(X_feat.numpy())
X_feat_scaled = torch.tensor(X_feat_scaled_np, dtype=torch.float32)

# Save the scaler for future use
joblib.dump(scaler, "scaler.pkl")
print("Scaler saved to scaler.pkl")

X_selected = X_feat_scaled
print(f"Using {X_selected.shape[1]} features extracted directly.")

# Split into training and validation sets
X_train_feat, X_val_feat, y_train, y_val = train_test_split(
    X_selected, y_resampled, test_size=0.2, stratify=y_resampled, random_state=42
)

# Create dataloaders
train_ds = TensorDataset(X_train_feat, y_train)
val_ds = TensorDataset(X_val_feat, y_val)
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=64)

# Compute class weights for imbalanced data
weights = compute_class_weight('balanced', classes=torch.unique(y_train).numpy(), y=y_train.numpy())
weights_tensor = torch.tensor(weights, dtype=torch.float32)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
best_f1 = 0.0
best_model_state = None

# Define the MLP model
class MLPOnly(nn.Module):
    def __init__(self, feat_dim=200, num_classes=2):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(feat_dim, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, num_classes)
        )

    def forward(self, x):
        return self.model(x)

# Initialize model, loss, optimizer, and scheduler
model = MLPOnly(feat_dim=X_selected.shape[1]).to(device)
criterion = nn.CrossEntropyLoss(weight=weights_tensor.to(device))
optimizer = Adam(model.parameters(), lr=1e-3)
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5, verbose=True)

# Training loop
for epoch in range(1, 100):
    model.train()
    for feat_batch, labels in train_loader:
        feat_batch, labels = feat_batch.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(feat_batch)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    # Evaluation
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for feat_batch, labels in val_loader:
            outputs = model(feat_batch.to(device))
            preds = outputs.argmax(dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    acc = accuracy_score(all_labels, all_preds)
    prec = precision_score(all_labels, all_preds, zero_division=0)
    rec = recall_score(all_labels, all_preds, zero_division=0)
    f1 = f1_score(all_labels, all_preds, zero_division=0)
    scheduler.step(f1)

    print(f"Epoch {epoch:03d} | Acc: {acc:.3f} | Prec: {prec:.3f} | Rec: {rec:.3f} | F1: {f1:.3f}")
    if f1 > best_f1:
        best_f1 = f1
        best_model_state = model.state_dict()
        print("New best model saved.")

# Save the best-performing model
torch.save(best_model_state, "D:/Models_Data/mlp_best_model.pth")
print("Best model saved to mlp_best_model.pth")


✅ Scaler saved to scaler.pkl
✅ Using 168 features extracted directly.




📊 Epoch 001 | Acc: 0.779 | Prec: 0.798 | Rec: 0.746 | F1: 0.771
✅ New best model saved.
📊 Epoch 002 | Acc: 0.788 | Prec: 0.825 | Rec: 0.731 | F1: 0.775
✅ New best model saved.
📊 Epoch 003 | Acc: 0.792 | Prec: 0.849 | Rec: 0.711 | F1: 0.774
📊 Epoch 004 | Acc: 0.798 | Prec: 0.845 | Rec: 0.729 | F1: 0.783
✅ New best model saved.
📊 Epoch 005 | Acc: 0.802 | Prec: 0.838 | Rec: 0.748 | F1: 0.790
✅ New best model saved.
📊 Epoch 006 | Acc: 0.806 | Prec: 0.828 | Rec: 0.772 | F1: 0.799
✅ New best model saved.
📊 Epoch 007 | Acc: 0.806 | Prec: 0.856 | Rec: 0.735 | F1: 0.791
📊 Epoch 008 | Acc: 0.809 | Prec: 0.855 | Rec: 0.744 | F1: 0.796
📊 Epoch 009 | Acc: 0.811 | Prec: 0.857 | Rec: 0.745 | F1: 0.797
📊 Epoch 010 | Acc: 0.812 | Prec: 0.853 | Rec: 0.753 | F1: 0.800
✅ New best model saved.
📊 Epoch 011 | Acc: 0.811 | Prec: 0.835 | Rec: 0.776 | F1: 0.804
✅ New best model saved.
📊 Epoch 012 | Acc: 0.814 | Prec: 0.859 | Rec: 0.752 | F1: 0.802
📊 Epoch 013 | Acc: 0.815 | Prec: 0.869 | Rec: 0.742 | F1: 0.801


## MLP Model Test (Eval on New unseen Data)

In [None]:
import os
import mne
import torch
import torch.nn as nn
import joblib
from tqdm import tqdm
import numpy as np
from braindecode.datasets.tuh import TUHAbnormal
from braindecode.preprocessing import Preprocessor, preprocess, create_fixed_length_windows
from braindecode.datasets.base import BaseConcatDataset
from scipy.stats import skew, kurtosis
from scipy.signal import welch
import pywt
from pycatch22 import catch22_all
from sklearn.metrics import classification_report

# Load preprocessed evaluation data
X_eval = torch.load("D:/Models_Data/X_eval.pt")
y_eval = torch.load("D:/Models_Data/y_eval.pt")

print(f"Loaded X_eval with shape {X_eval.shape}")
print(f"Loaded y_eval with shape {y_eval.shape}")

# Load selected feature names
with open("selected_feature_names.txt", "r") as f:
    selected_feature_names = set(line.strip() for line in f.readlines())

need_time = any("_time_" in f for f in selected_feature_names)
need_psd = any("_psd_" in f for f in selected_feature_names)
need_hjorth = any("_hjorth_" in f for f in selected_feature_names)
need_wavelet = any("_wavelet_" in f for f in selected_feature_names)
need_catch22 = any("_catch22_" in f for f in selected_feature_names)

# Feature extraction function
def extract_features(window):
    features = []
    for ch_idx, ch in enumerate(window):
        ch = np.nan_to_num(ch, nan=0.0, posinf=0.0, neginf=0.0)
        if np.all(ch == ch[0]) or np.std(ch) < 1e-6:
            features.extend([0.0 for name in selected_feature_names if name.startswith(f"ch_{ch_idx}_")])
            continue

        if need_time:
            if f"ch_{ch_idx}_time_mean" in selected_feature_names:
                features.append(ch.mean())
            if f"ch_{ch_idx}_time_std" in selected_feature_names:
                features.append(ch.std())
            if f"ch_{ch_idx}_time_max" in selected_feature_names:
                features.append(ch.max())
            if f"ch_{ch_idx}_time_min" in selected_feature_names:
                features.append(ch.min())
            if f"ch_{ch_idx}_time_skew" in selected_feature_names:
                features.append(skew(ch))
            if f"ch_{ch_idx}_time_kurtosis" in selected_feature_names:
                features.append(kurtosis(ch))

        if need_psd:
            freqs, psd = welch(ch, fs=100, nperseg=256)
            if f"ch_{ch_idx}_psd_delta" in selected_feature_names:
                features.append(psd[(freqs >= 0.5) & (freqs < 4)].mean())
            if f"ch_{ch_idx}_psd_theta" in selected_feature_names:
                features.append(psd[(freqs >= 4) & (freqs < 8)].mean())
            if f"ch_{ch_idx}_psd_alpha" in selected_feature_names:
                features.append(psd[(freqs >= 8) & (freqs < 13)].mean())
            if f"ch_{ch_idx}_psd_beta" in selected_feature_names:
                features.append(psd[(freqs >= 13) & (freqs < 30)].mean())
            if f"ch_{ch_idx}_psd_gamma" in selected_feature_names:
                features.append(psd[(freqs >= 30) & (freqs < 40)].mean())

        if need_hjorth:
            d1, d2 = np.diff(ch), np.diff(np.diff(ch))
            if f"ch_{ch_idx}_hjorth_var" in selected_feature_names:
                features.append(np.var(ch))
            if f"ch_{ch_idx}_hjorth_mob" in selected_feature_names:
                features.append(np.std(d1) / (np.std(ch) + 1e-8))
            if f"ch_{ch_idx}_hjorth_comp" in selected_feature_names:
                features.append(np.std(d2) / (np.std(d1) + 1e-8))

        if need_wavelet:
            coeffs = pywt.wavedec(ch, 'db4', level=3)
            for i, c in enumerate(coeffs):
                key = f"ch_{ch_idx}_wavelet_cD{i}"
                if key in selected_feature_names:
                    features.append(np.sqrt(np.sum(c ** 2)))

        if need_catch22:
            c22 = catch22_all(ch)["values"]
            for i, val in enumerate(c22):
                key = f"ch_{ch_idx}_catch22_{i}"
                if key in selected_feature_names:
                    features.append(np.nan_to_num(val, nan=0.0, posinf=0.0, neginf=0.0))

    return np.array(features, dtype=np.float32)

# Uncomment to extract features (if not yet saved)
# X_eval_feat = [torch.tensor(extract_features(sig.numpy())) for sig in tqdm(X_eval, desc="Extracting eval features")]
# X_eval_feat = torch.stack(X_eval_feat)
# torch.save(X_eval_feat, "X_eval_feat.pt")
# torch.save(y_eval, "y_eval.pt")

# Load previously saved feature data
X_eval_feat = torch.load("X_eval_feat.pt")
y_eval = torch.load("y_eval.pt")

# Scale features using the saved scaler
scaler = joblib.load("scaler.pkl")
X_eval_scaled_np = scaler.transform(X_eval_feat.numpy())
X_eval_scaled = torch.tensor(X_eval_scaled_np, dtype=torch.float32)

# Define the MLP model
class MLPOnly(nn.Module):
    def __init__(self, feat_dim=200, num_classes=2):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(feat_dim, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.4),
            nn.Linear(256, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(128, 64), nn.BatchNorm1d(64), nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(64, num_classes)
        )

    def forward(self, x):
        return self.model(x)

# Load and evaluate the model
model = MLPOnly(feat_dim=X_eval_scaled.shape[1])
model.load_state_dict(torch.load("D:/Models_Data/mlp_best_model.pth"))
model.eval()

with torch.no_grad():
    preds = model(X_eval_scaled).argmax(dim=1)

# Print evaluation report
print("\nEvaluation Report on Evaluation Dataset")
print(classification_report(y_eval, preds))


  model.load_state_dict(torch.load("D:\Models_Data\mlp_best_model.pth"))


✅ Loaded X_eval with shape torch.Size([92599, 21, 1000])
✅ Loaded y_eval with shape torch.Size([92599])

📊 Evaluation Report on Real Eval Dataset
              precision    recall  f1-score   support

       False       0.75      0.86      0.80     49872
        True       0.80      0.67      0.73     42727

    accuracy                           0.77     92599
   macro avg       0.78      0.76      0.76     92599
weighted avg       0.77      0.77      0.77     92599



### pipline

In [None]:
import os
import mne
import torch
import torch.nn as nn
import joblib
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from scipy.stats import skew, kurtosis
from scipy.signal import welch
import pywt
from pycatch22 import catch22_all
from sklearn.metrics import classification_report
from torch.nn.functional import softmax

# Define the MLP model architecture
class MLPOnly(nn.Module):
    def __init__(self, feat_dim=168, num_classes=2):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(feat_dim, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.4),
            nn.Linear(256, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(128, 64), nn.BatchNorm1d(64), nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(64, num_classes)
        )
    def forward(self, x):
        return self.model(x)

# Load pre-trained model and feature scaler
model = MLPOnly()
model.load_state_dict(torch.load(r"D:/Models_Data/mlp_best_model.pth"))
model.eval()
scaler = joblib.load("scaler.pkl")

# Load EEG data from EDF file
edf_path = r"C:/Users/obass/Desktop/linux_shared/v3.0.1/edf/eval/abnormal/01_tcp_ar/aaaaabdo_s003_t000.edf"
raw = mne.io.read_raw_edf(edf_path, preload=True)
raw.pick_types(eeg=True)
raw.filter(0.5, 40.)
raw._data *= 1e6  # Convert signal to microvolts

# Select a fixed set of EEG channels
fixed_channels = [
    'EEG FP1-REF', 'EEG FP2-REF', 'EEG F3-REF', 'EEG F4-REF',
    'EEG C3-REF', 'EEG C4-REF', 'EEG P3-REF', 'EEG P4-REF',
    'EEG O1-REF', 'EEG O2-REF', 'EEG F7-REF', 'EEG F8-REF',
    'EEG T3-REF', 'EEG T4-REF', 'EEG T5-REF', 'EEG T6-REF',
    'EEG A1-REF', 'EEG A2-REF', 'EEG FZ-REF', 'EEG CZ-REF', 'EEG PZ-REF'
]
raw.pick_channels([ch for ch in fixed_channels if ch in raw.ch_names])

# Warn if any expected channels are missing
if len(raw.ch_names) < len(fixed_channels):
    print("Warning: Some expected EEG channels are missing from the EDF file.")

# Create fixed-length non-overlapping windows from the EEG data
X = raw.get_data()
sampling_rate = 100
window_size = 1000
stride = 1000
num_windows = (X.shape[1] - window_size) // stride

X_windows = []
for i in range(num_windows):
    start = i * stride
    end = start + window_size
    win = X[:, start:end]
    if win.shape[1] == window_size:
        X_windows.append(torch.tensor(win, dtype=torch.float32))
X_eval = torch.stack(X_windows)

# Load names of selected features to extract
with open("selected_feature_names.txt", "r") as f:
    selected_feature_names = set(line.strip() for line in f.readlines())

# Determine which types of features are needed
need_time = any("_time_" in f for f in selected_feature_names)
need_psd = any("_psd_" in f for f in selected_feature_names)
need_hjorth = any("_hjorth_" in f for f in selected_feature_names)
need_wavelet = any("_wavelet_" in f for f in selected_feature_names)
need_catch22 = any("_catch22_" in f for f in selected_feature_names)

feature_freq_bands = []

# Feature extraction function
def extract_features(window):
    features = []
    for ch_idx, ch in enumerate(window):
        ch = np.nan_to_num(ch, nan=0.0, posinf=0.0, neginf=0.0)
        if np.all(ch == ch[0]) or np.std(ch) < 1e-6:
            features.extend([0.0 for name in selected_feature_names if name.startswith(f"ch_{ch_idx}_")])
            continue
        if need_time:
            features += [
                ch.mean() if f"ch_{ch_idx}_time_mean" in selected_feature_names else None,
                ch.std() if f"ch_{ch_idx}_time_std" in selected_feature_names else None,
                ch.max() if f"ch_{ch_idx}_time_max" in selected_feature_names else None,
                ch.min() if f"ch_{ch_idx}_time_min" in selected_feature_names else None,
                skew(ch) if f"ch_{ch_idx}_time_skew" in selected_feature_names else None,
                kurtosis(ch) if f"ch_{ch_idx}_time_kurtosis" in selected_feature_names else None
            ]
        if need_psd:
            freqs, psd = welch(ch, fs=sampling_rate, nperseg=256)
            band_power = {
                "delta": psd[(freqs >= 0.5) & (freqs < 4)].mean(),
                "theta": psd[(freqs >= 4) & (freqs < 8)].mean(),
                "alpha": psd[(freqs >= 8) & (freqs < 13)].mean(),
                "beta": psd[(freqs >= 13) & (freqs < 30)].mean(),
                "gamma": psd[(freqs >= 30) & (freqs < 40)].mean()
            }
            feature_freq_bands.append(band_power)
            features += [
                band_power['delta'] if f"ch_{ch_idx}_psd_delta" in selected_feature_names else None,
                band_power['theta'] if f"ch_{ch_idx}_psd_theta" in selected_feature_names else None,
                band_power['alpha'] if f"ch_{ch_idx}_psd_alpha" in selected_feature_names else None,
                band_power['beta'] if f"ch_{ch_idx}_psd_beta" in selected_feature_names else None,
                band_power['gamma'] if f"ch_{ch_idx}_psd_gamma" in selected_feature_names else None
            ]
        if need_hjorth:
            d1, d2 = np.diff(ch), np.diff(np.diff(ch))
            features += [
                np.var(ch) if f"ch_{ch_idx}_hjorth_var" in selected_feature_names else None,
                np.std(d1) / (np.std(ch) + 1e-8) if f"ch_{ch_idx}_hjorth_mob" in selected_feature_names else None,
                np.std(d2) / (np.std(d1) + 1e-8) if f"ch_{ch_idx}_hjorth_comp" in selected_feature_names else None
            ]
        if need_wavelet:
            coeffs = pywt.wavedec(ch, 'db4', level=3)
            for i, c in enumerate(coeffs):
                key = f"ch_{ch_idx}_wavelet_cD{i}"
                if key in selected_feature_names:
                    features.append(np.sqrt(np.sum(c ** 2)))
        if need_catch22:
            c22 = catch22_all(ch)["values"]
            for i, val in enumerate(c22):
                key = f"ch_{ch_idx}_catch22_{i}"
                if key in selected_feature_names:
                    features.append(np.nan_to_num(val, nan=0.0, posinf=0.0, neginf=0.0))
    return np.array([f for f in features if f is not None], dtype=np.float32)

# Extract features for each window
X_feat = [extract_features(w.numpy()) for w in tqdm(X_eval, desc="Extracting features")]
X_feat = torch.tensor(np.stack(X_feat), dtype=torch.float32)

# Normalize features using the pre-fitted scaler
X_scaled = torch.tensor(scaler.transform(X_feat.numpy()), dtype=torch.float32)

# Run model inference
with torch.no_grad():
    logits = model(X_scaled)
    probs = softmax(logits, dim=1)
    preds = probs.argmax(dim=1)
    confidences = probs.max(dim=1).values

# Identify windows classified as abnormal with high confidence
abnormal_windows = (preds == 1) & (confidences > 0.9)
abnormal_indices = torch.nonzero(abnormal_windows).cpu().numpy()
if abnormal_indices.ndim == 0:
    abnormal_indices = abnormal_indices.reshape(1)

# Report number of abnormal windows detected
print("Abnormal segments with confidence > 0.9:", len(abnormal_indices))


Extracting EDF parameters from C:\Users\obass\Desktop\linux_shared\v3.0.1\edf\eval\abnormal\01_tcp_ar\aaaaabdo_s003_t000.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 304249  =      0.000 ...  1216.996 secs...
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 40 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 40.00 Hz
- Upper transition bandwidth: 10.00 Hz (-6 dB cutoff frequency: 45.00 Hz)
- Filter length: 1651 samples (6.604 s)

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s
🧪 Extracting features: 100%|████████████████████████████████████████████████████████| 303/303 [00:43<00:00,  7.04it/s]


📊 Abnormal segments (conf > 0.9): 153



