# ECG-based Age Group Prediction with RNN (LSTM)
This notebook demonstrates how to train an LSTM-based recurrent neural network (RNN) to classify ECG signals into age groups.

In [None]:
# ============================================================
# 1. Imports
# ============================================================
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import confusion_matrix, classification_report
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Dropout
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import EarlyStopping

In [None]:
import wfdb
import os

DATA_PATH = "autonomic-aging-cardiovascular-1.0.0"

# Load subject metadata (with Age)
meta = pd.read_csv(os.path.join(DATA_PATH, "SubjectInformation.csv"))
print(meta.head())

# Example: read ECG for first subject
record_name = os.path.join(DATA_PATH, "ecg", "subject_001")
record = wfdb.rdrecord(record_name)
signal = record.p_signal[:, 0]  # first ECG channel
print("ECG shape:", signal.shape)

# Attach label (Age) for subject 001
age = meta.loc[meta["ID"] == "subject_001", "Age"].values[0]
print("Age:", age)


In [None]:
# ============================================================
# 3. Preprocessing
# ============================================================
# Features (ECG signal) and labels (age group)
X = []
y = []

for _, row in meta.iterrows():
    subj_id = row["ID"]
    age = row["Age"]

    # read ECG signal
    record_name = os.path.join(DATA_PATH, "ecg", subj_id)
    record = wfdb.rdrecord(record_name)
    sig = record.p_signal[:, 0]  # first channel

    # take a fixed-length segment, e.g., 500 samples
    if len(sig) >= 500:
        X.append(sig[:500])
        y.append(age)

X = np.array(X)
y = np.array(y)

# Encode labels
encoder = LabelEncoder()
y_encoded = encoder.fit_transform(y)
y_categorical = to_categorical(y_encoded)

# Normalize features
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# Reshape for RNN: (samples, timesteps, features)
# Example: each ECG has 500 time points, 1 feature
timesteps = 500
X_reshaped = X_scaled.reshape(-1, timesteps, 1)

# Train/validation/test split
X_train, X_temp, y_train, y_temp = train_test_split(
    X_reshaped, y_categorical, test_size=0.3, random_state=42, stratify=y_categorical)

X_val, X_test, y_val, y_test = train_test_split(
    X_temp, y_temp, test_size=0.5, random_state=42, stratify=y_temp)

In [None]:
# ============================================================
# 4. Build LSTM model
# ============================================================
model = Sequential([
    LSTM(128, return_sequences=True, input_shape=(timesteps, 1)),
    Dropout(0.3),
    LSTM(64),
    Dropout(0.3),
    Dense(32, activation="relu"),
    Dense(y_categorical.shape[1], activation="softmax")
])

model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
model.summary()

In [None]:
# ============================================================
# 5. Train model
# ============================================================
early_stopping = EarlyStopping(monitor="val_loss", patience=5, restore_best_weights=True)

history = model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=30,
    batch_size=32,
    callbacks=[early_stopping],
    verbose=1
)

In [None]:
# ============================================================
# 6. Plot training history
# ============================================================
plt.figure(figsize=(12,5))

plt.subplot(1,2,1)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Val Loss')
plt.legend()
plt.title("Loss")

plt.subplot(1,2,2)
plt.plot(history.history['accuracy'], label='Train Acc')
plt.plot(history.history['val_accuracy'], label='Val Acc')
plt.legend()
plt.title("Accuracy")

plt.show()

In [None]:
# ============================================================
# 7. Evaluation on test set
# ============================================================
test_loss, test_acc = model.evaluate(X_test, y_test, verbose=0)
print(f"Test Accuracy: {test_acc:.4f}")

In [None]:
# ============================================================
# 8. Predictions & Confusion Matrix
# ============================================================
y_pred_probs = model.predict(X_test, batch_size=32)
y_pred = np.argmax(y_pred_probs, axis=1)
y_true = np.argmax(y_test, axis=1)

cm = confusion_matrix(y_true, y_pred)

plt.figure(figsize=(10,8))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
            xticklabels=encoder.classes_,
            yticklabels=encoder.classes_)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix (Age Groups)")
plt.show()

In [None]:
# ============================================================
# 9. Classification Report
# ============================================================
report = classification_report(y_true, y_pred, target_names=encoder.classes_)
print(report)