# ECG Data Exploration

This notebook explores the PTB-XL dataset and visualizes ECG signals.

In [None]:
import sys
sys.path.append('..')

import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from scipy.stats import pearsonr

from data.data_modules import ECGReconstructionDataset

## Load Dataset

In [None]:
# Path to data directory (update as needed)
DATA_DIR = "../data/processed"

# Load datasets
train_dataset = ECGReconstructionDataset(
    os.path.join(DATA_DIR, 'train_input.npy'),
    os.path.join(DATA_DIR, 'train_target.npy')
)

val_dataset = ECGReconstructionDataset(
    os.path.join(DATA_DIR, 'val_input.npy'),
    os.path.join(DATA_DIR, 'val_target.npy')
)

## Visualize Sample ECGs

In [None]:
def plot_ecg(signals, title="12-Lead ECG", lead_names=None):
    """Plot 12-lead ECG"""
    if lead_names is None:
        lead_names = [
            'I', 'II', 'III', 'aVR', 'aVL', 'aVF', 
            'V1', 'V2', 'V3', 'V4', 'V5', 'V6'
        ]
    
    n_leads = signals.shape[0]
    n_samples = signals.shape[1]
    
    # Create time axis (assuming 500 Hz)
    time = np.arange(n_samples) / 500
    
    # Create figure
    fig, axes = plt.subplots(4, 3, figsize=(15, 12))
    axes = axes.flatten()
    
    for i in range(n_leads):
        ax = axes[i]
        ax.plot(time, signals[i], 'b-')
        ax.set_title(lead_names[i])
        ax.set_xlabel('Time (s)')
        ax.set_ylabel('Amplitude')
        ax.grid(True, linestyle='--', alpha=0.7)
    
    plt.tight_layout()
    plt.suptitle(title, fontsize=16)
    plt.subplots_adjust(top=0.92)
    plt.show()

In [None]:
# Get a sample from the dataset
sample_idx = 0
inputs, targets = train_dataset[sample_idx]

# Plot the full 12-lead ECG
plot_ecg(targets.numpy(), title="Full 12-Lead ECG")

## Visualize Input Leads Only (I, II, V4)

In [None]:
# Plot the input leads
plot_ecg(inputs.numpy(), title="Input Leads (I, II, V4)", lead_names=['I', 'II', 'V4'])

## Verify Limb Lead Calculations

In [None]:
import sys
sys.path.append('..')
from src.physics import calculate_limb_leads_numpy

# Get leads I and II
lead_I = targets[0].numpy()
lead_II = targets[1].numpy()

# Calculate limb leads
limb_leads = calculate_limb_leads_numpy(lead_I, lead_II)

# Extract original limb leads from targets
original_III = targets[2].numpy()
original_aVR = targets[3].numpy()
original_aVL = targets[4].numpy()
original_aVF = targets[5].numpy()

# Calculate correlations
corr_III = pearsonr(limb_leads['III'], original_III)[0]
corr_aVR = pearsonr(limb_leads['aVR'], original_aVR)[0]
corr_aVL = pearsonr(limb_leads['aVL'], original_aVL)[0]
corr_aVF = pearsonr(limb_leads['aVF'], original_aVF)[0]

print(f"Correlation between calculated and original limb leads:")
print(f"III: {corr_III:.6f}")
print(f"aVR: {corr_aVR:.6f}")
print(f"aVL: {corr_aVL:.6f}")
print(f"aVF: {corr_aVF:.6f}")

## Compare Original vs Calculated Limb Leads

In [None]:
# Time axis
time = np.arange(lead_I.shape[0]) / 500

# Plot comparison for Lead III
plt.figure(figsize=(12, 6))
plt.plot(time, original_III, 'b-', label='Original III')
plt.plot(time, limb_leads['III'], 'r--', label=f'Calculated III (r={corr_III:.4f})')
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')
plt.title('Lead III: Original vs Calculated from Einthoven\'s Law')
plt.grid(True, linestyle='--', alpha=0.7)
plt.legend()
plt.show()

## Check Data Statistics

In [None]:
# Get all inputs and targets as numpy arrays
all_inputs = train_dataset.inputs
all_targets = train_dataset.targets

# Calculate statistics
input_mean = np.mean(all_inputs, axis=(0, 2))
input_std = np.std(all_inputs, axis=(0, 2))
target_mean = np.mean(all_targets, axis=(0, 2))
target_std = np.std(all_targets, axis=(0, 2))

# Print statistics
print("Input leads statistics:")
for i, (mean, std) in enumerate(zip(input_mean, input_std)):
    lead_name = ['I', 'II', 'V4'][i]
    print(f"{lead_name}: Mean = {mean:.4f}, Std = {std:.4f}")

print("\nTarget leads statistics:")
lead_names = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
for i, (mean, std) in enumerate(zip(target_mean, target_std)):
    print(f"{lead_names[i]}: Mean = {mean:.4f}, Std = {std:.4f}")