# CNN Model Training Notebook

This notebook mirrors the experimental workflow of the RNN notebook. It performs hyper-parameter searches, model training, and final evaluation for the CNN architecture.

In [None]:
import random
import time

import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchmetrics import AUROC, F1Score, Precision, Recall

from sklearn.metrics import classification_report
from sklearn.model_selection import ParameterSampler, StratifiedKFold, train_test_split

from utils.cnn_models import ECG_CNN_Classifier
from utils.data import calculate_class_weights, split_x_y
from utils.logging import log_to_csv, log_to_json
from utils.preprocessing import Preprocessing
from utils.torch_classes import ECG_Dataset, EarlyStopping
from utils.train import test_loop, train_and_eval_model, val_loop


In [None]:
# Ensure reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
torch.backends.cudnn.deterministic = True


In [None]:
# Load preprocessed training and validation data
train_val_path = "data/ecg_preprocessed_train_val.npz"
train_val_data = np.load(train_val_path)

X = train_val_data["X"]
y = train_val_data["y"]


In [None]:
# Stratified K-Fold setup
K = 5
kfold = StratifiedKFold(n_splits=K, shuffle=True, random_state=42)


In [None]:
# Hyper-parameter search space for the CNN experiments
EXPERIMENTS = 20

param_grid = {
    "conv_channels": [
        (32, 64, 128),
        (64, 128, 256),
        (32, 64, 64),
        (64, 64, 128),
    ],
    "kernel_sizes": [
        (7, 5, 3),
        (9, 7, 5),
        (5, 5, 3),
    ],
    "pool_kernel_sizes": [
        (2, 2, 2),
        (2, 2, 1),
    ],
    "dropout": [0.1, 0.2, 0.3, 0.4],
    "fc_hidden_dim": [128, 256, None],
    "use_batch_norm": [True, False],
    "optimizer": ["Adam", "AdamW", "SGD"],
    "momentum": np.linspace(0.85, 0.95, 3).tolist(),
    "batch_size": [64, 128, 256],
    "learning_rate": np.logspace(-4, -3, num=5).tolist(),
    "weight_decay": np.logspace(-5, -3, num=5).tolist(),
}

configs = list(
    ParameterSampler(
        param_grid=param_grid,
        n_iter=EXPERIMENTS,
        random_state=42,
    )
)


In [None]:
# Preview one sampled configuration
configs[:1]


In [None]:
LOG_FOLDER = "cnn_random_search"
EPOCHS = 30

PATIENCE = 6
DELTA = 1e-4

NUM_CLASSES = 5

device = "cuda" if torch.cuda.is_available() else "cpu"

results_summary_json = []
results_summary_csv = []

for i, params in enumerate(configs, start=1):
    print(f"-------------- Experiment {i}/{len(configs)} ----------------")
    try:
        params_copy = dict(params)

        batch_size = params_copy["batch_size"]
        optimizer_name = params_copy["optimizer"]
        conv_channels = params_copy["conv_channels"]
        kernel_sizes = params_copy["kernel_sizes"]
        pool_kernel_sizes = params_copy["pool_kernel_sizes"]
        dropout = params_copy["dropout"]
        fc_hidden_dim = params_copy["fc_hidden_dim"]
        use_batch_norm = params_copy["use_batch_norm"]

        momentum = params_copy["momentum"]
        learning_rate = params_copy["learning_rate"]
        weight_decay = params_copy["weight_decay"]

        fold_metrics = []

        for fold, (train_index, val_index) in enumerate(kfold.split(X, y), start=1):
            print(f"
--------- Fold {fold}/{K} ---------
")

            X_train_fold, y_train_fold = X[train_index], y[train_index]
            X_val_fold, y_val_fold = X[val_index], y[val_index]

            train_dataset = ECG_Dataset(X_train_fold, y_train_fold)
            val_dataset = ECG_Dataset(X_val_fold, y_val_fold)

            _, class_weights = calculate_class_weights(y_train_fold)
            sample_weights = np.array(class_weights)[y_train_fold]
            weighted_sampler = WeightedRandomSampler(
                weights=sample_weights,
                num_samples=len(sample_weights),
                replacement=True,
            )

            train_dataloader = DataLoader(
                dataset=train_dataset,
                batch_size=batch_size,
                sampler=weighted_sampler,
                shuffle=False,
                num_workers=2,
                persistent_workers=True,
                pin_memory=True,
            )

            val_dataloader = DataLoader(
                dataset=val_dataset,
                batch_size=batch_size,
                shuffle=False,
                num_workers=2,
                persistent_workers=True,
                pin_memory=True,
            )

            model = ECG_CNN_Classifier(
                num_classes=NUM_CLASSES,
                in_channels=1,
                conv_channels=conv_channels,
                kernel_sizes=kernel_sizes,
                pool_kernel_sizes=pool_kernel_sizes,
                dropout=dropout,
                fc_hidden_dim=fc_hidden_dim,
                use_batch_norm=use_batch_norm,
            )
            model.to(device)

            if optimizer_name == "Adam":
                optimizer = torch.optim.Adam(
                    params=model.parameters(),
                    lr=learning_rate,
                    weight_decay=weight_decay,
                )
            elif optimizer_name == "AdamW":
                optimizer = torch.optim.AdamW(
                    params=model.parameters(),
                    lr=learning_rate,
                    weight_decay=weight_decay,
                )
            elif optimizer_name == "SGD":
                optimizer = torch.optim.SGD(
                    params=model.parameters(),
                    lr=learning_rate,
                    momentum=momentum,
                    weight_decay=weight_decay,
                )
            else:
                raise ValueError(f"Unknown optimizer: {optimizer_name}")

            loss_fn = nn.CrossEntropyLoss()
            early_stopper = EarlyStopping(
                patience=PATIENCE,
                delta=DELTA,
                checkpoint_path=f"{LOG_FOLDER}/checkpoints/experiment_{i}/fold_{fold}.pt",
                verbose=True,
            )

            precision_metric = Precision(task="multiclass", num_classes=NUM_CLASSES, average="macro").to(device)
            recall_metric = Recall(task="multiclass", num_classes=NUM_CLASSES, average="macro").to(device)
            f1_metric = F1Score(task="multiclass", num_classes=NUM_CLASSES, average="macro").to(device)
            auc_metric = AUROC(task="multiclass", num_classes=NUM_CLASSES, average="macro").to(device)

            start = time.time()
            history = train_and_eval_model(
                model=model,
                loss_fn=loss_fn,
                optimizer=optimizer,
                train_dataloader=train_dataloader,
                val_dataloader=val_dataloader,
                epochs=EPOCHS,
                device=device,
                early_stopper=early_stopper,
                debug=True,
                verbose=True,
                grad_clip=True,
                max_norm=1.0,
            )
            end = time.time()

            epochs_run = len(history["train_loss"])
            total_time = end - start
            time_per_epoch = total_time / epochs_run if epochs_run > 0 else 0.0

            checkpoint = torch.load(early_stopper.checkpoint_path, map_location=device)
            model.load_state_dict(checkpoint["model_state_dict"])

            val_data = val_loop(
                model=model,
                val_dataloader=val_dataloader,
                loss_fn=loss_fn,
                device=device,
            )

            val_pred = torch.cat(val_data["y_pred"])
            val_true = torch.cat(val_data["y_true"])
            val_logits = torch.cat(val_data["y_pred_logits"])

            fold_precision = precision_metric(val_pred.to(device), val_true.to(device)).item()
            fold_recall = recall_metric(val_pred.to(device), val_true.to(device)).item()
            fold_f1 = f1_metric(val_pred.to(device), val_true.to(device)).item()
            fold_auc = auc_metric(val_logits.to(device), val_true.to(device)).item()

            fold_metrics.append(
                {
                    "fold": fold,
                    "precision": fold_precision,
                    "recall": fold_recall,
                    "f1": fold_f1,
                    "auc": fold_auc,
                    "time_per_epoch": time_per_epoch,
                    "epochs_run": epochs_run,
                    "total_epochs": EPOCHS,
                }
            )

            print(f"Fold {fold}: F1={fold_f1:.3f} | AUC={fold_auc:.3f}")

        avg_precision = np.mean([m["precision"] for m in fold_metrics])
        avg_recall = np.mean([m["recall"] for m in fold_metrics])
        avg_f1 = np.mean([m["f1"] for m in fold_metrics])
        avg_auc = np.mean([m["auc"] for m in fold_metrics])
        avg_time_per_epoch = np.mean([m["time_per_epoch"] for m in fold_metrics])
        avg_epochs_run = np.mean([m["epochs_run"] for m in fold_metrics])

        results_summary_csv.append(
            {
                "experiment": i,
                "avg_precision": avg_precision,
                "avg_recall": avg_recall,
                "avg_f1": avg_f1,
                "avg_auc": avg_auc,
                "avg_time_per_epoch": avg_time_per_epoch,
                "avg_epochs_run": avg_epochs_run,
                "total_epochs": EPOCHS,
            }
        )

        if optimizer_name != "SGD":
            params_copy["momentum"] = None

        results_summary_json.append(
            {
                "experiment": i,
                **params_copy,
                "fold_metrics": fold_metrics,
            }
        )

        log_to_csv(f"{LOG_FOLDER}/results.csv", results_summary_csv)
        log_to_json(f"{LOG_FOLDER}/results.json", results_summary_json)

        print(f"
Experiment {i} Done: Avg F1={avg_f1:.3f}, Avg AUC={avg_auc:.3f}
")

    except Exception as exc:
        print(f"Experiment {i} failed: {exc}")

print("
---------------- All Experiments Completed ----------------")


## Train Final CNN Model

In [None]:
# Create a dedicated train/validation split for the final model
train_val_split = 0.05

X_train, X_val, y_train, y_val = train_test_split(
    X,
    y,
    test_size=train_val_split,
    random_state=42,
    stratify=y,
)

train_dataset = ECG_Dataset(X_train, y_train)
val_dataset = ECG_Dataset(X_val, y_val)

_, class_weights = calculate_class_weights(y_train)
train_sample_weights = np.array(class_weights)[y_train]

weighted_sampler = WeightedRandomSampler(
    weights=train_sample_weights,
    num_samples=len(train_sample_weights),
    replacement=True,
)


In [None]:
# Selected hyper-parameters from the search
BATCH_SIZE = 128
LEARNING_RATE = 5e-4
WEIGHT_DECAY = 1e-4
CONV_CHANNELS = (64, 128, 256)
KERNEL_SIZES = (7, 5, 3)
POOL_KERNEL_SIZES = (2, 2, 2)
DROPOUT = 0.2
FC_HIDDEN_DIM = 256
USE_BATCH_NORM = True
NUM_CLASSES = 5

PATIENCE = 12
DELTA = 1e-4
CHECKPOINT_PATH = "models/best_CNN.pt"

LR_PATIENCE = 6
MIN_LR = 1e-4
FACTOR = 0.5

GRAD_CLIP = True
MAX_NORM = 1.0

EPOCHS = 100


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = ECG_CNN_Classifier(
    num_classes=NUM_CLASSES,
    in_channels=1,
    conv_channels=CONV_CHANNELS,
    kernel_sizes=KERNEL_SIZES,
    pool_kernel_sizes=POOL_KERNEL_SIZES,
    dropout=DROPOUT,
    fc_hidden_dim=FC_HIDDEN_DIM,
    use_batch_norm=USE_BATCH_NORM,
)
model.to(device)

optimizer = torch.optim.AdamW(
    params=model.parameters(),
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
)
loss_fn = nn.CrossEntropyLoss()

lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer=optimizer,
    mode="min",
    factor=FACTOR,
    patience=LR_PATIENCE,
    min_lr=MIN_LR,
)

early_stopper = EarlyStopping(
    patience=PATIENCE,
    delta=DELTA,
    checkpoint_path=CHECKPOINT_PATH,
    verbose=True,
)

train_dataloader = DataLoader(
    dataset=train_dataset,
    batch_size=BATCH_SIZE,
    sampler=weighted_sampler,
    shuffle=False,
    num_workers=2,
    persistent_workers=True,
    pin_memory=True,
)

val_dataloader = DataLoader(
    dataset=val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    persistent_workers=True,
    pin_memory=True,
)

train_history = train_and_eval_model(
    model=model,
    loss_fn=loss_fn,
    optimizer=optimizer,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    epochs=EPOCHS,
    device=device,
    early_stopper=early_stopper,
    scheduler=lr_scheduler,
    debug=True,
    verbose=True,
    grad_clip=GRAD_CLIP,
    max_norm=MAX_NORM,
)


In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(8, 5))
plt.plot(train_history["train_loss"], label="Train Loss", marker="o")
plt.plot(train_history["val_loss"], label="Validation Loss", marker="s")

plt.title("Training vs Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.grid(True)
plt.show()


In [None]:
# Prepare the test dataset
TEST_DATA_PATH = "data/mitbih_test.csv"
test_df = pd.read_csv(TEST_DATA_PATH)
X_test, y_test = split_x_y(test_df)

preprocess = Preprocessing(
    sample_freq=125,
    cutoff_freq=25,
    order=3,
    target_r_peak_index=94,
    method="neurokit",
)

X_test_preprocessed = preprocess.transform(X_test)

test_dataset = ECG_Dataset(X_test_preprocessed, y_test)
test_dataloader = DataLoader(
    dataset=test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    persistent_workers=True,
    pin_memory=True,
)


In [None]:
model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=device)["model_state_dict"])
model.to(device)

test_results = test_loop(model=model, test_dataloader=test_dataloader, device=device)

y_pred = test_results["y_pred"]
y_true = test_results["y_true"]


In [None]:
labels_list = ["N", "S", "V", "F", "Q"]
print(
    classification_report(
        y_true,
        y_pred,
        target_names=labels_list,
        digits=4,
    )
)
