In [None]:
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', 
            xticklabels=['Non-AFIB', 'AFIB'],
            yticklabels=['Non-AFIB', 'AFIB'])
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title(f'Confusion Matrix (Test Set)\nAUC: {test_auc:.4f}, Acc: {test_acc:.4f}')
plt.tight_layout()
plt.show()

print(f"\n{'='*60}")
print(f"✓ AFIB detection pipeline complete!")
print(f"{'='*60}")

## 14. Visualize Confusion Matrix

In [None]:
# Load best model
model.load_state_dict(torch.load(MODEL_PATH))
model.eval()

print(f"Loaded best model from {MODEL_PATH}")
print(f"\nEvaluating on test set...\n")

# Evaluate on test set
test_loss = 0.0
all_test_preds = []
all_test_targets = []

with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(DEVICE), target.to(DEVICE)
        
        output = model(data).squeeze()
        loss = criterion(output, target)
        test_loss += loss.item()
        
        # Collect predictions
        probs = torch.sigmoid(output).cpu().numpy()
        all_test_preds.extend(probs)
        all_test_targets.extend(target.cpu().numpy())

avg_test_loss = test_loss / len(test_loader)

# Convert to arrays
all_test_preds = np.array(all_test_preds)
all_test_targets = np.array(all_test_targets)

# Compute metrics
test_auc = roc_auc_score(all_test_targets, all_test_preds)
test_pred_labels = (all_test_preds > 0.5).astype(int)
test_acc = accuracy_score(all_test_targets, test_pred_labels)
conf_matrix = confusion_matrix(all_test_targets, test_pred_labels)

print(f"--- Test Set Results ---")
print(f"Test Loss: {avg_test_loss:.4f}")
print(f"Test AUC:  {test_auc:.4f}")
print(f"Test Acc:  {test_acc:.4f}")
print(f"\nConfusion Matrix:")
print(conf_matrix)
print(f"\nClassification Report:")
print(classification_report(all_test_targets, test_pred_labels, target_names=['Non-AFIB', 'AFIB']))

## 13. Load Best Model & Evaluate on Test Set

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Plot 1: Loss
axes[0].plot(history['train_loss'], label='Train Loss')
axes[0].plot(history['val_loss'], label='Val Loss')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training & Validation Loss')
axes[0].legend()
axes[0].grid(True)

# Plot 2: AUC
axes[1].plot(history['val_auc'], label='Val AUC', color='green')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('AUC')
axes[1].set_title('Validation AUC')
axes[1].legend()
axes[1].grid(True)

# Plot 3: Accuracy
axes[2].plot(history['val_acc'], label='Val Accuracy', color='orange')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Accuracy')
axes[2].set_title('Validation Accuracy')
axes[2].legend()
axes[2].grid(True)

plt.tight_layout()
plt.show()

## 12. Plot Training History

In [None]:
# Training history
history = {
    'train_loss': [],
    'val_loss': [],
    'val_auc': [],
    'val_acc': []
}

best_val_auc = 0.0

print(f"Starting training for {NUM_EPOCHS} epochs...\n")

for epoch in range(NUM_EPOCHS):
    # ========== Training phase ==========
    model.train()
    train_loss = 0.0
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(DEVICE), target.to(DEVICE)
        
        # Forward pass
        optimizer.zero_grad()
        output = model(data).squeeze()  # Shape: (batch,)
        
        loss = criterion(output, target)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
    
    avg_train_loss = train_loss / len(train_loader)
    
    # ========== Validation phase ==========
    model.eval()
    val_loss = 0.0
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(DEVICE), target.to(DEVICE)
            
            output = model(data).squeeze()
            loss = criterion(output, target)
            val_loss += loss.item()
            
            # Collect predictions (apply sigmoid to get probabilities)
            probs = torch.sigmoid(output).cpu().numpy()
            all_preds.extend(probs)
            all_targets.extend(target.cpu().numpy())
    
    avg_val_loss = val_loss / len(val_loader)
    
    # Compute metrics
    all_preds = np.array(all_preds)
    all_targets = np.array(all_targets)
    
    val_auc = roc_auc_score(all_targets, all_preds)
    val_acc = accuracy_score(all_targets, (all_preds > 0.5).astype(int))
    
    # Store history
    history['train_loss'].append(avg_train_loss)
    history['val_loss'].append(avg_val_loss)
    history['val_auc'].append(val_auc)
    history['val_acc'].append(val_acc)
    
    # Print progress
    print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] "
          f"Train Loss: {avg_train_loss:.4f} | "
          f"Val Loss: {avg_val_loss:.4f} | "
          f"Val AUC: {val_auc:.4f} | "
          f"Val Acc: {val_acc:.4f}")
    
    # Save best model
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        torch.save(model.state_dict(), MODEL_PATH)
        print(f"  → Best model saved (AUC: {val_auc:.4f})")

print(f"\n✓ Training complete!")
print(f"Best validation AUC: {best_val_auc:.4f}")

## 11. Training Loop

In [None]:
# Loss function and optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

print(f"Loss function: BCEWithLogitsLoss")
print(f"Optimizer: Adam (lr={LEARNING_RATE})")
print(f"\n✓ Training setup complete")

## 10. Training Setup

In [None]:
class ECG1DCNN(nn.Module):
    """
    Simple 1D CNN for AFIB detection from 12-lead ECG.
    """
    
    def __init__(self, in_channels=12, num_classes=1):
        super(ECG1DCNN, self).__init__()
        
        # Convolutional blocks
        self.conv1 = nn.Conv1d(in_channels, 64, kernel_size=7, padding=3)
        self.bn1 = nn.BatchNorm1d(64)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool1d(kernel_size=2)
        
        self.conv2 = nn.Conv1d(64, 128, kernel_size=5, padding=2)
        self.bn2 = nn.BatchNorm1d(128)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool1d(kernel_size=2)
        
        self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm1d(256)
        self.relu3 = nn.ReLU()
        self.pool3 = nn.MaxPool1d(kernel_size=2)
        
        self.conv4 = nn.Conv1d(256, 512, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm1d(512)
        self.relu4 = nn.ReLU()
        self.pool4 = nn.MaxPool1d(kernel_size=2)
        
        # Global average pooling
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        
        # Fully connected layers
        self.dropout = nn.Dropout(0.5)
        self.fc = nn.Linear(512, num_classes)
    
    def forward(self, x):
        # Input shape: (batch, 12, T)
        
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu3(x)
        x = self.pool3(x)
        
        x = self.conv4(x)
        x = self.bn4(x)
        x = self.relu4(x)
        x = self.pool4(x)
        
        # Global pooling: (batch, 512, T') -> (batch, 512, 1)
        x = self.global_pool(x)
        x = x.squeeze(-1)  # (batch, 512)
        
        x = self.dropout(x)
        x = self.fc(x)  # (batch, 1)
        
        return x

# Initialize model
model = ECG1DCNN(in_channels=12, num_classes=1)
model = model.to(DEVICE)

print(f"Model architecture:\n{model}")
print(f"\n✓ Model initialized on {DEVICE}")

## 9. Define 1D CNN Model

In [None]:
# Create datasets
train_ds = ECGDataset(
    cleaned_mapping, 
    idx_train, 
    downsample=DOWNSAMPLE, 
    target_col=TARGET_COL,
    channel_mean=global_mean, 
    channel_std=global_std
)

val_ds = ECGDataset(
    cleaned_mapping, 
    idx_val, 
    downsample=DOWNSAMPLE, 
    target_col=TARGET_COL,
    channel_mean=global_mean, 
    channel_std=global_std
)

test_ds = ECGDataset(
    cleaned_mapping, 
    idx_test, 
    downsample=DOWNSAMPLE, 
    target_col=TARGET_COL,
    channel_mean=global_mean, 
    channel_std=global_std
)

# Create data loaders
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

print(f"Train loader: {len(train_loader)} batches")
print(f"Val loader:   {len(val_loader)} batches")
print(f"Test loader:  {len(test_loader)} batches")
print(f"\n✓ DataLoaders created")

## 8. Create DataLoaders

In [None]:
class ECGDataset(Dataset):
    """
    PyTorch Dataset for ECG data that loads from .npy files.
    """
    
    def __init__(self, mapping, indices, downsample=None, target_col="_AFIB", 
                 channel_mean=None, channel_std=None):
        """
        Args:
            mapping: DataFrame with columns ['record_id', 'ecg_path', target_col]
            indices: Array of indices to include in this dataset
            downsample: Downsample factor (None or int > 1)
            target_col: Name of the target column
            channel_mean: Per-channel mean for normalization (shape: 12,)
            channel_std: Per-channel std for normalization (shape: 12,)
        """
        self.mapping = mapping.iloc[indices].reset_index(drop=True)
        self.downsample = downsample
        self.target_col = target_col
        self.channel_mean = channel_mean
        self.channel_std = channel_std
        
        # Ensure mean/std are available
        if self.channel_mean is None:
            self.channel_mean = np.zeros(12, dtype=np.float32)
        if self.channel_std is None:
            self.channel_std = np.ones(12, dtype=np.float32)
    
    def __len__(self):
        return len(self.mapping)
    
    def __getitem__(self, idx):
        row = self.mapping.iloc[idx]
        record_id = str(row["record_id"])
        npy_file = Path(row["ecg_path"])
        
        # Load ECG data
        try:
            data = np.load(npy_file)
            
            # Ensure 2D array
            if data.ndim != 2:
                raise ValueError(f"Expected 2D array, got shape {data.shape}")
            
            # If shape is (L, 12), transpose to (12, L)
            if data.shape[0] > data.shape[1]:
                data = data.T
            
            # Should now be (12, L)
            if data.shape[0] != 12:
                raise ValueError(f"Expected 12 channels, got {data.shape[0]}")
            
            # Apply downsampling if specified
            if self.downsample is not None and self.downsample > 1:
                data = data[:, ::self.downsample]
            
            # Normalize: (data - mean) / std
            # mean and std have shape (12,), data has shape (12, T)
            data = (data - self.channel_mean[:, None]) / self.channel_std[:, None]
            
            # Convert to float32
            data = data.astype(np.float32)
            
        except Exception as e:
            print(f"Error loading {record_id}: {e}")
            # Return zeros as fallback
            data = np.zeros((12, 2500), dtype=np.float32)
        
        # Build label
        try:
            tval = float(row[self.target_col])
            label = 1.0 if tval > 0 else 0.0
        except:
            label = 0.0
        
        # Convert to tensors
        data_tensor = torch.from_numpy(data).float()
        label_tensor = torch.tensor(label, dtype=torch.float32)
        
        return data_tensor, label_tensor

print("✓ ECGDataset class defined")

## 7. Define ECG Dataset Class

In [None]:
# Compute global statistics from training set
print("Computing global channel statistics from training data...")
global_mean, global_std = compute_channel_stats(
    cleaned_mapping, 
    idx_train, 
    downsample=DOWNSAMPLE,
    max_samples=512
)

print(f"\nGlobal mean (per channel): {global_mean}")
print(f"Global std (per channel):  {global_std}")
print(f"\n✓ Channel statistics computed")

In [None]:
def compute_channel_stats(mapping, indices, downsample=None, max_samples=512):
    """
    Compute global per-channel mean and std from a subset of training data.
    
    Args:
        mapping: DataFrame with 'ecg_path' column
        indices: Array of indices to use
        downsample: Downsample factor (None or int > 1)
        max_samples: Maximum number of samples to load for statistics
    
    Returns:
        mean, std: Arrays of shape (12,) with per-channel statistics
    """
    # Randomly sample up to max_samples indices
    if len(indices) > max_samples:
        sample_indices = np.random.choice(indices, size=max_samples, replace=False)
    else:
        sample_indices = indices
    
    data_list = []
    
    for idx in sample_indices:
        row = mapping.iloc[int(idx)]
        npy_file = Path(row["ecg_path"])
        
        if not npy_file.exists():
            continue
        
        try:
            # Load ECG data
            data = np.load(npy_file)
            
            # Ensure 2D array
            if data.ndim != 2:
                continue
            
            # If shape is (L, 12), transpose to (12, L)
            if data.shape[0] > data.shape[1]:
                data = data.T
            
            # Should now be (12, L)
            if data.shape[0] != 12:
                continue
            
            # Apply downsampling if specified
            if downsample is not None and downsample > 1:
                data = data[:, ::downsample]
            
            data_list.append(data)
            
        except Exception as e:
            print(f"  Warning: Failed to load {npy_file.name}: {e}")
            continue
    
    if not data_list:
        raise ValueError("No valid data loaded for computing statistics")
    
    # Stack into (N, 12, T)
    # Note: T may vary, so we'll compute stats per sample then average
    all_means = []
    all_stds = []
    
    for data in data_list:
        # data shape: (12, T)
        all_means.append(data.mean(axis=1))  # Shape: (12,)
        all_stds.append(data.std(axis=1))     # Shape: (12,)
    
    # Average across samples
    mean = np.mean(all_means, axis=0).astype(np.float32)
    std = np.mean(all_stds, axis=0).astype(np.float32) + 1e-6
    
    return mean, std

print("✓ Function defined: compute_channel_stats")

## 6. Compute Global Channel Statistics

In [None]:
# Create indices array
indices = np.arange(len(cleaned_mapping))
labels = cleaned_mapping[TARGET_COL].values

# First split: 80% train, 20% temp
idx_train, idx_temp, y_train, y_temp = train_test_split(
    indices, labels, 
    test_size=0.2, 
    stratify=labels, 
    random_state=42
)

# Second split: temp into 50% val, 50% test
idx_val, idx_test, y_val, y_test = train_test_split(
    idx_temp, y_temp,
    test_size=0.5,
    stratify=y_temp,
    random_state=42
)

print(f"--- Dataset Splits ---")
print(f"Train: {len(idx_train):,} samples (AFIB: {y_train.sum():,}, Non-AFIB: {(y_train == 0).sum():,})")
print(f"Val:   {len(idx_val):,} samples (AFIB: {y_val.sum():,}, Non-AFIB: {(y_val == 0).sum():,})")
print(f"Test:  {len(idx_test):,} samples (AFIB: {y_test.sum():,}, Non-AFIB: {(y_test == 0).sum():,})")
print(f"\n✓ Dataset split complete")

## 5. Split Dataset (Train/Val/Test)

In [None]:
# Use _AFIB as the primary target column
TARGET_COL = "_AFIB"

# Ensure target is strictly 0/1
cleaned_mapping[TARGET_COL] = pd.to_numeric(cleaned_mapping[TARGET_COL], errors='coerce').fillna(0)
cleaned_mapping[TARGET_COL] = (cleaned_mapping[TARGET_COL] > 0).astype(int)

print(f"Target column: {TARGET_COL}")
print(f"Target distribution:")
print(cleaned_mapping[TARGET_COL].value_counts().sort_index())
print(f"\n✓ Labels prepared for binary AFIB detection")

## 4. Prepare Labels for Binary AFIB Detection

In [None]:
# Load the cleaned mapping
cleaned_mapping = pd.read_csv(MAPPING_CSV)

print(f"Loaded {len(cleaned_mapping):,} records from {MAPPING_CSV.name}")
print(f"\nColumns: {list(cleaned_mapping.columns)}")

# Verify required columns exist
required_cols = ['record_id', 'ecg_path', '_AFIB', '_SR']
missing_cols = [col for col in required_cols if col not in cleaned_mapping.columns]
assert not missing_cols, f"Missing required columns: {missing_cols}"
print(f"✓ All required columns present: {required_cols}")

# Basic statistics
print(f"\n--- Dataset Statistics ---")
print(f"Total records: {len(cleaned_mapping):,}")
print(f"AFIB records (_AFIB=1): {cleaned_mapping['_AFIB'].sum():,}")
print(f"Non-AFIB records (_AFIB=0): {(cleaned_mapping['_AFIB'] == 0).sum():,}")
print(f"SR records (_SR=1): {cleaned_mapping['_SR'].sum():,}")
print(f"Non-SR records (_SR=0): {(cleaned_mapping['_SR'] == 0).sum():,}")

# Display sample rows
print(f"\nSample rows:")
cleaned_mapping.head()

## 3. Load Cleaned Mapping & Verify Data

In [None]:
# Hyperparameters
DOWNSAMPLE = 2  # Downsample factor (e.g., 2 = keep every 2nd sample)
BATCH_SIZE = 32
LEARNING_RATE = 0.001
NUM_EPOCHS = 30
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Paths (all relative to project root)
PROJECT_ROOT = Path.cwd()
DATA_DIR = PROJECT_ROOT / "data"
CLEANED_ROOT = DATA_DIR / "cleaned_balanced_AFIB_SR"
CLEANED_WFDB_DIR = CLEANED_ROOT / "WFDBRecords"
MAPPING_CSV = CLEANED_ROOT / "file_mapping_cleaned.csv"

# Model save path
MODEL_DIR = PROJECT_ROOT / "models"
MODEL_DIR.mkdir(parents=True, exist_ok=True)
MODEL_PATH = MODEL_DIR / "ecg_1dcnn_afib_balanced.pth"

print(f"Device: {DEVICE}")
print(f"MAPPING_CSV: {MAPPING_CSV}")
print(f"MODEL_PATH: {MODEL_PATH}")
print(f"\n✓ Configuration complete")

## 2. Configuration & Path Setup

In [None]:
import pandas as pd
import numpy as np
from pathlib import Path
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, accuracy_score, confusion_matrix, classification_report

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt
import seaborn as sns

print("✓ All imports successful")

## 1. Imports

# AFIB Detection Using 1D CNN on Cleaned ECG Dataset

This notebook trains a binary AFIB detector using the cleaned and balanced dataset generated by data_handling.ipynb.

**Workflow:**
1. Setup paths and load cleaned mapping
2. Verify data and prepare labels
3. Split dataset (train/val/test)
4. Compute global channel statistics
5. Define ECG dataset class
6. Create data loaders
7. Define and train 1D CNN model
8. Evaluate on test set