# 01 - Exploratory Data Analysis (EDA)
Analyze the PTB-XL dataset structure and raw signals.

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import wfdb
import ast
import os
import sys

# Ensure we can import from src
sys.path.append('..')

# Configuration
DATA_DIR = '../data/ptb-xl'
CSV_PATH = os.path.join(DATA_DIR, 'ptbxl_database.csv')

if os.path.exists(CSV_PATH):
    print(f"Reading metadata from {CSV_PATH}...")
    df = pd.read_csv(CSV_PATH, index_col='ecg_id')
    display(df.head())
else:
    print(f"File not found: {CSV_PATH}. Please download the dataset first.")

In [None]:
# 1. Label Distribution
# scp_codes are stored as string dictionaries. Let's parse them.
# We focus on the 5 Superclasses: NORM, MI, STTC, CD, HYP

classes = {'NORM': 0, 'MI': 0, 'STTC': 0, 'CD': 0, 'HYP': 0}

if 'df' in locals():
    for _, row in df.iterrows():
        try:
            data = ast.literal_eval(row['scp_codes'])
            for k in data:
                if k in classes:
                    classes[k] += 1
        except:
            pass

    plt.figure(figsize=(10, 5))
    plt.bar(classes.keys(), classes.values(), color='skyblue')
    plt.title("Distribution of Superclasses")
    plt.ylabel("Count")
    plt.show()

In [None]:
# 2. Visualize Random ECGs
def load_and_plot(index):
    if index not in df.index:
        print(f"Index {index} not found")
        return
        
    row = df.loc[index]
    filename = row.get('filename_hr', row.get('filename_lr')) 
    path = os.path.join(DATA_DIR, filename)
    
    try:
        # Load
        data, meta = wfdb.rdsamp(path)
    except Exception as e:
        print(f"Error loading {path}: {e}")
        return
    
    # Plot 12 Leads
    fig, axes = plt.subplots(12, 1, figsize=(10, 20), sharex=True)
    leads = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
    
    for i, ax in enumerate(axes):
        # Handle if data is shorter/longer or missing leads (robustness)
        if i < data.shape[1]:
            ax.plot(data[:, i], color='black', linewidth=0.8)
        ax.set_ylabel(leads[i])
        ax.grid(True, alpha=0.3)
        
    plt.suptitle(f"ECG: {index} | Labels: {row['scp_codes']}")
    plt.tight_layout()
    plt.show()

# Pick a random patient
import random
if 'df' in locals() and len(df) > 0:
    rand_id = random.choice(df.index)
    load_and_plot(rand_id)