## This file contains the learning step of LSTM implementation.
### Structure of LSTM
- Linear Part
- Short-term memory (hidden state)
- Long-term memory (cell state)
- Non-linear part


In [54]:
import os
import pandas as pd
import ast
import random
from collections import Counter
import wfdb
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from pathlib import Path

from IPython.display import Markdown, display
dff = pd.read_csv("../../../data/ptbxl_database.csv")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


print(torch.__file__)
print(torch.__version__)
print(torch.version.cuda)
print(torch.cuda.is_available())

Using device: cuda
c:\Users\arjan\Documents\GitHub\SEARCH_AF_detection_OsloMet_BachelorGroup\venv\Lib\site-packages\torch\__init__.py
2.5.1+cu121
12.1
True


In [4]:
def build_ecg_index(root="../../../data/records500", suffix="_hr"):
    index = {}
    for p in Path(root).rglob(f"*{suffix}.hea"):
        ecg_id = int(p.stem.replace(suffix, ""))
        index[ecg_id] = str(p.with_suffix(""))
    return index

In [5]:
#Ensure if patient_id have different ECG records, all records are in the same set
patient_ids = dff["patient_id"].value_counts()
multiple_recordings = patient_ids[patient_ids > 1]
print("Total patients:", dff["patient_id"].nunique())
print("Patients with multiple recordings:", len(multiple_recordings))
print("Max recordings for one patient:", multiple_recordings.max())

def get_label(scp_codes):
    if "AFIB" in scp_codes:
        return 1
    if "NORM" in scp_codes:
        return 0
    if "AFLT" in scp_codes:
        return 2
    
    if "NDT" in scp_codes:
        return 4
    if "NST_" in scp_codes:
        return 5
    if "SVARR" in scp_codes:
        return 6
    if "SVTAC" in scp_codes:
        return 7
    if "PAC" in scp_codes:
        return 8
    return None

dff["label"] = dff["scp_codes"].apply(lambda x: get_label(ast.literal_eval(x)))
label_counts_per_patient = dff.groupby("patient_id")["label"].nunique()
patients_with_label_change = label_counts_per_patient[label_counts_per_patient > 1]

print("Patients with multiple labels:", len(patients_with_label_change))
print(
    "Percentage:",
    f"{round(100 * len(patients_with_label_change) / dff['patient_id'].nunique(), 1)}%"
)
TARGET_LABELS = {"NORM", "AFIB", "AFLT"}
def extract_labels(scp_codes):
    codes = ast.literal_eval(scp_codes)
    return set(codes.keys()) & TARGET_LABELS

dff["target_labels"] = dff["scp_codes"].apply(extract_labels)
dff_target = dff[dff["target_labels"].apply(len) > 0]


patient_groups = dff_target.groupby("patient_id")["target_labels"]

def has_repeated_label(label_sets):
    # label_sets = list of sets, one per ECG
    all_labels = []
    for s in label_sets:
        all_labels.extend(list(s))

    counts = Counter(all_labels)

    # label appears in 2 or more recordings
    return any(v >= 2 for v in counts.values())
patients_with_repeated_label = patient_groups.apply(has_repeated_label)
# patients with more than one ECG
patient_target_counts = dff_target["patient_id"].value_counts()

patients_with_multiple_target_ecgs = patient_target_counts[
    patient_target_counts > 1
].index
dff_final = dff_target[
    dff_target["patient_id"].isin(patients_with_multiple_target_ecgs)
]




Total patients: 18885
Patients with multiple recordings: 2127
Max recordings for one patient: 10
Patients with multiple labels: 269
Percentage: 1.4%


In [6]:
output_path = "ptbxl_target_labels_patients_multiple_ecgs.csv"

dff_final.to_csv(output_path, index=False)

print("Saved file:", output_path)
print("Rows saved:", len(dff_final))
print("Patients saved:", dff_final["patient_id"].nunique())




Saved file: ptbxl_target_labels_patients_multiple_ecgs.csv
Rows saved: 1673
Patients saved: 722


In [7]:
patient_record_counts = dff_final["patient_id"].value_counts().reset_index()
patient_record_counts.columns = ["patient_id", "num_recordings"]
print(patient_record_counts.head)

<bound method NDFrame.head of      patient_id  num_recordings
0        8304.0               8
1       15765.0               7
2       17542.0               7
3       12743.0               6
4       13619.0               5
..          ...             ...
717       449.0               2
718     21409.0               2
719     14496.0               2
720      8840.0               2
721     14330.0               2

[722 rows x 2 columns]>


In [8]:
speicific_patient_id = 15765.0
specific_patient_records = dff_final[dff_final["patient_id"] == speicific_patient_id]
print( "Patient ID:", speicific_patient_id)
print("Number of recordings:", len(specific_patient_records))
print(specific_patient_records)


Patient ID: 15765.0
Number of recordings: 7
       ecg_id  patient_id   age  sex  height  weight  nurse  site      device  \
20434   20435     15765.0  15.0    0     NaN     NaN    0.0   0.0  CS100    3   
20448   20449     15765.0  15.0    0     NaN     NaN    0.0   0.0  CS100    3   
20452   20453     15765.0  15.0    0     NaN     NaN    0.0   0.0  CS100    3   
20457   20458     15765.0  15.0    0     NaN     NaN    0.0   0.0  CS100    3   
20464   20465     15765.0  15.0    0     NaN     NaN    0.0   0.0  CS100    3   
20480   20481     15765.0  15.0    0     NaN     NaN    0.0   0.0  CS100    3   
20494   20495     15765.0  15.0    0     NaN     NaN    0.0   0.0  CS100    3   

            recording_date  ... static_noise burst_noise electrodes_problems  \
20434  1999-08-06 11:25:08  ...          NaN         NaN                 NaN   
20448  1999-08-07 13:29:28  ...          NaN         NaN                 NaN   
20452  1999-08-08 18:07:02  ...          NaN         NaN           

In [9]:
norm_ids  = []
afib_ids  = []
aflt_ids  = []
other_ids = []


In [10]:
dataset_root = "../../../data/"
df = pd.read_csv(os.path.join(dataset_root, "ptbxl_database.csv"))
for _, row in df.iterrows():
    
    scp_codes = ast.literal_eval(row["scp_codes"])
    label = get_label(scp_codes)
    ecg_id = row["ecg_id"]
    p_id = row["patient_id"]
    if p_id  in dff_final["patient_id"].values:
        continue

    if label is None:
        continue


    if label == 0:
        norm_ids.append(ecg_id)

    elif label == 1:
        afib_ids.append(ecg_id)

    elif label == 2:
        aflt_ids.append(ecg_id)

    else:
        other_ids.append(ecg_id)


In [11]:
print("NORM :", len(norm_ids))
print("AFIB :", len(afib_ids))
print("AFLT :", len(aflt_ids))
print("OTHER:", len(other_ids))
print("TOTAL:", len(norm_ids) + len(afib_ids) + len(aflt_ids) + len(other_ids))

'''
before skip unknown labels:
NORM : 9491
AFIB : 1514
AFLT : 56
OTHER: 10776
TOTAL: 21837
'''

NORM : 8329
AFIB : 1027
AFLT : 32
OTHER: 2394
TOTAL: 11782


'\nbefore skip unknown labels:\nNORM : 9491\nAFIB : 1514\nAFLT : 56\nOTHER: 10776\nTOTAL: 21837\n'

In [12]:
assert set(norm_ids).isdisjoint(afib_ids)
assert set(norm_ids).isdisjoint(aflt_ids)
assert set(afib_ids).isdisjoint(aflt_ids)


### Lighweight test on 100 record per each type

In [14]:
def sample_ids(ids, n, seed=42):
    random.seed(seed)
    if len(ids) <= n:
        return ids.copy()
    return random.sample(ids, n)

In [16]:
norm_sampled_ids    = sample_ids(norm_ids, 1500)
afib_sampled_ids    = sample_ids(afib_ids, 501)
aflt_sampled_ids    = sample_ids(aflt_ids, 15)
other_sampled_ids   = sample_ids(other_ids, 1503)
print("Sampled NORM :", len(norm_sampled_ids))
print("Sampled AFIB :", len(afib_sampled_ids))
print("Sampled AFLT :", len(aflt_sampled_ids))
print("Sampled OTHER:", len(other_sampled_ids))
print("Normal example IDs:", norm_sampled_ids[:10], "\nAFBI example IDs:", afib_sampled_ids[:10], "\nAFLT example IDs:", aflt_sampled_ids[:10], "\nOTHER example IDs:", other_sampled_ids[:10])

Sampled NORM : 1500
Sampled AFIB : 501
Sampled AFLT : 15
Sampled OTHER: 1503
Normal example IDs: [3782, 845, 10650, 9489, 8494, 4936, 3491, 2933, 17212, 1092] 
AFBI example IDs: [5390, 1419, 12872, 11755, 21804, 3423, 16756, 2527, 15586, 16818] 
AFLT example IDs: [4874, 449, 12986, 4885, 19961, 15106, 2430, 16401, 1773, 10966] 
OTHER example IDs: [3809, 907, 9749, 8593, 7936, 4665, 3502, 20169, 3069, 15103]


In [17]:
train_ids = (
    norm_sampled_ids + afib_sampled_ids + aflt_sampled_ids + other_sampled_ids
)


afib_set  = set(afib_ids)
aflt_set  = set(aflt_ids)
norm_set  = set(norm_ids)
other_set = set(other_ids)


def get_class(ecg_id):
    if ecg_id in afib_set:  return 1
    if ecg_id in aflt_set:  return 2
    if ecg_id in norm_set:  return 0
    if ecg_id in other_set: return 3
    raise ValueError(f"ecg_id {ecg_id} not found in any class set")


label = [get_class(ecg_id) for ecg_id in train_ids]
print("Class distribution in training set:", Counter(label))

Class distribution in training set: Counter({3: 1503, 0: 1500, 1: 501, 2: 15})


In [18]:
ECG_INDEX = build_ecg_index()


In [395]:
def load_ecg(ecg_id, root="../data/records500", suffix="_hr"):
    """
    Load PTB-XL ECG by ecg_id when files are named like:
    03462_hr.hea / 03462_hr.dat and stored in subdirectories.
    """
    target = f"{ecg_id:05d}{suffix}.hea"

    for dirpath, _, filenames in os.walk(root):
        if target in filenames:
            record_path = os.path.join(dirpath, target[:-4])  # remove .hea
            record = wfdb.rdrecord(record_path)
            return record.p_signal  # (5000, 12)

    raise FileNotFoundError(f"ECG ID {ecg_id} not found under {root}")

In [19]:
def load_ecg_fast(ecg_id):
    record = wfdb.rdrecord(ECG_INDEX[ecg_id])
    return record.p_signal


In [20]:
def normalize_ecg(ecg):
    mean = ecg.mean(axis=0)
    std = ecg.std(axis=0) + 1e-8
    return (ecg - mean) / std


In [21]:
#Training array
X = []
y = []

for ecg_id in train_ids:
    ecg = load_ecg_fast(ecg_id)
    ecg = normalize_ecg(ecg)
    X.append(ecg)
    y.append(get_class(ecg_id))
    
for i in range(len(train_ids)):
    print(train_ids[i], y[i])


X = np.array(X)   # (N, 5000, 12)
y = np.array(y)   # (N,)
print("Unique labels:", sorted(set(y)))
print("dtype:", y.dtype)


3782 0
845 0
10650 0
9489 0
8494 0
4936 0
3491 0
2933 0
17212 0
1092 0
1008 0
3149 0
8303 0
8934 0
21667 0
897 0
7555 0
17109 0
8394 0
18398 0
10787 0
193 0
5896 0
17246 0
13631 0
10770 0
5723 0
8174 0
13463 0
3481 0
3114 0
15282 0
3261 0
14306 0
13756 0
10256 0
1459 0
18992 0
4285 0
15173 0
2652 0
11528 0
14416 0
7205 0
2351 0
1538 0
8702 0
11331 0
2689 0
8938 0
3430 0
15291 0
10776 0
18623 0
14558 0
5980 0
14775 0
14176 0
7978 0
10335 0
2398 0
6274 0
9483 0
6006 0
19144 0
15248 0
10434 0
8341 0
12877 0
1891 0
8757 0
1098 0
12485 0
16237 0
10363 0
2263 0
8040 0
12466 0
8087 0
21290 0
16009 0
18968 0
5091 0
10267 0
4939 0
9584 0
10188 0
17570 0
16159 0
14438 0
8330 0
4873 0
20879 0
3053 0
1589 0
3740 0
5587 0
5904 0
17216 0
2166 0
15525 0
15372 0
19416 0
9783 0
359 0
3909 0
10329 0
13630 0
3806 0
11534 0
17788 0
5824 0
18634 0
80 0
10213 0
21385 0
6583 0
3646 0
11791 0
7556 0
21492 0
14957 0
5954 0
11 0
12862 0
20565 0
639 0
3813 0
14484 0
12158 0
9274 0
1974 0
9330 0
2651 0
2884 0
203

In [22]:
class ECGDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [23]:
class ECG_LSTM(nn.Module):
    def __init__(self, input_size=12, hidden_size=128, num_classes=4):
        super().__init__()

        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=2,
            dropout=0.3,
            batch_first=True
        )

        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        out, _ = self.lstm(x)
        last_out = out[:, -1, :]
        logits = self.fc(last_out)
        return logits

In [24]:
counts = Counter(y)
weights = torch.tensor([1.0 / counts[i] for i in range(4)], dtype=torch.float32).to(device)
weights = torch.tensor(weights, dtype=torch.float32)

criterion = nn.CrossEntropyLoss(weight=weights)

  weights = torch.tensor(weights, dtype=torch.float32)


### Basic model LSTM with only 100 recording of each NROM, AFIB, AFLT, Others

In [25]:
dataset = ECGDataset(X, y)
loader = DataLoader(dataset, batch_size=16, shuffle=True)
model = ECG_LSTM().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(20):
    model.train()
    total_loss, correct, total = 0.0, 0, 0

    for X_batch, y_batch in loader:
        X_batch = X_batch.to(device)
        y_batch = y_batch.to(device)
        optimizer.zero_grad()
        logits = model(X_batch)
        loss = criterion(logits, y_batch)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        preds = logits.argmax(dim=1)
        correct += (preds == y_batch).sum().item()
        total += y_batch.size(0)

    avg_loss = total_loss / len(loader)
    acc = correct / total
    print(f"Epoch {epoch+1} | Avg loss: {avg_loss:.4f} | Train acc: {acc:.3f}")

# training is DONE here
torch.save(model.state_dict(), "final_model.pt")



Epoch 1 | Avg loss: 1.2769 | Train acc: 0.373
Epoch 2 | Avg loss: 1.2550 | Train acc: 0.393
Epoch 3 | Avg loss: 1.2631 | Train acc: 0.372
Epoch 4 | Avg loss: 1.2417 | Train acc: 0.388
Epoch 5 | Avg loss: 1.2599 | Train acc: 0.403
Epoch 6 | Avg loss: 1.2531 | Train acc: 0.406
Epoch 7 | Avg loss: 1.2522 | Train acc: 0.378
Epoch 8 | Avg loss: 1.2433 | Train acc: 0.369
Epoch 9 | Avg loss: 1.2463 | Train acc: 0.399
Epoch 10 | Avg loss: 1.2402 | Train acc: 0.414
Epoch 11 | Avg loss: 1.2210 | Train acc: 0.438
Epoch 12 | Avg loss: 1.2180 | Train acc: 0.423
Epoch 13 | Avg loss: 1.2116 | Train acc: 0.427
Epoch 14 | Avg loss: 1.2021 | Train acc: 0.420
Epoch 15 | Avg loss: 1.1992 | Train acc: 0.415
Epoch 16 | Avg loss: 1.1756 | Train acc: 0.429
Epoch 17 | Avg loss: 1.1782 | Train acc: 0.444
Epoch 18 | Avg loss: 1.2178 | Train acc: 0.406
Epoch 19 | Avg loss: 1.2264 | Train acc: 0.407
Epoch 20 | Avg loss: 1.2067 | Train acc: 0.407


In [26]:
model = ECG_LSTM().to(device)
model.load_state_dict(torch.load("final_model.pt", map_location=device))
model.eval()
print("Model device:", next(model.parameters()).device)


Model device: cuda:0


  model.load_state_dict(torch.load("final_model.pt", map_location=device))


In [27]:
def normalize(ecg):
    mean = ecg.mean(axis=0)
    std = ecg.std(axis=0) + 1e-8
    return (ecg - mean) / std


In [32]:
  # example ECG ID
ecg_id = 203

if ecg_id not in norm_set:
    print(f"{ecg_id} is NOT in norm_set")
else:
    print(f"{ecg_id} IS in norm_set")


203 is NOT in norm_set


In [33]:

ecg = load_ecg_fast(ecg_id)
ecg = normalize(ecg)

X = torch.tensor(ecg, dtype=torch.float32)
X = X.unsqueeze(0)   # (1, 5000, 12)
X = X.to(device)


In [34]:
with torch.no_grad():
    logits = model(X)
    probs = torch.softmax(logits, dim=1)
    pred = torch.argmax(probs, dim=1).item()


In [35]:
label_map = {
    0: "NORM",
    1: "AFIB",
    2: "AFLT",
    3: "OTHER"
}

print("Predicted rhythm:", label_map[pred])
print("Class probabilities:", probs.cpu().numpy())


Predicted rhythm: OTHER
Class probabilities: [[0.27078995 0.30481562 0.09419108 0.33020332]]


## Testing

In [36]:
# As previously defined distinguishing patients with multiple ECG records.
exclude_patients = set(dff_final["patient_id"])
dff_excluded = dff[~dff["patient_id"].isin(exclude_patients)]
print("Orginal records:", len(dff))
print("Total patients in dff_final:", len(dff_final["patient_id"].unique()))


patients = dff_excluded["patient_id"].unique()
print("Total patients after exclusion:", len(patients))
print("Total ECG records after exclusion:", len(dff_excluded["ecg_id"].unique()))
#Making the trained data into combined set
train_ecg_set = set(train_ids)
print("Training ECGs:", len(train_ecg_set))
#exctract the traning data from the testing data
dff_test = dff_excluded[~dff_excluded["ecg_id"].isin(train_ecg_set)].copy()
patients_test = dff_test["patient_id"].unique()
print("Remaining patients for testing:", len(patients_test))
print("Remaining ECGs for testing:", len(dff_test))


Orginal records: 21837
Total patients in dff_final: 722
Total patients after exclusion: 18163
Total ECG records after exclusion: 20032
Training ECGs: 3519
Remaining patients for testing: 15060
Remaining ECGs for testing: 16513


In [37]:
test_norm, test_afib, test_aflt, test_other = [], [], [], []

for _, row in dff_test.iterrows():
    scp_codes = ast.literal_eval(row["scp_codes"])
    label = get_label(scp_codes)
    ecg_id = int(row["ecg_id"])

    if label is None:
        continue

    if label == 0:
        test_norm.append(ecg_id)
    elif label == 1:
        test_afib.append(ecg_id)
    elif label == 2:
        test_aflt.append(ecg_id)
    else:
        test_other.append(ecg_id)
print("Test NORM :", len(test_norm))
print("Test AFIB :", len(test_afib))
print("Test AFLT :", len(test_aflt))
print("Test OTHER:", len(test_other))

Test NORM : 6829
Test AFIB : 526
Test AFLT : 17
Test OTHER: 891


In [38]:
test_ids = (
    sample_ids(test_norm, 1500) +
    sample_ids(test_afib, 500) +
    sample_ids(test_aflt, 15)  +                 
    sample_ids(test_other, 500)
)


In [39]:
train_patients, test_patients = train_test_split(
    patients_test,
    test_size=0.2,
    random_state=42
)

In [40]:
assert set(test_ids).isdisjoint(set(train_ids))
print("Final test ECGs:", len(test_ids))

Final test ECGs: 2515


In [41]:
train_df = dff_test[dff_test["patient_id"].isin(train_patients)]
test_df  = dff_test[dff_test["patient_id"].isin(test_patients)]
print("Train records:", len(train_df))
print("Test records :", len(test_df))

Train records: 13236
Test records : 3277


In [42]:
# Safety checks (must pass)
assert set(train_df["patient_id"]).isdisjoint(set(test_df["patient_id"]))

In [43]:
model.eval()

y_true, y_pred = [], []

with torch.no_grad():
    for ecg_id in test_ids:
        ecg = normalize_ecg(load_ecg_fast(ecg_id))
        X = torch.from_numpy(ecg).float().unsqueeze(0).to(device)

        pred = model(X).argmax(dim=1).item()
        true = get_class(ecg_id)

        y_true.append(true)
        y_pred.append(pred)

# Convert to numpy
y_true = np.array(y_true)
y_pred = np.array(y_pred)

# Accuracy
accuracy = (y_true == y_pred).mean()
print(f"Test Accuracy: {accuracy:.4f} ({(y_true == y_pred).sum()}/{len(y_true)})")

Test Accuracy: 0.3129 (787/2515)


### Single testing

In [49]:
ecg_id = test_ids[500]

ecg = normalize_ecg(load_ecg_fast(ecg_id))
X = torch.tensor(ecg, dtype=torch.float32).unsqueeze(0).to(device)

with torch.no_grad():
    pred = model(X).argmax(dim=1).item()

print("True label:", label_map[get_class(ecg_id)])
print("Predicted:", label_map[pred])


True label: NORM
Predicted: NORM


### Testing accuracy

In [50]:
label_names = ["NORM", "AFIB", "AFLT", "OTHER"]

print(classification_report(
    y_true,
    y_pred,
    labels=[0, 1, 2, 3],
    target_names=label_names,
    digits=4
))


              precision    recall  f1-score   support

        NORM     0.7564    0.1760    0.2856      1500
        AFIB     0.3280    0.3260    0.3270       500
        AFLT     0.0000    0.0000    0.0000        15
       OTHER     0.2194    0.7200    0.3363       500

    accuracy                         0.3129      2515
   macro avg     0.3259    0.3055    0.2372      2515
weighted avg     0.5600    0.3129    0.3022      2515



In [51]:
labels = [0, 1, 2, 3]
label_names = ["NORM", "AFIB", "AFLT", "OTHER"]

cm = confusion_matrix(y_true, y_pred, labels=labels)

cm_df = pd.DataFrame(
    cm,
    index=label_names,
    columns=label_names
)

print("Confusion Matrix:")
print(cm_df)

Confusion Matrix:
       NORM  AFIB  AFLT  OTHER
NORM    264   235    19    982
AFIB     44   163     4    289
AFLT      1     4     0     10
OTHER    40    95     5    360


In [52]:

print("Test label distribution:", Counter(y_true))


Test label distribution: Counter({np.int64(0): 1500, np.int64(1): 500, np.int64(3): 500, np.int64(2): 15})


In [55]:

def confusion_matrix_to_markdown(cm, label_names):
    cm = np.asarray(cm, dtype=int)

    total = cm.sum()
    correct = np.trace(cm)
    accuracy = correct / total if total > 0 else 0.0

    supports = cm.sum(axis=1)

    md = []

    md.append("## Model Evaluation Results (Test Set)\n")
    md.append(
        "The confusion matrix shows how the trained model classifies unseen ECG recordings "
        "across four rhythm classes.\n"
    )

    # Overall performance
    md.append("### Overall Performance")
    md.append(f"- **Total test samples:** {total}")
    md.append(f"- **Overall accuracy:** **{accuracy*100:.1f}%**\n")

    # Dataset composition
    md.append("### Test Set Composition")
    for i, name in enumerate(label_names):
        md.append(f"- **{name}:** {supports[i]} samples")
    md.append("")

    # Per-class metrics
    md.append("### Per-class Results\n")

    for i, name in enumerate(label_names):
        support = supports[i]
        tp = cm[i, i]
        recall = tp / support if support > 0 else 0.0

        md.append(f"- **{name}**")
        md.append(f"  - Correctly classified: **{tp} / {support}**")
        md.append(f"  - **Recall:** **{recall*100:.1f}%**")

        # Misclassifications
        row = cm[i].copy()
        row[i] = 0

        if row.sum() > 0:
            top_idx = np.argsort(row)[::-1]
            errors = []
            for j in top_idx:
                if row[j] > 0:
                    pct = row[j] / support * 100
                    errors.append(f"{label_names[j]} ({pct:.1f}%)")
                if len(errors) == 2:
                    break
            md.append(f"  - Most errors were misclassified as **{', '.join(errors)}**.")
        md.append("")

    return "\n".join(md)


# =======================
# Generate markdown
# =======================

markdown_report = confusion_matrix_to_markdown(cm, label_names)

display(Markdown(markdown_report))


## Model Evaluation Results (Test Set)

The confusion matrix shows how the trained model classifies unseen ECG recordings across four rhythm classes.

### Overall Performance
- **Total test samples:** 2515
- **Overall accuracy:** **31.3%**

### Test Set Composition
- **NORM:** 1500 samples
- **AFIB:** 500 samples
- **AFLT:** 15 samples
- **OTHER:** 500 samples

### Per-class Results

- **NORM**
  - Correctly classified: **264 / 1500**
  - **Recall:** **17.6%**
  - Most errors were misclassified as **OTHER (65.5%), AFIB (15.7%)**.

- **AFIB**
  - Correctly classified: **163 / 500**
  - **Recall:** **32.6%**
  - Most errors were misclassified as **OTHER (57.8%), NORM (8.8%)**.

- **AFLT**
  - Correctly classified: **0 / 15**
  - **Recall:** **0.0%**
  - Most errors were misclassified as **OTHER (66.7%), AFIB (26.7%)**.

- **OTHER**
  - Correctly classified: **360 / 500**
  - **Recall:** **72.0%**
  - Most errors were misclassified as **AFIB (19.0%), NORM (8.0%)**.
