# ECG 12-Lead Reconstruction from 3 Leads

**Paper**: [AI-enhanced reconstruction of the 12-lead electrocardiogram via 3-leads with accurate clinical assessment](https://www.nature.com/articles/s41746-024-01193-7)

Mason, F., Pandey, A.C., Gadaleta, M. et al. npj Digit. Med. 7, 201 (2024)

## Key Findings from Paper:
- **Input**: 3 leads (I, II, V3) are sufficient to reconstruct full 12-lead ECG
- **Output**: Reconstructs precordial leads (V1-V6) with high correlation
- **Performance**: AUC = 0.95 for acute MI detection
- **Clinical validation**: 81.4% accuracy in identifying STEMI features

## Pipeline Overview:
1. **Setup** - Import libraries and configure settings
2. **Data Processing** - Prepare and split the dataset
3. **Training** - Train the reconstruction model
4. **Testing** - Evaluate model performance
5. **Visualization** - Plot results and examples

## 1. Setup and Configuration

In [1]:
import os
import sys
import torch

# Check GPU availability
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    DEVICE = 'cuda:0'
else:
    print("Using CPU")
    DEVICE = 'cpu'

PyTorch version: 2.9.0+cpu
CUDA available: False
Using CPU


In [2]:
# ============================================================
# CONFIGURATION - Based on Paper (Nature s41746-024-01193-7)
# ============================================================

CONFIG = {
    # Device
    'device': DEVICE,
    
    # ---------------------------------------------------------
    # INPUT/OUTPUT LEADS (Paper: I + II + V3 → V1-V6)
    # ---------------------------------------------------------
    'input_leads': 'limb+v3',     # Paper uses: I, II, V3 (3 leads)
    'output_leads': 'precordial', # Paper reconstructs: V1, V2, V3, V4, V5, V6 (6 leads)
    
    # Dataset
    'dataset': 'infarct+noninfarct',  # Paper focuses on MI detection
    'data_size': 'max',               # Paper used ~600,000 ECGs
    
    # ---------------------------------------------------------
    # NETWORK ARCHITECTURE (Paper: ResCNN blocks)
    # ---------------------------------------------------------
    # Input network: 3 ResCNN blocks (one per input lead)
    # Middle network: 1 ResCNN block (aggregates features)  
    # Output network: 6 ResCNN blocks (one per output lead)
    'input_channel': 32,
    'middle_channel': 32,
    'output_channel': 32,
    'input_depth': 3,      # 3 ResCNN blocks for input
    'middle_depth': 2,     # Middle processing blocks
    'output_depth': 3,     # 3 ResCNN blocks for output
    'input_kernel': 17,
    'middle_kernel': 17,
    'output_kernel': 17,
    'use_residual': 'true',  # Paper uses residual connections
    
    # ---------------------------------------------------------
    # TRAINING PARAMETERS
    # ---------------------------------------------------------
    'epoch_num': 200,
    'batch_size': 16,
    'optimizer': 'adam',
    'learning_rate': 0.000003,  # 3e-6
    'weight_decay': 0.001,
    'momentum': 0.9,
    'nesterov': 'true',
    'prioritize_percent': 0,
    'prioritize_size': 0,
}

print("=" * 60)
print("CONFIGURATION (Based on Paper)")
print("=" * 60)
print(f"\nInput leads:  {CONFIG['input_leads']} → I, II, V3 (3 leads)")
print(f"Output leads: {CONFIG['output_leads']} → V1, V2, V3, V4, V5, V6 (6 leads)")
print(f"\nNetwork: ResCNN with residual connections")
print(f"  - Input depth:  {CONFIG['input_depth']} blocks")
print(f"  - Middle depth: {CONFIG['middle_depth']} blocks")
print(f"  - Output depth: {CONFIG['output_depth']} blocks")
print(f"  - Channels: {CONFIG['input_channel']}")
print(f"  - Kernel size: {CONFIG['input_kernel']}")
print(f"\nTraining:")
print(f"  - Epochs: {CONFIG['epoch_num']}")
print(f"  - Batch size: {CONFIG['batch_size']}")
print(f"  - Learning rate: {CONFIG['learning_rate']}")
print(f"  - Optimizer: {CONFIG['optimizer']}")
print("=" * 60)

CONFIGURATION (Based on Paper)

Input leads:  limb+v3 → I, II, V3 (3 leads)
Output leads: precordial → V1, V2, V3, V4, V5, V6 (6 leads)

Network: ResCNN with residual connections
  - Input depth:  3 blocks
  - Middle depth: 2 blocks
  - Output depth: 3 blocks
  - Channels: 32
  - Kernel size: 17

Training:
  - Epochs: 200
  - Batch size: 16
  - Learning rate: 3e-06
  - Optimizer: adam


## 2. Initialize the Reconstruction Manager

In [3]:
from util_functions.general import get_parent_folder, get_data_classes, get_lead_keys
from training_functions.single_reconstruction_manager import ReconstructionManager

# Get settings
parent_folder = get_parent_folder()
data_classes = get_data_classes(CONFIG['dataset'])
sub_classes = []

# Show lead configuration
input_keys = get_lead_keys(CONFIG['input_leads'])
output_keys = get_lead_keys(CONFIG['output_leads'])
print(f"Input leads ({len(input_keys)}): {input_keys}")
print(f"Output leads ({len(output_keys)}): {output_keys}")
print(f"Data classes: {data_classes}")
print(f"Data folder: {parent_folder}")

ModuleNotFoundError: No module named 'util_functions'

In [None]:
# Create the Reconstruction Manager
manager = ReconstructionManager(
    parent_folder=parent_folder,
    device=CONFIG['device'],
    sub_classes=sub_classes,
    input_leads=CONFIG['input_leads'],
    output_leads=CONFIG['output_leads'],
    data_classes=data_classes,
    data_size=CONFIG['data_size'],
    input_channel=CONFIG['input_channel'],
    middle_channel=CONFIG['middle_channel'],
    output_channel=CONFIG['output_channel'],
    input_depth=CONFIG['input_depth'],
    middle_depth=CONFIG['middle_depth'],
    output_depth=CONFIG['output_depth'],
    input_kernel=CONFIG['input_kernel'],
    middle_kernel=CONFIG['middle_kernel'],
    output_kernel=CONFIG['output_kernel'],
    use_residual=CONFIG['use_residual'],
    epoch_num=CONFIG['epoch_num'],
    batch_size=CONFIG['batch_size'],
    prioritize_percent=CONFIG['prioritize_percent'],
    prioritize_size=CONFIG['prioritize_size'],
    optimizer=CONFIG['optimizer'],
    learning_rate=CONFIG['learning_rate'],
    weight_decay=CONFIG['weight_decay'],
    momentum=CONFIG['momentum'],
    nesterov=CONFIG['nesterov']
)

print("Reconstruction Manager initialized successfully!")

## 3. Training

Train the deep learning model to reconstruct 12-lead ECG from input leads.

In [None]:
# Reset/Initialize the model
print("Initializing model...")
manager.reset_model()
print("Model initialized!")

In [None]:
# Load training and validation datasets
print("Loading training and validation datasets...")
manager.load_dataset(train=True, valid=True)
print("Datasets loaded!")

In [None]:
# Train the model
print(f"Starting training for {CONFIG['epoch_num']} epochs...")
print("="*50)
manager.train()
print("="*50)
print("Training completed!")

In [None]:
# Release dataset memory
manager.release_dataset()
print("Dataset memory released.")

In [None]:
# Plot training statistics
print("Plotting training statistics...")
manager.plot_train_stats()
manager.plot_valid_stats()
print("Training plots saved!")

## 4. Testing

Evaluate the trained model on the test set.

In [None]:
# Load the trained model
print("Loading trained model...")
manager.load_model()
print("Model loaded!")

In [None]:
# Load test dataset
print("Loading test dataset...")
manager.load_dataset(test=True)
print("Test dataset loaded!")

In [None]:
# Run testing
print("Running tests...")
manager.test()
print("Testing completed!")

In [None]:
# Release test dataset memory
manager.release_dataset()
print("Test dataset memory released.")

In [None]:
# Plot test statistics
print("Plotting test statistics...")
manager.plot_test_stats(plot_sub_classes=sub_classes)
print("Test plots saved!")

## 5. Visualization

Visualize reconstruction examples and model performance.

In [None]:
# Load model for visualization
manager.load_model()

# Plot random reconstruction examples
print("Plotting random reconstruction examples...")
manager.plot_random_example(plot_format='png')
print("Random examples saved!")

In [None]:
# Plot error examples (worst reconstructions)
print("Plotting error examples...")
manager.load_test_stats()
manager.plot_error_example(plot_format='png')
print("Error examples saved!")

In [None]:
# Evaluate and plot model statistics
print("Computing model statistics...")
manager.compute_model_stats()
manager.plot_model_stats()
print("Model statistics saved!")

## 6. Quick Run (All in One)

Run the complete pipeline in one cell.

In [None]:
def run_full_pipeline(config):
    """Run the complete training and testing pipeline."""
    from util_functions.general import get_parent_folder, get_data_classes
    from training_functions.single_reconstruction_manager import ReconstructionManager
    
    # Initialize
    parent_folder = get_parent_folder()
    data_classes = get_data_classes(config['dataset'])
    
    manager = ReconstructionManager(
        parent_folder=parent_folder,
        device=config['device'],
        sub_classes=[],
        input_leads=config['input_leads'],
        output_leads=config['output_leads'],
        data_classes=data_classes,
        data_size=config['data_size'],
        input_channel=config['input_channel'],
        middle_channel=config['middle_channel'],
        output_channel=config['output_channel'],
        input_depth=config['input_depth'],
        middle_depth=config['middle_depth'],
        output_depth=config['output_depth'],
        input_kernel=config['input_kernel'],
        middle_kernel=config['middle_kernel'],
        output_kernel=config['output_kernel'],
        use_residual=config['use_residual'],
        epoch_num=config['epoch_num'],
        batch_size=config['batch_size'],
        prioritize_percent=config['prioritize_percent'],
        prioritize_size=config['prioritize_size'],
        optimizer=config['optimizer'],
        learning_rate=config['learning_rate'],
        weight_decay=config['weight_decay'],
        momentum=config['momentum'],
        nesterov=config['nesterov']
    )
    
    # Train
    print("=" * 50)
    print("TRAINING")
    print("=" * 50)
    manager.reset_model()
    manager.load_dataset(train=True, valid=True)
    manager.train()
    manager.release_dataset()
    manager.plot_train_stats()
    manager.plot_valid_stats()
    
    # Test
    print("\n" + "=" * 50)
    print("TESTING")
    print("=" * 50)
    manager.load_model()
    manager.load_dataset(test=True)
    manager.test()
    manager.release_dataset()
    manager.plot_test_stats()
    
    # Visualize
    print("\n" + "=" * 50)
    print("VISUALIZATION")
    print("=" * 50)
    manager.plot_random_example()
    manager.load_test_stats()
    manager.plot_error_example()
    
    print("\n" + "=" * 50)
    print("PIPELINE COMPLETED!")
    print("=" * 50)
    
    return manager

# Uncomment to run the full pipeline:
# manager = run_full_pipeline(CONFIG)

## 7. Alternative: Using Command Line Scripts

You can also run the scripts directly from command line:

In [None]:
# Run reconstruction training and testing via command line
# Uncomment to execute:

# !python single_reconstruction.py -device cuda:0 -input limb -output precordial -train -test -plot

In [None]:
# Run classification training
# Uncomment to execute:

# !python single_classification.py -device cuda:0 -input limb -train -test -plot

In [None]:
# Run combined reconstruction + classification
# Uncomment to execute:

# !python single_recon_classif.py -device cuda:0 -input limb -output precordial -train -test -plot

## Input/Output Lead Options Reference

| Configuration | Leads | Count |
|---------------|-------|-------|
| `limb` | I, II | 2 |
| `limb+v1` | I, II, V1 | 3 |
| `limb+v2` | I, II, V2 | 3 |
| `limb+v3` | I, II, V3 | 3 |
| `limb+v4` | I, II, V4 | 3 |
| `limb+v5` | I, II, V5 | 3 |
| `limb+v6` | I, II, V6 | 3 |
| `full_limb` | I, II, III, aVL, aVR, aVF | 6 |
| `precordial` | V1, V2, V3, V4, V5, V6 | 6 |
| `full` | I, II, V1, V2, V3, V4, V5, V6 | 8 |