In [None]:
import numpy as np
import tensorflow as tf
import shap
import matplotlib.pyplot as plt

from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import (
    accuracy_score, f1_score, precision_score, recall_score,
    roc_auc_score, confusion_matrix, classification_report
)
from sklearn.utils.class_weight import compute_class_weight
from tensorflow.keras import layers, models, regularizers, callbacks


SEED = 42
np.random.seed(SEED)
tf.random.set_seed(SEED)
AUTOTUNE = tf.data.AUTOTUNE

TRAIN_NPZ = "merged_eeg_dataset.npz"
TEST_NPZ  = "test_eeg_dataset.npz"

VAL_FRAC_OF_TRAIN = 0.10
EPOCHS = 50
BATCH = 64
LEARNING_RATE = 3e-4
L2W = 5e-4
DROPOUT = 0.45
LABEL_SMOOTH = 0.05




In [None]:
def load_npz(path):
    d = np.load(path, allow_pickle=True)
    X, y = d["X"], d["y"]
    channels = d["channels"] if "channels" in d.files else None
    session_ids = d["session_ids"] if "session_ids" in d.files else None
    print(channels)
    return X, y.astype(np.int32), channels, session_ids

def ensure_time_feature_shape(X):
    """Ensure shape (N, 1280, 20)."""
    if X.ndim != 3:
        raise ValueError(f"Expected 3D, got {X.ndim}D")
    n, a, b = X.shape
    if a == 20 and b == 1280:
        return np.transpose(X, (0, 2, 1)).astype(np.float32)
    elif a == 1280 and b == 20:
        return X.astype(np.float32)
    else:
        raise ValueError(f"Unrecognized shape {X.shape}")

def fit_channelwise_zscore(X):
    mean = X.mean(axis=(0, 1), keepdims=True)
    std  = X.std(axis=(0, 1), keepdims=True) + 1e-8
    return mean, std

def apply_channelwise_zscore(X, mean, std):
    return (X - mean) / std

def per_segment_standardize(X, eps=1e-8):
    m = X.mean(axis=1, keepdims=True)
    s = X.std(axis=1, keepdims=True) + eps
    return (X - m) / s

def to_tf_dataset(X, y, batch_size=64, shuffle=False):
    ds = tf.data.Dataset.from_tensor_slices((X.astype(np.float32), y.astype(np.int32)))
    if shuffle:
        ds = ds.shuffle(
            buffer_size=min(len(X), 10000),
            seed=SEED,
            reshuffle_each_iteration=True
        )
    return ds.batch(batch_size).prefetch(AUTOTUNE)


def build_cnn_lstm(input_shape=(1280, 20),
                   l2w=5e-4,
                   dropout=0.45,
                   label_smooth=0.05,
                   lr=3e-4):
    inp = layers.Input(shape=input_shape)

    # Conv block 1
    x = layers.Conv1D(64, kernel_size=7, strides=2, padding='same',
                      kernel_regularizer=regularizers.l2(l2w))(inp)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.MaxPooling1D(pool_size=2)(x)   # seq len roughly â†’ 1280 / 2 / 2


    x = layers.Conv1D(128, kernel_size=5, padding='same',
                      kernel_regularizer=regularizers.l2(l2w))(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.MaxPooling1D(pool_size=2)(x)


    x = layers.Conv1D(256, kernel_size=3, padding='same',
                      kernel_regularizer=regularizers.l2(l2w))(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Dropout(dropout)(x)

    x = layers.Bidirectional(
        layers.LSTM(64, return_sequences=False)
    )(x)

    # Dense head
    x = layers.Dense(64, activation='relu',
                     kernel_regularizer=regularizers.l2(l2w))(x)
    x = layers.Dropout(dropout)(x)
    out = layers.Dense(1, activation='sigmoid')(x)

    model = models.Model(inp, out)
    opt = tf.keras.optimizers.Adam(learning_rate=lr)
    loss = tf.keras.losses.BinaryCrossentropy(label_smoothing=label_smooth)
    model.compile(
        optimizer=opt,
        loss=loss,
        metrics=['accuracy', tf.keras.metrics.AUC(name="auc")]
    )
    return model




In [None]:
X_train_all, y_train_all, train_channels, train_sessions = load_npz(TRAIN_NPZ)
X_test,      y_test,      test_channels,  test_sessions  = load_npz(TEST_NPZ)

print("Train raw:", X_train_all.shape, y_train_all.shape)
print("Test  raw:", X_test.shape,      y_test.shape)

X_train_all = ensure_time_feature_shape(X_train_all)
X_test      = ensure_time_feature_shape(X_test)
print("Fixed shapes -> Train:", X_train_all.shape, "| Test:", X_test.shape)


# Train and Validation Split 

sss = StratifiedShuffleSplit(
    n_splits=1,
    test_size=VAL_FRAC_OF_TRAIN,
    random_state=SEED
)
tr_idx, val_idx = next(sss.split(X_train_all, y_train_all))

X_tr,  y_tr  = X_train_all[tr_idx],  y_train_all[tr_idx]
X_val, y_val = X_train_all[val_idx], y_train_all[val_idx]

print("Train:", X_tr.shape, "Val:", X_val.shape, "Test:", X_test.shape)



In [None]:
# Normalization 

mu, sigma = fit_channelwise_zscore(X_tr)

X_tr_n  = per_segment_standardize(apply_channelwise_zscore(X_tr,  mu, sigma))
X_val_n = per_segment_standardize(apply_channelwise_zscore(X_val, mu, sigma))
X_te_n  = per_segment_standardize(apply_channelwise_zscore(X_test, mu, sigma))

# Class Weights

classes = np.unique(y_tr)
cw = compute_class_weight(class_weight="balanced", classes=classes, y=y_tr)
class_weight = {int(c): float(w) for c, w in zip(classes, cw)}
print("Class weights:", class_weight)

# tf.data Datasets

train_ds = to_tf_dataset(X_tr_n,  y_tr,  batch_size=BATCH, shuffle=True)
val_ds   = to_tf_dataset(X_val_n, y_val, batch_size=BATCH, shuffle=False)
test_ds  = to_tf_dataset(X_te_n,  y_test, batch_size=BATCH, shuffle=False)



model = build_cnn_lstm(
    input_shape=X_tr_n.shape[1:],
    l2w=L2W,
    dropout=DROPOUT,
    label_smooth=LABEL_SMOOTH,
    lr=LEARNING_RATE
)
model.summary()

es  = callbacks.EarlyStopping(
    patience=8,
    restore_best_weights=True,
    monitor='val_auc',
    mode='max'
)
rlr = callbacks.ReduceLROnPlateau(
    patience=4,
    factor=0.5,
    monitor='val_auc',
    mode='max',
    min_lr=1e-6
)
ckp = callbacks.ModelCheckpoint(
    "best_cnn_lstm.keras",
    monitor='val_auc',
    mode='max',
    save_best_only=True
)

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    class_weight=class_weight,
    callbacks=[es, rlr, ckp],
    verbose=1
)

# Threshold tuning on calidation dataset

val_probs = model.predict(val_ds, verbose=0).ravel()
y_val_all = np.concatenate([y.numpy() for _, y in val_ds], axis=0)

best_thr, best_f1 = 0.5, -1
for thr in np.linspace(0.05, 0.95, 19):
    preds = (val_probs >= thr).astype(int)
    f1 = f1_score(y_val_all, preds)
    if f1 > best_f1:
        best_thr, best_f1 = thr, f1

print(f"\n[VAL] Best Threshold = {best_thr:.2f} | F1 = {best_f1:.4f}")




test_probs = model.predict(test_ds, verbose=0).ravel()
test_preds = (test_probs >= best_thr).astype(int)




In [None]:

acc  = accuracy_score(y_test, test_preds)
f1   = f1_score(y_test, test_preds)
prec = precision_score(y_test, test_preds)
rec  = recall_score(y_test, test_preds)
auc  = roc_auc_score(y_test, test_probs)
cm   = confusion_matrix(y_test, test_preds)

print("\n=== CNN + LSTM TEST METRICS ===")
print(f"Accuracy : {acc:.4f}")
print(f"Precision: {prec:.4f}")
print(f"Recall   : {rec:.4f}")
print(f"F1-score : {f1:.4f}")
print(f"ROC-AUC  : {auc:.4f}")
print("\nConfusion Matrix:\n", cm)
print("\nClassification report:\n", classification_report(y_test, test_preds, digits=4))


# np.savez("norm_stats_cnn_lstm.npz", mu=mu, sigma=sigma)
# print("\nSaved: best_cnn_lstm.keras and norm_stats_cnn_lstm.npz")




In [None]:
import shap
import numpy as np
import matplotlib.pyplot as plt

print("SHAP for global explaination")

X_all = np.concatenate([X_tr_n, X_val_n], axis=0)

print("Total samples used for SHAP:", X_all.shape[0])   # (N_total, 1280, 20)


bg_size = min(100, X_tr_n.shape[0])
idx_bg = np.random.choice(X_tr_n.shape[0], size=bg_size, replace=False)
background = X_tr_n[idx_bg]      
print("Background shape:", background.shape)


explainer = shap.GradientExplainer(model, background)


exp_size = min(500, X_all.shape[0])
X_exp = X_all[:exp_size]
print("Explain samples shape:", X_exp.shape)

print(f"Computing SHAP values for {exp_size} segments...")
shap_values = explainer.shap_values(X_exp)


if isinstance(shap_values, list):
    shap_values = shap_values[0]

# print("Raw SHAP values shape:", shap_values.shape) 


if shap_values.ndim == 4 and shap_values.shape[-1] == 1:
    shap_values = shap_values[..., 0]   # (N, 1280, 20)

# print("Fixed SHAP values shape:", shap_values.shape)


channel_importance = np.mean(np.abs(shap_values), axis=(0, 1))   # shape: (20,)
channel_importance = np.asarray(channel_importance).reshape(-1)  # ensure 1D


print("\nChannel-wise global importance (mean |SHAP| per channel):")
for ch_idx, imp in enumerate(channel_importance):
    ch_name = f"Ch{ch_idx}"
    if 'train_channels' in globals() and train_channels is not None and ch_idx < len(train_channels):
        ch_name = str(train_channels[ch_idx])
    imp_val = float(imp)   # numpy scalar -> normal float
    print(f"{ch_name:>6}: {imp_val:.6f}")


plt.figure(figsize=(8, 4))
x_pos = np.arange(len(channel_importance))

labels = (
    train_channels
    if ('train_channels' in globals()
        and train_channels is not None
        and len(train_channels) == len(channel_importance))
    else [f"Ch{idx}" for idx in range(len(channel_importance))]
)

plt.bar(x_pos, channel_importance)
plt.xticks(x_pos, labels, rotation=45, ha="right")
plt.ylabel("Mean |SHAP value|")
plt.title("Global EEG Channel Importance (CNN + LSTM + SHAP)")
plt.tight_layout()
plt.show()


In [None]:
# local explainer 

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

idx = 0   


x_sample = X_te_n[idx:idx+1]  

print(f"\n=== Single-sample explanation for test index {idx} ===")


x_tensor = tf.convert_to_tensor(x_sample)

with tf.GradientTape() as tape:
    tape.watch(x_tensor)
    y_pred = model(x_tensor)     

grads = tape.gradient(y_pred, x_tensor).numpy() 


chan_importance = np.mean(np.abs(grads[0]), axis=0)  


print("Per-channel importance for a single segment")
for ch_idx, imp in enumerate(chan_importance):
    ch_name = f"Ch{ch_idx}"
    if 'train_channels' in globals() and train_channels is not None and ch_idx < len(train_channels):
        ch_name = str(train_channels[ch_idx])
    print(f"{ch_name:>6}: {imp:.6f}")


plt.figure(figsize=(8, 4))
x_pos = np.arange(len(chan_importance))

labels = (
    train_channels
    if ('train_channels' in globals()
        and train_channels is not None
        and len(train_channels) == len(chan_importance))
    else [f"Ch{idx}" for idx in range(len(chan_importance))]
)

plt.bar(x_pos, chan_importance)
plt.xticks(x_pos, labels, rotation=45, ha="right")
plt.ylabel("Mean |Gradient| over time")
plt.title(f"Channel importance for test sample #{idx}")
plt.tight_layout()
plt.show()
