## This file contains the learning step of LSTM implementation.
### Structure of LSTM
- Linear Part
- Short-term memory (hidden state)
- Long-term memory (cell state)
- Non-linear part


In [134]:
import os
import pandas as pd
import ast
import random
from collections import Counter
import wfdb
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
dff = pd.read_csv("../data/ptbxl_database.csv")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


print(torch.__file__)
print(torch.__version__)
print(torch.version.cuda)
print(torch.cuda.is_available())

Using device: cuda
c:\Users\arjan\Documents\GitHub\SEARCH_AF_detection_OsloMet_BachelorGroup\venv\Lib\site-packages\torch\__init__.py
2.5.1+cu121
12.1
True


In [135]:
#Ensure if patient_id have different ECG records, all records are in the same set
patient_ids = dff["patient_id"].value_counts()
multiple_recordings = patient_ids[patient_ids > 1]
print("Total patients:", dff["patient_id"].nunique())
print("Patients with multiple recordings:", len(multiple_recordings))
print("Max recordings for one patient:", multiple_recordings.max())

def get_label(scp_codes):
    if "AFIB" in scp_codes:
        return 1
    if "NORM" in scp_codes:
        return 0
    if "AFLT" in scp_codes:
        return 2
    
    if "NDT" in scp_codes:
        return 4
    if "NST_" in scp_codes:
        return 5
    if "SVARR" in scp_codes:
        return 6
    if "SVTAC" in scp_codes:
        return 7
    if "PAC" in scp_codes:
        return 8
    return None

dff["label"] = dff["scp_codes"].apply(lambda x: get_label(ast.literal_eval(x)))
label_counts_per_patient = dff.groupby("patient_id")["label"].nunique()
patients_with_label_change = label_counts_per_patient[label_counts_per_patient > 1]

print("Patients with multiple labels:", len(patients_with_label_change))
print(
    "Percentage:",
    f"{round(100 * len(patients_with_label_change) / dff['patient_id'].nunique(), 1)}%"
)
TARGET_LABELS = {"NORM", "AFIB", "AFLT"}
def extract_labels(scp_codes):
    codes = ast.literal_eval(scp_codes)
    return set(codes.keys()) & TARGET_LABELS

dff["target_labels"] = dff["scp_codes"].apply(extract_labels)
dff_target = dff[dff["target_labels"].apply(len) > 0]


patient_groups = dff_target.groupby("patient_id")["target_labels"]

def has_repeated_label(label_sets):
    # label_sets = list of sets, one per ECG
    all_labels = []
    for s in label_sets:
        all_labels.extend(list(s))

    counts = Counter(all_labels)

    # label appears in 2 or more recordings
    return any(v >= 2 for v in counts.values())
patients_with_repeated_label = patient_groups.apply(has_repeated_label)
# patients with more than one ECG
patient_target_counts = dff_target["patient_id"].value_counts()

patients_with_multiple_target_ecgs = patient_target_counts[
    patient_target_counts > 1
].index
dff_final = dff_target[
    dff_target["patient_id"].isin(patients_with_multiple_target_ecgs)
]




Total patients: 18885
Patients with multiple recordings: 2127
Max recordings for one patient: 10
Patients with multiple labels: 269
Percentage: 1.4%


In [138]:
output_path = "ptbxl_target_labels_patients_multiple_ecgs.csv"

dff_final.to_csv(output_path, index=False)

print("Saved file:", output_path)
print("Rows saved:", len(dff_final))
print("Patients saved:", dff_final["patient_id"].nunique())




Saved file: ptbxl_target_labels_patients_multiple_ecgs.csv
Rows saved: 1673
Patients saved: 722


In [139]:
patient_record_counts = dff_final["patient_id"].value_counts().reset_index()
patient_record_counts.columns = ["patient_id", "num_recordings"]
print(patient_record_counts.head)

<bound method NDFrame.head of      patient_id  num_recordings
0        8304.0               8
1       15765.0               7
2       17542.0               7
3       12743.0               6
4       13619.0               5
..          ...             ...
717       449.0               2
718     21409.0               2
719     14496.0               2
720      8840.0               2
721     14330.0               2

[722 rows x 2 columns]>


In [140]:
speicific_patient_id = 15765.0
specific_patient_records = dff_final[dff_final["patient_id"] == speicific_patient_id]
print( "Patient ID:", speicific_patient_id)
print("Number of recordings:", len(specific_patient_records))
print(specific_patient_records)


Patient ID: 15765.0
Number of recordings: 7
       ecg_id  patient_id   age  sex  height  weight  nurse  site      device  \
20434   20435     15765.0  15.0    0     NaN     NaN    0.0   0.0  CS100    3   
20448   20449     15765.0  15.0    0     NaN     NaN    0.0   0.0  CS100    3   
20452   20453     15765.0  15.0    0     NaN     NaN    0.0   0.0  CS100    3   
20457   20458     15765.0  15.0    0     NaN     NaN    0.0   0.0  CS100    3   
20464   20465     15765.0  15.0    0     NaN     NaN    0.0   0.0  CS100    3   
20480   20481     15765.0  15.0    0     NaN     NaN    0.0   0.0  CS100    3   
20494   20495     15765.0  15.0    0     NaN     NaN    0.0   0.0  CS100    3   

            recording_date  ... static_noise burst_noise electrodes_problems  \
20434  1999-08-06 11:25:08  ...          NaN         NaN                 NaN   
20448  1999-08-07 13:29:28  ...          NaN         NaN                 NaN   
20452  1999-08-08 18:07:02  ...          NaN         NaN           

In [141]:
norm_ids  = []
afib_ids  = []
aflt_ids  = []
other_ids = []


In [142]:
dataset_root = "../data/"
df = pd.read_csv(os.path.join(dataset_root, "ptbxl_database.csv"))
for _, row in df.iterrows():
    
    scp_codes = ast.literal_eval(row["scp_codes"])
    label = get_label(scp_codes)
    ecg_id = row["ecg_id"]
    p_id = row["patient_id"]
    if p_id  in dff_final["patient_id"].values:
        continue

    if label is None:
        continue


    if label == 0:
        norm_ids.append(ecg_id)

    elif label == 1:
        afib_ids.append(ecg_id)

    elif label == 2:
        aflt_ids.append(ecg_id)

    else:
        other_ids.append(ecg_id)


In [143]:
print("NORM :", len(norm_ids))
print("AFIB :", len(afib_ids))
print("AFLT :", len(aflt_ids))
print("OTHER:", len(other_ids))
print("TOTAL:", len(norm_ids) + len(afib_ids) + len(aflt_ids) + len(other_ids))

'''
before skip unknown labels:
NORM : 9491
AFIB : 1514
AFLT : 56
OTHER: 10776
TOTAL: 21837
'''

NORM : 8329
AFIB : 1027
AFLT : 32
OTHER: 2394
TOTAL: 11782


'\nbefore skip unknown labels:\nNORM : 9491\nAFIB : 1514\nAFLT : 56\nOTHER: 10776\nTOTAL: 21837\n'

In [144]:
assert set(norm_ids).isdisjoint(afib_ids)
assert set(norm_ids).isdisjoint(aflt_ids)
assert set(afib_ids).isdisjoint(aflt_ids)


### Lighweight test on 100 record per each type

In [146]:
def sample_ids(ids, n, seed=42):
    random.seed(seed)
    if len(ids) <= n:
        return ids.copy()
    return random.sample(ids, n)

In [147]:
norm_sampled_ids = sample_ids(norm_ids, 100)
afib_sampled_ids = sample_ids(afib_ids, 101)
aflt_sampled_ids = aflt_ids.copy()
other_sampled_ids = sample_ids(other_ids, 103)
print("Sampled NORM :", len(norm_sampled_ids))
print("Sampled AFIB :", len(afib_sampled_ids))
print("Sampled AFLT :", len(aflt_sampled_ids))
print("Sampled OTHER:", len(other_sampled_ids))
print("Normal example IDs:", norm_sampled_ids, "\nAFBI example IDs:", afib_sampled_ids[:10], "\nAFLT example IDs:", aflt_sampled_ids[:10], "\nOTHER example IDs:", other_sampled_ids[:10])

Sampled NORM : 100
Sampled AFIB : 101
Sampled AFLT : 32
Sampled OTHER: 103
Normal example IDs: [3782, 845, 10650, 9489, 8494, 4936, 3491, 2933, 17212, 1092, 1008, 3149, 8303, 8934, 21667, 897, 7555, 17109, 8394, 18398, 10787, 193, 5896, 17246, 13631, 10770, 5723, 8174, 13463, 3481, 3114, 15282, 3261, 14306, 13756, 10256, 1459, 18992, 4285, 15173, 2652, 11528, 14416, 7205, 2351, 1538, 8702, 11331, 2689, 8938, 3430, 15291, 10776, 18623, 14558, 5980, 14775, 14176, 7978, 10335, 2398, 6274, 9483, 6006, 19144, 15248, 10434, 8341, 12877, 1891, 8757, 1098, 12485, 16237, 10363, 2263, 8040, 12466, 8087, 21290, 16009, 18968, 5091, 10267, 4939, 9584, 10188, 17570, 16159, 14438, 8330, 4873, 20879, 3053, 1589, 3740, 5587, 5904, 17216, 2166] 
AFBI example IDs: [5390, 1419, 12872, 11755, 21804, 3423, 16756, 2527, 15586, 16818] 
AFLT example IDs: [449, 706, 858, 1773, 2430, 2739, 3505, 4874, 4885, 6179] 
OTHER example IDs: [3809, 907, 9749, 8593, 7936, 4665, 3502, 20169, 3069, 15103]


In [148]:
train_ids = (
    norm_sampled_ids + afib_sampled_ids + aflt_sampled_ids + other_sampled_ids
)


afib_set  = set(afib_ids)
aflt_set  = set(aflt_ids)
norm_set  = set(norm_ids)
other_set = set(other_ids)


def get_class(ecg_id):
    if ecg_id in afib_set:
        return 1
    if ecg_id in aflt_set:
        return 2
    if ecg_id in norm_set:
        return 0
    if ecg_id in other_set:
        return 3

label = [get_class(ecg_id) for ecg_id in train_ids]
print("Class distribution in training set:", Counter(label))

Class distribution in training set: Counter({3: 103, 1: 101, 0: 100, 2: 32})


In [149]:
def load_ecg(ecg_id, root="../data/records500", suffix="_hr"):
    """
    Load PTB-XL ECG by ecg_id when files are named like:
    03462_hr.hea / 03462_hr.dat and stored in subdirectories.
    """
    target = f"{ecg_id:05d}{suffix}.hea"

    for dirpath, _, filenames in os.walk(root):
        if target in filenames:
            record_path = os.path.join(dirpath, target[:-4])  # remove .hea
            record = wfdb.rdrecord(record_path)
            return record.p_signal  # (5000, 12)

    raise FileNotFoundError(f"ECG ID {ecg_id} not found under {root}")

In [152]:
def normalize_ecg(ecg):
    mean = ecg.mean(axis=0)
    std = ecg.std(axis=0) + 1e-8
    return (ecg - mean) / std


In [153]:
#Training array
X = []
y = []

for ecg_id in train_ids:
    ecg = load_ecg(ecg_id)
    ecg = normalize_ecg(ecg)
    X.append(ecg)
    y.append(get_class(ecg_id))
    
for i in range(5):
    print(train_ids[i], y[i])


X = np.array(X)   # (N, 5000, 12)
y = np.array(y)   # (N,)
print("Unique labels:", sorted(set(y)))
print("dtype:", y.dtype)


3782 0
845 0
10650 0
9489 0
8494 0
Unique labels: [np.int64(0), np.int64(1), np.int64(2), np.int64(3)]
dtype: int64


In [154]:
class ECGDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [155]:
class ECG_LSTM(nn.Module):
    def __init__(self, input_size=12, hidden_size=64, num_classes=4):
        super().__init__()

        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            batch_first=True
        )

        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        out, _ = self.lstm(x)
        last_out = out[:, -1, :]
        logits = self.fc(last_out)
        return logits

In [156]:
counts = Counter(y)
weights = [1.0 / counts[i] for i in range(4)]
weights = torch.tensor(weights, dtype=torch.float32)

criterion = nn.CrossEntropyLoss(weight=weights)

### Basic model LSTM with only 100 recording of each NROM, AFIB, AFLT, Others

In [None]:
dataset = ECGDataset(X, y)
loader = DataLoader(dataset, batch_size=16, shuffle=True)

model = ECG_LSTM()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(1):
    model.train()
    total_loss = 0.0
    n_batches = 0

    for X_batch, y_batch in loader:
        optimizer.zero_grad()

        logits = model(X_batch)
        loss = criterion(logits, y_batch)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        n_batches += 1


    
    avg_loss = total_loss / n_batches
    print(f"Epoch {epoch+1} | Avg loss: {avg_loss:.4f} Loss : {total_loss:.4f}")
# training is DONE here
torch.save(model.state_dict(), "final_model.pt")



Epoch 1 | Loss: 29.1686


In [158]:
model = ECG_LSTM().to(device)
model.load_state_dict(torch.load("final_model.pt", map_location=device))
model.eval()
print("Model device:", next(model.parameters()).device)


Model device: cuda:0


  model.load_state_dict(torch.load("final_model.pt", map_location=device))


In [159]:
def normalize(ecg):
    mean = ecg.mean(axis=0)
    std = ecg.std(axis=0) + 1e-8
    return (ecg - mean) / std


In [160]:
  # example ECG ID
ecg_id = 763

if ecg_id not in norm_set:
    print(f"{ecg_id} is NOT in norm_set")
else:
    print(f"{ecg_id} IS in norm_set")


763 is NOT in norm_set


In [161]:

ecg = load_ecg(ecg_id)
ecg = normalize(ecg)

X = torch.tensor(ecg, dtype=torch.float32)
X = X.unsqueeze(0)   # (1, 5000, 12)
X = X.to(device)


In [162]:
with torch.no_grad():
    logits = model(X)
    probs = torch.softmax(logits, dim=1)
    pred = torch.argmax(probs, dim=1).item()


In [163]:
label_map = {
    0: "NORM",
    1: "AFIB",
    2: "AFLT",
    3: "OTHER"
}

print("Predicted rhythm:", label_map[pred])
print("Class probabilities:", probs.cpu().numpy())


Predicted rhythm: OTHER
Class probabilities: [[0.26019028 0.22515209 0.24621683 0.2684408 ]]
