# AFIB Detection Using 1D CNN on Cleaned ECG Dataset

This notebook trains a binary AFIB detector using the cleaned and balanced dataset generated by data_handling.ipynb.

Workflow:
1. Setup paths and load cleaned mapping
2. Verify data and prepare labels
3. Split dataset (train/val/test)
4. Compute global channel statistics
5. Define ECG dataset class
6. Create data loaders
7. Define and train 1D CNN model
8. Evaluate on test set


## 1. Imports


In [1]:
import pandas as pd
import numpy as np
from pathlib import Path
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, accuracy_score, confusion_matrix, classification_report

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt
import seaborn as sns

print("✓ All imports successful")


✓ All imports successful


## 2. Configuration & Path Setup


In [None]:
# Hyperparameters
DOWNSAMPLE = 2  # Downsample factor (e.g., 2 = keep every 2nd sample)
BATCH_SIZE = 32
LEARNING_RATE = 0.001
NUM_EPOCHS = 30
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Paths (all relative to project root)
PROJECT_ROOT = Path.cwd()
DATA_DIR = PROJECT_ROOT / "data"
CLEANED_ROOT = DATA_DIR / "cleaned_balanced_AFIB_SR"
CLEANED_WFDB_DIR = CLEANED_ROOT / "WFDBRecords"
MAPPING_CSV = CLEANED_ROOT / "file_mapping_cleaned.csv"

# Model save path
MODEL_DIR = PROJECT_ROOT / "models"
MODEL_DIR.mkdir(parents=True, exist_ok=True)
MODEL_PATH = MODEL_DIR / "ecg_1dcnn_afib_balanced.pth"

print(f"Device: {DEVICE}")
print(f"MAPPING_CSV: {MAPPING_CSV}")
print(f"MODEL_PATH: {MODEL_PATH}")
print("\n✓ Configuration complete")


## 3. Load Cleaned Mapping & Verify Data


In [None]:
# Load the cleaned mapping
cleaned_mapping = pd.read_csv(MAPPING_CSV)

print(f"Loaded {len(cleaned_mapping):,} records from {MAPPING_CSV.name}")
print(f"\nColumns: {list(cleaned_mapping.columns)}")

# Verify required columns exist
required_cols = ['record_id', 'ecg_path', '_AFIB', '_SR']
missing_cols = [col for col in required_cols if col not in cleaned_mapping.columns]
assert not missing_cols, f"Missing required columns: {missing_cols}"
print(f"✓ All required columns present: {required_cols}")

# Basic statistics
print(f"\n--- Dataset Statistics ---")
print(f"Total records: {len(cleaned_mapping):,}")
print(f"AFIB records (_AFIB=1): {cleaned_mapping['_AFIB'].sum():,}")
print(f"Non-AFIB records (_AFIB=0): {(cleaned_mapping['_AFIB'] == 0).sum():,}")
print(f"SR records (_SR=1): {cleaned_mapping['_SR'].sum():,}")
print(f"Non-SR records (_SR=0): {(cleaned_mapping['_SR'] == 0).sum():,}")

# Display sample rows
print("\nSample rows:")
cleaned_mapping.head()


## 4. Prepare Labels for Binary AFIB Detection


In [None]:
# Use _AFIB as the primary target column
target_col = "_AFIB"

# Ensure target is strictly 0/1
cleaned_mapping[target_col] = pd.to_numeric(cleaned_mapping[target_col], errors='coerce').fillna(0)
cleaned_mapping[target_col] = (cleaned_mapping[target_col] > 0).astype(int)

print(f"Target column: {target_col}")
print("Target distribution:")
print(cleaned_mapping[target_col].value_counts().sort_index())
print("\n✓ Labels prepared for binary AFIB detection")


## 5. Split Dataset (Train/Val/Test)


In [None]:
# Create indices array
indices = np.arange(len(cleaned_mapping))
labels = cleaned_mapping[target_col].values

# First split: 80% train, 20% temp
idx_train, idx_temp, y_train, y_temp = train_test_split(
    indices, labels,
    test_size=0.2,
    stratify=labels,
    random_state=42
)

# Second split: temp into 50% val, 50% test
idx_val, idx_test, y_val, y_test = train_test_split(
    idx_temp, y_temp,
    test_size=0.5,
    stratify=y_temp,
    random_state=42
)

print("--- Dataset Splits ---")
print(f"Train: {len(idx_train):,} samples (AFIB: {y_train.sum():,}, Non-AFIB: {(y_train == 0).sum():,})")
print(f"Val:   {len(idx_val):,} samples (AFIB: {y_val.sum():,}, Non-AFIB: {(y_val == 0).sum():,})")
print(f"Test:  {len(idx_test):,} samples (AFIB: {y_test.sum():,}, Non-AFIB: {(y_test == 0).sum():,})")
print("\n✓ Dataset split complete")


## 6. Compute Global Channel Statistics


In [None]:
def compute_channel_stats(mapping, indices, downsample=None):
    """
    Compute global per-channel mean and std from a subset of training data.
    Stacks up to 512 random samples to estimate statistics.
    """
    if len(indices) > 512:
        sample_indices = np.random.choice(indices, size=512, replace=False)
    else:
        sample_indices = indices

    data_list = []

    for idx in sample_indices:
        row = mapping.iloc[int(idx)]
        npy_file = Path(row["ecg_path"])
        if not npy_file.exists():
            continue
        try:
            data = np.load(npy_file)
            # Ensure 2D array and shape (12, L)
            if data.ndim != 2:
                continue
            if data.shape[0] > data.shape[1]:
                data = data.T
            if data.shape[0] != 12:
                continue
            if downsample is not None and downsample > 1:
                data = data[:, ::downsample]
            data_list.append(data)
        except Exception:
            continue

    if not data_list:
        raise ValueError("No valid data loaded for computing statistics")

    # Align lengths to stack
    min_len = min(arr.shape[1] for arr in data_list)
    data_list = [arr[:, :min_len] for arr in data_list]
    stack = np.stack(data_list, axis=0)  # (N, 12, T)

    mean = stack.mean(axis=(0, 2)).astype(np.float32)
    std = (stack.std(axis=(0, 2)) + 1e-6).astype(np.float32)
    return mean, std

print("✓ Function defined: compute_channel_stats")


In [None]:
# Compute global statistics from training set
print("Computing global channel statistics from training data...")
global_mean, global_std = compute_channel_stats(
    cleaned_mapping,
    idx_train,
    downsample=DOWNSAMPLE,
)

print(f"\nGlobal mean (per channel): {global_mean}")
print(f"Global std (per channel):  {global_std}")
print("\n✓ Channel statistics computed")


## 7. Define ECG Dataset Class


In [None]:
class ECGDataset(Dataset):
    """PyTorch Dataset for ECG data that loads from .npy files."""
    def __init__(self, mapping, indices, downsample=None, target_col="_AFIB",
                 channel_mean=None, channel_std=None):
        self.mapping = mapping.iloc[indices].reset_index(drop=True)
        self.downsample = downsample
        self.target_col = target_col
        self.channel_mean = channel_mean if channel_mean is not None else np.zeros(12, dtype=np.float32)
        self.channel_std = channel_std if channel_std is not None else np.ones(12, dtype=np.float32)

    def __len__(self):
        return len(self.mapping)

    def __getitem__(self, idx):
        row = self.mapping.iloc[idx]
        record_id = str(row["record_id"])
        npy_file = Path(row["ecg_path"])
        try:
            data = np.load(npy_file)
            # Ensure 2D array and shape (12, L)
            if data.ndim != 2:
                raise ValueError(f"Expected 2D array, got shape {data.shape}")
            if data.shape[0] > data.shape[1]:
                data = data.T
            if data.shape[0] != 12:
                raise ValueError(f"Expected 12 channels, got {data.shape[0]}")
            if self.downsample is not None and self.downsample > 1:
                data = data[:, ::self.downsample]
            # Normalize
            data = (data - self.channel_mean[:, None]) / self.channel_std[:, None]
        except Exception as e:
            print(f"Error loading {record_id}: {e}")
            # Fallback dummy tensor (kept simple)
            data = np.zeros((12, 2500), dtype=np.float32)
        try:
            tval = float(row[self.target_col])
            label = 1.0 if tval > 0 else 0.0
        except Exception:
            label = 0.0
        data_tensor = torch.from_numpy(data.astype(np.float32))
        label_tensor = torch.tensor(label, dtype=torch.float32)
        return data_tensor, label_tensor

print("✓ ECGDataset class defined")


## 8. Create DataLoaders


In [None]:
# Create datasets
train_ds = ECGDataset(cleaned_mapping, idx_train, downsample=DOWNSAMPLE, target_col="_AFIB",
                       channel_mean=global_mean, channel_std=global_std)
val_ds = ECGDataset(cleaned_mapping, idx_val, downsample=DOWNSAMPLE, target_col="_AFIB",
                     channel_mean=global_mean, channel_std=global_std)
test_ds = ECGDataset(cleaned_mapping, idx_test, downsample=DOWNSAMPLE, target_col="_AFIB",
                      channel_mean=global_mean, channel_std=global_std)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

print(f"Train loader: {len(train_loader)} batches")
print(f"Val loader:   {len(val_loader)} batches")
print(f"Test loader:  {len(test_loader)} batches")
print("\n✓ DataLoaders created")


## 9. Define 1D CNN Model


In [None]:
class ECG1DCNN(nn.Module):
    def __init__(self, in_channels=12, num_classes=1):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels, 64, kernel_size=7, padding=3)
        self.bn1 = nn.BatchNorm1d(64)
        self.conv2 = nn.Conv1d(64, 128, kernel_size=5, padding=2)
        self.bn2 = nn.BatchNorm1d(128)
        self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm1d(256)
        self.conv4 = nn.Conv1d(256, 512, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm1d(512)
        self.pool = nn.MaxPool1d(kernel_size=2)
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        self.dropout = nn.Dropout(0.5)
        self.fc = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.pool(torch.relu(self.bn1(self.conv1(x))))
        x = self.pool(torch.relu(self.bn2(self.conv2(x))))
        x = self.pool(torch.relu(self.bn3(self.conv3(x))))
        x = self.pool(torch.relu(self.bn4(self.conv4(x))))
        x = self.global_pool(x).squeeze(-1)
        x = self.dropout(x)
        x = self.fc(x)
        return x

model = ECG1DCNN(in_channels=12, num_classes=1).to(DEVICE)
print(f"Model architecture:\n{model}")
print(f"\n✓ Model initialized on {DEVICE}")


## 10. Training Setup


In [None]:
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

print("Loss function: BCEWithLogitsLoss")
print(f"Optimizer: Adam (lr={LEARNING_RATE})")
print("\n✓ Training setup complete")


## 11. Training Loop


In [None]:
history = {'train_loss': [], 'val_loss': [], 'val_auc': [], 'val_acc': []}
best_val_auc = 0.0

print(f"Starting training for {NUM_EPOCHS} epochs...\n")
for epoch in range(NUM_EPOCHS):
    model.train()
    train_loss = 0.0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(DEVICE), target.to(DEVICE)
        optimizer.zero_grad()
        output = model(data).squeeze()
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    avg_train_loss = train_loss / max(1, len(train_loader))

    model.eval()
    val_loss = 0.0
    all_preds, all_targets = [], []
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(DEVICE), target.to(DEVICE)
            output = model(data).squeeze()
            loss = criterion(output, target)
            val_loss += loss.item()
            probs = torch.sigmoid(output).cpu().numpy()
            all_preds.extend(probs)
            all_targets.extend(target.cpu().numpy())
    avg_val_loss = val_loss / max(1, len(val_loader))
    all_preds = np.array(all_preds)
    all_targets = np.array(all_targets)
    val_auc = roc_auc_score(all_targets, all_preds) if len(np.unique(all_targets)) > 1 else 0.5
    val_acc = accuracy_score(all_targets, (all_preds > 0.5).astype(int))
    history['train_loss'].append(avg_train_loss)
    history['val_loss'].append(avg_val_loss)
    history['val_auc'].append(val_auc)
    history['val_acc'].append(val_acc)
    print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val AUC: {val_auc:.4f} | Val Acc: {val_acc:.4f}")
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        torch.save(model.state_dict(), MODEL_PATH)
        print(f"  → Best model saved (AUC: {val_auc:.4f})")

print(f"\n✓ Training complete!")
print(f"Best validation AUC: {best_val_auc:.4f}")


## 12. Plot Training History


In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
axes[0].plot(history['train_loss'], label='Train Loss')
axes[0].plot(history['val_loss'], label='Val Loss')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training & Validation Loss')
axes[0].legend(); axes[0].grid(True)
axes[1].plot(history['val_auc'], label='Val AUC', color='green')
axes[1].set_xlabel('Epoch'); axes[1].set_ylabel('AUC'); axes[1].set_title('Validation AUC')
axes[1].legend(); axes[1].grid(True)
axes[2].plot(history['val_acc'], label='Val Accuracy', color='orange')
axes[2].set_xlabel('Epoch'); axes[2].set_ylabel('Accuracy'); axes[2].set_title('Validation Accuracy')
axes[2].legend(); axes[2].grid(True)
plt.tight_layout(); plt.show()


## 13. Load Best Model & Evaluate on Test Set


In [None]:
model.load_state_dict(torch.load(MODEL_PATH))
model.eval()

print(f"Loaded best model from {MODEL_PATH}")
print("\nEvaluating on test set...\n")

test_loss = 0.0
all_test_preds, all_test_targets = [], []
with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(DEVICE), target.to(DEVICE)
        output = model(data).squeeze()
        loss = criterion(output, target)
        test_loss += loss.item()
        probs = torch.sigmoid(output).cpu().numpy()
        all_test_preds.extend(probs)
        all_test_targets.extend(target.cpu().numpy())
avg_test_loss = test_loss / max(1, len(test_loader))
all_test_preds = np.array(all_test_preds)
all_test_targets = np.array(all_test_targets)

test_auc = roc_auc_score(all_test_targets, all_test_preds) if len(np.unique(all_test_targets)) > 1 else 0.5
test_pred_labels = (all_test_preds > 0.5).astype(int)
test_acc = accuracy_score(all_test_targets, test_pred_labels)
conf_matrix = confusion_matrix(all_test_targets, test_pred_labels)

print("--- Test Set Results ---")
print(f"Test Loss: {avg_test_loss:.4f}")
print(f"Test AUC:  {test_auc:.4f}")
print(f"Test Acc:  {test_acc:.4f}")
print("\nConfusion Matrix:")
print(conf_matrix)
print("\nClassification Report:")
print(classification_report(all_test_targets, test_pred_labels, target_names=['Non-AFIB', 'AFIB']))


## 14. Visualize Confusion Matrix


In [None]:
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', 
            xticklabels=['Non-AFIB', 'AFIB'],
            yticklabels=['Non-AFIB', 'AFIB'])
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title(f'Confusion Matrix (Test Set)\nAUC: {test_auc:.4f}, Acc: {test_acc:.4f}')
plt.tight_layout()
plt.show()

print("\n============================================================")
print("✓ AFIB detection pipeline complete!")
print("============================================================")
