# ASZED EEG Classification Training (v2.3.0)

**Dataset:** African Schizophrenia EEG Dataset (ASZED-153)  
**Paper DOI:** 10.1016/j.dib.2025.111934  
**Data DOI:** 10.5281/zenodo.14178398

This notebook trains a Random Forest classifier for schizophrenia detection using EEG data.

## 1. Setup & Mount Google Drive

In [2]:
# Mount Google Drive
# If this fails, try running again or check for popup authentication window
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

KeyboardInterrupt: 

In [None]:
# Install required packages
!pip install -q mne scipy scikit-learn pandas numpy joblib tqdm

In [None]:
# Check GPU (optional - not required for Random Forest)
!nvidia-smi

## 2. Configure Paths

**IMPORTANT:** Update these paths to match your Google Drive structure!

In [None]:
# ============================================================
# PATHS FOR YOUR GOOGLE DRIVE (from previous session)
# ============================================================

# Path to ASZED EDF files
ASZED_DATA_PATH = '/content/drive/MyDrive/For Project Schiz/ASZED-153/ASZED/version_1.1'

# Path to the CSV spreadsheet with labels
CSV_PATH = '/content/drive/MyDrive/For Project Schiz/ASZED-153/ASZED_SpreadSheet.csv'

# Output path for results and model
OUTPUT_PATH = '/content/drive/MyDrive/For Project Schiz/ASZED-153/results_v230'

# Verify paths exist
import os
print(f"ASZED data path exists: {os.path.exists(ASZED_DATA_PATH)}")
print(f"CSV path exists: {os.path.exists(CSV_PATH)}")

# List contents to verify
if os.path.exists(ASZED_DATA_PATH):
    print(f"\nContents of {ASZED_DATA_PATH}:")
    !ls -la "{ASZED_DATA_PATH}" | head -20

## 3. Training Pipeline Code (v2.3.0)

In [None]:
%%writefile /content/complete_v230.py
"""
ASZED EEG Classification Pipeline (v2.3.0) - Colab Version

Authoritative Reference:
  Mosaku et al. (2025). "An open-access EEG dataset from indigenous African
  populations for schizophrenia research." Data in Brief, 62, 111934.
  DOI: 10.1016/j.dib.2025.111934

Channel Montage (per Data in Brief paper):
  Fp1, Fp2, F3, F4, F7, F8, C3, C4, Cz, T3, T4, T5, T6, P3, P4, Pz
"""

import os
os.environ["NUMBA_CACHE_DIR"] = "/tmp"
os.environ["NUMBA_DISABLE_CACHING"] = "1"

import re
import sys
import json
import time
import warnings
import argparse
import traceback
from pathlib import Path
from datetime import datetime
from collections import Counter, defaultdict

import numpy as np
import pandas as pd
from scipy import signal, stats

try:
    from scipy.integrate import simpson
except ImportError:
    from scipy.integrate import simps as simpson

try:
    from numpy import trapezoid as np_trapz
except ImportError:
    from numpy import trapz as np_trapz

try:
    from scipy.stats import fisher_exact
    FISHER_AVAILABLE = True
except ImportError:
    FISHER_AVAILABLE = False

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import StratifiedKFold
from sklearn.pipeline import Pipeline
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.metrics import (
    accuracy_score, classification_report, confusion_matrix,
    roc_auc_score, f1_score,
)

from joblib import Parallel, delayed, dump
from tqdm import tqdm

warnings.filterwarnings("ignore")

N_JOBS = 2  # Colab has 2 CPU cores
PRIMARY_MODEL_NAME = "Random Forest"

# Correct channel montage per Data in Brief paper
EXPECTED_CHANNELS = [
    "Fp1", "Fp2", "F3", "F4", "F7", "F8", "C3", "C4",
    "Cz", "T3", "T4", "T5", "T6", "P3", "P4", "Pz"
]

CHANNEL_ALIASES = {
    "T7": "T3", "T8": "T4", "P7": "T5", "P8": "T6",
    "FP1": "Fp1", "FP2": "Fp2",
}

MIN_CHANNELS_REQUIRED = 10
ENTROPY_SAMPLE_SIZE = 250


def normalize_subject_id(sid) -> str:
    try:
        if pd.isna(sid):
            return ""
    except Exception:
        pass
    if isinstance(sid, (int, np.integer)):
        return str(int(sid))
    if isinstance(sid, (float, np.floating)):
        if float(sid).is_integer():
            return str(int(sid))
    s = str(sid).strip().lower()
    if s in ("", "nan", "none"):
        return ""
    s = re.sub(r"^subject_", "", s)
    if re.fullmatch(r"\d+(\.0+)?", s):
        return str(int(float(s)))
    groups = re.findall(r"\d+", s)
    if groups:
        return str(int(groups[-1]))
    return s


def normalize_column_name(col: str) -> str:
    return str(col).strip().lower()


def canonicalize_channel_name(ch: str) -> str:
    original = str(ch).strip()
    result = original
    result = re.sub(r"^EEG[-_ ]?", "", result, flags=re.IGNORECASE)
    result = re.sub(r"\[\d+\]$", "", result)
    result = re.sub(r"[-_ ]+(REF|A1|A2|M1|M2|LE|AVG|CZ)$", "", result, flags=re.IGNORECASE)
    result = result.strip("-_ ")
    if not result:
        result = original.strip()
    return result


class Config:
    def __init__(self, aszed_dir, csv_path, output_path):
        self.ASZED_DIR = Path(aszed_dir)
        self.CSV_PATH = Path(csv_path)
        self.OUTPUT_PATH = Path(output_path)
        self.SAMPLING_RATE = 250
        self.N_CHANNELS = 16
        self.BANDS = {
            "delta": (0.5, 4), "theta": (4, 8), "alpha": (8, 13),
            "beta": (13, 30), "gamma": (30, 45),
        }
        self.ERP_WINDOWS = {
            "N100": (20, 30), "P200": (37, 62), "MMN": (25, 62), "P300": (62, 125)
        }
        self.FILTER_LOW = 0.5
        self.FILTER_HIGH = 45
        self.NOTCH_FREQ = 50


# Feature extraction functions
def extract_spectral_power(data, fs, bands):
    features = []
    for ch in range(data.shape[0]):
        freqs, psd = signal.welch(data[ch], fs=fs, nperseg=min(256, data.shape[1]))
        for low, high in bands.values():
            idx = (freqs >= low) & (freqs <= high)
            try:
                features.append(simpson(psd[idx], x=freqs[idx]) if idx.any() else 0)
            except:
                features.append(np_trapz(psd[idx], freqs[idx]) if idx.any() else 0)
    return np.array(features[:80])


def extract_erp_components(data, fs, windows):
    features = []
    avg = np.mean(data, axis=0)
    for comp, (s, e) in windows.items():
        if e < len(avg):
            w = avg[s:e]
            if comp in ["N100", "MMN"]:
                pa, pi = (np.min(w), int(np.argmin(w))) if len(w) else (0, 0)
            else:
                pa, pi = (np.max(w), int(np.argmax(w))) if len(w) else (0, 0)
            features.extend([pa, (s + pi) / fs * 1000 if fs else 0])
        else:
            features.extend([0, 0])
    for s, e in windows.values():
        if e < len(avg):
            seg = avg[s:e]
            features.extend([float(np.mean(seg)), float(np.std(seg)), float(np_trapz(seg))])
        else:
            features.extend([0, 0, 0])
    return np.array(features[:20])


def compute_coherence(data, fs, bands):
    features = []
    pairs = [(0, 1), (2, 3), (6, 7), (9, 10), (13, 14), (11, 12)]
    for c1, c2 in pairs:
        if np.allclose(data[c1], 0) or np.allclose(data[c2], 0):
            features.extend([0.0] * len(bands))
            continue
        try:
            f, coh = signal.coherence(data[c1], data[c2], fs=fs, nperseg=min(256, data.shape[1]))
            for low, high in bands.values():
                idx = (f >= low) & (f <= high)
                features.append(float(np.mean(coh[idx])) if idx.any() else 0.0)
        except:
            features.extend([0.0] * len(bands))
    return np.array(features[:30])


def compute_pli(data):
    features = []
    pairs = [(0, 1), (2, 3), (6, 7), (9, 10), (13, 14), (11, 12)]
    for c1, c2 in pairs:
        if np.allclose(data[c1], 0) or np.allclose(data[c2], 0):
            features.append(0.0)
            continue
        try:
            a1, a2 = signal.hilbert(data[c1]), signal.hilbert(data[c2])
            phase_diff = np.angle(a1) - np.angle(a2)
            features.append(float(np.abs(np.mean(np.sign(np.sin(phase_diff))))))
        except:
            features.append(0.0)
    return np.array(features[:6])


def extract_stats(data):
    features = []
    for ch in range(data.shape[0]):
        d = data[ch]
        features.extend([
            float(np.mean(d)), float(np.std(d)),
            float(stats.skew(d)) if np.std(d) > 0 else 0.0,
            float(stats.kurtosis(d)) if np.std(d) > 0 else 0.0,
            float(np.sqrt(np.mean(d ** 2))), float(np.ptp(d)),
        ])
    return np.array(features[:96])


def compute_entropy(data, m=2, r=0.2):
    features = []
    for ch in range(data.shape[0]):
        d = data[ch][:ENTROPY_SAMPLE_SIZE] if len(data[ch]) > ENTROPY_SAMPLE_SIZE else data[ch]
        if np.std(d) > 0:
            d = (d - np.mean(d)) / np.std(d)
            N = len(d)
            def count_matches(tlen):
                count = 0
                templates = [d[i:i + tlen] for i in range(N - tlen)]
                for i in range(len(templates)):
                    for j in range(i + 1, len(templates)):
                        if np.max(np.abs(templates[i] - templates[j])) < r:
                            count += 1
                return count
            try:
                B, A = count_matches(m), count_matches(m + 1)
                features.append(float(-np.log(A / B)) if A > 0 and B > 0 else 0.0)
            except:
                features.append(0.0)
        else:
            features.append(0.0)
    return np.array(features[:16])


def compute_fd(data, kmax=10):
    features = []
    for ch in range(data.shape[0]):
        d = data[ch]
        N = len(d)
        if np.allclose(d, 0) or N <= kmax * 2:
            features.append(0.0)
            continue
        L, x = [], []
        for k in range(1, min(kmax + 1, N // 2)):
            Lk = 0.0
            for m in range(k):
                mx = int(np.floor((N - m - 1) / k))
                if mx > 0:
                    Lmk = 0.0
                    for i in range(1, mx + 1):
                        i1, i2 = m + i * k, m + (i - 1) * k
                        if i1 < N and i2 < N:
                            Lmk += np.abs(d[i1] - d[i2])
                    Lmk = Lmk * (N - 1) / (mx * k * k)
                    Lk += Lmk
            if Lk > 0:
                L.append(np.log(Lk / k))
                x.append(np.log(1.0 / k))
        if len(x) > 1:
            try:
                features.append(float(np.polyfit(x, L, 1)[0]))
            except:
                features.append(0.0)
        else:
            features.append(0.0)
    return np.array(features[:16])


def extract_all_features(data, fs, config):
    f = []
    f.extend(extract_spectral_power(data, fs, config.BANDS))
    f.extend(extract_erp_components(data, fs, config.ERP_WINDOWS))
    f.extend(compute_coherence(data, fs, config.BANDS))
    f.extend(compute_pli(data))
    f.extend(extract_stats(data))
    f.extend(compute_entropy(data))
    f.extend(compute_fd(data))
    return np.array(f, dtype=float)


def load_labels(csv_path: Path):
    print(f"\nLoading labels from: {csv_path}")
    for encoding in ["utf-8-sig", "utf-8", "latin-1", "cp1252"]:
        try:
            df = pd.read_csv(csv_path, encoding=encoding, dtype=str)
            break
        except:
            continue
    else:
        raise ValueError(f"Could not read CSV: {csv_path}")
    
    df.columns = [normalize_column_name(c) for c in df.columns]
    print(f"  Columns: {list(df.columns)}")
    
    sn_col = None
    for c in ["sn", "subject", "subject_id", "id"]:
        if c in df.columns:
            sn_col = c
            break
    
    cat_col = None
    for c in ["category", "group", "diagnosis", "label"]:
        if c in df.columns:
            cat_col = c
            break
    
    if not sn_col or not cat_col:
        raise ValueError(f"Missing columns. Found: {list(df.columns)}")
    
    label_map = {}
    for _, row in df.iterrows():
        sid = normalize_subject_id(row.get(sn_col, ""))
        if not sid:
            continue
        cat = str(row.get(cat_col, "")).lower().strip()
        if "control" in cat or "hc" in cat or "healthy" in cat:
            label_map[sid] = 0
        elif "patient" in cat or "schiz" in cat or "sz" in cat:
            label_map[sid] = 1
    
    n_ctrl = sum(1 for v in label_map.values() if v == 0)
    n_pat = sum(1 for v in label_map.values() if v == 1)
    print(f"  Mapped {len(label_map)} subjects ({n_ctrl} controls, {n_pat} patients)")
    return label_map


def find_files(aszed_dir: Path):
    print(f"\nScanning: {aszed_dir}")
    files = []
    for ext in ["*.edf", "*.EDF", "*.bdf", "*.BDF"]:
        files.extend(aszed_dir.rglob(ext))
    print(f"  Found {len(files)} EEG files")
    
    pairs = []
    for f in files:
        sid = None
        for part in f.parts:
            part_lower = str(part).lower()
            if part_lower.startswith("subject_") or part_lower.startswith("sub_"):
                sid = normalize_subject_id(part)
                break
        if sid:
            pairs.append((f, sid))
    
    print(f"  Unique subjects: {len(set(s for _, s in pairs))}")
    return pairs


def standardize_to_16ch_matrix(raw, expected_channels, aliases):
    raw_to_canonical = {}
    seen_canonical = set()
    
    for ch in raw.ch_names:
        base = canonicalize_channel_name(ch)
        canonical = None
        for alias, canon in aliases.items():
            if base.upper() == alias.upper():
                canonical = canon
                break
        if canonical is None:
            for exp in expected_channels:
                if base.lower() == exp.lower():
                    canonical = exp
                    break
        if canonical and canonical not in seen_canonical:
            raw_to_canonical[ch] = canonical
            seen_canonical.add(canonical)
    
    canonical_to_raw = {v: k for k, v in raw_to_canonical.items()}
    data = []
    channels_found = 0
    
    for exp_ch in expected_channels:
        if exp_ch in canonical_to_raw:
            raw_ch = canonical_to_raw[exp_ch]
            data.append(raw.get_data(picks=[raw_ch])[0])
            channels_found += 1
        else:
            data.append(np.zeros(raw.n_times, dtype=float))
    
    return np.vstack(data), channels_found


def load_eeg(file_path: Path, target_fs=250):
    import mne
    mne.set_log_level("ERROR")
    
    ext = file_path.suffix.lower()
    if ext == ".bdf":
        raw = mne.io.read_raw_bdf(str(file_path), preload=True, verbose="ERROR")
    else:
        raw = mne.io.read_raw_edf(str(file_path), preload=True, verbose="ERROR")
    
    try:
        raw_eeg = raw.copy()
        raw_eeg.pick_types(eeg=True, exclude="bads")
        if len(raw_eeg.ch_names) > 0:
            raw = raw_eeg
    except:
        pass
    
    if len(raw.ch_names) == 0:
        raise ValueError("No EEG channels found")
    
    fs = float(raw.info["sfreq"])
    if fs != target_fs:
        raw.resample(target_fs)
    
    data, n_channels = standardize_to_16ch_matrix(raw, EXPECTED_CHANNELS, CHANNEL_ALIASES)
    return data, target_fs, n_channels


def preprocess(data, fs, config):
    out = []
    for ch in data:
        if np.allclose(ch, 0):
            out.append(ch)
            continue
        ch = ch - np.mean(ch)
        try:
            nyq = fs / 2.0
            low = config.FILTER_LOW / nyq
            high = min(config.FILTER_HIGH / nyq, 0.99)
            b, a = signal.butter(4, [low, high], "band")
            ch = signal.filtfilt(b, a, ch)
        except:
            pass
        if config.NOTCH_FREQ < min(config.FILTER_HIGH, fs / 2.0):
            try:
                b, a = signal.iirnotch(config.NOTCH_FREQ, 30, fs=fs)
                ch = signal.filtfilt(b, a, ch)
            except:
                pass
        out.append(ch)
    return np.array(out)


def process_single_file(fp, sid, label_map, config):
    import mne
    mne.set_log_level("ERROR")
    warnings.filterwarnings("ignore")
    
    label = label_map.get(sid, -1)
    try:
        if sid not in label_map:
            return {"status": "no_label", "subject_id": sid, "label": -1}
        
        data, fs, n_ch = load_eeg(fp, config.SAMPLING_RATE)
        
        if data.shape[1] < 500:
            return {"status": "too_short", "subject_id": sid, "label": label}
        
        if n_ch < MIN_CHANNELS_REQUIRED:
            return {"status": "low_channels", "subject_id": sid, "label": label}
        
        data = preprocess(data, fs, config)
        feat = extract_all_features(data, fs, config)
        feat = np.nan_to_num(feat, nan=0.0, posinf=0.0, neginf=0.0)
        
        return {"status": "ok", "features": feat, "label": label, "subject_id": sid, "n_channels": n_ch}
    except Exception as e:
        return {"status": "error", "subject_id": sid, "label": label, "error": str(e)}


def process_dataset(config, max_files=None):
    print("\n" + "=" * 60)
    print("ASZED-153 PREPROCESSING (v2.3.0)")
    print("=" * 60)
    
    label_map = load_labels(config.CSV_PATH)
    pairs = find_files(config.ASZED_DIR)
    
    if not pairs:
        raise RuntimeError("No EEG files found")
    
    if max_files:
        pairs = pairs[:max_files]
    
    print(f"\nProcessing {len(pairs)} files...")
    
    results = Parallel(n_jobs=N_JOBS, backend="loky")(
        delayed(process_single_file)(fp, sid, label_map, config)
        for fp, sid in tqdm(pairs, desc="Processing")
    )
    
    results = [r for r in results if isinstance(r, dict)]
    valid = [r for r in results if r.get("status") == "ok"]
    
    print(f"\n  Accepted: {len(valid)}")
    print(f"  Rejected: {len(results) - len(valid)}")
    
    if not valid:
        raise RuntimeError("No samples processed successfully")
    
    X = np.array([r["features"] for r in valid], dtype=float)
    y = np.array([r["label"] for r in valid], dtype=int)
    subject_ids = np.array([r["subject_id"] for r in valid], dtype=str)
    
    n_subjects = len(set(subject_ids))
    print(f"\n  Recordings: {len(y)} ({(y == 0).sum()} ctrl, {(y == 1).sum()} patient)")
    print(f"  Subjects: {n_subjects}")
    print(f"  Features: {X.shape[1]}")
    
    return X, y, subject_ids


def subject_level_cv(X, y, groups, n_splits=5, random_state=42):
    print("\n" + "=" * 60)
    print("SUBJECT-LEVEL CROSS-VALIDATION")
    print("=" * 60)
    
    # Build subject table
    unique_subjects = sorted(set(groups))
    subject_labels = {}
    for i, subj in enumerate(groups):
        if subj not in subject_labels:
            subject_labels[subj] = y[i]
    
    subject_y = np.array([subject_labels[s] for s in unique_subjects])
    
    # Stratified split on subjects
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state)
    
    # Convert subject folds to recording indices
    folds = []
    for train_subj_idx, test_subj_idx in skf.split(unique_subjects, subject_y):
        train_subjects = set(unique_subjects[i] for i in train_subj_idx)
        test_subjects = set(unique_subjects[i] for i in test_subj_idx)
        train_idx = np.array([i for i, g in enumerate(groups) if g in train_subjects])
        test_idx = np.array([i for i, g in enumerate(groups) if g in test_subjects])
        folds.append((train_idx, test_idx))
    
    model = Pipeline([
        ("scaler", StandardScaler()),
        ("clf", RandomForestClassifier(n_estimators=300, max_depth=20, min_samples_split=5, random_state=random_state, n_jobs=-1)),
    ])
    
    all_y_true, all_y_pred, all_y_prob, all_subjects = [], [], [], []
    
    for fold_idx, (train_idx, test_idx) in enumerate(folds):
        X_train, X_test = X[train_idx], X[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]
        
        model.fit(X_train, y_train)
        y_pred = model.predict(X_test)
        y_prob = model.predict_proba(X_test)[:, 1]
        
        all_y_true.extend(y_test.tolist())
        all_y_pred.extend(y_pred.tolist())
        all_y_prob.extend(y_prob.tolist())
        all_subjects.extend(groups[test_idx].tolist())
        
        acc = accuracy_score(y_test, y_pred)
        print(f"  Fold {fold_idx + 1}: acc={acc:.3f}")
    
    # Subject-level aggregation
    df = pd.DataFrame({"subject": all_subjects, "y_true": all_y_true, "y_prob": all_y_prob})
    agg = df.groupby("subject").agg(y_true=("y_true", "first"), y_prob_mean=("y_prob", "mean")).reset_index()
    agg["y_pred"] = (agg["y_prob_mean"] >= 0.5).astype(int)
    
    subj_acc = accuracy_score(agg["y_true"], agg["y_pred"])
    rec_acc = accuracy_score(all_y_true, [1 if p >= 0.5 else 0 for p in all_y_prob])
    
    print(f"\n  Recording-level accuracy: {rec_acc*100:.1f}%")
    print(f"  Subject-level accuracy: {subj_acc*100:.1f}%")
    
    return subj_acc, rec_acc, agg


def train_final_model(X, y, output_path, random_state=42):
    print("\n" + "=" * 60)
    print("TRAINING FINAL MODEL")
    print("=" * 60)
    
    model = Pipeline([
        ("scaler", StandardScaler()),
        ("clf", RandomForestClassifier(n_estimators=300, max_depth=20, min_samples_split=5, random_state=random_state, n_jobs=-1)),
    ])
    
    model.fit(X, y)
    
    output_path = Path(output_path)
    output_path.mkdir(parents=True, exist_ok=True)
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    model_path = output_path / f"schizophrenia_model_v230_{timestamp}.pkl"
    dump(model, model_path)
    
    print(f"\n  Model saved: {model_path}")
    return model_path


def main(aszed_dir, csv_path, output_path):
    print("=" * 60)
    print("ASZED EEG Classification Pipeline v2.3.0")
    print("Dataset: African Schizophrenia EEG Dataset")
    print("  Paper DOI: 10.1016/j.dib.2025.111934")
    print("  Data DOI: 10.5281/zenodo.14178398")
    print("=" * 60)
    
    config = Config(aszed_dir, csv_path, output_path)
    
    # Process dataset
    X, y, subject_ids = process_dataset(config)
    
    # Cross-validation
    subj_acc, rec_acc, agg = subject_level_cv(X, y, subject_ids)
    
    # Train final model
    model_path = train_final_model(X, y, output_path)
    
    print("\n" + "=" * 60)
    print("COMPLETE")
    print("=" * 60)
    print(f"\nSubject-level accuracy: {subj_acc*100:.1f}%")
    print(f"Model saved to: {model_path}")
    
    return model_path, subj_acc


if __name__ == "__main__":
    import sys
    if len(sys.argv) >= 4:
        main(sys.argv[1], sys.argv[2], sys.argv[3])
    else:
        print("Usage: python complete_v230.py <aszed_dir> <csv_path> <output_path>")

## 4. Run Training

In [None]:
# Import and run the training
import sys
sys.path.insert(0, '/content')

from complete_v230 import main

# Run training
model_path, accuracy = main(
    aszed_dir=ASZED_DATA_PATH,
    csv_path=CSV_PATH,
    output_path=OUTPUT_PATH
)

## 5. Copy Model to Backend Location

In [None]:
# Copy the trained model to a standard location
import shutil

# Standard model name for backend
backend_model_path = f"{OUTPUT_PATH}/schizophrenia_backend_model.pkl"
shutil.copy(model_path, backend_model_path)

print(f"\nModel ready for backend: {backend_model_path}")
print("\nDownload this file and place it in: mind-bloom/backend/schizophrenia_backend_model.pkl")

## 6. Verify Model

In [None]:
# Quick verification that model loads and works
from joblib import load
import numpy as np

model = load(backend_model_path)

# Test with random input (264 features)
test_input = np.random.randn(1, 264)
prediction = model.predict_proba(test_input)

print(f"Model loaded successfully!")
print(f"Test prediction shape: {prediction.shape}")
print(f"Test output: {prediction}")