# ECG-Level Mortality by LOS Class

Analyze mortality rate for each ECG, grouped by LOS class.


In [None]:
# Setup
import sys
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

# Add project root to path
project_root = Path().resolve().parent.parent
sys.path.insert(0, str(project_root))

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (14, 8)
plt.rcParams['font.size'] = 11


In [None]:
# Import required modules
from src.data.ecg.ecg_loader import build_npy_index
from src.data.ecg.ecg_dataset import extract_subject_id_from_path
from src.data.labeling import load_icustays, ICUStayMapper, los_to_bin
from typing import Dict, List
import pandas as pd
import numpy as np

# Define functions
def analyze_ecg_mortality_by_los_class(
    dataset_name: str,
    data_dir: Path,
    icustays_df: pd.DataFrame,
    admissions_df: pd.DataFrame
) -> Dict:
    """Analyze ECG mortality by LOS class."""
    print(f"\n{'='*80}")
    print(f"Analyzing dataset: {dataset_name}")
    print(f"{'='*80}")
    
    print(f"Scanning ECG files in: {data_dir}")
    records = build_npy_index(data_dir=str(data_dir))
    print(f"Found {len(records):,} ECG files")
    
    icu_mapper = ICUStayMapper(icustays_df)
    
    admissions_with_deathtime = admissions_df[['hadm_id', 'deathtime']].copy()
    admissions_with_deathtime['deathtime'] = pd.to_datetime(
        admissions_with_deathtime['deathtime'], errors='coerce'
    )
    hadm_to_deathtime = dict(zip(
        admissions_with_deathtime['hadm_id'],
        admissions_with_deathtime['deathtime']
    ))
    
    icustays_with_info = icustays_df.copy()
    icustays_with_info['intime'] = pd.to_datetime(icustays_with_info['intime'])
    icustays_with_info['outtime'] = pd.to_datetime(icustays_with_info['outtime'])
    stay_to_info = {}
    for _, row in icustays_with_info.iterrows():
        stay_to_info[row['stay_id']] = {
            'hadm_id': row['hadm_id'],
            'intime': row['intime'],
            'outtime': row['outtime'],
            'los': row['los']
        }
    
    class_stats = {i: {'died': 0, 'survived': 0} for i in range(10)}
    matched_count = 0
    unmatched_count = 0
    
    print("\nProcessing ECGs...")
    for i, record in enumerate(records):
        if (i + 1) % 10000 == 0:
            print(f"  Processed {i + 1:,}/{len(records):,} ECGs...")
        
        base_path = record["base_path"]
        try:
            subject_id = extract_subject_id_from_path(base_path)
            subject_stays = icu_mapper.icustays_df[
                icu_mapper.icustays_df['subject_id'] == subject_id
            ]
            
            if len(subject_stays) == 0:
                unmatched_count += 1
                continue
            
            first_stay = subject_stays.iloc[0]
            ecg_time = pd.to_datetime(first_stay['intime'])
            stay_id = icu_mapper.map_ecg_to_stay(subject_id, ecg_time)
            
            if stay_id is None:
                unmatched_count += 1
                continue
            
            los_days = icu_mapper.get_los(stay_id)
            if los_days is None:
                unmatched_count += 1
                continue
            
            los_class = los_to_bin(los_days)
            
            if stay_id not in stay_to_info:
                unmatched_count += 1
                continue
            
            stay_info = stay_to_info[stay_id]
            hadm_id = stay_info['hadm_id']
            intime = stay_info['intime']
            outtime = stay_info['outtime']
            
            deathtime = hadm_to_deathtime.get(hadm_id)
            died_in_this_stay = (
                deathtime is not None and
                pd.notna(deathtime) and
                intime <= deathtime <= outtime
            )
            
            if died_in_this_stay:
                class_stats[los_class]['died'] += 1
            else:
                class_stats[los_class]['survived'] += 1
            
            matched_count += 1
        except Exception:
            unmatched_count += 1
            continue
    
    print(f"\n  Matched: {matched_count:,} ECGs")
    print(f"  Unmatched: {unmatched_count:,} ECGs")
    
    results = []
    for class_idx in range(10):
        died = class_stats[class_idx]['died']
        survived = class_stats[class_idx]['survived']
        total = died + survived
        mortality_rate = (died / total * 100) if total > 0 else 0.0
        
        los_range = "[9, +inf) days" if class_idx == 9 else f"[{class_idx}, {class_idx+1}) days"
        
        results.append({
            'class': class_idx,
            'los_range': los_range,
            'total_ecgs': total,
            'died_ecgs': died,
            'survived_ecgs': survived,
            'mortality_rate': mortality_rate
        })
        
        if total > 0:
            print(f"\n  Class {class_idx} ({los_range}):")
            print(f"    Total ECGs: {total:,}")
            print(f"    Died: {died:,} ({mortality_rate:.2f}%)")
            print(f"    Survived: {survived:,} ({100 - mortality_rate:.2f}%)")
    
    return {
        'dataset_name': dataset_name,
        'total_ecgs': matched_count,
        'results_by_class': results
    }

def create_mortality_chart(results_list: List[Dict], return_figure: bool = True):
    """Create visualization: mortality rate by LOS class."""
    if len(results_list) == 0:
        print("No results to visualize.")
        return None
    
    n_classes = 10
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 7))
    
    for result in results_list:
        dataset_name = result['dataset_name']
        results_by_class = result['results_by_class']
        classes = [r['class'] for r in results_by_class]
        mortality_rates = [r['mortality_rate'] for r in results_by_class]
        ax1.plot(classes, mortality_rates, marker='o', linewidth=2, markersize=8,
                label=dataset_name, alpha=0.8)
    
    ax1.set_xlabel('LOS Class', fontsize=12, fontweight='bold')
    ax1.set_ylabel('Mortality Rate (%)', fontsize=12, fontweight='bold')
    ax1.set_title('ECG Mortality Rate by LOS Class', fontsize=14, fontweight='bold')
    ax1.set_xticks(range(10))
    ax1.set_xticklabels([f'{i}' for i in range(10)])
    ax1.legend(loc='upper left')
    ax1.grid(axis='y', alpha=0.3, linestyle='--')
    ax1.set_ylim([0, None])
    
    x = np.arange(n_classes)
    width = 0.35
    for i, result in enumerate(results_list):
        dataset_name = result['dataset_name']
        results_by_class = result['results_by_class']
        mortality_rates = [r['mortality_rate'] for r in results_by_class]
        offset = width * (i - 0.5) if len(results_list) == 2 else 0
        bars = ax2.bar(x + offset, mortality_rates, width, label=dataset_name, alpha=0.8)
        for bar, rate in zip(bars, mortality_rates):
            if rate > 0:
                height = bar.get_height()
                ax2.text(bar.get_x() + bar.get_width()/2., height,
                        f'{rate:.1f}%',
                        ha='center', va='bottom', fontsize=8)
    
    ax2.set_xlabel('LOS Class', fontsize=12, fontweight='bold')
    ax2.set_ylabel('Mortality Rate (%)', fontsize=12, fontweight='bold')
    ax2.set_title('ECG Mortality Rate by LOS Class (Grouped)', fontsize=14, fontweight='bold')
    ax2.set_xticks(x)
    ax2.set_xticklabels([f'{i}' for i in range(10)])
    ax2.legend(loc='upper left')
    ax2.grid(axis='y', alpha=0.3, linestyle='--')
    ax2.set_ylim([0, None])
    
    plt.tight_layout()
    return fig


In [None]:
# Load data
icustays_path = "data/labeling/labels_csv/icustays.csv"
admissions_path = "data/labeling/labels_csv/admissions.csv"

icustays_df = load_icustays(icustays_path)
admissions_df = pd.read_csv(admissions_path)

print(f"Loaded {len(icustays_df):,} ICU stays")
print(f"Loaded {len(admissions_df):,} admissions")


In [None]:
# Analyze both datasets
datasets = {
    'all_icu_ecgs': {
        'name': 'All ICU ECGs',
        'data_dir': Path('data/all_icu_ecgs_P1')
    },
    'icu_24h': {
        'name': 'ICU 24h',
        'data_dir': Path('data/icu_ecgs_24h')
    }
}

results_list = []
for dataset_key, dataset_info in datasets.items():
    if dataset_info['data_dir'].exists():
        result = analyze_ecg_mortality_by_los_class(
            dataset_name=dataset_info['name'],
            data_dir=dataset_info['data_dir'],
            icustays_df=icustays_df,
            admissions_df=admissions_df
        )
        if result:
            results_list.append(result)
    else:
        print(f"Dataset {dataset_info['name']} not found at {dataset_info['data_dir']}")


In [None]:
# Visualize results
if results_list:
    # Create chart (display inline in notebook)
    fig = create_mortality_chart(results_list, return_figure=True)
    plt.show()
    
    # Display results table
    print("\n" + "="*80)
    print("DETAILED RESULTS BY LOS CLASS")
    print("="*80)
    
    for result in results_list:
        print(f"\n{result['dataset_name']}:")
        df = pd.DataFrame(result['results_by_class'])
        print(df.to_string(index=False))
