# Seizure Prediction LSTM Model

Simple LSTM model for predicting seizures 3 minutes in advance using HRV features.

Labels:
- 0: Normal periods
- 1: Pre-seizure (3 minutes before seizure onset)
- 2: During seizure

In [6]:
import numpy as np
import h5py
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Dropout
from sklearn.metrics import classification_report, confusion_matrix

In [20]:
# Load training data
with h5py.File('/Volumes/Seizury/HRV/sequences/train_sequences.h5', 'r') as f:
    X_train = f['X'][:]
    y_train = f['y'][:]

# Load validation data
with h5py.File('/Volumes/Seizury/HRV/sequences/val_sequences.h5', 'r') as f:
    X_val = f['X'][:]
    y_val = f['y'][:]

# Load test data
with h5py.File('/Volumes/Seizury/HRV/sequences/test_sequences.h5', 'r') as f:
    X_test = f['X'][:]
    y_test = f['y'][:]

print(f"Train: {X_train.shape}, Val: {X_val.shape}, Test: {X_test.shape}")
print(f"Label distribution - Train: {np.bincount(y_train)}")
print(f"Label distribution - Val: {np.bincount(y_val)}")
print(f"Label distribution - Test: {np.bincount(y_test)}")

Train: (626855, 36, 22), Val: (84207, 36, 22), Test: (220623, 36, 22)
Label distribution - Train: [626453    132    270]
Label distribution - Val: [83667   114   426]
Label distribution - Test: [220483     36    104]


In [26]:
# Convert to binary classification for seizure PREDICTION (3 minutes in advance)
# 0=Normal/During seizure, 1=Pre-seizure only (label 1)
# We want to predict BEFORE seizures happen, not during them
y_train_binary = (y_train == 1).astype(int)
y_val_binary = (y_val == 1).astype(int) 
y_test_binary = (y_test == 1).astype(int)

print(f"Prediction labels - Train: {np.bincount(y_train_binary)}")
print(f"Prediction labels - Val: {np.bincount(y_val_binary)}")
print(f"Prediction labels - Test: {np.bincount(y_test_binary)}")
print(f"Target: Predict label 1 (pre-seizure) vs everything else")

Prediction labels - Train: [626723    132]
Prediction labels - Val: [84093   114]
Prediction labels - Test: [220587     36]
Target: Predict label 1 (pre-seizure) vs everything else


In [27]:
# Verify our labeling strategy is correct for seizure prediction
print("\nOriginal 3-class distribution:")
print("Label 0 (Normal):", np.sum(y_train == 0), "sequences")
print("Label 1 (Pre-seizure - 3min before):", np.sum(y_train == 1), "sequences") 
print("Label 2 (During seizure):", np.sum(y_train == 2), "sequences")

print("\nOur prediction task:")
print("Predict: Label 1 (pre-seizure) = 1")
print("Everything else (normal + during seizure) = 0")
print("This gives 3-minute advance warning before seizures")


Original 3-class distribution:
Label 0 (Normal): 626453 sequences
Label 1 (Pre-seizure - 3min before): 132 sequences
Label 2 (During seizure): 270 sequences

Our prediction task:
Predict: Label 1 (pre-seizure) = 1
Everything else (normal + during seizure) = 0


In [28]:
# Calculate class weights for imbalanced data
from sklearn.utils.class_weight import compute_class_weight

# For binary classification
classes = np.unique(y_train_binary)
class_weights_binary = compute_class_weight('balanced', classes=classes, y=y_train_binary)
class_weight_dict = {i: weight for i, weight in enumerate(class_weights_binary)}

print(f"Class weights for binary classification:")
print(f"Normal (0): {class_weight_dict[0]:.4f}")
print(f"Alert (1): {class_weight_dict[1]:.4f}")
print(f"Weight ratio: {class_weight_dict[1]/class_weight_dict[0]:.1f}:1")

Class weights for binary classification:
Normal (0): 0.5001
Alert (1): 2374.4508
Weight ratio: 4747.9:1


In [29]:
from tensorflow.keras import metrics
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau

# Build improved LSTM model for seizure prediction
model = Sequential([
    LSTM(128, return_sequences=True, input_shape=(X_train.shape[1], X_train.shape[2])),
    Dropout(0.4),
    LSTM(64, return_sequences=True),
    Dropout(0.4),
    LSTM(32),
    Dropout(0.3),
    Dense(32, activation='relu'),
    Dropout(0.2),
    Dense(16, activation='relu'),
    Dense(1, activation='sigmoid')
])

# Compile without F1Score to avoid shape issues
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
    loss='binary_crossentropy',
    metrics=['accuracy', 
             metrics.Precision(name="precision"),
             metrics.Recall(name="recall")]
)

model.summary()

  super().__init__(**kwargs)


In [31]:
# Train with callbacks for better convergence
callbacks = [
    EarlyStopping(monitor='val_recall', patience=5, restore_best_weights=True, mode='max'),
    ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, min_lr=1e-6)
]

history = model.fit(
    X_train, y_train_binary,
    epochs=2,
    batch_size=64,
    validation_data=(X_val, y_val_binary),
    class_weight=class_weight_dict,
    callbacks=callbacks,
    verbose=1
)

Epoch 1/2
[1m1054/9795[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m18:45[0m 129ms/step - accuracy: 0.9982 - loss: 1.4980 - precision: 0.0000e+00 - recall: 0.0000e+00

KeyboardInterrupt: 

In [None]:
# Evaluate on test set with optimized threshold
y_pred_prob = model.predict(X_test, verbose=0).flatten()

# Find optimal threshold based on F1 score
from sklearn.metrics import f1_score
thresholds = np.arange(0.1, 0.9, 0.05)
f1_scores = [f1_score(y_test_binary, (y_pred_prob > t).astype(int)) for t in thresholds]
optimal_threshold = thresholds[np.argmax(f1_scores)]

print(f"Optimal threshold: {optimal_threshold:.3f}")

# Make predictions with optimal threshold
y_pred = (y_pred_prob > optimal_threshold).astype(int)

# Evaluate
test_results = model.evaluate(X_test, y_test_binary, verbose=0)
print(f"\nTest Results:")
for i, metric in enumerate(model.metrics_names):
    print(f"{metric}: {test_results[i]:.4f}")

print(f"\nWith optimal threshold ({optimal_threshold:.3f}):")
print(classification_report(y_test_binary, y_pred, target_names=['Normal', 'Pre-seizure']))

print("\nConfusion Matrix:")
cm = confusion_matrix(y_test_binary, y_pred)
print(cm)
print(f"True Negatives: {cm[0,0]}, False Positives: {cm[0,1]}")
print(f"False Negatives: {cm[1,0]}, True Positives: {cm[1,1]}")

# Save the model
model.save('seizure_prediction_lstm.h5')
print(f"\nModel saved as seizure_prediction_lstm.h5")
print(f"Use threshold {optimal_threshold:.3f} for predictions")