# This Expermental take CNN-adrian with LSTM model
>CNN Learn feature of ECG morphology
>LSTM deccide the AFIB, NORM, AFLAT, OTHER 

In [58]:
#Dataset loader & preparation
import os
import pandas as pd
import ast
import random
from collections import Counter
import wfdb
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from pathlib import Path
from IPython.display import Markdown, display

dff = pd.read_csv("../../../data/ptbxl_database.csv")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


print(torch.__file__)
print(torch.__version__)
print(torch.version.cuda)
print(torch.cuda.is_available())

Using device: cuda
c:\Users\arjan\Documents\GitHub\SEARCH_AF_detection_OsloMet_BachelorGroup\venv\Lib\site-packages\torch\__init__.py
2.5.1+cu121
12.1
True


In [59]:
TARGET_LABELS = {"NORM", "AFIB", "AFLT"}


norm_ids  = []
afib_ids  = []
aflt_ids  = []
other_ids = []


LABEL_MAP = {
    "NORM": 0,
    "AFIB": 1,
    "AFLT": 2,
    "OTHER": 3
}




In [60]:
def has_repeated_label(label_sets):
    # label_sets = list of sets, one per ECG
    all_labels = []
    for s in label_sets:
        all_labels.extend(list(s))

    counts = Counter(all_labels)

    # label appears in 2 or more recordings
    return any(v >= 2 for v in counts.values())



def extract_labels(scp_codes):
    codes = ast.literal_eval(scp_codes)
    return set(codes.keys()) & TARGET_LABELS

def select_from_pool(pool, label, n, seed=42):
    """
    Select n ECG IDs from pool[label] and REMOVE them.
    """
    random.seed(seed)

    available = list(pool[label])

    if len(available) < n:
        raise ValueError(
            f"Not enough ECGs for label {label}. "
            f"Requested {n}, available {len(available)}"
        )

    selected = random.sample(available, n)

    # remove selected from pool
    pool[label] -= set(selected)

    return selected

In [61]:
#FASTER for loading ECG files
def build_ecg_index(root="../../../data/records500", suffix="_hr"):
    index = {}
    for p in Path(root).rglob(f"*{suffix}.hea"):
        ecg_id = int(p.stem.replace(suffix, ""))
        index[ecg_id] = str(p.with_suffix(""))
    return index

ECG_INDEX = build_ecg_index()

def load_ecg_fast(ecg_id):
    record = wfdb.rdrecord(ECG_INDEX[ecg_id])
    return record.p_signal


In [62]:
def build_arrays(id_label_list):
    X, y = [], []
    for ecg_id, label in id_label_list:
        X.append(load_ecg_fast(ecg_id))
        y.append(label)
    return np.array(X, dtype=np.float32), np.array(y, dtype=np.int64)

In [63]:
#Ensure if patient_id have different ECG records, all records are in the same set
ecg_records = dff["ecg_id"].count() 
print("Total ECG records:", ecg_records)
patient_ids = dff["patient_id"].value_counts()
multiple_recordings = patient_ids[patient_ids > 1]
print("Total patients:", dff["patient_id"].nunique())
print("Patients with multiple recordings:", len(multiple_recordings))
print("Max recordings for one patient:", multiple_recordings.max())

def get_label(scp_codes):
    if "AFIB" in scp_codes:
        return 1
    if "NORM" in scp_codes:
        return 0
    if "AFLT" in scp_codes:
        return 2
    
    if "NDT" in scp_codes:
        return 4
    if "NST_" in scp_codes:
        return 5
    if "SVARR" in scp_codes:
        return 6
    if "SVTAC" in scp_codes:
        return 7
    if "PAC" in scp_codes:
        return 8
    return None

dff["label"] = dff["scp_codes"].apply(lambda x: get_label(ast.literal_eval(x)))
label_counts_per_patient = dff.groupby("patient_id")["label"].nunique()
patients_with_label_change = label_counts_per_patient[label_counts_per_patient > 1]

print("Patients with multiple labels:", len(patients_with_label_change))
print(
    "Percentage:",
    f"{round(100 * len(patients_with_label_change) / dff['patient_id'].nunique(), 1)}%"
)


dff["target_labels"] = dff["scp_codes"].apply(extract_labels)
dff_target = dff[dff["target_labels"].apply(len) > 0]


patient_groups = dff_target.groupby("patient_id")["target_labels"]


patients_with_repeated_label = patient_groups.apply(has_repeated_label)
# patients with more than one ECG
patient_target_counts = dff_target["patient_id"].value_counts()

patients_with_multiple_target_ecgs = patient_target_counts[
    patient_target_counts > 1
].index
dff_final = dff_target[
    dff_target["patient_id"].isin(patients_with_multiple_target_ecgs)
]




Total ECG records: 21837
Total patients: 18885
Patients with multiple recordings: 2127
Max recordings for one patient: 10
Patients with multiple labels: 269
Percentage: 1.4%


In [64]:
dataset_root = "../../../data/"
df = pd.read_csv(os.path.join(dataset_root, "ptbxl_database.csv"))
for _, row in df.iterrows():
    
    scp_codes = ast.literal_eval(row["scp_codes"])
    label = get_label(scp_codes)
    ecg_id = row["ecg_id"]
    p_id = row["patient_id"]
    if p_id  in dff_final["patient_id"].values:
        continue

    if label is None:
        continue


    if label == 0:
        norm_ids.append(ecg_id)

    elif label == 1:
        afib_ids.append(ecg_id)

    elif label == 2:
        aflt_ids.append(ecg_id)

    else:
        other_ids.append(ecg_id)
print("NORM :", len(norm_ids))
print("AFIB :", len(afib_ids))
print("AFLT :", len(aflt_ids))
print("OTHER:", len(other_ids))
print("TOTAL:", len(norm_ids) + len(afib_ids) + len(aflt_ids) + len(other_ids))


NORM : 8329
AFIB : 1027
AFLT : 32
OTHER: 2394
TOTAL: 11782


In [34]:

class CNNFeatureExtractor(nn.Module):
    """
    Learns ECG morphology:
    QRS complex shape, amplitude, local waveform patterns
    """

    def __init__(self):
        super().__init__()

        self.cnn = nn.Sequential(
            nn.Conv1d(12, 32, kernel_size=7, padding=3),
            nn.BatchNorm1d(32),
            nn.ReLU(),

            nn.Conv1d(32, 64, kernel_size=5, padding=2),
            nn.BatchNorm1d(64),
            nn.ReLU(),

            nn.MaxPool1d(2)
        )

    def forward(self, x):
        # x: (batch, time, leads)
        x = x.permute(0, 2, 1)   # → (batch, leads, time)
        x = self.cnn(x)
        x = x.permute(0, 2, 1)   # → (batch, time, features)
        return x


In [35]:
class ECG_CNN_LSTM(nn.Module):
    def __init__(self, num_classes=4):
        super().__init__()

        self.cnn = CNNFeatureExtractor()   # from main_cnn
        self.lstm = nn.LSTM(
            input_size=64,
            hidden_size=128,
            num_layers=2,
            batch_first=True,
            dropout=0.3
        )

        self.fc = nn.Linear(128, num_classes)

    def forward(self, x):
        # morphology learning
        x = self.cnn(x)

        # rhythm learning
        out, _ = self.lstm(x)

        # temporal pooling (VERY IMPORTANT)
        out = out.mean(dim=1)

        return self.fc(out)


In [37]:


def sample_ids(ids, n, seed=42):
    ids = np.array(ids)
    if len(ids) < n:
        raise ValueError(f"Requested {n}, but only {len(ids)} available")
    rng = np.random.default_rng(seed)
    return rng.choice(ids, size=n, replace=False)


In [65]:
ecg_pool = {
    0: set(norm_ids),
    1: set(afib_ids),
    2: set(aflt_ids),
    3: set(other_ids)
}

print("Initial pool sizes:")
for k, v in ecg_pool.items():
    print(k, len(v))

TRAIN_COUNTS = {
    0: 1500,  # NORM
    1: 500,   # AFIB
    2: 15,  # AFLT
    3: 500    # OTHER
}
TEST_COUNTS = {
    0: 300,
    1: 100,
    2: 15,   
    3: 100
}


Initial pool sizes:
0 8329
1 1027
2 32
3 2394


In [66]:
train_ids = []

for label, n in TRAIN_COUNTS.items():
    ids = select_from_pool(ecg_pool, label, n)
    train_ids.extend([(eid, label) for eid in ids])


In [67]:
test_ids = []

for label, n in TEST_COUNTS.items():
    ids = select_from_pool(ecg_pool, label, n)
    test_ids.extend([(eid, label) for eid in ids])


### Verifiction of no mixing of traingin and testing

In [68]:
train_set = set(eid for eid, _ in train_ids)
test_set  = set(eid for eid, _ in test_ids)

print("Overlap:", len(train_set & test_set))


Overlap: 0


### Loader of ECG array

In [69]:
def build_arrays(id_label_list):
    X, y = [], []
    for ecg_id, label in id_label_list:
        X.append(load_ecg_fast(ecg_id))
        y.append(label)
    return np.array(X, dtype=np.float32), np.array(y, dtype=np.int64)


In [70]:
X_train, y_train = build_arrays(train_ids)
X_test,  y_test  = build_arrays(test_ids)


In [71]:


class ECGDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


In [72]:
train_dataset = ECGDataset(X_train, y_train)
test_dataset  = ECGDataset(X_test, y_test)

train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=32,
    shuffle=False
)

In [73]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ECG_CNN_LSTM(num_classes=4).to(device)


In [74]:
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

criterion = torch.nn.CrossEntropyLoss()
