## This file contains the learning step of LSTM implementation.
### Structure of LSTM
- Linear Part
- Short-term memory (hidden state)
- Long-term memory (cell state)
- Non-linear part


In [1]:
import os
import pandas as pd
import ast
import random
from collections import Counter
import wfdb
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
dff = pd.read_csv("../data/ptbxl_database.csv")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


print(torch.__file__)
print(torch.__version__)
print(torch.version.cuda)
print(torch.cuda.is_available())

Using device: cuda
c:\Users\arjan\Documents\GitHub\SEARCH_AF_detection_OsloMet_BachelorGroup\venv\Lib\site-packages\torch\__init__.py
2.5.1+cu121
12.1
True


In [None]:
#Ensure if patient_id have different ECG records, all records are in the same set
patient_ids = dff["patient_id"].value_counts()
multiple_recordings = patient_ids[patient_ids > 1]
print("Total patients:", df["patient_id"].nunique())
print("Patients with multiple recordings:", len(multiple_recordings))
print("Max recordings for one patient:", multiple_recordings.max())

def get_label(scp_codes):
    if "AFIB" in scp_codes:
        return 1
    if "NORM" in scp_codes:
        return 0
    if "AFLT" in scp_codes:
        return 2
    
    if "NDT" in scp_codes:
        return 4
    if "NST_" in scp_codes:
        return 5
    if "SVARR" in scp_codes:
        return 6
    if "SVTAC" in scp_codes:
        return 7
    if "PAC" in scp_codes:
        return 8
    return None

dff["label"] = dff["scp_codes"].apply(lambda x: get_label(ast.literal_eval(x)))
label_counts_per_patient = dff.groupby("patient_id")["label"].nunique()
patients_with_label_change = label_counts_per_patient[label_counts_per_patient > 1]
print("Total patients:", df["patient_id"].nunique())
print("Patients with multiple labels:", len(patients_with_label_change))
print(
    "Percentage:",
    str(round(100 * len(patients_with_label_change) / dff["patient_id"].nunique(), 1)) +"%"
)
TARGET_LABELS = {"NORM", "AFIB", "AFLT"}
def extract_labels(scp_codes):
    codes = ast.literal_eval(scp_codes)
    return set(codes.keys() & TARGET_LABELS)

dff["labels"] = dff["scp_codes"].apply(extract_labels)
patient_groups = dff.groupby("patient_id")["labels"]

def has_repeated_label(label_sets):
    # label_sets = list of sets, one per ECG
    all_labels = []
    for s in label_sets:
        all_labels.extend(list(s))

    counts = Counter(all_labels)

    # label appears in 2 or more recordings
    return any(v >= 2 for v in counts.values())
patients_with_repeated_label = patient_groups.apply(has_repeated_label)
# patients with more than one ECG
multi_record_patients = dff["patient_id"].value_counts()
multi_record_patients = multi_record_patients[multi_record_patients > 1].index

# final selection
selected_patients = patients_with_repeated_label[
    patients_with_repeated_label.index.isin(multi_record_patients)
]

selected_patients = selected_patients[selected_patients]




Total patients: 18885
Patients with multiple recordings: 2127
Max recordings for one patient: 10
Total patients: 18885
Patients with multiple labels: 269
Percentage: 1.4%


In [60]:
print("Patients with >1 ECG:", len(multi_record_patients))
print("Patients with repeated label across recordings:", len(selected_patients))

print(
    "Percentage:",
    str(round(
        100 * len(selected_patients) / len(multi_record_patients), 1
    )) + "%"
)
example_patient = selected_patients.index[2]

dff[dff["patient_id"] == example_patient][
    ["ecg_id", "scp_codes"]
]



Patients with >1 ECG: 2127
Patients with repeated label across recordings: 1957
Percentage: 92.0%


Unnamed: 0,ecg_id,scp_codes
6317,6318,"{'IMI': 50.0, 'IRBBB': 100.0, 'LAFB': 100.0, '..."
6324,6325,"{'AMI': 50.0, 'IRBBB': 100.0, 'LAFB': 100.0, '..."


In [3]:
norm_ids  = []
afib_ids  = []
aflt_ids  = []
other_ids = []


In [None]:
dataset_root = "../data/"
df = pd.read_csv(os.path.join(dataset_root, "ptbxl_database.csv"))
for _, row in df.iterrows():
    
    scp_codes = ast.literal_eval(row["scp_codes"])
    label = get_label(scp_codes)
    ecg_id = row["ecg_id"]

    if label is None:
        continue

    if label == 0:
        norm_ids.append(ecg_id)

    elif label == 1:
        afib_ids.append(ecg_id)

    elif label == 2:
        aflt_ids.append(ecg_id)

    else:
        other_ids.append(ecg_id)


In [5]:
print("NORM :", len(norm_ids))
print("AFIB :", len(afib_ids))
print("AFLT :", len(aflt_ids))
print("OTHER:", len(other_ids))
print("TOTAL:", len(norm_ids) + len(afib_ids) + len(aflt_ids) + len(other_ids))

'''
before skip unknown labels:
NORM : 9491
AFIB : 1514
AFLT : 56
OTHER: 10776
TOTAL: 21837
'''

NORM : 9491
AFIB : 1514
AFLT : 56
OTHER: 2442
TOTAL: 13503


'\nbefore skip unknown labels:\nNORM : 9491\nAFIB : 1514\nAFLT : 56\nOTHER: 10776\nTOTAL: 21837\n'

In [6]:
assert set(norm_ids).isdisjoint(afib_ids)
assert set(norm_ids).isdisjoint(aflt_ids)
assert set(afib_ids).isdisjoint(aflt_ids)


### Lighweight test on 100 record per each type

In [7]:
def sample_ids(ids, n, seed=42):
    random.seed(seed)
    if len(ids) <= n:
        return ids.copy()
    return random.sample(ids, n)

In [19]:
norm_sampled_ids = sample_ids(norm_ids, 100)
afib_sampled_ids = sample_ids(afib_ids, 101)
aflt_sampled_ids = aflt_ids.copy()
other_sampled_ids = sample_ids(other_ids, 103)
print("Sampled NORM :", len(norm_sampled_ids))
print("Sampled AFIB :", len(afib_sampled_ids))
print("Sampled AFLT :", len(aflt_sampled_ids))
print("Sampled OTHER:", len(other_sampled_ids))
print("Normal example IDs:", norm_sampled_ids, "\nAFBI example IDs:", afib_sampled_ids[:10], "\nAFLT example IDs:", aflt_sampled_ids[:10], "\nOTHER example IDs:", other_sampled_ids[:10])

Sampled NORM : 100
Sampled AFIB : 101
Sampled AFLT : 56
Sampled OTHER: 103
Normal example IDs: [3462, 772, 9521, 8318, 7583, 4376, 3149, 20114, 2678, 15049, 979, 916, 2890, 7395, 7928, 18367, 809, 20892, 6597, 20094, 14940, 7465, 16155, 9667, 151, 5220, 15083, 11996, 9646, 5013, 7258, 11861, 3140, 2855, 13551, 2973, 12698, 12123, 9120, 1326, 16542, 19724, 3887, 13506, 2424, 20411, 10163, 12789, 21755, 6348, 2168, 1389, 7754, 10020, 2449, 7932, 3102, 13560, 9658, 16310, 12947, 5344, 13149, 12541, 6999, 9229, 2224, 5622, 19627, 8313, 5380, 16648, 13540, 9339, 20666, 7430, 11381, 1718, 7793, 989, 11000, 14231, 9258, 2059, 7070, 21241, 10974, 7126, 18125, 14067, 16516, 4504, 9132, 4378, 8397, 20897, 19844, 9057, 15360, 14179] 
AFBI example IDs: [19774, 4451, 1026, 9241, 8276, 7492, 5286, 21779, 4154, 20488] 
AFLT example IDs: [18, 20, 23, 34, 336, 347, 449, 706, 858, 1173] 
OTHER example IDs: [3739, 880, 9570, 8423, 7755, 4548, 3464, 19767, 3028, 21606]


In [9]:
train_ids = (
    norm_sampled_ids + afib_sampled_ids + aflt_sampled_ids + other_sampled_ids
)


afib_set  = set(afib_ids)
aflt_set  = set(aflt_ids)
norm_set  = set(norm_ids)
other_set = set(other_ids)


def get_class(ecg_id):
    if ecg_id in afib_set:
        return 1
    if ecg_id in aflt_set:
        return 2
    if ecg_id in norm_set:
        return 0
    if ecg_id in other_set:
        return 3

label = [get_class(ecg_id) for ecg_id in train_ids]
print("Class distribution in training set:", Counter(label))

Class distribution in training set: Counter({3: 103, 1: 101, 0: 100, 2: 56})


In [10]:
def load_ecg(ecg_id, root="../data/records500", suffix="_hr"):
    """
    Load PTB-XL ECG by ecg_id when files are named like:
    03462_hr.hea / 03462_hr.dat and stored in subdirectories.
    """
    target = f"{ecg_id:05d}{suffix}.hea"

    for dirpath, _, filenames in os.walk(root):
        if target in filenames:
            record_path = os.path.join(dirpath, target[:-4])  # remove .hea
            record = wfdb.rdrecord(record_path)
            return record.p_signal  # (5000, 12)

    raise FileNotFoundError(f"ECG ID {ecg_id} not found under {root}")

In [11]:
#Training array
X = []
y = []

for ecg_id in train_ids:
    ecg = load_ecg(ecg_id)
    X.append(ecg)
    y.append(get_class(ecg_id))

X = np.array(X)   # (N, 5000, 12)
y = np.array(y)   # (N,)


In [12]:
class ECGDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [13]:
class ECG_LSTM(nn.Module):
    def __init__(self, input_size=12, hidden_size=64, num_classes=4):
        super().__init__()

        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            batch_first=True
        )

        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        out, _ = self.lstm(x)
        last_out = out[:, -1, :]
        logits = self.fc(last_out)
        return logits

In [14]:
counts = Counter(y)
weights = [1.0 / counts[i] for i in range(4)]
weights = torch.tensor(weights, dtype=torch.float32)

criterion = nn.CrossEntropyLoss(weight=weights)

In [15]:
dataset = ECGDataset(X, y)
loader = DataLoader(dataset, batch_size=16, shuffle=True)

model = ECG_LSTM()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(20):
    model.train()
    total_loss = 0

    for X_batch, y_batch in loader:
        optimizer.zero_grad()

        logits = model(X_batch)
        loss = criterion(logits, y_batch)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()


    print(f"Epoch {epoch+1} | Loss: {total_loss:.4f}")
# training is DONE here
torch.save(model.state_dict(), "final_model.pt")



Epoch 1 | Loss: 32.0717
Epoch 2 | Loss: 31.8723
Epoch 3 | Loss: 31.8126
Epoch 4 | Loss: 31.7534
Epoch 5 | Loss: 31.6539
Epoch 6 | Loss: 31.5467
Epoch 7 | Loss: 31.3019
Epoch 8 | Loss: 31.3382
Epoch 9 | Loss: 31.3360
Epoch 10 | Loss: 31.1882
Epoch 11 | Loss: 31.0542
Epoch 12 | Loss: 30.9654
Epoch 13 | Loss: 30.8390
Epoch 14 | Loss: 30.7120
Epoch 15 | Loss: 30.7056
Epoch 16 | Loss: 30.4555
Epoch 17 | Loss: 30.3701
Epoch 18 | Loss: 29.8953
Epoch 19 | Loss: 29.9741
Epoch 20 | Loss: 29.9096


In [17]:
model = ECG_LSTM().to(device)
model.load_state_dict(torch.load("final_model.pt", map_location=device))
model.eval()
print("Model device:", next(model.parameters()).device)


Model device: cuda:0


  model.load_state_dict(torch.load("final_model.pt", map_location=device))


In [18]:
def normalize(ecg):
    mean = ecg.mean(axis=0)
    std = ecg.std(axis=0) + 1e-8
    return (ecg - mean) / std


In [44]:
  # example ECG ID
ecg_id = 636

if ecg_id not in norm_set:
    print(f"{ecg_id} is NOT in norm_set")
else:
    print(f"{ecg_id} IS in norm_set")


636 is NOT in norm_set


In [45]:

ecg = load_ecg(ecg_id)
ecg = normalize(ecg)

X = torch.tensor(ecg, dtype=torch.float32)
X = X.unsqueeze(0)   # (1, 5000, 12)
X = X.to(device)


In [46]:
with torch.no_grad():
    logits = model(X)
    probs = torch.softmax(logits, dim=1)
    pred = torch.argmax(probs, dim=1).item()


In [47]:
label_map = {
    0: "NORM",
    1: "AFIB",
    2: "AFLT",
    3: "OTHER"
}

print("Predicted rhythm:", label_map[pred])
print("Class probabilities:", probs.cpu().numpy())


Predicted rhythm: AFLT
Class probabilities: [[0.37144315 0.0950818  0.4175679  0.11590715]]
