# This Expermental take CNN-adrian with LSTM model
>CNN Learn feature of ECG morphology
>LSTM deccide the AFIB, NORM, AFLAT, OTHER 

In [33]:
#Dataset loader & preparation
import os
import pandas as pd
import ast
import random
from collections import Counter
import wfdb
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import classification_report, confusion_matrix
from pathlib import Path
from IPython.display import Markdown, display


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


print(torch.__file__)
print(torch.__version__)
print(torch.version.cuda)
print(torch.cuda.is_available())

Using device: cuda
c:\Users\arjan\Documents\GitHub\SEARCH_AF_detection_OsloMet_BachelorGroup\venv\Lib\site-packages\torch\__init__.py
2.5.1+cu121
12.1
True


In [None]:
dff = pd.read_csv("../../../data/ptbxl_database.csv")


ecg_ids = []
labels = []
patient_ids = []


norm_ids  = []
afib_ids  = []
aflt_ids  = []
other_ids = []

def get_label(scp_codes):
    if "AFIB" in scp_codes:
        return 1
    if "NORM" in scp_codes:
        return 0
    if "AFLT" in scp_codes:
        return 2
    
    if "NDT" in scp_codes:
        return 4
    if "NST_" in scp_codes:
        return 5
    if "SVARR" in scp_codes:
        return 6
    if "SVTAC" in scp_codes:
        return 7
    if "PAC" in scp_codes:
        return 8
    return None


LABEL_MAP = {
    "NORM": 0,
    "AFIB": 1,
    "OTHER": 0
}




In [35]:
#FASTER for loading ECG files
def build_ecg_index(root="../../../data/records500", suffix="_hr"):
    index = {}
    for p in Path(root).rglob(f"*{suffix}.hea"):
        ecg_id = int(p.stem.replace(suffix, ""))
        index[ecg_id] = str(p.with_suffix(""))
    return index

ECG_INDEX = build_ecg_index()

def load_ecg_fast(ecg_id):
    record = wfdb.rdrecord(ECG_INDEX[ecg_id])
    return record.p_signal


In [36]:
def extract_main_label(scp_codes_str: str) -> str:
    codes = ast.literal_eval(scp_codes_str)  # dict
    if "AFIB" in codes:
        return "AFIB"
    elif "NORM" in codes:
        return "NORM"
    else:
        return "OTHER"

dff["ecg_label"] = dff["scp_codes"].apply(extract_main_label)

In [37]:
def patient_label(ecg_labels) -> str:
    s = set(ecg_labels)
    if "AFIB" in s:
        return "AFIB"
    elif "NORM" in s:
        return "NORM"
    else:
        return "OTHER"

patient_df = (
    dff.groupby("patient_id")["ecg_label"]
       .apply(patient_label)
       .reset_index()
       .rename(columns={"ecg_label": "patient_label"})
)


In [38]:
norm_patients  = patient_df.loc[patient_df.patient_label == "NORM",  "patient_id"].values
afib_patients  = patient_df.loc[patient_df.patient_label == "AFIB",  "patient_id"].values
other_patients = patient_df.loc[patient_df.patient_label == "OTHER", "patient_id"].values

print("Patient counts:")
print("NORM :", len(norm_patients))
print("AFIB :", len(afib_patients))
print("OTHER:", len(other_patients))


Patient counts:
NORM : 8831
AFIB : 1245
OTHER: 8809


In [39]:
np.random.seed(42)

# ---- Train patients ----
train_norm  = np.random.choice(norm_patients, 500, replace=False)
train_afib  = np.random.choice(afib_patients, 500, replace=False)
train_other = np.random.choice(other_patients, 500, replace=False)

# remaining pools
rem_norm  = np.setdiff1d(norm_patients, train_norm)
rem_afib  = np.setdiff1d(afib_patients, train_afib)
rem_other = np.setdiff1d(other_patients, train_other)

# ---- Test patients ----
test_norm  = np.random.choice(rem_norm, 100, replace=False)
test_afib  = np.random.choice(rem_afib, 100, replace=False)
test_other = np.random.choice(rem_other, 100, replace=False)

train_patients = np.concatenate([train_norm, train_afib, train_other])
test_patients  = np.concatenate([test_norm,  test_afib,  test_other])


In [40]:

AF = 1 
NONAF = 0
train_labels = np.concatenate([
    np.full(len(train_norm),  NONAF,  dtype=int),
    np.full(len(train_afib),  AF,  dtype=int),
    np.full(len(train_other), NONAF, dtype=int),
])

test_labels = np.concatenate([
    np.full(len(test_norm),  NONAF,  dtype=int),
    np.full(len(test_afib),  AF,  dtype=int),
    np.full(len(test_other), NONAF, dtype=int),
])


In [41]:
print("Train patients:", len(train_patients))
print("Test patients :", len(test_patients))
print("Overlap (must be 0):", len(np.intersect1d(train_patients, test_patients)))

print("Train label counts:", np.bincount(train_labels))
print("Test label counts :", np.bincount(test_labels))


Train patients: 1500
Test patients : 300
Overlap (must be 0): 0
Train label counts: [1000  500]
Test label counts : [200 100]


In [42]:


dff["label_id"] = dff["ecg_label"].map(LABEL_MAP)


In [43]:
train_df = dff[dff["patient_id"].isin(train_patients)].copy()
test_df  = dff[dff["patient_id"].isin(test_patients)].copy()


In [44]:
assert (
    set(train_df.patient_id).isdisjoint(set(test_df.patient_id))
), "❌ Patient leakage detected!"


In [45]:
print("TRAIN ECG distribution:")
print(train_df["ecg_label"].value_counts())

print("\nTEST ECG distribution:")
print(test_df["ecg_label"].value_counts())


TRAIN ECG distribution:
ecg_label
OTHER    667
AFIB     593
NORM     556
Name: count, dtype: int64

TEST ECG distribution:
ecg_label
OTHER    139
AFIB     123
NORM     109
Name: count, dtype: int64


In [46]:
def normalize_ecg(ecg):
    mean = ecg.mean(axis=0)
    std = ecg.std(axis=0) + 1e-8
    return (ecg - mean) / std

In [47]:


class ECGDataset(Dataset):
    def __init__(self, df):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        ecg = load_ecg_fast(row.ecg_id)
        ecg = normalize_ecg(ecg)

        label = row.label_id

        return (
            torch.tensor(ecg, dtype=torch.float32),
            torch.tensor(label, dtype=torch.long)
        )


In [48]:
BATCH_SIZE = 1

train_dataset = ECGDataset(train_df)
test_dataset  = ECGDataset(test_df)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

In [49]:
X, y = next(iter(train_loader))

print(X.shape)
print(y.shape)


torch.Size([1, 5000, 12])
torch.Size([1])


In [50]:
def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()

    total_loss = 0
    correct = 0
    total = 0

    for X, y in loader:
        X, y = X.to(device), y.to(device)

        optimizer.zero_grad()

        outputs = model(X)
        loss = criterion(outputs, y)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        preds = outputs.argmax(dim=1)

        correct += (preds == y).sum().item()
        total += y.size(0)

    return total_loss / len(loader), correct / total


In [51]:
from sklearn.metrics import f1_score, recall_score

@torch.no_grad()
def validate_with_metrics(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    y_true, y_pred = [], []

    for X, y in loader:
        X, y = X.to(device), y.to(device)
        out = model(X)
        loss = criterion(out, y)
        total_loss += loss.item()

        pred = out.argmax(dim=1)
        y_true.extend(y.cpu().numpy().tolist())
        y_pred.extend(pred.cpu().numpy().tolist())

    val_loss = total_loss / len(loader)
    val_acc  = (np.array(y_true) == np.array(y_pred)).mean()
    val_f1   = f1_score(y_true, y_pred, zero_division=0)
    val_rec_af = recall_score(y_true, y_pred, pos_label=1, zero_division=0)  # AF sensitivity

    return val_loss, val_acc, val_f1, val_rec_af


In [52]:

class CNNFeatureExtractor(nn.Module):
    """
    Learns ECG morphology:
    QRS complex shape, amplitude, local waveform patterns
    """

    def __init__(self):
        super().__init__()

        self.cnn = nn.Sequential(
            nn.Conv1d(12, 32, kernel_size=7, padding=3),
            nn.BatchNorm1d(32),
            nn.ReLU(),

            nn.Conv1d(32, 64, kernel_size=5, padding=2),
            nn.BatchNorm1d(64),
            nn.ReLU(),

            nn.MaxPool1d(2)
        )

    def forward(self, x):
        # x: (batch, time, leads)
        x = x.permute(0, 2, 1)   # → (batch, leads, time)
        x = self.cnn(x)
        x = x.permute(0, 2, 1)   # → (batch, time, features)
        return x


In [53]:
class ECG_CNN_LSTM(nn.Module):
    def __init__(self, num_classes=3):
        super().__init__()

        self.cnn = CNNFeatureExtractor()   # from main_cnn
        self.lstm = nn.LSTM(
            input_size=64,
            hidden_size=128,
            num_layers=2,
            batch_first=True,
            dropout=0.3
        )

        self.fc = nn.Linear(128, num_classes)

    def forward(self, x):
        # morphology learning
        x = self.cnn(x)

        # rhythm learning
        out, _ = self.lstm(x)

        # temporal pooling (VERY IMPORTANT)
        out = out.mean(dim=1)

        return self.fc(out)


In [54]:
skf = StratifiedKFold(
    n_splits=5,
    shuffle=True,
    random_state=42
)

In [55]:


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

EPOCHS = 4
BATCH_SIZE = 16

fold_results = []


In [56]:
for fold, (train_idx, val_idx) in enumerate(
    skf.split(train_patients, train_labels)
):

    print(f"\n================ FOLD {fold+1} =================")

    fold_train_patients = train_patients[train_idx]
    fold_val_patients   = train_patients[val_idx]

    fold_train_df = dff[dff.patient_id.isin(fold_train_patients)]
    fold_val_df   = dff[dff.patient_id.isin(fold_val_patients)]

    train_ds = ECGDataset(fold_train_df)
    val_ds   = ECGDataset(fold_val_df)

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=0)
    val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

    # ✅ binary model (labels are 0/1)
    model = ECG_CNN_LSTM(num_classes=2).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # ✅ HERE: compute class weights from THIS fold's training labels
    counts = np.bincount(train_labels[train_idx], minlength=2)   # [nonAF, AF]
    w0 = 1.0
    w1 = counts[0] / max(counts[1], 1)  # e.g. 2.0 if 2:1
    class_weights = torch.tensor([w0, w1], dtype=torch.float32).to(device)

    criterion = nn.CrossEntropyLoss(weight=class_weights)

    print(f"Fold {fold+1} class counts [nonAF, AF]: {counts} | weights: {[w0, float(w1)]:}")

    best_val_acc = -1.0
    best_epoch = -1

    for epoch in range(EPOCHS):

        train_loss, train_acc = train_one_epoch(
            model, train_loader, optimizer, criterion, device
        )

        val_loss, val_acc, val_f1, val_rec_af = validate_with_metrics(
            model, val_loader, criterion, device
        )

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_epoch = epoch + 1
            torch.save(model.state_dict(), f"best_model_fold_{fold}.pt")

        print(
            f"Fold {fold+1} | Epoch {epoch+1:02d} | "
            f"Train Loss {train_loss:.4f} | Train Acc {train_acc:.3f} | "
            f"Val Loss {val_loss:.4f} | Val Acc {val_acc:.3f} | "
            f"Val F1 {val_f1:.3f} | Val Rec AF {val_rec_af:.3f} | "
            f"Best Val {best_val_acc:.3f} (ep {best_epoch})"
        )

    fold_results.append(best_val_acc)  # usually better than last epoch



Fold 1 class counts [nonAF, AF]: [800 400] | weights: [1.0, 2.0]
Fold 1 | Epoch 01 | Train Loss 0.6518 | Train Acc 0.617 | Val Loss 0.5814 | Val Acc 0.704 | Val F1 0.576 | Val Rec AF 0.626 | Best Val 0.704 (ep 1)
Fold 1 | Epoch 02 | Train Loss 0.5932 | Train Acc 0.699 | Val Loss 0.5722 | Val Acc 0.684 | Val F1 0.614 | Val Rec AF 0.783 | Best Val 0.704 (ep 1)
Fold 1 | Epoch 03 | Train Loss 0.5741 | Train Acc 0.702 | Val Loss 0.5214 | Val Acc 0.726 | Val F1 0.642 | Val Rec AF 0.765 | Best Val 0.726 (ep 3)
Fold 1 | Epoch 04 | Train Loss 0.5306 | Train Acc 0.735 | Val Loss 0.5060 | Val Acc 0.785 | Val F1 0.624 | Val Rec AF 0.557 | Best Val 0.785 (ep 4)

Fold 2 class counts [nonAF, AF]: [800 400] | weights: [1.0, 2.0]
Fold 2 | Epoch 01 | Train Loss 0.6330 | Train Acc 0.670 | Val Loss 0.6058 | Val Acc 0.680 | Val F1 0.547 | Val Rec AF 0.603 | Best Val 0.680 (ep 1)
Fold 2 | Epoch 02 | Train Loss 0.5851 | Train Acc 0.706 | Val Loss 0.5820 | Val Acc 0.701 | Val F1 0.609 | Val Rec AF 0.727 | Be

In [57]:
print("\n===== Cross-Validation Result =====")
print("Fold accuracies:", fold_results)
print("Mean accuracy:", np.mean(fold_results))
print("Std:", np.std(fold_results))



===== Cross-Validation Result =====
Fold accuracies: [np.float64(0.7849162011173184), np.float64(0.8201058201058201), np.float64(0.7563739376770539), np.float64(0.7282913165266106), np.float64(0.7567567567567568)]
Mean accuracy: 0.7692888064367119
Std: 0.031084480455489973


In [63]:
def confusion_matrix_to_markdown_binary(cm):
    cm = np.asarray(cm, dtype=int)
    TN, FP = cm[0]
    FN, TP = cm[1]

    total = cm.sum()
    accuracy = (TP + TN) / total if total > 0 else 0.0
    sensitivity = TP / (TP + FN) if (TP + FN) > 0 else 0.0
    specificity = TN / (TN + FP) if (TN + FP) > 0 else 0.0
    fp_rate = FP / (FP + TN) if (FP + TN) > 0 else 0.0

    md = []
    md.append("## Model Evaluation Results (Binary AF Detection)\n")
    md.append("### Overall Performance")
    md.append(f"- **Total test samples:** {total}")
    md.append(f"- **Overall accuracy:** **{accuracy*100:.1f}%**\n")

    md.append("### Binary Classification Metrics")
    md.append(f"- **AF sensitivity (recall):** **{sensitivity*100:.1f}%**")
    md.append(f"- **Non-AF specificity:** **{specificity*100:.1f}%**")
    md.append(f"- **False AF alarm rate:** **{fp_rate*100:.1f}%**\n")

    md.append("### Confusion Matrix Interpretation")
    md.append(f"- **True Positives (AF correctly detected):** {TP}")
    md.append(f"- **False Negatives (missed AF):** {FN}")
    md.append(f"- **False Positives (non-AF classified as AF):** {FP}")
    md.append(f"- **True Negatives (correct non-AF rejection):** {TN}")

    return "\n".join(md)


In [65]:
from sklearn.metrics import confusion_matrix, f1_score, recall_score

# --- Build final dataframes (hold-out by patient) ---
final_train_df = dff[dff.patient_id.isin(train_patients)].copy()
final_test_df  = dff[dff.patient_id.isin(test_patients)].copy()

assert set(final_train_df.patient_id).isdisjoint(set(final_test_df.patient_id)), "Leakage in final split!"

# --- Dedicated final loaders (do NOT reuse train_loader from CV) ---
final_train_loader = DataLoader(
    ECGDataset(final_train_df),
    batch_size=16,
    shuffle=True,
    num_workers=0
)

final_test_loader = DataLoader(
    ECGDataset(final_test_df),
    batch_size=16,
    shuffle=False,
    num_workers=0
)

# --- Fresh model ---
model = ECG_CNN_LSTM(num_classes=2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# --- Class weights from FINAL TRAIN RECORD labels ---
counts = np.bincount(final_train_df["label_id"].values, minlength=2)
w0 = 1.0
w1 = counts[0] / max(counts[1], 1)
class_weights = torch.tensor([w0, w1], dtype=torch.float32).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights)

print("Final train record counts [nonAF, AF]:", counts, "| weights:", [float(w0), float(w1)])

# --- Train ---
EPOCHS_FINAL = 4
for epoch in range(EPOCHS_FINAL):
    train_loss, train_acc = train_one_epoch(model, final_train_loader, optimizer, criterion, device)
    print(f"Epoch {epoch+1:02d} | Train Loss {train_loss:.4f} | Train Acc {train_acc:.3f}")

# --- Test ---
model.eval()
y_true, y_pred = [], []

with torch.no_grad():
    for X, y in final_test_loader:
        X = X.to(device)
        out = model(X)
        preds = out.argmax(dim=1).cpu().numpy()
        y_pred.extend(preds.tolist())
        y_true.extend(y.numpy().tolist())

y_true = np.array(y_true)
y_pred = np.array(y_pred)

cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
test_acc = (y_true == y_pred).mean()
test_f1 = f1_score(y_true, y_pred, zero_division=0)
test_rec_af = recall_score(y_true, y_pred, pos_label=1, zero_division=0)

print("\n===== FINAL TEST RESULT =====")
print("Test accuracy :", float(test_acc))
print("Test F1       :", float(test_f1))
print("AF Recall     :", float(test_rec_af))
print("Confusion matrix [[TN FP],[FN TP]]:\n", cm)

display(Markdown(confusion_matrix_to_markdown_binary(cm)))


Final train record counts [nonAF, AF]: [1223  593] | weights: [1.0, 2.0623946037099494]
Epoch 01 | Train Loss 0.6219 | Train Acc 0.660
Epoch 02 | Train Loss 0.5569 | Train Acc 0.719
Epoch 03 | Train Loss 0.5080 | Train Acc 0.784
Epoch 04 | Train Loss 0.6532 | Train Acc 0.675

===== FINAL TEST RESULT =====
Test accuracy : 0.7223719676549866
Test F1       : 0.5795918367346938
AF Recall     : 0.5772357723577236
Confusion matrix [[TN FP],[FN TP]]:
 [[197  51]
 [ 52  71]]


## Model Evaluation Results (Binary AF Detection)

### Overall Performance
- **Total test samples:** 371
- **Overall accuracy:** **72.2%**

### Binary Classification Metrics
- **AF sensitivity (recall):** **57.7%**
- **Non-AF specificity:** **79.4%**
- **False AF alarm rate:** **20.6%**

### Confusion Matrix Interpretation
- **True Positives (AF correctly detected):** 71
- **False Negatives (missed AF):** 52
- **False Positives (non-AF classified as AF):** 51
- **True Negatives (correct non-AF rejection):** 197