# ECG CNN Training

This notebook implements an advanced 1D CNN for ECG age group classification with the following features:

- **Class imbalance handling**: Focal loss, class weighting, data augmentation
- **Advanced architecture**: ResNet blocks, attention mechanisms
- **Comprehensive regularization**: Dropout, batch normalization, gradient clipping
- **Direct WFDB parsing**: From .dat/.hea files

## Table of Contents
1. [Data Loading](#data-loading)
2. [Data Preprocessing](#data-preprocessing)
3. [Model Architecture](#model-architecture)
4. [Hyperparameter Configuration](#hyperparameter-configuration)
5. [Model Training](#model-training)
6. [Model Evaluation](#model-evaluation)
7. [Visualization and Analysis](#visualization-and-analysis)


## 1. Data Loading

Load ECG data using the WFDB parser with advanced preprocessing.


In [8]:
# Import required libraries
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split, WeightedRandomSampler, Subset
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, f1_score
from sklearn.utils.class_weight import compute_class_weight
import random
from collections import Counter
import warnings
import sys
import os
import time
from sklearn.model_selection import GroupShuffleSplit
# Add parent directory to path to import wfdb_parser
sys.path.append('..')
from wfdb_parser import create_wfdb_dataset
warnings.filterwarnings('ignore')

# Configuration
DATA_PATH = "../input/autonomic-aging-a-dataset-to-quantify-changes-of-cardiovascular-autonomic-function-during-healthy-aging-1.0.0"
SUBJECT_INFO_CSV = "../input/autonomic-aging-a-dataset-to-quantify-changes-of-cardiovascular-autonomic-function-during-healthy-aging-1.0.0/subject-info.csv"
OUTPUT_DIR = "./ecg_cnn_outputs"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Dataset parameters
WINDOW_SIZE_SEC = 10
WINDOW_STEP_SEC = 5
PRELOAD_DATA = True  # Optimized for 16GB RAM

# Training parameters
RANDOM_STATE = 42
TEST_SPLIT = 0.2
BATCH_SIZE = 128  # Optimized for M4 Pro GPU
EPOCHS = 50
LR = 1e-3

# Device detection with M4 Pro GPU support
def get_device():
    if torch.backends.mps.is_available():
        return torch.device("mps")  # M4 Pro GPU
    elif torch.cuda.is_available():
        return torch.device("cuda")  # NVIDIA GPU
    else:
        return torch.device("cpu")  # CPU fallback

DEVICE = get_device()
print(f"Using device: {DEVICE}")

# M4 Pro GPU optimizations
if DEVICE.type == "mps":
    torch.backends.mps.allow_tf32 = True
    torch.backends.mps.allow_fp16 = True
    print("M4 Pro GPU optimizations enabled")

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)


Using device: mps
M4 Pro GPU optimizations enabled


In [4]:
# Load dataset from WFDB files
print("Loading dataset from WFDB files...")
dataset = create_wfdb_dataset(
    data_path=DATA_PATH,
    subject_info_csv=SUBJECT_INFO_CSV,
    window_size_sec=WINDOW_SIZE_SEC,
    window_step_sec=WINDOW_STEP_SEC,
    augment=False,  # We'll handle augmentation in the training loop
    preload=PRELOAD_DATA
)

# Train/test split
groups = [sample['record_name'] for sample in dataset.samples]

# 2. –í–∏–∫–æ—Ä–∏—Å—Ç–æ–≤—É—î–º–æ GroupShuffleSplit –∑–∞–º—ñ—Å—Ç—å random_split
# –¶–µ –≥–∞—Ä–∞–Ω—Ç—É—î, —â–æ –≤—Å—ñ –≤—ñ–∫–Ω–∞ –∑ –æ–¥–Ω–æ–≥–æ record_name –ø–æ—Ç—Ä–∞–ø–ª—è—Ç—å –ê–ë–û –≤ train, –ê–ë–û –≤ test
gss = GroupShuffleSplit(n_splits=1, test_size=TEST_SPLIT, random_state=42)
train_idx, test_idx = next(gss.split(X=range(len(dataset)), groups=groups))

# 3. –°—Ç–≤–æ—Ä—é—î–º–æ –ø—ñ–¥–º–Ω–æ–∂–∏–Ω–∏ (Subsets) –≤–∏–∫–æ—Ä–∏—Å—Ç–æ–≤—É—é—á–∏ –æ—Ç—Ä–∏–º–∞–Ω—ñ —ñ–Ω–¥–µ–∫—Å–∏
train_dataset = Subset(dataset, train_idx)
test_dataset = Subset(dataset, test_idx)

# –ü–µ—Ä–µ–≤—ñ—Ä–∫–∞ –Ω–∞ "—á–µ—Å–Ω—ñ—Å—Ç—å" –µ–∫—Å–ø–µ—Ä–∏–º–µ–Ω—Ç—É (Leakage Check)
train_patients = set(np.array(groups)[train_idx])
test_patients = set(np.array(groups)[test_idx])
intersection = train_patients.intersection(test_patients)

print(f"--- LEAKAGE CHECK ---")
print(f"Unique records in Train: {len(train_patients)}")
print(f"Unique records in Test: {len(test_patients)}")
print(f"Overlapping records (MUST BE 0): {len(intersection)}")
if len(intersection) > 0:
    raise ValueError("CRITICAL ERROR: Data leakage detected! Same patient in Train and Test.")
else:
    print("SUCCESS: Data split is clean based on Record ID.")
print("---------------------")

# Create a wrapper for the training dataset that enables augmentation
class AugmentedDataset:
    def __init__(self, base_dataset, indices):
        self.base_dataset = base_dataset
        self.indices = indices
        # Enable augmentation for this dataset
        self.base_dataset.augment = True
    
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        return self.base_dataset[self.indices[idx]]

# Create augmented training dataset
train_dataset_aug = AugmentedDataset(dataset, train_dataset.indices)

n_classes = len(dataset.classes_)
n_channels = dataset.max_channels
print(f"Number of classes: {n_classes}, Channels: {n_channels}")
print(f"Train samples: {len(train_dataset)}, Test samples: {len(test_dataset)}")
print(f"Augmented train samples: {len(train_dataset_aug)}")
print(f"Class distribution: {dataset.get_class_distribution()}")

# Debug: Check if indices are valid
print(f"Dataset total samples: {len(dataset)}")
print(f"Train indices range: {min(train_dataset.indices)} to {max(train_dataset.indices)}")
print(f"Test indices range: {min(test_dataset.indices)} to {max(test_dataset.indices)}")


Loading dataset from WFDB files...
Error processing record 0400: [Errno 2] No such file or directory: '/Users/dmytro/Diploma/ecg_ml_analysis/v2/input/autonomic-aging-a-dataset-to-quantify-changes-of-cardiovascular-autonomic-function-during-healthy-aging-1.0.0/0400.dat'
Preloading data...
Preloaded 0/234943 samples
Preloaded 1000/234943 samples
Preloaded 2000/234943 samples
Preloaded 3000/234943 samples
Preloaded 4000/234943 samples
Preloaded 5000/234943 samples
Preloaded 6000/234943 samples
Preloaded 7000/234943 samples
Preloaded 8000/234943 samples
Preloaded 9000/234943 samples
Preloaded 10000/234943 samples
Preloaded 11000/234943 samples
Preloaded 12000/234943 samples
Preloaded 13000/234943 samples
Preloaded 14000/234943 samples
Preloaded 15000/234943 samples
Preloaded 16000/234943 samples
Preloaded 17000/234943 samples
Preloaded 18000/234943 samples
Preloaded 19000/234943 samples
Preloaded 20000/234943 samples
Preloaded 21000/234943 samples
Preloaded 22000/234943 samples
Preloaded 2

## 2. Data Preprocessing

Handle class imbalance and create weighted sampling.


In [5]:
# Compute class weights for handling imbalance
# Get class distribution for weighting
all_labels = []
for sample in dataset.samples:
    all_labels.append(dataset.le.transform([sample['age_group']])[0])

class_counts = Counter(all_labels)
class_weights = compute_class_weight('balanced', classes=np.unique(all_labels), y=all_labels)
class_weights = torch.FloatTensor(class_weights).to(DEVICE)

print(f"Class distribution: {dict(class_counts)}")
print(f"Class weights: {class_weights}")

# Use weighted sampler for training (only for training indices)
train_labels = [all_labels[i] for i in train_dataset.indices]
sample_weights = [class_weights[label] for label in train_labels]
sampler = WeightedRandomSampler(sample_weights, len(sample_weights))

# Create data loaders (use num_workers=0 to avoid multiprocessing issues on macOS)
train_loader = DataLoader(train_dataset_aug, batch_size=BATCH_SIZE, sampler=sampler, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

print(f"‚úÖ Data preprocessing completed")
print(f"  ‚Ä¢ Training samples: {len(train_dataset_aug)}")
print(f"  ‚Ä¢ Test samples: {len(test_dataset)}")
print(f"  ‚Ä¢ Class weights computed: {len(class_weights)} classes")


Class distribution: {np.int64(1): 87555, np.int64(6): 11694, np.int64(3): 22515, np.int64(2): 49820, np.int64(0): 11248, np.int64(8): 3483, np.int64(7): 10842, np.int64(11): 2328, np.int64(4): 10506, np.int64(9): 5329, np.int64(10): 3689, np.int64(5): 11474, np.int64(12): 1169, np.int64(14): 1253, np.int64(13): 2038}
Class weights: tensor([ 1.3925,  0.1789,  0.3144,  0.6957,  1.4908,  1.3651,  1.3394,  1.4446,
         4.4969,  2.9392,  4.2458,  6.7280, 13.3985,  7.6854, 12.5003],
       device='mps:0')
‚úÖ Data preprocessing completed
  ‚Ä¢ Training samples: 187591
  ‚Ä¢ Test samples: 47352
  ‚Ä¢ Class weights computed: 15 classes


## 3. Model Architecture

Define the advanced CNN architecture with ResNet blocks and attention mechanisms.


In [11]:
# Advanced CNN Architecture Components

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dropout_prob=0.3):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding=kernel_size//2)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size, 1, padding=kernel_size//2)
        self.bn2 = nn.BatchNorm1d(out_channels)
        
        # Spatial Dropout (Dropout1d) –µ—Ñ–µ–∫—Ç–∏–≤–Ω—ñ—à–∏–π –¥–ª—è —á–∞—Å–æ–≤–∏—Ö —Ä—è–¥—ñ–≤
        self.dropout = nn.Dropout1d(p=dropout_prob)
        
        # Skip connection
        self.skip = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.skip = nn.Sequential(
                nn.Conv1d(in_channels, out_channels, 1, stride),
                nn.BatchNorm1d(out_channels)
            )
    
    def forward(self, x):
        residual = self.skip(x)
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out)
        out = self.dropout(out) # Dropout –ø—ñ—Å–ª—è –∞–∫—Ç–∏–≤–∞—Ü—ñ—ó
        
        out = self.conv2(out)
        out = self.bn2(out)
        
        out += residual
        out = F.relu(out)
        return out

class AttentionBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.channels = channels
        self.attention = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Conv1d(channels, channels//4, 1),
            nn.ReLU(),
            nn.Conv1d(channels//4, channels, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        att = self.attention(x)
        return x * att

class AdvancedCNN1D(nn.Module):
    def __init__(self, in_channels, n_classes):
        super().__init__()
        
        # 2. –ó–º–µ–Ω—à—É—î–º–æ –ø–æ—á–∞—Ç–∫–æ–≤—É –∫—ñ–ª—å–∫—ñ—Å—Ç—å —Ñ—ñ–ª—å—Ç—Ä—ñ–≤ (Model Thinning)
        # –ë—É–ª–æ 64, —Å—Ç–∞–ª–æ 32. –¶–µ –∑–º–µ–Ω—à–∏—Ç—å –∫—ñ–ª—å–∫—ñ—Å—Ç—å –ø–∞—Ä–∞–º–µ—Ç—Ä—ñ–≤.
        base_filters = 32 
        
        self.initial_conv = nn.Sequential(
            nn.Conv1d(in_channels, base_filters, kernel_size=7, padding=3),
            nn.BatchNorm1d(base_filters),
            nn.ReLU(),
            nn.MaxPool1d(2)
        )
        
        # Stage 1
        self.res_block1 = ResidualBlock(base_filters, base_filters, dropout_prob=0.2)
        self.attention1 = AttentionBlock(base_filters)
        self.pool1 = nn.MaxPool1d(2)
        
        # Stage 2 (Filters: 32 -> 64)
        self.res_block2 = ResidualBlock(base_filters, base_filters*2, stride=2, dropout_prob=0.3)
        self.attention2 = AttentionBlock(base_filters*2)
        self.pool2 = nn.MaxPool1d(2)
        
        # Stage 3 (Filters: 64 -> 128) - –ú–∞–∫—Å–∏–º—É–º 128 –∫–∞–Ω–∞–ª—ñ–≤ –∑–∞–º—ñ—Å—Ç—å 256
        self.res_block3 = ResidualBlock(base_filters*2, base_filters*4, stride=2, dropout_prob=0.4)
        self.attention3 = AttentionBlock(base_filters*4)
        self.pool3 = nn.MaxPool1d(2)
        
        # Global pooling and classification
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        
        # –§—ñ–Ω–∞–ª—å–Ω–∏–π Dropout –∑–±—ñ–ª—å—à–µ–Ω–æ –¥–æ 0.6
        self.dropout = nn.Dropout(0.6) 
        self.fc = nn.Linear(base_filters*4, n_classes)
        
    def forward(self, x):
        x = self.initial_conv(x)
        
        x = self.res_block1(x)
        x = self.attention1(x)
        x = self.pool1(x)
        
        x = self.res_block2(x)
        x = self.attention2(x)
        x = self.pool2(x)
        
        x = self.res_block3(x)
        x = self.attention3(x)
        x = self.pool3(x)
        
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = self.fc(x)
        return x

print("‚úÖ Advanced CNN architecture defined")


‚úÖ Advanced CNN architecture defined


## 4. Hyperparameter Configuration

Set up loss functions, optimizer, and training parameters.


In [None]:
# Focal Loss for handling class imbalance
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
    
    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

# Initialize model
model = AdvancedCNN1D(n_channels, n_classes).to(DEVICE)

# 3. –ê–≥—Ä–µ—Å–∏–≤–Ω–∞ —Ä–µ–≥—É–ª—è—Ä–∏–∑–∞—Ü—ñ—è –≤ –æ–ø—Ç–∏–º—ñ–∑–∞—Ç–æ—Ä—ñ
# Weight Decay –∑–±—ñ–ª—å—à–µ–Ω–æ –∑ 1e-4 –¥–æ 0.01 (–∞–±–æ 1e-2)
# –¶–µ –∑–º—É—à—É—î –≤–∞–≥–∏ –∑–∞–ª–∏—à–∞—Ç–∏—Å—è –º–∞–ª–∏–º–∏, —â–æ –∑–º–µ–Ω—à—É—î –ø–µ—Ä–µ–Ω–∞–≤—á–∞–Ω–Ω—è
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01)

# Scheduler –∑–∞–ª–∏—à–∞—î–º–æ, –∞–ª–µ –º–æ–∂–Ω–∞ –∑–±—ñ–ª—å—à–∏—Ç–∏ patience
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5)

criterion = FocalLoss(alpha=1, gamma=2)

print("üìã Hyperparameter Configuration:")
print(f"  ‚Ä¢ Model: Advanced CNN with ResNet blocks")
print(f"  ‚Ä¢ Regularization: Spatial Dropout inside blocks + High Weight Decay")
print(f"  ‚Ä¢ Loss function: Focal Loss (alpha=1, gamma=2)")
print(f"  ‚Ä¢ Optimizer: AdamW (lr={LR}, weight_decay=0.01)")
print(f"  ‚Ä¢ Scheduler: ReduceLROnPlateau")
print(f"  ‚Ä¢ Batch size: {BATCH_SIZE}")
print(f"  ‚Ä¢ Epochs: {EPOCHS}")
print(f"  ‚Ä¢ Classes: {n_classes}")
print(f"  ‚Ä¢ Channels: {n_channels}")

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"  ‚Ä¢ Total parameters: {total_params:,}")
print(f"  ‚Ä¢ Model size: {total_params * 4 / 1024 / 1024:.2f} MB")


üìã Hyperparameter Configuration:
  ‚Ä¢ Model: Advanced CNN with ResNet blocks
  ‚Ä¢ Loss function: Focal Loss (alpha=1, gamma=2)
  ‚Ä¢ Optimizer: AdamW (lr=0.001, weight_decay=1e-4)
  ‚Ä¢ Scheduler: ReduceLROnPlateau
  ‚Ä¢ Batch size: 128
  ‚Ä¢ Epochs: 50
  ‚Ä¢ Classes: 15
  ‚Ä¢ Channels: 3
  ‚Ä¢ Total parameters: 486,975
  ‚Ä¢ Model size: 1.86 MB


## 5. Model Training

Train the CNN with advanced techniques including early stopping and gradient clipping.


In [10]:
# Advanced training with early stopping
train_losses, val_losses, train_accs, val_accs = [], [], [], []
best_val_acc = 0
patience = 20
patience_counter = 0

print(f"\nüéØ Training for {EPOCHS} epochs...")
start_time = time.time()

for epoch in range(1, EPOCHS+1):
    # Training
    print(f"Training epoch {epoch}/{EPOCHS}")
    if DEVICE.type == "mps":
        print(f"M4 Pro GPU memory allocated: {torch.mps.current_allocated_memory() / 1024**3:.2f} GB")
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    for xb, yb in train_loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        optimizer.zero_grad()
        out = model(xb)
        loss = criterion(out, yb)
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        running_loss += loss.item() * xb.size(0)
        preds = out.argmax(dim=1)
        correct += (preds == yb).sum().item()
        total += yb.size(0)
    
    train_losses.append(running_loss/total)
    train_accs.append(correct/total)

    # Validation
    model.eval()
    running_loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for xb, yb in test_loader:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            out = model(xb)
            loss = criterion(out, yb)
            running_loss += loss.item() * xb.size(0)
            preds = out.argmax(dim=1)
            correct += (preds == yb).sum().item()
            total += yb.size(0)
    
    val_losses.append(running_loss/total)
    val_accs.append(correct/total)
    
    # Learning rate scheduling
    scheduler.step(val_accs[-1])
    
    # Early stopping
    if val_accs[-1] > best_val_acc:
        best_val_acc = val_accs[-1]
        patience_counter = 0
        # Save best model
        torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, 'best_model.pth'))
    else:
        patience_counter += 1
    
    print(f"Epoch {epoch}/{EPOCHS} | Train Acc: {train_accs[-1]:.3f} | Val Acc: {val_accs[-1]:.3f} | Best: {best_val_acc:.3f}")
    
    if patience_counter >= patience:
        print(f"Early stopping at epoch {epoch}")
        break

# Load best model
model.load_state_dict(torch.load(os.path.join(OUTPUT_DIR, 'best_model.pth')))

training_time = time.time() - start_time
print(f"\n‚è±Ô∏è Training completed in {training_time:.2f} seconds")
print(f"üèÜ Best validation accuracy: {best_val_acc:.4f}")



üéØ Training for 50 epochs...
Training epoch 1/50
M4 Pro GPU memory allocated: 0.00 GB
Epoch 1/50 | Train Acc: 0.592 | Val Acc: 0.221 | Best: 0.221
Training epoch 2/50
M4 Pro GPU memory allocated: 0.02 GB
Epoch 2/50 | Train Acc: 0.833 | Val Acc: 0.239 | Best: 0.239
Training epoch 3/50
M4 Pro GPU memory allocated: 0.02 GB
Epoch 3/50 | Train Acc: 0.902 | Val Acc: 0.280 | Best: 0.280
Training epoch 4/50
M4 Pro GPU memory allocated: 0.02 GB
Epoch 4/50 | Train Acc: 0.933 | Val Acc: 0.301 | Best: 0.301
Training epoch 5/50
M4 Pro GPU memory allocated: 0.02 GB
Epoch 5/50 | Train Acc: 0.951 | Val Acc: 0.301 | Best: 0.301
Training epoch 6/50
M4 Pro GPU memory allocated: 0.02 GB
Epoch 6/50 | Train Acc: 0.963 | Val Acc: 0.289 | Best: 0.301
Training epoch 7/50
M4 Pro GPU memory allocated: 0.02 GB
Epoch 7/50 | Train Acc: 0.971 | Val Acc: 0.291 | Best: 0.301
Training epoch 8/50
M4 Pro GPU memory allocated: 0.02 GB
Epoch 8/50 | Train Acc: 0.977 | Val Acc: 0.276 | Best: 0.301
Training epoch 9/50
M4 P

KeyboardInterrupt: 

## 6. Model Evaluation

Comprehensive evaluation with detailed metrics and analysis.


In [9]:
# Comprehensive evaluation
model.eval()
y_true, y_pred = [], []
probs_list = []

with torch.no_grad():
    for xb, yb in test_loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        out = model(xb)
        probs = nn.functional.softmax(out, dim=1)
        max_probs, _ = probs.max(1)
        probs_list.append(max_probs.cpu().numpy())

        preds = out.argmax(dim=1)
        y_true.append(yb.cpu().numpy())
        y_pred.append(preds.cpu().numpy())

y_true = np.concatenate(y_true)
y_pred = np.concatenate(y_pred)
probs_all = np.concatenate(probs_list)

# Comprehensive metrics
acc = accuracy_score(y_true, y_pred)
f1_macro = f1_score(y_true, y_pred, average='macro')
f1_weighted = f1_score(y_true, y_pred, average='weighted')

print(f"\n=== COMPREHENSIVE EVALUATION ===")
print(f"Test Accuracy: {acc:.4f}")
print(f"F1-Score (Macro): {f1_macro:.4f}")
print(f"F1-Score (Weighted): {f1_weighted:.4f}")
print("\nDetailed Classification Report:\n")
print(classification_report(y_true, y_pred, digits=4, target_names=[f'Age_{c}' for c in dataset.le.classes_]))
cm = confusion_matrix(y_true, y_pred)

print("‚úÖ Model evaluation completed")



=== COMPREHENSIVE EVALUATION ===
Test Accuracy: 0.9994
F1-Score (Macro): 0.9995
F1-Score (Weighted): 0.9994

Detailed Classification Report:

              precision    recall  f1-score   support

     Age_1.0     0.9991    1.0000    0.9996      2242
     Age_2.0     0.9997    0.9993    0.9995     17409
     Age_3.0     0.9991    0.9989    0.9990     10026
     Age_4.0     0.9991    0.9998    0.9994      4453
     Age_5.0     1.0000    0.9995    0.9998      2108
     Age_6.0     0.9983    0.9991    0.9987      2315
     Age_7.0     0.9987    1.0000    0.9994      2367
     Age_8.0     0.9991    1.0000    0.9995      2167
     Age_9.0     1.0000    1.0000    1.0000       697
    Age_10.0     1.0000    1.0000    1.0000      1098
    Age_11.0     1.0000    1.0000    1.0000       748
    Age_12.0     1.0000    0.9958    0.9979       478
    Age_13.0     1.0000    1.0000    1.0000       230
    Age_14.0     1.0000    1.0000    1.0000       412
    Age_15.0     1.0000    1.0000    1.0000   

## 7. Visualization and Analysis

Generate comprehensive visualizations for training analysis and model performance.


In [10]:
# Generate comprehensive plots
sns.set(style="whitegrid")

# 1) Training curves with learning rate
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
axes[0,0].plot(train_losses, label="Train Loss")
axes[0,0].plot(val_losses, label="Val Loss")
axes[0,0].set_xlabel("Epoch"); axes[0,0].set_ylabel("Loss")
axes[0,0].set_title("Loss Curves"); axes[0,0].legend(); axes[0,0].grid(True)

axes[0,1].plot(train_accs, label="Train Acc")
axes[0,1].plot(val_accs, label="Val Acc")
axes[0,1].set_xlabel("Epoch"); axes[0,1].set_ylabel("Accuracy")
axes[0,1].set_title("Accuracy Curves"); axes[0,1].legend(); axes[0,1].grid(True)

# Class distribution
class_dist = Counter(y_true)
axes[1,0].bar(range(len(class_dist)), list(class_dist.values()))
axes[1,0].set_xlabel("Age Group"); axes[1,0].set_ylabel("Count")
axes[1,0].set_title("Test Set Class Distribution")
axes[1,0].set_xticks(range(len(class_dist)))
axes[1,0].set_xticklabels([f'Age_{c}' for c in sorted(class_dist.keys())], rotation=45)

# Prediction confidence
axes[1,1].hist(probs_all, bins=20, color="green", alpha=0.7)
axes[1,1].set_xlabel("Max Predicted Probability")
axes[1,1].set_ylabel("Count")
axes[1,1].set_title("Prediction Confidence Distribution")
axes[1,1].grid(True)

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, "training_analysis.png"))
plt.close()

print("‚úÖ Training analysis plot saved")


‚úÖ Training analysis plot saved


In [11]:
# 2) Enhanced confusion matrix
plt.figure(figsize=(12, 10))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
            xticklabels=[f'Age_{c}' for c in dataset.le.classes_], 
            yticklabels=[f'Age_{c}' for c in dataset.le.classes_])
plt.xlabel("Predicted Age Group"); plt.ylabel("True Age Group")
plt.title(f"Confusion Matrix\nAccuracy: {acc:.3f} | F1-Macro: {f1_macro:.3f} | F1-Weighted: {f1_weighted:.3f}")
plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, "confusion_matrix.png"))
plt.close()

print("‚úÖ Confusion matrix saved")


‚úÖ Confusion matrix saved


In [12]:
# 3) Per-class performance analysis
per_class_acc = []
per_class_f1 = []
for i in range(n_classes):
    idx = (y_true == i)
    if idx.any():
        acc_i = (y_pred[idx] == y_true[idx]).mean()
        # For multiclass, we need to use a different approach for per-class F1
        # Create binary labels for this class vs all others
        y_true_binary = (y_true == i).astype(int)
        y_pred_binary = (y_pred == i).astype(int)
        f1_i = f1_score(y_true_binary, y_pred_binary, zero_division=0)
    else:
        acc_i = 0.0
        f1_i = 0.0
    per_class_acc.append(acc_i)
    per_class_f1.append(f1_i)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
x_pos = range(len(dataset.le.classes_))
ax1.bar(x_pos, per_class_acc, color="skyblue", alpha=0.7)
ax1.set_xlabel("Age Group"); ax1.set_ylabel("Accuracy")
ax1.set_title("Per-class Accuracy")
ax1.set_xticks(x_pos)
ax1.set_xticklabels([f'Age_{c}' for c in dataset.le.classes_], rotation=45)
ax1.grid(axis='y')

ax2.bar(x_pos, per_class_f1, color="lightcoral", alpha=0.7)
ax2.set_xlabel("Age Group"); ax2.set_ylabel("F1-Score")
ax2.set_title("Per-class F1-Score")
ax2.set_xticks(x_pos)
ax2.set_xticklabels([f'Age_{c}' for c in dataset.le.classes_], rotation=45)
ax2.grid(axis='y')

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, "per_class_performance.png"))
plt.close()

print("‚úÖ Per-class performance analysis saved")


‚úÖ Per-class performance analysis saved


In [13]:
# 4) Class imbalance analysis
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# True vs Predicted distribution
ax1.hist(y_true, bins=len(dataset.le.classes_), alpha=0.7, label="True", color="blue")
ax1.hist(y_pred, bins=len(dataset.le.classes_), alpha=0.7, label="Predicted", color="red")
ax1.set_xlabel("Age Group"); ax1.set_ylabel("Count")
ax1.set_title("True vs Predicted Distribution")
ax1.legend(); ax1.grid(True)

# Class weights visualization
class_weights_np = class_weights.cpu().numpy()
ax2.bar(range(len(class_weights_np)), class_weights_np, color="orange", alpha=0.7)
ax2.set_xlabel("Age Group"); ax2.set_ylabel("Class Weight")
ax2.set_title("Computed Class Weights")
ax2.set_xticks(range(len(dataset.le.classes_)))
ax2.set_xticklabels([f'Age_{c}' for c in dataset.le.classes_], rotation=45)
ax2.grid(axis='y')

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, "class_imbalance_analysis.png"))
plt.close()

print("‚úÖ Class imbalance analysis saved")


‚úÖ Class imbalance analysis saved


In [14]:
# 5) Sample ECG signals with predictions
plt.figure(figsize=(15, 10))
sample_indices = np.random.choice(len(test_dataset), 6, replace=False)
for i, idx in enumerate(sample_indices):
    xb, yb = test_dataset[idx]
    xb_numpy = xb.cpu().numpy()  # Keep numpy version for plotting
    y_true_sample = dataset.le.inverse_transform([yb.item()])[0]
    with torch.no_grad():
        # Use the original tensor for model inference
        out = model(xb.unsqueeze(0).to(DEVICE))
        pred_label = dataset.le.inverse_transform([out.argmax(1).item()])[0]
        confidence = torch.softmax(out, dim=1).max().item()
    
    plt.subplot(3, 2, i+1)
    plt.plot(xb_numpy.T, alpha=0.7)  # plot all channels
    plt.title(f"True: Age_{y_true_sample} | Pred: Age_{pred_label} | Conf: {confidence:.3f}")
    plt.xlabel("Time"); plt.ylabel("Amplitude")
    plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, "sample_ecg_predictions.png"))
plt.close()

print("‚úÖ Sample ECG predictions saved")


‚úÖ Sample ECG predictions saved


In [15]:
# 6) Model architecture visualization
plt.figure(figsize=(12, 8))
plt.text(0.1, 0.9, "Advanced CNN Architecture:", fontsize=16, fontweight='bold')
plt.text(0.1, 0.8, "‚Ä¢ Residual Blocks with Skip Connections", fontsize=12)
plt.text(0.1, 0.75, "‚Ä¢ Attention Mechanisms for Feature Selection", fontsize=12)
plt.text(0.1, 0.7, "‚Ä¢ Batch Normalization and Dropout", fontsize=12)
plt.text(0.1, 0.65, "‚Ä¢ Focal Loss for Class Imbalance", fontsize=12)
plt.text(0.1, 0.6, "‚Ä¢ Weighted Random Sampling", fontsize=12)
plt.text(0.1, 0.55, "‚Ä¢ Learning Rate Scheduling", fontsize=12)
plt.text(0.1, 0.5, "‚Ä¢ Early Stopping", fontsize=12)
plt.text(0.1, 0.4, f"Final Performance:", fontsize=14, fontweight='bold')
plt.text(0.1, 0.35, f"‚Ä¢ Accuracy: {acc:.4f}", fontsize=12)
plt.text(0.1, 0.3, f"‚Ä¢ F1-Macro: {f1_macro:.4f}", fontsize=12)
plt.text(0.1, 0.25, f"‚Ä¢ F1-Weighted: {f1_weighted:.4f}", fontsize=12)
plt.text(0.1, 0.2, f"‚Ä¢ Classes: {n_classes}", fontsize=12)
plt.text(0.1, 0.15, f"‚Ä¢ Channels: {n_channels}", fontsize=12)
plt.axis('off')
plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, "model_summary.png"))
plt.close()

print("‚úÖ Model summary saved")

# Final summary
print("\nüéâ CNN training completed!")
print(f"üìÅ Results saved to: {OUTPUT_DIR}/")
print("üìä Generated files:")
print("  ‚Ä¢ training_analysis.png")
print("  ‚Ä¢ confusion_matrix.png")
print("  ‚Ä¢ per_class_performance.png")
print("  ‚Ä¢ class_imbalance_analysis.png")
print("  ‚Ä¢ sample_ecg_predictions.png")
print("  ‚Ä¢ model_summary.png")
print("  ‚Ä¢ best_model.pth")

print(f"\nüìà Final Metrics:")
print(f"  ‚Ä¢ Test Accuracy: {acc:.4f}")
print(f"  ‚Ä¢ F1-Score (Macro): {f1_macro:.4f}")
print(f"  ‚Ä¢ F1-Score (Weighted): {f1_weighted:.4f}")
print(f"  ‚Ä¢ Training time: {training_time:.2f} seconds")
print(f"  ‚Ä¢ Total parameters: {total_params:,}")
print(f"  ‚Ä¢ Best validation accuracy: {best_val_acc:.4f}")


‚úÖ Model summary saved

üéâ CNN training completed!
üìÅ Results saved to: ./ecg_cnn_outputs/
üìä Generated files:
  ‚Ä¢ training_analysis.png
  ‚Ä¢ confusion_matrix.png
  ‚Ä¢ per_class_performance.png
  ‚Ä¢ class_imbalance_analysis.png
  ‚Ä¢ sample_ecg_predictions.png
  ‚Ä¢ model_summary.png
  ‚Ä¢ best_model.pth

üìà Final Metrics:
  ‚Ä¢ Test Accuracy: 0.9994
  ‚Ä¢ F1-Score (Macro): 0.9995
  ‚Ä¢ F1-Score (Weighted): 0.9994
  ‚Ä¢ Training time: 122635.35 seconds
  ‚Ä¢ Total parameters: 486,975
  ‚Ä¢ Best validation accuracy: 0.9996
