# ASZED EEG Classification Training (v2.3.0) - Kaggle

**Dataset:** African Schizophrenia EEG Dataset (ASZED-153)  
**Paper DOI:** 10.1016/j.dib.2025.111934  
**Data DOI:** 10.5281/zenodo.14178398

This notebook automatically downloads the ASZED dataset from Zenodo - no manual upload needed!

## 1. Install Dependencies

In [None]:
!pip install -q mne zenodo_get

In [None]:
# Check available resources
!nvidia-smi
!cat /proc/cpuinfo | grep 'model name' | head -1
!free -h

## 2. Download ASZED Dataset from Zenodo

This will download ~2GB of data. Takes about 5-10 minutes.

In [None]:
import os
import subprocess

# Download ASZED from Zenodo (DOI: 10.5281/zenodo.14178398)
ZENODO_RECORD = "14178398"
DOWNLOAD_DIR = "/kaggle/working/aszed_download"

os.makedirs(DOWNLOAD_DIR, exist_ok=True)
os.chdir(DOWNLOAD_DIR)

print("Downloading ASZED-153 dataset from Zenodo...")
print("This may take 5-10 minutes for ~2GB of data...\n")

!zenodo_get {ZENODO_RECORD}

print("\nDownload complete! Checking files...")
!ls -la

In [None]:
# Extract zip files if needed
import zipfile
import glob

os.chdir(DOWNLOAD_DIR)

# Extract any zip files
for zf in glob.glob("*.zip"):
    print(f"Extracting {zf}...")
    with zipfile.ZipFile(zf, 'r') as z:
        z.extractall(".")

# Find the paths
print("\nSearching for EDF files and CSV...")
!find . -name "*.edf" | head -5
!find . -name "*.csv"

# Set paths (adjust based on actual structure)
ASZED_DATA_PATH = "/kaggle/working/aszed_download/ASZED/version_1.1"
CSV_PATH = "/kaggle/working/aszed_download/ASZED_SpreadSheet.csv"
OUTPUT_PATH = "/kaggle/working"

# Try alternate paths if needed
if not os.path.exists(ASZED_DATA_PATH):
    # Search for version folder
    for root, dirs, files in os.walk(DOWNLOAD_DIR):
        if "version_1.1" in dirs:
            ASZED_DATA_PATH = os.path.join(root, "version_1.1")
            break
        if any(f.endswith('.edf') for f in files):
            ASZED_DATA_PATH = root
            break

if not os.path.exists(CSV_PATH):
    # Search for CSV
    csv_files = glob.glob(f"{DOWNLOAD_DIR}/**/*SpreadSheet*.csv", recursive=True)
    if csv_files:
        CSV_PATH = csv_files[0]
    else:
        csv_files = glob.glob(f"{DOWNLOAD_DIR}/**/*.csv", recursive=True)
        if csv_files:
            CSV_PATH = csv_files[0]

print(f"\nASZED_DATA_PATH: {ASZED_DATA_PATH}")
print(f"CSV_PATH: {CSV_PATH}")
print(f"OUTPUT_PATH: {OUTPUT_PATH}")
print(f"\nData path exists: {os.path.exists(ASZED_DATA_PATH)}")
print(f"CSV exists: {os.path.exists(CSV_PATH)}")

## 3. Training Pipeline (v2.3.0)

Run this cell to load the training functions.

In [None]:
"""
ASZED EEG Classification Pipeline (v2.3.0) - Kaggle Version

Authoritative Reference:
  Mosaku et al. (2025). "An open-access EEG dataset from indigenous African
  populations for schizophrenia research." Data in Brief, 62, 111934.
  DOI: 10.1016/j.dib.2025.111934

Channel Montage (per Data in Brief paper):
  Fp1, Fp2, F3, F4, F7, F8, C3, C4, Cz, T3, T4, T5, T6, P3, P4, Pz
"""

import os
os.environ["NUMBA_CACHE_DIR"] = "/tmp"
os.environ["NUMBA_DISABLE_CACHING"] = "1"

import re
import warnings
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
from scipy import signal, stats
import mne

try:
    from scipy.integrate import simpson
except ImportError:
    from scipy.integrate import simps as simpson

try:
    from numpy import trapezoid as np_trapz
except ImportError:
    from numpy import trapz as np_trapz

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import StratifiedKFold
from sklearn.pipeline import Pipeline
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score

from joblib import Parallel, delayed, dump
from tqdm import tqdm

warnings.filterwarnings("ignore")
mne.set_log_level("ERROR")

N_JOBS = 4  # Kaggle typically has 4 CPU cores

# Correct channel montage per Data in Brief paper
EXPECTED_CHANNELS = [
    "Fp1", "Fp2", "F3", "F4", "F7", "F8", "C3", "C4",
    "Cz", "T3", "T4", "T5", "T6", "P3", "P4", "Pz"
]

CHANNEL_ALIASES = {
    "T7": "T3", "T8": "T4", "P7": "T5", "P8": "T6",
    "FP1": "Fp1", "FP2": "Fp2",
}

MIN_CHANNELS_REQUIRED = 10
ENTROPY_SAMPLE_SIZE = 250


def normalize_subject_id(sid) -> str:
    try:
        if pd.isna(sid):
            return ""
    except Exception:
        pass
    if isinstance(sid, (int, np.integer)):
        return str(int(sid))
    if isinstance(sid, (float, np.floating)):
        if float(sid).is_integer():
            return str(int(sid))
    s = str(sid).strip().lower()
    if s in ("", "nan", "none"):
        return ""
    s = re.sub(r"^subject_", "", s)
    if re.fullmatch(r"\d+(\.0+)?", s):
        return str(int(float(s)))
    groups = re.findall(r"\d+", s)
    if groups:
        return str(int(groups[-1]))
    return s


def normalize_column_name(col: str) -> str:
    return str(col).strip().lower()


def canonicalize_channel_name(ch: str) -> str:
    original = str(ch).strip()
    result = original
    result = re.sub(r"^EEG[-_ ]?", "", result, flags=re.IGNORECASE)
    result = re.sub(r"\[\d+\]$", "", result)
    result = re.sub(r"[-_ ]+(REF|A1|A2|M1|M2|LE|AVG|CZ)$", "", result, flags=re.IGNORECASE)
    result = result.strip("-_ ")
    if not result:
        result = original.strip()
    return result


class Config:
    def __init__(self, aszed_dir, csv_path, output_path):
        self.ASZED_DIR = Path(aszed_dir)
        self.CSV_PATH = Path(csv_path)
        self.OUTPUT_PATH = Path(output_path)
        self.SAMPLING_RATE = 250
        self.N_CHANNELS = 16
        self.BANDS = {
            "delta": (0.5, 4), "theta": (4, 8), "alpha": (8, 13),
            "beta": (13, 30), "gamma": (30, 45),
        }
        self.ERP_WINDOWS = {
            "N100": (20, 30), "P200": (37, 62), "MMN": (25, 62), "P300": (62, 125)
        }
        self.FILTER_LOW = 0.5
        self.FILTER_HIGH = 45
        self.NOTCH_FREQ = 50


# Feature extraction functions
def extract_spectral_power(data, fs, bands):
    features = []
    for ch in range(data.shape[0]):
        freqs, psd = signal.welch(data[ch], fs=fs, nperseg=min(256, data.shape[1]))
        for low, high in bands.values():
            idx = (freqs >= low) & (freqs <= high)
            try:
                features.append(simpson(psd[idx], x=freqs[idx]) if idx.any() else 0)
            except:
                features.append(np_trapz(psd[idx], freqs[idx]) if idx.any() else 0)
    return np.array(features[:80])


def extract_erp_components(data, fs, windows):
    features = []
    avg = np.mean(data, axis=0)
    for comp, (s, e) in windows.items():
        if e < len(avg):
            w = avg[s:e]
            if comp in ["N100", "MMN"]:
                pa, pi = (np.min(w), int(np.argmin(w))) if len(w) else (0, 0)
            else:
                pa, pi = (np.max(w), int(np.argmax(w))) if len(w) else (0, 0)
            features.extend([pa, (s + pi) / fs * 1000 if fs else 0])
        else:
            features.extend([0, 0])
    for s, e in windows.values():
        if e < len(avg):
            seg = avg[s:e]
            features.extend([float(np.mean(seg)), float(np.std(seg)), float(np_trapz(seg))])
        else:
            features.extend([0, 0, 0])
    return np.array(features[:20])


def compute_coherence(data, fs, bands):
    features = []
    pairs = [(0, 1), (2, 3), (6, 7), (9, 10), (13, 14), (11, 12)]
    for c1, c2 in pairs:
        if np.allclose(data[c1], 0) or np.allclose(data[c2], 0):
            features.extend([0.0] * len(bands))
            continue
        try:
            f, coh = signal.coherence(data[c1], data[c2], fs=fs, nperseg=min(256, data.shape[1]))
            for low, high in bands.values():
                idx = (f >= low) & (f <= high)
                features.append(float(np.mean(coh[idx])) if idx.any() else 0.0)
        except:
            features.extend([0.0] * len(bands))
    return np.array(features[:30])


def compute_pli(data):
    features = []
    pairs = [(0, 1), (2, 3), (6, 7), (9, 10), (13, 14), (11, 12)]
    for c1, c2 in pairs:
        if np.allclose(data[c1], 0) or np.allclose(data[c2], 0):
            features.append(0.0)
            continue
        try:
            a1, a2 = signal.hilbert(data[c1]), signal.hilbert(data[c2])
            phase_diff = np.angle(a1) - np.angle(a2)
            features.append(float(np.abs(np.mean(np.sign(np.sin(phase_diff))))))
        except:
            features.append(0.0)
    return np.array(features[:6])


def extract_stats(data):
    features = []
    for ch in range(data.shape[0]):
        d = data[ch]
        features.extend([
            float(np.mean(d)), float(np.std(d)),
            float(stats.skew(d)) if np.std(d) > 0 else 0.0,
            float(stats.kurtosis(d)) if np.std(d) > 0 else 0.0,
            float(np.sqrt(np.mean(d ** 2))), float(np.ptp(d)),
        ])
    return np.array(features[:96])


def compute_entropy(data, m=2, r=0.2):
    features = []
    for ch in range(data.shape[0]):
        d = data[ch][:ENTROPY_SAMPLE_SIZE] if len(data[ch]) > ENTROPY_SAMPLE_SIZE else data[ch]
        if np.std(d) > 0:
            d = (d - np.mean(d)) / np.std(d)
            N = len(d)
            def count_matches(tlen):
                count = 0
                templates = [d[i:i + tlen] for i in range(N - tlen)]
                for i in range(len(templates)):
                    for j in range(i + 1, len(templates)):
                        if np.max(np.abs(templates[i] - templates[j])) < r:
                            count += 1
                return count
            try:
                B, A = count_matches(m), count_matches(m + 1)
                features.append(float(-np.log(A / B)) if A > 0 and B > 0 else 0.0)
            except:
                features.append(0.0)
        else:
            features.append(0.0)
    return np.array(features[:16])


def compute_fd(data, kmax=10):
    features = []
    for ch in range(data.shape[0]):
        d = data[ch]
        N = len(d)
        if np.allclose(d, 0) or N <= kmax * 2:
            features.append(0.0)
            continue
        L, x = [], []
        for k in range(1, min(kmax + 1, N // 2)):
            Lk = 0.0
            for m in range(k):
                mx = int(np.floor((N - m - 1) / k))
                if mx > 0:
                    Lmk = 0.0
                    for i in range(1, mx + 1):
                        i1, i2 = m + i * k, m + (i - 1) * k
                        if i1 < N and i2 < N:
                            Lmk += np.abs(d[i1] - d[i2])
                    Lmk = Lmk * (N - 1) / (mx * k * k)
                    Lk += Lmk
            if Lk > 0:
                L.append(np.log(Lk / k))
                x.append(np.log(1.0 / k))
        if len(x) > 1:
            try:
                features.append(float(np.polyfit(x, L, 1)[0]))
            except:
                features.append(0.0)
        else:
            features.append(0.0)
    return np.array(features[:16])


def extract_all_features(data, fs, config):
    f = []
    f.extend(extract_spectral_power(data, fs, config.BANDS))
    f.extend(extract_erp_components(data, fs, config.ERP_WINDOWS))
    f.extend(compute_coherence(data, fs, config.BANDS))
    f.extend(compute_pli(data))
    f.extend(extract_stats(data))
    f.extend(compute_entropy(data))
    f.extend(compute_fd(data))
    return np.array(f, dtype=float)


def load_labels(csv_path: Path):
    print(f"\nLoading labels from: {csv_path}")
    for encoding in ["utf-8-sig", "utf-8", "latin-1", "cp1252"]:
        try:
            df = pd.read_csv(csv_path, encoding=encoding, dtype=str)
            break
        except:
            continue
    else:
        raise ValueError(f"Could not read CSV: {csv_path}")
    
    df.columns = [normalize_column_name(c) for c in df.columns]
    print(f"  Columns: {list(df.columns)}")
    
    sn_col = None
    for c in ["sn", "subject", "subject_id", "id"]:
        if c in df.columns:
            sn_col = c
            break
    
    cat_col = None
    for c in ["category", "group", "diagnosis", "label"]:
        if c in df.columns:
            cat_col = c
            break
    
    if not sn_col or not cat_col:
        raise ValueError(f"Missing columns. Found: {list(df.columns)}")
    
    label_map = {}
    for _, row in df.iterrows():
        sid = normalize_subject_id(row.get(sn_col, ""))
        if not sid:
            continue
        cat = str(row.get(cat_col, "")).lower().strip()
        if "control" in cat or "hc" in cat or "healthy" in cat:
            label_map[sid] = 0
        elif "patient" in cat or "schiz" in cat or "sz" in cat:
            label_map[sid] = 1
    
    n_ctrl = sum(1 for v in label_map.values() if v == 0)
    n_pat = sum(1 for v in label_map.values() if v == 1)
    print(f"  Mapped {len(label_map)} subjects ({n_ctrl} controls, {n_pat} patients)")
    return label_map


def find_files(aszed_dir: Path):
    print(f"\nScanning: {aszed_dir}")
    files = []
    for ext in ["*.edf", "*.EDF", "*.bdf", "*.BDF"]:
        files.extend(aszed_dir.rglob(ext))
    print(f"  Found {len(files)} EEG files")
    
    pairs = []
    for f in files:
        sid = None
        for part in f.parts:
            part_lower = str(part).lower()
            if part_lower.startswith("subject_") or part_lower.startswith("sub_"):
                sid = normalize_subject_id(part)
                break
        if sid:
            pairs.append((f, sid))
    
    print(f"  Unique subjects: {len(set(s for _, s in pairs))}")
    return pairs


def standardize_to_16ch_matrix(raw, expected_channels, aliases):
    raw_to_canonical = {}
    seen_canonical = set()
    
    for ch in raw.ch_names:
        base = canonicalize_channel_name(ch)
        canonical = None
        for alias, canon in aliases.items():
            if base.upper() == alias.upper():
                canonical = canon
                break
        if canonical is None:
            for exp in expected_channels:
                if base.lower() == exp.lower():
                    canonical = exp
                    break
        if canonical and canonical not in seen_canonical:
            raw_to_canonical[ch] = canonical
            seen_canonical.add(canonical)
    
    canonical_to_raw = {v: k for k, v in raw_to_canonical.items()}
    data = []
    channels_found = 0
    
    for exp_ch in expected_channels:
        if exp_ch in canonical_to_raw:
            raw_ch = canonical_to_raw[exp_ch]
            data.append(raw.get_data(picks=[raw_ch])[0])
            channels_found += 1
        else:
            data.append(np.zeros(raw.n_times, dtype=float))
    
    return np.vstack(data), channels_found


def load_eeg(file_path: Path, target_fs=250):
    ext = file_path.suffix.lower()
    if ext == ".bdf":
        raw = mne.io.read_raw_bdf(str(file_path), preload=True, verbose="ERROR")
    else:
        raw = mne.io.read_raw_edf(str(file_path), preload=True, verbose="ERROR")
    
    try:
        raw_eeg = raw.copy()
        raw_eeg.pick_types(eeg=True, exclude="bads")
        if len(raw_eeg.ch_names) > 0:
            raw = raw_eeg
    except:
        pass
    
    if len(raw.ch_names) == 0:
        raise ValueError("No EEG channels found")
    
    fs = float(raw.info["sfreq"])
    if fs != target_fs:
        raw.resample(target_fs)
    
    data, n_channels = standardize_to_16ch_matrix(raw, EXPECTED_CHANNELS, CHANNEL_ALIASES)
    return data, target_fs, n_channels


def preprocess(data, fs, config):
    out = []
    for ch in data:
        if np.allclose(ch, 0):
            out.append(ch)
            continue
        ch = ch - np.mean(ch)
        try:
            nyq = fs / 2.0
            low = config.FILTER_LOW / nyq
            high = min(config.FILTER_HIGH / nyq, 0.99)
            b, a = signal.butter(4, [low, high], "band")
            ch = signal.filtfilt(b, a, ch)
        except:
            pass
        if config.NOTCH_FREQ < min(config.FILTER_HIGH, fs / 2.0):
            try:
                b, a = signal.iirnotch(config.NOTCH_FREQ, 30, fs=fs)
                ch = signal.filtfilt(b, a, ch)
            except:
                pass
        out.append(ch)
    return np.array(out)


def process_single_file(fp, sid, label_map, config):
    warnings.filterwarnings("ignore")
    
    label = label_map.get(sid, -1)
    try:
        if sid not in label_map:
            return {"status": "no_label", "subject_id": sid, "label": -1}
        
        data, fs, n_ch = load_eeg(fp, config.SAMPLING_RATE)
        
        if data.shape[1] < 500:
            return {"status": "too_short", "subject_id": sid, "label": label}
        
        if n_ch < MIN_CHANNELS_REQUIRED:
            return {"status": "low_channels", "subject_id": sid, "label": label}
        
        data = preprocess(data, fs, config)
        feat = extract_all_features(data, fs, config)
        feat = np.nan_to_num(feat, nan=0.0, posinf=0.0, neginf=0.0)
        
        return {"status": "ok", "features": feat, "label": label, "subject_id": sid, "n_channels": n_ch}
    except Exception as e:
        return {"status": "error", "subject_id": sid, "label": label, "error": str(e)}


print("Pipeline code loaded successfully!")

## 4. Process Dataset

In [None]:
print("=" * 60)
print("ASZED-153 PREPROCESSING (v2.3.0)")
print("=" * 60)

config = Config(ASZED_DATA_PATH, CSV_PATH, OUTPUT_PATH)

label_map = load_labels(config.CSV_PATH)
pairs = find_files(config.ASZED_DIR)

if not pairs:
    raise RuntimeError("No EEG files found - check your paths!")

print(f"\nProcessing {len(pairs)} files...")

results = Parallel(n_jobs=N_JOBS, backend="loky")(
    delayed(process_single_file)(fp, sid, label_map, config)
    for fp, sid in tqdm(pairs, desc="Processing")
)

results = [r for r in results if isinstance(r, dict)]
valid = [r for r in results if r.get("status") == "ok"]

print(f"\n  Accepted: {len(valid)}")
print(f"  Rejected: {len(results) - len(valid)}")

if not valid:
    raise RuntimeError("No samples processed successfully")

X = np.array([r["features"] for r in valid], dtype=float)
y = np.array([r["label"] for r in valid], dtype=int)
subject_ids = np.array([r["subject_id"] for r in valid], dtype=str)

n_subjects = len(set(subject_ids))
print(f"\n  Recordings: {len(y)} ({(y == 0).sum()} ctrl, {(y == 1).sum()} patient)")
print(f"  Subjects: {n_subjects}")
print(f"  Features: {X.shape[1]}")

## 5. Subject-Level Cross-Validation

In [None]:
print("\n" + "=" * 60)
print("SUBJECT-LEVEL CROSS-VALIDATION")
print("=" * 60)

n_splits = 5
random_state = 42

# Build subject table
unique_subjects = sorted(set(subject_ids))
subject_labels = {}
for i, subj in enumerate(subject_ids):
    if subj not in subject_labels:
        subject_labels[subj] = y[i]

subject_y = np.array([subject_labels[s] for s in unique_subjects])

# Stratified split on subjects
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state)

# Convert subject folds to recording indices
folds = []
for train_subj_idx, test_subj_idx in skf.split(unique_subjects, subject_y):
    train_subjects = set(unique_subjects[i] for i in train_subj_idx)
    test_subjects = set(unique_subjects[i] for i in test_subj_idx)
    train_idx = np.array([i for i, g in enumerate(subject_ids) if g in train_subjects])
    test_idx = np.array([i for i, g in enumerate(subject_ids) if g in test_subjects])
    folds.append((train_idx, test_idx))

all_y_true, all_y_pred, all_y_prob, all_subjects = [], [], [], []

for fold_idx, (train_idx, test_idx) in enumerate(folds):
    X_train, X_test = X[train_idx], X[test_idx]
    y_train, y_test = y[train_idx], y[test_idx]
    
    model = Pipeline([
        ("scaler", StandardScaler()),
        ("clf", RandomForestClassifier(n_estimators=300, max_depth=20, min_samples_split=5, 
                                        random_state=random_state, n_jobs=-1)),
    ])
    
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)
    y_prob = model.predict_proba(X_test)[:, 1]
    
    all_y_true.extend(y_test.tolist())
    all_y_pred.extend(y_pred.tolist())
    all_y_prob.extend(y_prob.tolist())
    all_subjects.extend(subject_ids[test_idx].tolist())
    
    acc = accuracy_score(y_test, y_pred)
    print(f"  Fold {fold_idx + 1}: acc={acc:.3f}")

# Subject-level aggregation
df = pd.DataFrame({"subject": all_subjects, "y_true": all_y_true, "y_prob": all_y_prob})
agg = df.groupby("subject").agg(y_true=("y_true", "first"), y_prob_mean=("y_prob", "mean")).reset_index()
agg["y_pred"] = (agg["y_prob_mean"] >= 0.5).astype(int)

subj_acc = accuracy_score(agg["y_true"], agg["y_pred"])
rec_acc = accuracy_score(all_y_true, [1 if p >= 0.5 else 0 for p in all_y_prob])

print(f"\n  Recording-level accuracy: {rec_acc*100:.1f}%")
print(f"  Subject-level accuracy: {subj_acc*100:.1f}%")

## 6. Train Final Model

In [None]:
print("\n" + "=" * 60)
print("TRAINING FINAL MODEL")
print("=" * 60)

final_model = Pipeline([
    ("scaler", StandardScaler()),
    ("clf", RandomForestClassifier(n_estimators=300, max_depth=20, min_samples_split=5, 
                                    random_state=42, n_jobs=-1)),
])

final_model.fit(X, y)

# Save model
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_path = f"{OUTPUT_PATH}/schizophrenia_model_v230_{timestamp}.pkl"
dump(final_model, model_path)

# Also save with standard backend name
backend_model_path = f"{OUTPUT_PATH}/schizophrenia_backend_model.pkl"
dump(final_model, backend_model_path)

print(f"\n  Model saved: {model_path}")
print(f"  Backend model: {backend_model_path}")

## 7. Verify & Download

In [None]:
# Verify model works
from joblib import load

test_model = load(backend_model_path)
test_input = np.random.randn(1, 264)
test_pred = test_model.predict_proba(test_input)

print("Model verification:")
print(f"  Input shape: {test_input.shape}")
print(f"  Output shape: {test_pred.shape}")
print(f"  Test prediction: {test_pred}")
print("\n✓ Model works correctly!")

In [None]:
# List output files
print("\nOutput files (available in 'Output' tab on right):")
!ls -la /kaggle/working/*.pkl

## Done!

Your model is saved. To download:

1. Click **"Save Version"** (top right) → **"Save & Run All"**
2. After it completes, go to your notebook page
3. Click the **"Output"** tab on the right
4. Download `schizophrenia_backend_model.pkl`
5. Place it in `mind-bloom/backend/`

Or use Kaggle CLI:
```
kaggle kernels output milkandcoding/aszed-training-kaggle -p ./
```