# 1D CNN Training Template

This notebook mirrors the RNN workflow provided elsewhere in the project while swapping in a convolutional architecture defined in `utils.cnn_models`. Configure the `TrainingConfig` cell and execute the pipeline cells sequentially to train, validate, and optionally test a 1D CNN on the ECG heartbeat dataset.

In [None]:
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Optional, Tuple

import numpy as np
import pandas as pd
import torch
from sklearn.metrics import (
    accuracy_score,
    classification_report,
    f1_score,
    precision_score,
    recall_score,
    roc_auc_score,
)
from sklearn.model_selection import train_test_split
from torch import nn
from torch.utils.data import DataLoader, WeightedRandomSampler

from utils.cnn_models import ECG_CNN_Classifier
from utils.data import calculate_class_weights, split_x_y
from utils.preprocessing import Preprocessing
from utils.torch_classes import ECG_Dataset, EarlyStopping
from utils.train import test_loop, train_and_eval_model


In [None]:
@dataclass
class TrainingConfig:
    '''Configuration options for the CNN training template.'''

    num_classes: int = 5
    batch_size: int = 256
    learning_rate: float = 1e-3
    weight_decay: float = 1e-4
    dropout: float = 0.3
    conv_channels: Tuple[int, ...] = (32, 64, 128)
    kernel_sizes: Tuple[int, ...] = (7, 5, 3)
    pool_kernel_sizes: Tuple[int, ...] = (2, 2, 2)
    fc_hidden_dim: int = 128
    use_batch_norm: bool = True

    max_epochs: int = 30
    patience: int = 6
    delta: float = 1e-4
    grad_clip: bool = True
    max_norm: float = 5.0

    scheduler_factor: float = 0.1
    scheduler_patience: int = 3
    min_lr: float = 1e-6

    num_workers: int = 2
    use_weighted_sampler: bool = True
    debug: bool = True

    checkpoint_path: str = "models/best_CNN.pt"
    train_val_path: str = "data/ecg_preprocessed_train_val.npz"
    test_csv_path: str = "data/mitbih_test.csv"

    preprocess_test: bool = True
    preprocessing_params: Dict[str, object] | None = None

    device: Optional[str] = None

    def __post_init__(self) -> None:
        if self.preprocessing_params is None:
            self.preprocessing_params = {
                "sample_freq": 125,
                "cutoff_freq": 25,
                "order": 3,
                "target_r_peak_index": 94,
                "method": "neurokit",
            }


In [None]:
def build_sampler(y: np.ndarray) -> WeightedRandomSampler:
    '''Create a ``WeightedRandomSampler`` based on class frequencies.'''

    _, class_weights = calculate_class_weights(y)
    sample_weights = np.array(class_weights, dtype=np.float64)[y]
    weights_tensor = torch.as_tensor(sample_weights, dtype=torch.double)
    return WeightedRandomSampler(
        weights=weights_tensor,
        num_samples=len(sample_weights),
        replacement=True,
    )


def create_dataloader(
    X: np.ndarray,
    y: np.ndarray,
    batch_size: int,
    *,
    sampler: Optional[WeightedRandomSampler] = None,
    shuffle: bool = True,
    num_workers: int = 0,
) -> DataLoader:
    '''Wrap numpy arrays inside a ``DataLoader``.'''

    dataset = ECG_Dataset(X, y)
    pin_memory = torch.cuda.is_available()
    return DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        sampler=sampler,
        shuffle=shuffle if sampler is None else False,
        num_workers=num_workers,
        persistent_workers=num_workers > 0,
        pin_memory=pin_memory,
    )


def _to_numpy(array: torch.Tensor | np.ndarray) -> np.ndarray:
    if isinstance(array, torch.Tensor):
        return array.detach().cpu().numpy()
    return array


def compute_metrics(
    y_true: torch.Tensor | np.ndarray,
    y_pred: torch.Tensor | np.ndarray,
    y_logits: torch.Tensor | np.ndarray,
    average: str = "macro",
) -> Dict[str, object]:
    '''Compute a suite of evaluation metrics given model outputs.'''

    y_true_np = _to_numpy(y_true)
    y_pred_np = _to_numpy(y_pred)
    y_logits_np = _to_numpy(y_logits)

    logits_tensor = torch.from_numpy(y_logits_np).float()
    probabilities = torch.softmax(logits_tensor, dim=1).numpy()

    metrics = {
        "accuracy": accuracy_score(y_true_np, y_pred_np),
        "precision_macro": precision_score(
            y_true_np, y_pred_np, average=average, zero_division=0
        ),
        "recall_macro": recall_score(
            y_true_np, y_pred_np, average=average, zero_division=0
        ),
        "f1_macro": f1_score(y_true_np, y_pred_np, average=average, zero_division=0),
    }

    try:
        metrics["roc_auc_ovr"] = roc_auc_score(
            y_true_np, probabilities, multi_class="ovr", average=average
        )
    except ValueError:
        metrics["roc_auc_ovr"] = float("nan")

    metrics["classification_report"] = classification_report(
        y_true_np, y_pred_np, digits=4
    )
    return metrics


In [None]:
def train_single_run(
    X_train: np.ndarray,
    y_train: np.ndarray,
    X_val: np.ndarray,
    y_val: np.ndarray,
    config: TrainingConfig,
):
    '''Train a CNN model for a single train/validation split.'''

    device = config.device or ("cuda" if torch.cuda.is_available() else "cpu")

    model = ECG_CNN_Classifier(
        num_classes=config.num_classes,
        conv_channels=config.conv_channels,
        kernel_sizes=config.kernel_sizes,
        pool_kernel_sizes=config.pool_kernel_sizes,
        dropout=config.dropout,
        fc_hidden_dim=config.fc_hidden_dim,
        use_batch_norm=config.use_batch_norm,
    )
    model.to(device)

    optimizer = torch.optim.AdamW(
        params=model.parameters(),
        lr=config.learning_rate,
        weight_decay=config.weight_decay,
    )
    loss_fn = nn.CrossEntropyLoss()

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer=optimizer,
        mode="min",
        factor=config.scheduler_factor,
        patience=config.scheduler_patience,
        min_lr=config.min_lr,
    )

    early_stopper = EarlyStopping(
        patience=config.patience,
        delta=config.delta,
        checkpoint_path=config.checkpoint_path,
        verbose=True,
    )

    sampler = build_sampler(y_train) if config.use_weighted_sampler else None

    train_loader = create_dataloader(
        X_train,
        y_train,
        batch_size=config.batch_size,
        sampler=sampler,
        shuffle=config.use_weighted_sampler is False,
        num_workers=config.num_workers,
    )

    val_loader = create_dataloader(
        X_val,
        y_val,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=config.num_workers,
    )

    history = train_and_eval_model(
        model=model,
        loss_fn=loss_fn,
        optimizer=optimizer,
        train_dataloader=train_loader,
        val_dataloader=val_loader,
        epochs=config.max_epochs,
        device=device,
        early_stopper=early_stopper,
        debug=config.debug,
        verbose=True,
        grad_clip=config.grad_clip,
        max_norm=config.max_norm,
        scheduler=scheduler,
    )

    best_epoch = int(np.argmin(history["val_loss"]))
    val_metrics = compute_metrics(
        history["val_true"][best_epoch],
        history["val_pred"][best_epoch],
        history["val_pred_logits"][best_epoch],
    )

    return model, {"history": history, "val_metrics": val_metrics, "best_epoch": best_epoch}


In [None]:
def evaluate_on_test(
    model: ECG_CNN_Classifier,
    config: TrainingConfig,
):
    '''Load the best checkpoint (if any) and evaluate on the test split.'''

    device = config.device or ("cuda" if torch.cuda.is_available() else "cpu")
    checkpoint_file = Path(config.checkpoint_path)
    if checkpoint_file.exists():
        checkpoint = torch.load(checkpoint_file, map_location=device)
        model.load_state_dict(checkpoint["model_state_dict"])
        print(
            f"Loaded checkpoint from {checkpoint_file} (epoch {checkpoint.get('epoch', 'N/A')})."
        )

    test_path = Path(config.test_csv_path)
    if not test_path.exists():
        print(
            f"Test file '{test_path}' not found. Skipping test evaluation."
        )
        return None

    test_df = pd.read_csv(test_path)
    X_test, y_test = split_x_y(test_df)

    if config.preprocess_test:
        preprocess = Preprocessing(**config.preprocessing_params)
        X_test = preprocess.transform(X_test)

    test_loader = create_dataloader(
        X_test,
        y_test,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=config.num_workers,
    )

    results = test_loop(model=model, test_dataloader=test_loader, device=device)
    metrics = compute_metrics(
        results["y_true"], results["y_pred"], results["y_pred_logits"],
    )

    return {"metrics": metrics, "raw_outputs": results}


In [None]:
def run_training_pipeline(config: TrainingConfig) -> None:
    '''Execute the full train/validation/test workflow.'''

    train_val_path = Path(config.train_val_path)
    if not train_val_path.exists():
        print(
            "Preprocessed train/validation dataset not found at",
            f" '{train_val_path}'. Please generate it before running the",
            " CNN training template.",
        )
        return

    data = np.load(train_val_path)
    X = data["X"]
    y = data["y"]

    X_train, X_val, y_train, y_val = train_test_split(
        X,
        y,
        test_size=0.05,
        random_state=42,
        stratify=y,
    )

    model, train_summary = train_single_run(
        X_train=X_train,
        y_train=y_train,
        X_val=X_val,
        y_val=y_val,
        config=config,
    )

    print("\nValidation metrics (best epoch):")
    for key, value in train_summary["val_metrics"].items():
        if key == "classification_report":
            print("\nClassification report:\n", value)
        else:
            print(f"{key}: {value:.4f}")

    test_summary = evaluate_on_test(model=model, config=config)
    if test_summary is not None:
        print("\nTest metrics:")
        for key, value in test_summary["metrics"].items():
            if key == "classification_report":
                print("\nClassification report:\n", value)
            else:
                print(f"{key}: {value:.4f}")


## Usage

Instantiate a configuration, adjust any hyperparameters as needed, and execute the training pipeline:

```python
config = TrainingConfig()
run_training_pipeline(config)
```

Running the final cell below will execute these steps.

In [None]:
config = TrainingConfig()
run_training_pipeline(config)
