# ECG Arrhythmia 1D-CNN (binary AF detector)

This notebook loads the ECG dataset in `ecg_arrhythmia_dataset_CSV`, prepares a binary label for Atrial Fibrillation (AF / AFIB), builds a small 1D convolutional neural network (Conv1D) that accepts 12-lead ECG signals, and trains it. The notebook uses a configurable `max_samples` and `downsample_factor` so you can run quickly on a local machine.

**PyTorch version**

The cells below provide an equivalent pipeline implemented with PyTorch: data loading, a `torch.utils.data.DataLoader`, a small Conv1D model, training loop, evaluation, and model saving. 

In [None]:
%pip install --upgrade pandas torch torchvision scikit-learn

In [None]:
import torch, sklearn
print('torch', torch.__version__)
print('cuda available:', torch.cuda.is_available())
print('scikit-learn', sklearn.__version__)

In [None]:
import os
from pathlib import Path
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score

print('torch', torch.__version__)

# Parameters (adjust before running)
ROOT = Path('.')
DATA_DIR = ROOT / 'ecg_arrhythmia_dataset_CSV'
WFDB_DIR = DATA_DIR / 'WFDBRecords'
MAPPING_CSV = DATA_DIR / 'file_mapping.csv'
# Quick test: use one quarter of the dataset. Set to None to use entire mapping.
MAX_SAMPLES = 'third'  # 'quarter' => int(len(mapping)/4)
DOWNSAMPLE = None  # downsampling disabled. Set to int > 1 to enable
EPOCHS = 6          # run X amount of epochs for quick test
BATCH_SIZE = 8      # smaller batch for testing
RANDOM_STATE = 42
TARGET_COLUMNS = ['AF', 'AFIB']
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device', DEVICE)


In [None]:
# Build a streaming (on-disk) PyTorch Dataset so we don't load all records into memory
mapping = pd.read_csv(MAPPING_CSV)

# Robustly determine AF label column and create `af_label`
if all(col in mapping.columns for col in TARGET_COLUMNS):
    mapping['af_label'] = (mapping[TARGET_COLUMNS].sum(axis=1) > 0).astype(int)
else:
    # try a list of common candidate column names
    candidates = ['af_label','AFIB','afib','is_af','is_afib','label','arrhythmia']
    found = [c for c in candidates if c in mapping.columns]
    if found:
        # if the found column is numeric/binary, use it; otherwise try string matching
        col = found[0]
        if np.issubdtype(mapping[col].dtype, np.number):
            mapping['af_label'] = (mapping[col] > 0).astype(int)
        else:
            mapping['af_label'] = mapping[col].astype(str).str.contains('AF', case=False, na=False).astype(int)
    elif 'diagnosis' in mapping.columns:
        mapping['af_label'] = mapping['diagnosis'].astype(str).str.contains('AF', case=False, na=False).astype(int)
    else:
        mapping['af_label'] = 0
        print("Warning: couldn't find AF label column; setting all labels to 0 (no AF samples)")

# If MAX_SAMPLES was set to an integer computed earlier, keep it; if it's a sentinel string like 'quarter' or 'third', handle it
if isinstance(MAX_SAMPLES, str):
    if MAX_SAMPLES == 'quarter':
        MAX_SAMPLES = int(len(mapping) / 4)
    elif MAX_SAMPLES == 'third':
        MAX_SAMPLES = int(len(mapping) / 3)

if MAX_SAMPLES is not None and isinstance(MAX_SAMPLES, (int, np.integer)):
    mapping = mapping.iloc[:int(MAX_SAMPLES)].reset_index(drop=True)

# Shuffle / split
rng = np.random.default_rng(RANDOM_STATE)
all_idx = np.arange(len(mapping))
rng.shuffle(all_idx)

n = len(all_idx)
n_train = int(n * 0.8)
n_val = int(n * 0.1)

idx_train = all_idx[:n_train]
idx_val = all_idx[n_train:n_train + n_val]
idx_test = all_idx[n_train + n_val:]

print(f"train samples {len(idx_train)} val samples {len(idx_val)} test samples {len(idx_test)}")
print("Example record path: (columns 'subdir'/'filename' not present in mapping)")

# Create datasets
train_ds = ECGOnDiskDataset(mapping, WFDB_DIR, idx_train, downsample=DOWNSAMPLE, target_col='af_label')
val_ds = ECGOnDiskDataset(mapping, WFDB_DIR, idx_val, downsample=DOWNSAMPLE, target_col='af_label')
test_ds = ECGOnDiskDataset(mapping, WFDB_DIR, idx_test, downsample=DOWNSAMPLE, target_col='af_label')

# DataLoaders (num_workers=0 for Windows)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

In [None]:
# Dry-run test: fetch a single batch from the `train_loader` to verify streaming loader and shapes
#try:
#    xb, yb = next(iter(train_loader))
#    print('Dry-run OK — batch shapes: x=', xb.shape, ' y=', yb.shape)
#    # show dtype and device info
#    print('x dtype:', xb.dtype, ' y dtype:', yb.dtype)
#    print('Sample label counts:', int((yb>0.5).sum()), 'positive out of', yb.size(0))
#except Exception as e:
#    print('Dry-run failed:', repr(e))
#    raise


In [None]:
# Define the PyTorch 1D-CNN model and instantiate it before training
class ECG1DCNN(nn.Module):
    def __init__(self, in_channels=12, num_classes=1, channels=[32,64,128], kernel_size=7):
        super().__init__()
        layers = []
        prev = in_channels
        for ch in channels:
            layers += [
                nn.Conv1d(prev, ch, kernel_size=kernel_size, padding=kernel_size//2),
                nn.BatchNorm1d(ch),
                nn.ReLU(inplace=True),
                nn.MaxPool1d(2)
            ]
            prev = ch
        self.features = nn.Sequential(*layers)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(0.3),
            nn.Linear(prev, 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, num_classes)
        )

    def forward(self, x):
        # x: (batch, channels, timesteps)
        x = self.features(x)
        x = self.pool(x)  # (batch, channels, 1)
        x = self.head(x)
        return x.squeeze(-1)

# Instantiate model and move to device
model = ECG1DCNN(in_channels=12, num_classes=1).to(DEVICE)
print('Model instantiated. Parameters:', sum(p.numel() for p in model.parameters()))


In [None]:
# Re-initialize model and run a single forward pass (no optimizer step)
import traceback
try:
    # Re-create the model (fresh weights)
    model = ECG1DCNN(in_channels=12, num_classes=1).to(DEVICE)
    print('Model re-initialized. Parameters:', sum(p.numel() for p in model.parameters()))

    # Fetch a single batch
    xb, yb = next(iter(train_loader))
    print('Batch shapes:', xb.shape, yb.shape)

    xb_dev = xb.to(DEVICE)
    with torch.no_grad():
        logits = model(xb_dev)

    print('logits NaN:', torch.isnan(logits).any().item(), 'Inf:', torch.isinf(logits).any().item())
    # print stats safely using .detach()
    if not torch.isnan(logits).any().item() and not torch.isinf(logits).any().item():
        l = logits.detach().cpu()
        print('logits min/max/mean/std:', float(l.min()), float(l.max()), float(l.mean()), float(l.std()))
        print('sample logits (first 8):', l[:8].numpy())
    else:
        print('Logits contain NaN/Inf — forward pass unstable on fresh model')
except Exception as e:
    print('Re-init forward pass failed:', e)
    traceback.print_exc()


In [None]:
# Activation trace: run input through each layer and print stats to find NaNs
import torch
import math

xb, yb = next(iter(train_loader))
xb = xb.to(DEVICE)
print('Input batch shape:', xb.shape)

x = xb
print('Input NaN/Inf:', torch.isnan(x).any().item(), torch.isinf(x).any().item())

# Trace through feature layers
for i, layer in enumerate(model.features):
    x = layer(x)
    has_nan = torch.isnan(x).any().item()
    has_inf = torch.isinf(x).any().item()
    if not has_nan and not has_inf:
        vmin = float(x.min()); vmax = float(x.max()); vmean = float(x.mean()); vstd = float(x.std())
    else:
        vmin = vmax = vmean = vstd = float('nan')
    print(f'features[{i}] {layer.__class__.__name__} -> shape={x.shape} NaN={has_nan} Inf={has_inf} min={vmin} max={vmax} mean={vmean} std={vstd}')
    if has_nan or has_inf:
        print('NaN/Inf detected at features index', i, 'layer:', layer)
        break

# If features completed without NaN, check pool and head
if not (torch.isnan(x).any() or torch.isinf(x).any()):
    x = model.pool(x)
    print('After pool shape:', x.shape, 'NaN:', torch.isnan(x).any().item())
    # pass through head sequentially
    for j, layer in enumerate(model.head):
        x = layer(x)
        has_nan = torch.isnan(x).any().item()
        has_inf = torch.isinf(x).any().item()
        print(f'head[{j}] {layer.__class__.__name__} -> shape={x.shape} NaN={has_nan} Inf={has_inf}')
        if has_nan or has_inf:
            print('NaN/Inf detected at head index', j, 'layer:', layer)
            break
else:
    print('Stopping trace because NaN/Inf already present in features output')


In [None]:
import time

# training setup
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.BCEWithLogitsLoss()

# training loop with timing
start_time = time.time()
for epoch in range(1, EPOCHS+1):
    model.train()
    train_loss = 0.0
    for xb, yb in train_loader:
        xb = xb.to(DEVICE)
        yb = yb.to(DEVICE)
        optimizer.zero_grad()
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * xb.size(0)
    train_loss /= len(train_loader.dataset)

    # validation
    model.eval()
    val_loss = 0.0
    preds = []
    targets = []
    with torch.no_grad():
        for xb, yb in val_loader:
            xb = xb.to(DEVICE)
            yb = yb.to(DEVICE)
            logits = model(xb)
            loss = criterion(logits, yb)
            val_loss += loss.item() * xb.size(0)
            probs = torch.sigmoid(logits).cpu().numpy()
            preds.extend(probs.tolist())
            targets.extend(yb.cpu().numpy().tolist())
    val_loss /= len(val_loader.dataset) if len(val_loader.dataset)>0 else 1.0
    try:
        val_auc = roc_auc_score(targets, preds)
    except Exception:
        val_auc = float('nan')
    print(f'Epoch {epoch}/{EPOCHS}  train_loss={train_loss:.4f}  val_loss={val_loss:.4f}  val_auc={val_auc:.4f}')

end_time = time.time()
TRAIN_RUNTIME_SECONDS = int(end_time - start_time)
print('Training runtime (s):', TRAIN_RUNTIME_SECONDS)

# save final model
out_dir = Path('models')
out_dir.mkdir(exist_ok=True)
torch.save(model.state_dict(), out_dir / 'pytorch_ecg_1dcnn.pth')
print('Saved PyTorch model to', out_dir / 'pytorch_ecg_1dcnn.pth')


In [None]:
# === Test Metrics Summary ===
from sklearn.metrics import accuracy_score, f1_score, precision_recall_fscore_support, classification_report, confusion_matrix
import math
import datetime

# Ensure test set exists
if len(test_ds) == 0:
    print('No separate test set available (test set length = 0). Using validation set as test set for summary.')
    eval_loader = val_loader
    eval_size = len(val_ds)
else:
    eval_loader = test_loader
    eval_size = len(test_ds)

# Evaluate on chosen test loader
model.eval()
test_loss = 0.0
all_probs = []
all_targets = []
with torch.no_grad():
    for xb, yb in eval_loader:
        xb = xb.to(DEVICE)
        yb = yb.to(DEVICE)
        logits = model(xb)
        loss = criterion(logits, yb)
        test_loss += loss.item() * xb.size(0)
        probs = torch.sigmoid(logits).cpu().numpy()
        all_probs.extend(probs.tolist())
        all_targets.extend(yb.cpu().numpy().tolist())

test_loss = test_loss / eval_size if eval_size>0 else float('nan')
# predictions
y_pred = [1 if p>=0.5 else 0 for p in all_probs]

# metrics
acc = accuracy_score(all_targets, y_pred) if eval_size>0 else float('nan')
f1 = f1_score(all_targets, y_pred, zero_division=0) if eval_size>0 else float('nan')
precision, recall, f1_per_label, _ = precision_recall_fscore_support(all_targets, y_pred, zero_division=0)
clf_report = classification_report(all_targets, y_pred, target_names=['Normal','AFIB'], zero_division=0)
cm = confusion_matrix(all_targets, y_pred)

# sampling frequency: take mode of selected records' sampling_frequency (original), then divide by downsample
try:
    sel_sampling = mapping.loc[sel_idx, 'sampling_frequency']
    sampling_freq = float(sel_sampling.mode().iloc[0]) / (DOWNSAMPLE if DOWNSAMPLE and DOWNSAMPLE>0 else 1)
except Exception:
    sampling_freq = float('nan')

# runtime formatting
try:
    seconds = TRAIN_RUNTIME_SECONDS
    runtime_str = str(datetime.timedelta(seconds=seconds))
except Exception:
    runtime_str = 'N/A'

# dataset sizes
train_n = len(train_ds)
val_n = len(val_ds)
test_n = len(test_ds)

# Print summary
print('=== Test Metrics Summary ===')
print(f'  Test Loss:     {test_loss:.4f}')
print(f'  Test F1-score: {f1:.4f}')
print(f'  Test Accuracy: {acc:.4f}')
print(f'  Sampling frequency: {int(sampling_freq) if not math.isnan(sampling_freq) else "N/A"} Hz')
print(f'  Device used: {DEVICE}\n')
print('Label mapping:')
print('0: Normal Sinus Rhythm (NORM)')
print('1: Atrial Fibrillation (AFIB)\n')
print('Classification Report:')
print(clf_report)
print('Confusion Matrix:')
print(cm)

# human-readable confusion breakdown
if cm.size == 4:
    tn, fp, fn, tp = cm.ravel()
    total_norm = tn + fp
    total_af = fn + tp
    print(f"{tn} Normal ECGs correctly classified ({(tn/total_norm*100) if total_norm>0 else 0:.1f}%)")
    print(f"{fp} Normal ECGs wrongly predicted as AFIB ({(fp/total_norm*100) if total_norm>0 else 0:.1f}%) [False Positives]")
    print(f"{tp} AFIB ECGs correctly classified ({(tp/total_af*100) if total_af>0 else 0:.1f}%)")
    print(f"{fn} AFIB ECGs wrongly predicted as non-AFIB ({(fn/total_af*100) if total_af>0 else 0:.1f}%) [False Negatives]\n")

print('Dataset Sizes:')
print(f'  Training records:   {train_n}')
print(f'  Validation records: {val_n}')
print(f'  Test records:       {test_n}\n')
print('Runtime:', runtime_str)


In [None]:
# Save the Test Metrics Summary to a timestamped file in `results/`
import os
import math
import datetime
from pathlib import Path

out_dir = Path('results')
out_dir.mkdir(exist_ok=True)

now = datetime.datetime.now()
filename = f"run_{now.strftime('%Y%m%d_%H%M%S')}.txt"
out_path = out_dir / filename

# Build the summary text (reuse variables computed in the evaluation cell)
lines = []
lines.append('=== Test Metrics Summary ===')
try:
    lines.append(f'  Test Loss:     {test_loss:.4f}')
except Exception:
    lines.append('  Test Loss:     N/A')
try:
    lines.append(f'  Test F1-score: {f1:.4f}')
except Exception:
    lines.append('  Test F1-score: N/A')
try:
    lines.append(f'  Test Accuracy: {acc:.4f}')
except Exception:
    lines.append('  Test Accuracy: N/A')

# sampling frequency
try:
    sf_text = int(sampling_freq) if not math.isnan(sampling_freq) else 'N/A'
except Exception:
    sf_text = 'N/A'
lines.append(f'  Sampling frequency: {sf_text} Hz')
lines.append(f'  Device used: {DEVICE}\n')

lines.append('Label mapping:')
lines.append('0: Normal Sinus Rhythm (NORM)')
lines.append('1: Atrial Fibrillation (AFIB)\n')

lines.append('Classification Report:')
try:
    lines.append(clf_report)
except Exception:
    lines.append('N/A')

lines.append('\nConfusion Matrix:')
try:
    lines.append(str(cm))
except Exception:
    lines.append('N/A')

# human-readable confusion breakdown
try:
    if cm.size == 4:
        tn, fp, fn, tp = cm.ravel()
        total_norm = tn + fp
        total_af = fn + tp
        lines.append(f"{tn} Normal ECGs correctly classified ({(tn/total_norm*100) if total_norm>0 else 0:.1f}%)")
        lines.append(f"{fp} Normal ECGs wrongly predicted as AFIB ({(fp/total_norm*100) if total_norm>0 else 0:.1f}%) [False Positives]")
        lines.append(f"{tp} AFIB ECGs correctly classified ({(tp/total_af*100) if total_af>0 else 0:.1f}%)")
        lines.append(f"{fn} AFIB ECGs wrongly predicted as non-AFIB ({(fn/total_af*100) if total_af>0 else 0:.1f}%) [False Negatives]\n")
except Exception:
    pass

lines.append('Dataset Sizes:')
try:
    lines.append(f'  Training records:   {train_n}')
    lines.append(f'  Validation records: {val_n}')
    lines.append(f'  Test records:       {test_n}\n')
except Exception:
    lines.append('  Training records:   N/A')
    lines.append('  Validation records: N/A')
    lines.append('  Test records:       N/A\n')

# runtime
try:
    lines.append(f'Runtime: {runtime_str}')
except Exception:
    lines.append('Runtime: N/A')

# Write to file
with open(out_path, 'w', encoding='utf-8') as fh:
    fh.write('\n'.join(lines))

print(f'Saved Test Metrics Summary to: {out_path}')