# Class Distribution Analysis

Analyze the distribution of LOS bin classes (0-9) in ECG datasets.


In [None]:
# Setup
import sys
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

# Add project root to path
project_root = Path().resolve().parent.parent
sys.path.insert(0, str(project_root))

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 11


In [None]:
# Import required modules
from src.data.ecg.ecg_loader import build_npy_index
from src.data.ecg.ecg_dataset import extract_subject_id_from_path
from src.data.labeling import load_icustays, ICUStayMapper, los_to_bin
from collections import Counter

# Define function
def analyze_class_distribution(
    data_dir: str,
    icustays_path: str = None,
    dataset_name: str = "dataset"
):
    """Analyze class distribution in ECG dataset."""
    print("="*80)
    print(f"Class Distribution Analysis: {dataset_name}")
    print("="*80)
    
    # Load ICU stays
    if icustays_path is None:
        icustays_path = Path(data_dir).parent.parent / "labeling" / "labels_csv" / "icustays.csv"
        if not Path(icustays_path).exists():
            icustays_path = Path("data/labeling/labels_csv/icustays.csv")
    
    icustays_path = Path(icustays_path)
    if not icustays_path.exists():
        raise FileNotFoundError(f"icustays.csv not found at: {icustays_path}")
    
    print(f"\nLoading ICU stays from: {icustays_path}")
    icustays_df = load_icustays(str(icustays_path))
    icu_mapper = ICUStayMapper(icustays_df)
    print(f"Loaded {len(icustays_df)} ICU stays")
    
    # Find all ECG files
    print(f"\nScanning ECG files in: {data_dir}")
    records = build_npy_index(data_dir=data_dir)
    print(f"Found {len(records)} ECG files")
    
    # Analyze class distribution
    class_counts = Counter()
    matched_count = 0
    unmatched_count = 0
    
    print("\nAnalyzing class labels...")
    for i, record in enumerate(records):
        if (i + 1) % 10000 == 0:
            print(f"  Processed {i + 1}/{len(records)} files...")
        
        base_path = record["base_path"]
        try:
            subject_id = extract_subject_id_from_path(base_path)
            subject_stays = icu_mapper.icustays_df[
                icu_mapper.icustays_df['subject_id'] == subject_id
            ]
            
            if len(subject_stays) == 0:
                unmatched_count += 1
                continue
            
            first_stay = subject_stays.iloc[0]
            ecg_time = pd.to_datetime(first_stay['intime'])
            stay_id = icu_mapper.map_ecg_to_stay(subject_id, ecg_time)
            
            if stay_id is None:
                unmatched_count += 1
                continue
            
            los_days = icu_mapper.get_los(stay_id)
            if los_days is None:
                unmatched_count += 1
                continue
            
            class_label = los_to_bin(los_days)
            class_counts[class_label] += 1
            matched_count += 1
        except (ValueError, KeyError):
            unmatched_count += 1
            continue
    
    print("\n" + "="*80)
    print("CLASS DISTRIBUTION RESULTS")
    print("="*80)
    print(f"\nTotal ECG files: {len(records):,}")
    print(f"Matched: {matched_count:,} ({matched_count/len(records)*100:.2f}%)")
    print(f"Unmatched: {unmatched_count:,} ({unmatched_count/len(records)*100:.2f}%)")
    
    if matched_count == 0:
        print("\n⚠️  WARNING: No matched samples found!")
        return None
    
    print("\n" + "-"*80)
    print("Class Distribution (LOS Bins):")
    print("-"*80)
    print(f"{'Class':<8} {'LOS Range':<20} {'Count':<12} {'Percentage':<12} {'Cumulative %':<12}")
    print("-"*80)
    
    total = matched_count
    cumulative = 0
    
    for class_idx in range(10):
        count = class_counts.get(class_idx, 0)
        percentage = (count / total * 100) if total > 0 else 0
        cumulative += percentage
        los_range = "[9, +inf) days" if class_idx == 9 else f"[{class_idx}, {class_idx+1}) days"
        print(f"{class_idx:<8} {los_range:<20} {count:<12,} {percentage:<12.2f}% {cumulative:<12.2f}%")
    
    print("-"*80)
    print(f"{'Total':<8} {'':<20} {total:<12,} {100.0:<12.2f}% {100.0:<12.2f}%")
    
    print("\n" + "-"*80)
    print("Statistics:")
    print("-"*80)
    
    most_frequent = class_counts.most_common(1)[0] if class_counts else None
    least_frequent = min(class_counts.items(), key=lambda x: x[1]) if class_counts else None
    
    if most_frequent:
        print(f"Most frequent class: {most_frequent[0]} ({most_frequent[1]:,} samples, {most_frequent[1]/total*100:.2f}%)")
    if least_frequent:
        print(f"Least frequent class: {least_frequent[0]} ({least_frequent[1]:,} samples, {least_frequent[1]/total*100:.2f}%)")
    
    if class_counts:
        max_count = max(class_counts.values())
        min_count = min(class_counts.values())
        imbalance_ratio = max_count / min_count if min_count > 0 else float('inf')
        print(f"Class imbalance ratio (max/min): {imbalance_ratio:.2f}x")
    
    print("\n" + "="*80)
    return class_counts


In [None]:
# Configure paths
data_dir = "data/all_icu_ecgs_P1"  # Adjust as needed
icustays_path = "data/labeling/labels_csv/icustays.csv"
dataset_name = "All ICU ECGs"


In [None]:
# Run analysis
class_counts = analyze_class_distribution(
    data_dir=data_dir,
    icustays_path=icustays_path,
    dataset_name=dataset_name
)


In [None]:
# Visualize class distribution
if class_counts:
    classes = sorted(class_counts.keys())
    counts = [class_counts[c] for c in classes]
    total = sum(counts)
    percentages = [c / total * 100 for c in counts]
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    
    # Bar chart
    bars = ax1.bar(classes, counts, color='steelblue', alpha=0.7)
    ax1.set_xlabel("LOS Class", fontsize=12)
    ax1.set_ylabel("Count", fontsize=12)
    ax1.set_title(f"Class Distribution - {dataset_name}\nTotal: {total:,} samples", fontsize=14, fontweight='bold')
    ax1.grid(axis='y', alpha=0.3)
    
    # Add value labels on bars
    for bar, count, pct in zip(bars, counts, percentages):
        height = bar.get_height()
        ax1.text(bar.get_x() + bar.get_width()/2., height,
                f'{count:,}\n({pct:.1f}%)',
                ha='center', va='bottom', fontsize=9)
    
    # Pie chart
    labels = [f"Class {c}" for c in classes]
    ax2.pie(counts, labels=labels, autopct='%1.1f%%', startangle=90)
    ax2.set_title(f"Class Distribution (Pie Chart)", fontsize=14, fontweight='bold')
    
    plt.tight_layout()
    plt.show()


In [None]:
# Calculate and display statistics
if class_counts:
    from collections import Counter
    total = sum(class_counts.values())
    max_count = max(class_counts.values())
    min_count = min([c for c in class_counts.values() if c > 0])
    imbalance_ratio = max_count / min_count if min_count > 0 else float('inf')
    
    print("="*60)
    print("STATISTICS")
    print("="*60)
    print(f"Total samples: {total:,}")
    print(f"Most frequent class: {class_counts.most_common(1)[0][0]} ({class_counts.most_common(1)[0][1]:,} samples)")
    print(f"Least frequent class: {min(class_counts.items(), key=lambda x: x[1])[0]} ({min(class_counts.items(), key=lambda x: x[1])[1]:,} samples)")
    print(f"Class imbalance ratio: {imbalance_ratio:.2f}x")
    print("="*60)
