# This Expermental take CNN-adrian with LSTM model
>CNN Learn feature of ECG morphology
>LSTM deccide the AFIB, NORM, AFLAT, OTHER 

In [81]:
#Dataset loader & preparation
import os
import pandas as pd
import ast
import random
from collections import Counter
import wfdb
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from pathlib import Path
from IPython.display import Markdown, display

dff = pd.read_csv("../../../data/ptbxl_database.csv")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


print(torch.__file__)
print(torch.__version__)
print(torch.version.cuda)
print(torch.cuda.is_available())

Using device: cuda
c:\Users\arjan\Documents\GitHub\SEARCH_AF_detection_OsloMet_BachelorGroup\venv\Lib\site-packages\torch\__init__.py
2.5.1+cu121
12.1
True


In [82]:
TARGET_LABELS = {"NORM", "AFIB", "AFLT"}


norm_ids  = []
afib_ids  = []
aflt_ids  = []
other_ids = []


LABEL_MAP = {
    "NORM": 0,
    "AFIB": 1,
    "AFLT": 2,
    "OTHER": 3
}




In [83]:
def has_repeated_label(label_sets):
    # label_sets = list of sets, one per ECG
    all_labels = []
    for s in label_sets:
        all_labels.extend(list(s))

    counts = Counter(all_labels)

    # label appears in 2 or more recordings
    return any(v >= 2 for v in counts.values())



def extract_labels(scp_codes):
    codes = ast.literal_eval(scp_codes)
    return set(codes.keys()) & TARGET_LABELS

def select_from_pool(pool, label, n, seed=42):
    """
    Select n ECG IDs from pool[label] and REMOVE them.
    """
    random.seed(seed)

    available = list(pool[label])

    if len(available) < n:
        raise ValueError(
            f"Not enough ECGs for label {label}. "
            f"Requested {n}, available {len(available)}"
        )

    selected = random.sample(available, n)

    # remove selected from pool
    pool[label] -= set(selected)

    return selected

In [84]:
#FASTER for loading ECG files
def build_ecg_index(root="../../../data/records500", suffix="_hr"):
    index = {}
    for p in Path(root).rglob(f"*{suffix}.hea"):
        ecg_id = int(p.stem.replace(suffix, ""))
        index[ecg_id] = str(p.with_suffix(""))
    return index

ECG_INDEX = build_ecg_index()

def load_ecg_fast(ecg_id):
    record = wfdb.rdrecord(ECG_INDEX[ecg_id])
    return record.p_signal


In [85]:
def build_arrays(id_label_list):
    X, y = [], []
    for ecg_id, label in id_label_list:
        X.append(load_ecg_fast(ecg_id))
        y.append(label)
    return np.array(X, dtype=np.float32), np.array(y, dtype=np.int64)

In [86]:
#Ensure if patient_id have different ECG records, all records are in the same set
ecg_records = dff["ecg_id"].count() 
print("Total ECG records:", ecg_records)
patient_ids = dff["patient_id"].value_counts()
multiple_recordings = patient_ids[patient_ids > 1]
print("Total patients:", dff["patient_id"].nunique())
print("Patients with multiple recordings:", len(multiple_recordings))
print("Max recordings for one patient:", multiple_recordings.max())

def get_label(scp_codes):
    if "AFIB" in scp_codes:
        return 1
    if "NORM" in scp_codes:
        return 0
    if "AFLT" in scp_codes:
        return 2
    
    if "NDT" in scp_codes:
        return 4
    if "NST_" in scp_codes:
        return 5
    if "SVARR" in scp_codes:
        return 6
    if "SVTAC" in scp_codes:
        return 7
    if "PAC" in scp_codes:
        return 8
    return None

dff["label"] = dff["scp_codes"].apply(lambda x: get_label(ast.literal_eval(x)))
label_counts_per_patient = dff.groupby("patient_id")["label"].nunique()
patients_with_label_change = label_counts_per_patient[label_counts_per_patient > 1]

print("Patients with multiple labels:", len(patients_with_label_change))
print(
    "Percentage:",
    f"{round(100 * len(patients_with_label_change) / dff['patient_id'].nunique(), 1)}%"
)


dff["target_labels"] = dff["scp_codes"].apply(extract_labels)
dff_target = dff[dff["target_labels"].apply(len) > 0]


patient_groups = dff_target.groupby("patient_id")["target_labels"]


patients_with_repeated_label = patient_groups.apply(has_repeated_label)
# patients with more than one ECG
patient_target_counts = dff_target["patient_id"].value_counts()

patients_with_multiple_target_ecgs = patient_target_counts[
    patient_target_counts > 1
].index
dff_final = dff_target[
    dff_target["patient_id"].isin(patients_with_multiple_target_ecgs)
]




Total ECG records: 21837
Total patients: 18885
Patients with multiple recordings: 2127
Max recordings for one patient: 10
Patients with multiple labels: 269
Percentage: 1.4%


In [87]:
dataset_root = "../../../data/"
df = pd.read_csv(os.path.join(dataset_root, "ptbxl_database.csv"))
for _, row in df.iterrows():
    
    scp_codes = ast.literal_eval(row["scp_codes"])
    label = get_label(scp_codes)
    ecg_id = row["ecg_id"]
    p_id = row["patient_id"]
    if p_id  in dff_final["patient_id"].values:
        continue

    if label is None:
        continue


    if label == 0:
        norm_ids.append(ecg_id)

    elif label == 1:
        afib_ids.append(ecg_id)

    elif label == 2:
        aflt_ids.append(ecg_id)

    else:
        other_ids.append(ecg_id)
print("NORM :", len(norm_ids))
print("AFIB :", len(afib_ids))
print("AFLT :", len(aflt_ids))
print("OTHER:", len(other_ids))
print("TOTAL:", len(norm_ids) + len(afib_ids) + len(aflt_ids) + len(other_ids))


NORM : 8329
AFIB : 1027
AFLT : 32
OTHER: 2394
TOTAL: 11782


In [88]:

class CNNFeatureExtractor(nn.Module):
    """
    Learns ECG morphology:
    QRS complex shape, amplitude, local waveform patterns
    """

    def __init__(self):
        super().__init__()

        self.cnn = nn.Sequential(
            nn.Conv1d(12, 32, kernel_size=7, padding=3),
            nn.BatchNorm1d(32),
            nn.ReLU(),

            nn.Conv1d(32, 64, kernel_size=5, padding=2),
            nn.BatchNorm1d(64),
            nn.ReLU(),

            nn.MaxPool1d(2)
        )

    def forward(self, x):
        # x: (batch, time, leads)
        x = x.permute(0, 2, 1)   # → (batch, leads, time)
        x = self.cnn(x)
        x = x.permute(0, 2, 1)   # → (batch, time, features)
        return x


In [89]:
class ECG_CNN_LSTM(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()

        self.cnn = CNNFeatureExtractor()   # from main_cnn
        self.lstm = nn.LSTM(
            input_size=64,
            hidden_size=128,
            num_layers=2,
            batch_first=True,
            dropout=0.3
        )

        self.fc = nn.Linear(128, num_classes)

    def forward(self, x):
        # morphology learning
        x = self.cnn(x)

        # rhythm learning
        out, _ = self.lstm(x)

        # temporal pooling (VERY IMPORTANT)
        out = out.mean(dim=1)

        return self.fc(out)


In [90]:


def sample_ids(ids, n, seed=42):
    ids = np.array(ids)
    if len(ids) < n:
        raise ValueError(f"Requested {n}, but only {len(ids)} available")
    rng = np.random.default_rng(seed)
    return rng.choice(ids, size=n, replace=False)


In [91]:
ecg_pool = {
    0: set(norm_ids),
    1: set(afib_ids),
    2: set(aflt_ids),
    3: set(other_ids)
}

print("Initial pool sizes:")
for k, v in ecg_pool.items():
    print(k, len(v))

TRAIN_COUNTS = {
    0: 1500,  # NORM
    1: 500,   # AFIB
    
}
TEST_COUNTS = {
    0: 300,
    1: 200,
    2: 25,   
    3: 300
}


Initial pool sizes:
0 8329
1 1027
2 32
3 2394


In [92]:
train_ids = []

for label, n in TRAIN_COUNTS.items():
    ids = select_from_pool(ecg_pool, label, n)
    train_ids.extend([(eid, label) for eid in ids])


In [93]:
test_ids = []

for label, n in TEST_COUNTS.items():
    ids = select_from_pool(ecg_pool, label, n)
    test_ids.extend([(eid, label) for eid in ids])


### Verifiction of no mixing of traingin and testing

In [94]:
train_set = set(eid for eid, _ in train_ids)
test_set  = set(eid for eid, _ in test_ids)

print("Overlap:", len(train_set & test_set))


Overlap: 0


### Loader of ECG array

In [95]:
def build_arrays(id_label_list):
    X, y = [], []
    for ecg_id, label in id_label_list:
        X.append(load_ecg_fast(ecg_id))
        y.append(label)
    return np.array(X, dtype=np.float32), np.array(y, dtype=np.int64)


In [96]:
X_train, y_train = build_arrays(train_ids)
X_test,  y_test  = build_arrays(test_ids)


In [97]:


class ECGDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


In [98]:
train_dataset = ECGDataset(X_train, y_train)
test_dataset  = ECGDataset(X_test, y_test)

train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=32,
    shuffle=False
)

In [99]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ECG_CNN_LSTM(num_classes=2).to(device)


In [100]:
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

criterion = torch.nn.CrossEntropyLoss()


### Training

In [101]:
train_model(
    model=model,
    loader=train_loader,
    optimizer=optimizer,
    criterion=criterion,
    device=device,
    epochs=20
)


Epoch 01 | Loss: 0.5438 | Train Acc: 0.724
Epoch 02 | Loss: 0.3626 | Train Acc: 0.842
Epoch 03 | Loss: 0.3135 | Train Acc: 0.876
Epoch 04 | Loss: 0.2762 | Train Acc: 0.894
Epoch 05 | Loss: 0.3181 | Train Acc: 0.874
Epoch 06 | Loss: 0.2537 | Train Acc: 0.899
Epoch 07 | Loss: 0.2582 | Train Acc: 0.902
Epoch 08 | Loss: 0.2516 | Train Acc: 0.906
Epoch 09 | Loss: 0.2338 | Train Acc: 0.909
Epoch 10 | Loss: 0.2399 | Train Acc: 0.912
Epoch 11 | Loss: 0.2502 | Train Acc: 0.897
Epoch 12 | Loss: 0.2329 | Train Acc: 0.909
Epoch 13 | Loss: 0.2597 | Train Acc: 0.901
Epoch 14 | Loss: 0.3391 | Train Acc: 0.847
Epoch 15 | Loss: 0.3639 | Train Acc: 0.859
Epoch 16 | Loss: 0.2593 | Train Acc: 0.907
Epoch 17 | Loss: 0.2573 | Train Acc: 0.905
Epoch 18 | Loss: 0.2561 | Train Acc: 0.914
Epoch 19 | Loss: 0.2520 | Train Acc: 0.912
Epoch 20 | Loss: 0.3027 | Train Acc: 0.892


### Testing

In [102]:
def test_model(model, loader, device):
    model.eval()  

    y_true = []
    y_pred = []

    with torch.no_grad(): 
        for X_batch, y_batch in loader:
            X_batch = X_batch.to(device)
            y_batch = y_batch.to(device)

            logits = model(X_batch)
            preds = logits.argmax(dim=1)

            y_true.extend(y_batch.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())

    return y_true, y_pred


In [103]:
y_true, y_pred = test_model(model, test_loader, device)

y_true_af = np.array([1 if y == 1 else 0 for y in y_true])
y_pred_af = np.array(y_pred)


In [104]:


labels = [0, 1]
#label_names = ["NORM", "AFIB", "AFLT", "OTHER"]

cm = confusion_matrix(y_true_af, y_pred_af, labels=[0,1])

print("Confusion Matrix:")
print(cm)

print("\nClassification Report:")
print(classification_report(
    y_true_af,
    y_pred_af,
    target_names=["NOT_AF", "AFIB"]
))


Confusion Matrix:
[[448 177]
 [ 51 149]]

Classification Report:
              precision    recall  f1-score   support

      NOT_AF       0.90      0.72      0.80       625
        AFIB       0.46      0.74      0.57       200

    accuracy                           0.72       825
   macro avg       0.68      0.73      0.68       825
weighted avg       0.79      0.72      0.74       825



In [106]:
def confusion_matrix_to_markdown_binary(cm):
    """
    cm format:
        [[TN, FP],
         [FN, TP]]
    """

    cm = np.asarray(cm, dtype=int)

    TN, FP = cm[0]
    FN, TP = cm[1]

    total = cm.sum()
    accuracy = (TP + TN) / total if total > 0 else 0.0

    sensitivity = TP / (TP + FN) if (TP + FN) > 0 else 0.0   # AF recall
    specificity = TN / (TN + FP) if (TN + FP) > 0 else 0.0   # non-AF rejection
    fp_rate = FP / (FP + TN) if (FP + TN) > 0 else 0.0

    md = []

    md.append("## Model Evaluation Results (Binary AF Detection)\n")
    md.append(
        "The CNN–LSTM model was evaluated as a binary atrial fibrillation detector. "
        "Although the test set contained multiple rhythm types, model predictions "
        "were limited to AF versus non-AF, reflecting a realistic clinical screening scenario.\n"
    )

    md.append("### Overall Performance")
    md.append(f"- **Total test samples:** {total}")
    md.append(f"- **Overall accuracy:** **{accuracy*100:.1f}%**\n")

    md.append("### Binary Classification Metrics")
    md.append(f"- **AF sensitivity (recall):** **{sensitivity*100:.1f}%**")
    md.append(f"- **Non-AF specificity:** **{specificity*100:.1f}%**")
    md.append(f"- **False AF alarm rate:** **{fp_rate*100:.1f}%**\n")

    md.append("### Confusion Matrix Interpretation")
    md.append(f"- **True Positives (AF correctly detected):** {TP}")
    md.append(f"- **False Negatives (missed AF):** {FN}")
    md.append(f"- **False Positives (non-AF classified as AF):** {FP}")
    md.append(f"- **True Negatives (correct non-AF rejection):** {TN}")

    return "\n".join(md)


In [107]:
cm = confusion_matrix(y_true_af, y_pred_af, labels=[0,1])

markdown_report = confusion_matrix_to_markdown_binary(cm)
display(Markdown(markdown_report))


## Model Evaluation Results (Binary AF Detection)

The CNN–LSTM model was evaluated as a binary atrial fibrillation detector. Although the test set contained multiple rhythm types, model predictions were limited to AF versus non-AF, reflecting a realistic clinical screening scenario.

### Overall Performance
- **Total test samples:** 825
- **Overall accuracy:** **72.4%**

### Binary Classification Metrics
- **AF sensitivity (recall):** **74.5%**
- **Non-AF specificity:** **71.7%**
- **False AF alarm rate:** **28.3%**

### Confusion Matrix Interpretation
- **True Positives (AF correctly detected):** 149
- **False Negatives (missed AF):** 51
- **False Positives (non-AF classified as AF):** 177
- **True Negatives (correct non-AF rejection):** 448