In [1]:
# import 
import numpy as np
from scipy.stats import chi2
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.dates as mdates
import matplotlib.colors as mcolors
import polars as pl
import pandas as pd
import seaborn as sns
import torch
import lightning as L
from sklearn.calibration import calibration_curve
from sklearn.metrics import roc_auc_score, average_precision_score, confusion_matrix, roc_curve
from omegaconf import OmegaConf, DictConfig
import hydra
import wandb
from dataset import SupervisedDataset
from lightning_modules import SupervisedTask
from models.ecg_models import *
from run import interpolate
pl.Config.set_tbl_rows(50)
MY_NAVY = '#001F54'

  torch.utils._pytree._register_pytree_node(


In [2]:
WANDB_RUN='payalchandak/SILVER/nuzrc47q'
COHORTS = ['all','worsen']

In [None]:
# train data
for site in [
    {
        'name':'MGH',
        'split':'train',
        'line':'-',
    },
]:
    device = 'cuda:1'
    cfg = OmegaConf.create(wandb.Api().run(WANDB_RUN).config)
    L.seed_everything(cfg.utils.seed)
    train_pyd = hydra.utils.instantiate(cfg.dataset, split='train')
    cfg = interpolate(cfg, train_pyd)
    del train_pyd
    trainer = L.Trainer(devices=[int(device[-1])])
    LM = SupervisedTask.load_from_checkpoint(cfg.best_model_path, map_location=torch.device(device))
    model = LM.model
    model.to(device)
    model.eval()
    cfg.optimizer.batch_size = 2048
    cfg.dataset.config.label = 'future_1_365_any_below_40'

    if site['split'] == 'mimic': 
        cfg.dataset.config.datadir = '/storage/shared/mimic/'
        cfg.dataset.config.ecg.storedir = '/storage/shared/mimic/raw/ecg/'

    pyd = hydra.utils.instantiate(cfg.dataset, split=site['split'])
    pyd.data = pyd.data.reset_index(drop=1)
    assert len(pyd)
    site['pyd'] = pyd
    loader = torch.utils.data.DataLoader(
        dataset = pyd,
        batch_size = cfg.optimizer.batch_size,
        num_workers = 0, 
        collate_fn = pyd.collate,
        shuffle=False,
        pin_memory=True
    )

    out = trainer.predict(LM, loader)
    site['pred'] = np.array(torch.sigmoid(torch.cat([x[0] for x in out])).tolist())
    site['true'] = np.array(torch.cat([x[1] for x in out]).tolist())

    site['idx_all'] = site['pyd'].data.index.values
    site['idx_improve'] = site['pyd'].data.query('tag_hfref').index.values
    site['idx_worsen'] = site['pyd'].data.query('~tag_hfref').index.values
    # site['idx_worsen_50'] = site['pyd'].data.query('tag_50==True').index.values
    # site['idx_worsen_no_com'] = site['pyd'].data.query('~tag_hfref').query('hypertension==0').query('diabetes_mellitus==0').query('atheroscler==0').query('chronic_obstructive_pulmonary_disease==0').query('atrial_fibrillation==0').index.values
    # site['idx_worsen_no_med'] = site['pyd'].data.query('~tag_hfref').query('angio==0').query('betablocker==0').query('mra==0').query('diuretic==0').index.values
    # site['idx_worsen_healthy'] = site['pyd'].data.query('~tag_hfref').query('angio==0').query('betablocker==0').query('mra==0').query('diuretic==0').query('hypertension==0').query('diabetes_mellitus==0').query('atheroscler==0').query('chronic_obstructive_pulmonary_disease==0').query('atrial_fibrillation==0').index.values

    for cohort in COHORTS:
        if f"idx_{cohort}" not in site: continue 
        idx = site[f"idx_{cohort}"]
        true = site['true'][idx]
        pred = site['pred'][idx]
        if cohort=='improve': 
            true = 1 - true 
            pred = 1 - pred
        x = {}
        x['fpr'], x['tpr'], x['thresholds'] = roc_curve(true, pred)
        x['sens_to_thres'] = {}
        sensitivities = x['tpr']
        for i in np.concatenate([np.arange(0.1, 1, 0.1)]):
            i = round(i, 2)
            tol = 1e-10
            matches = np.array([])
            while not matches.any(): 
                matches = np.where(np.isclose(sensitivities, i, atol=tol, rtol=tol))[0]
                tol *= 10
            x['sens_to_thres'][i] = np.round(np.mean(np.unique(np.round(x['thresholds'][matches],2))),3)
        site[cohort] = x
    train_data = site

In [3]:
train_data = {'name': 'MGH', 'split': 'train', 'line': '-',  
'all': {'sens_to_thres': {0.1: 0.96, 0.2: 0.94, 0.3: 0.91, 0.4: 0.87, 0.5: 0.81, 0.6: 0.73, 0.7: 0.61, 0.8: 0.44, 0.9: 0.21}},
'worsen': { 'sens_to_thres': {0.1: 0.955, 0.2: 0.89, 0.3: 0.8, 0.4: 0.68, 0.5: 0.56, 0.6: 0.43, 0.7: 0.29, 0.8: 0.17, 0.9: 0.07}}, 
}

In [4]:
sites = []
for site in [
    {
        'name':'MGH',
        'split':'test',
        'line':'-',
        'color':'#0F548D',
    },
    {
        'name':'BWH',
        'split':'external',
        'line':'--',
        'color':'#941751',
    },
    {
        'name':'MIMIC',
        'split':'mimic',
        'line':'-.',
        'color':'#4F8F00',
    },
]:
    device = 'cuda:2'
    cfg = OmegaConf.create(wandb.Api().run(WANDB_RUN).config)
    L.seed_everything(cfg.utils.seed)
    train_pyd = hydra.utils.instantiate(cfg.dataset, split='train')
    cfg = interpolate(cfg, train_pyd)
    del train_pyd
    trainer = L.Trainer(devices=[int(device[-1])])
    LM = SupervisedTask.load_from_checkpoint(cfg.best_model_path, map_location=torch.device(device))
    model = LM.model
    model.to(device)
    model.eval()
    cfg.optimizer.batch_size = 2048
    cfg.dataset.config.label = 'future_1_365_any_below_40'

    if site['split'] == 'mimic': 
        cfg.dataset.config.datadir = '/storage/shared/mimic/'
        cfg.dataset.config.ecg.storedir = '/storage/shared/mimic/raw/ecg/'

    pyd = hydra.utils.instantiate(cfg.dataset, split=site['split'])
    pyd.data = pyd.data.reset_index(drop=1)
    assert len(pyd)
    site['pyd'] = pyd
    loader = torch.utils.data.DataLoader(
        dataset = pyd,
        batch_size = cfg.optimizer.batch_size,
        num_workers = 0, 
        collate_fn = pyd.collate,
        shuffle=False,
        pin_memory=True
    )

    out = trainer.predict(LM, loader)
    site['pred'] = np.array(torch.sigmoid(torch.cat([x[0] for x in out])).tolist())
    site['true'] = np.array(torch.cat([x[1] for x in out]).tolist())

    site['idx_all'] = site['pyd'].data.index.values
    # site['idx_improve'] = site['pyd'].data.query('tag_hfref').index.values
    site['idx_worsen'] = site['pyd'].data.query('~tag_hfref').index.values
    # site['idx_worsen_no_com'] = site['pyd'].data.query('~tag_hfref').query('hypertension==0').query('diabetes_mellitus==0').query('atheroscler==0').query('chronic_obstructive_pulmonary_disease==0').query('atrial_fibrillation==0').index.values
    # site['idx_worsen_no_med'] = site['pyd'].data.query('~tag_hfref').query('angio==0').query('betablocker==0').query('mra==0').query('diuretic==0').index.values
    # site['idx_worsen_healthy'] = site['pyd'].data.query('~tag_hfref').query('angio==0').query('betablocker==0').query('mra==0').query('diuretic==0').query('hypertension==0').query('diabetes_mellitus==0').query('atheroscler==0').query('chronic_obstructive_pulmonary_disease==0').query('atrial_fibrillation==0').index.values

    sites.append(site)

Seed set to 140799
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]


Predicting: |          | 0/? [00:00<?, ?it/s]

TypeError: 'NoneType' object is not iterable

In [6]:
# functions 

def read_lvef(split): 
    if split == 'mimic': 
        lvef = pl.read_csv(
            '/storage/shared/mimic/raw/lvef.csv'
        ).rename({
            'subject_id':'empi',
            'study_datetime':'lvef_date',
            'result':'lvef'
        }).drop(
            'measurement'
        ).filter(
            pl.col('lvef').is_not_null()
        ).with_columns(
            pl.col('lvef').cast(pl.Int64),
            pl.col("lvef_date").str.strptime(pl.Datetime, "%Y-%m-%dT%H:%M:%S").dt.date()
        ).filter(
            (pl.col('lvef')>0) & (pl.col('lvef')<100)
        ).sort('lvef_date')
    else: 
        lvef = pl.read_parquet('/storage2/payal/dropbox/private/data/processed/lvef.parquet').sort('lvef_date') 
    return lvef

def create_bootstrap_idx(length, N=1000): 
    return [np.random.choice(length, length, replace=True) for _ in range(N)]

def bootstrap(metric_fn, true, pred, idx=None, conf=0.95, **kwargs): 
    score = metric_fn(true=true, pred=pred, **kwargs)
    if idx is None: 
        idx = create_bootstrap_idx(length=len(true), N=1000)
    scores = [metric_fn(true[x], pred[x], **kwargs) for x in idx]
    scores = sorted([x for x in scores if x is not None])
    N = len(scores)
    lower_idx = int(N * (1 - conf) / 2)
    lower = np.abs(scores[lower_idx] - score)
    upper_idx = int(N * (1 + conf) / 2)
    upper = np.abs(scores[upper_idx] - score)
    return score, (lower, upper)

def conf_matrix(true, pred, threshold): 
    pred_labels = (pred >= threshold).astype(int)
    tn = np.sum((pred_labels == 0) & (true == 0))
    fp = np.sum((pred_labels == 1) & (true == 0))
    fn = np.sum((pred_labels == 0) & (true == 1))
    tp = np.sum((pred_labels == 1) & (true == 1))
    return tn, fp, fn, tp

def get_auc(true, pred):
    prev = true.sum()/len(true)
    if prev==0 or prev==1: return None 
    return roc_auc_score(true, pred)

def get_auprc(true, pred): 
    return average_precision_score(true, pred)

def get_sens(true, pred, **kwargs): 
    tn, fp, fn, tp = conf_matrix(true, pred, kwargs['threshold'])
    sens = tp / (tp + fn)
    return sens

def get_spec(true, pred, **kwargs): 
    tn, fp, fn, tp = conf_matrix(true, pred, kwargs['threshold'])
    spec = tn / (tn + fp)
    return spec

def get_ppv(true, pred, **kwargs): 
    prev = kwargs['prevalence']
    tn, fp, fn, tp = conf_matrix(true, pred, kwargs['threshold'])
    sens = tp / (tp + fn)
    spec = tn / (tn + fp)
    ppv = (sens * prev) / (sens * prev + (1 - spec) * (1 - prev)) 
    return ppv 

def get_npv(true, pred, **kwargs): 
    prev = kwargs['prevalence']
    tn, fp, fn, tp = conf_matrix(true, pred, kwargs['threshold'])
    sens = tp / (tp + fn)
    spec = tn / (tn + fp)
    npv = (spec * (1 - prev)) / (spec * (1 - prev) + (1 - sens) * prev) 
    return npv

def get_metric(name, cohort=None, verbose=False): 
    my_sites = sites
    data = []
    for x in my_sites: 
        if cohort is not None: x = x[cohort]
        if name not in x: 
            continue 
        data.append(x[name])
    if data:
        if isinstance(data[0], (int)): data = [round(x) for x in data]
        if isinstance(data[0], (float)): data = [round(x,3) for x in data]
        if 'cf_' in name: name = name.replace('cf_','').split('_')[0]
        if verbose: print(name, " ".join([str(x) for x in data]))
        return data
    
def get_metric_ci(name, cohort=None): 
    return np.array(get_metric(name+'_ci', cohort)).T

def barplot(x, y, percent=False, decimal=0, ylim=None, yerr=None, colors=None):
    sns.set_theme(style="whitegrid")
    fig, ax = plt.subplots(figsize=(6, 4))
    if colors is None: colors = sns.color_palette("Blues_r", len(x))
    else: colors = colors[:len(x)] 
    ax = sns.barplot(x=x, y=y, palette=colors, ax=ax)
    if ylim is not None:
        ax.set_ylim(ylim)
    if yerr is not None:
        x_coords = [p.get_x() + 0.5 * p.get_width() for p in ax.patches]
        y_coords = [p.get_height() for p in ax.patches]
        plt.errorbar(x=x_coords, y=y_coords, yerr=yerr, fmt="none", c="k", capsize=5)
    ymin, ymax = ax.get_ylim() if ylim is None else ylim
    annotation_y = ymin + (ymax - ymin) * 0.025  
    for p in ax.patches:
        label = round(p.get_height(), decimal) if decimal else round(p.get_height())
        ax.annotate(f"{label}{'%' if percent else ''}", 
                    (p.get_x() + p.get_width() / 2, annotation_y), 
                    ha="center", va="bottom", fontsize=12, fontweight="bold")
    sns.despine()
    return fig

def generate_shades(hex_color, n=3, factor=0.15):
    rgb = mcolors.hex2color(hex_color)
    shades = [(min(1, c + factor * i)) for i in range(1, n+1) for c in rgb]
    shades = [mcolors.to_hex(shades[i:i+3]) for i in range(0, len(shades), 3)]
    return shades

In [7]:
for num in range(len(sites)): 
    site = sites[num]
    for cohort in COHORTS:
        if f"idx_{cohort}" not in site: continue 
        idx = site[f"idx_{cohort}"]
        true = site['true'][idx]
        pred = site['pred'][idx]
        bootstrap_idx = create_bootstrap_idx(len(true))

        if cohort=='improve': 
            true = 1 - true 
            pred = 1 - pred
    
        x = {}
        x['samples'] = len(true)
        x['prevalence'] = true.sum()/len(true)
        x['auc'], x['auc_ci'] = bootstrap(get_auc, true, pred, bootstrap_idx)
        x['auprc'], x['auprc_ci'] = bootstrap(get_auprc, true, pred, bootstrap_idx)
        x['fpr'], x['tpr'], x['thresholds'] = roc_curve(true, pred)
        
        s2t = train_data[cohort]['sens_to_thres']

        x['sensitivity'], x['sensitivity_ci'] = [], []
        x['specificity'], x['specificity_ci'] = [], []
        for _, threshold in s2t.items(): 
            sens, sens_ci = bootstrap(get_sens, true, pred, bootstrap_idx, threshold=threshold)
            x['sensitivity'].append(sens)
            x['sensitivity_ci'].append(sens_ci)
            spec, spec_ci = bootstrap(get_spec, true, pred, bootstrap_idx, threshold=threshold)
            x['specificity'].append(spec)
            x['specificity_ci'].append(spec_ci)

        for desired_sens in [.7,.8,.9]: 
            prefix = f'sens_{round(desired_sens*100)}'
            threshold = s2t[desired_sens]

            # x[f'{prefix}_spec'], x[f'{prefix}_spec_ci'] = bootstrap(get_spec, true, pred, bootstrap_idx, threshold=threshold)
            x[f'{prefix}_ppv'], x[f'{prefix}_ppv_ci'] = bootstrap(get_ppv, true, pred, bootstrap_idx, threshold=threshold, prevalence=x['prevalence'])
            x[f'{prefix}_npv'], x[f'{prefix}_npv_ci'] = bootstrap(get_npv, true, pred, bootstrap_idx, threshold=threshold, prevalence=x['prevalence'])

            for prev in np.arange(0.1,1,.05):
                prev = round(prev,2)
                x[f'{prefix}_prev_{prev}_ppv'], x[f'{prefix}_prev_{prev}_ppv_ci'] = bootstrap(get_ppv, true, pred, bootstrap_idx, threshold=threshold, prevalence=prev)
                x[f'{prefix}_prev_{prev}_npv'], x[f'{prefix}_prev_{prev}_npv_ci'] = bootstrap(get_npv, true, pred, bootstrap_idx, threshold=threshold, prevalence=prev)

        sites[num][cohort] = x

    print(f"Finished {sites[num]['name']}")

Finished MGH
Finished BWH


In [8]:
# confounder auc in all patients 
confounders = [
    # 'mean_annual_hospitalizations<1',
    # 'mean_annual_hospitalizations>=1',
    'paced==True',
    'paced==False',
    'afib==True',
    'afib==False',
    # 'transplant==0',
    # 'transplant==1',
    # 'diabetes_mellitus==0',
    # 'diabetes_mellitus==1',
    # 'hypertension==0',
    # 'hypertension==1',
    # 'atheroscler==0',
    # 'atheroscler==1',
    # 'chronic_obstructive_pulmonary_disease==0',
    # 'chronic_obstructive_pulmonary_disease==1',
    # 'atrial_fibrillation==0',
    # 'atrial_fibrillation==1',
    # 'num_meds==0',
    # 'num_meds>0',
    # 'angio==0',
    # 'angio==1',
    # 'mra==0',
    # 'mra==1',
    # 'betablocker==0',
    # 'betablocker==1',
    # 'diuretic==0',
    # 'diuretic==1',
]
for cohort in ['all','worsen']:
    print('-'*40)
    print(cohort)
    print('-'*40)
    for cf in confounders:
        result = [cf,] 
        for site in sites:
            idx = site[f"idx_{cohort}"]
            try: site['pyd'].data.query(cf)
            except: continue 
            cf_idx = np.intersect1d(idx, site['pyd'].data.query(cf).index.values)
            cf_true = site['true'][cf_idx]
            cf_pred = site['pred'][cf_idx]
            try: roc_auc_score(cf_true, cf_pred)
            except: continue 
            auc, ci = bootstrap(get_auc, cf_true, cf_pred)
            delta = np.mean(ci)
            result.append(f"\t {round(auc*100, 1)} ± {round(delta*100, 1)}")
        print(*result)

----------------------------------------
all
----------------------------------------
paced==True 	 92.7 ± 0.7 	 92.4 ± 0.5
paced==False 	 92.9 ± 0.2 	 91.9 ± 0.2
afib==True 	 92.0 ± 0.8 	 91.5 ± 0.4
afib==False 	 93.1 ± 0.2 	 92.3 ± 0.2
----------------------------------------
worsen
----------------------------------------
paced==True 	 87.3 ± 1.7 	 88.3 ± 0.8
paced==False 	 89.5 ± 0.5 	 89.2 ± 0.3
afib==True 	 88.8 ± 1.3 	 87.9 ± 0.6
afib==False 	 89.5 ± 0.4 	 89.5 ± 0.3


In [None]:
# bars
for cohort in COHORTS:
    print(cohort)
    fig = barplot(
        x=get_metric('name'), 
        y=get_metric('samples',cohort),
        ylim=(0,100000),
        colors=get_metric('color'), 
    )
    plt.show()
    fig = barplot(
        x=get_metric('name'), 
        y=[100*i for i in get_metric('prevalence',cohort)], 
        percent=True, 
        decimal=1, 
        ylim=(0,100),
        colors=get_metric('color'), 
    )
    plt.show()
    fig = barplot(
        x=get_metric('name'), 
        y=[100*i for i in get_metric('auc',cohort)], 
        yerr=100* get_metric_ci('auc',cohort), 
        percent=True, 
        decimal=1, 
        ylim=(50,100),
        colors=get_metric('color'), 
    )
    plt.show()
    fig = barplot(
        x=get_metric('name'), 
        y=[100*i for i in get_metric('auprc',cohort)], 
        yerr=100* get_metric_ci('auprc',cohort), 
        percent=True, 
        decimal=1, 
        ylim=(0,100),
        colors=get_metric('color'), 
    )
    plt.show()


In [None]:
# sens sens
for cohort in COHORTS:
    print(cohort)
    for site in sites:
        data = site[cohort]
        prev = np.round(np.arange(0.1, 1, 0.1), 1)
        sens_levels = [70, 80, 90]
        fig, ax = plt.subplots(figsize=(6, 4))
        x = data[f'sensitivity']
        x_err = np.array(data[f'sensitivity_ci']).T
        plt.errorbar(
            np.arange(0.1, 1, 0.1),
            x,
            yerr=x_err,
            capsize=3,
            marker="o",
            markersize=3,
            linestyle="-",
            linewidth=1,
            color=site['color'],
        )
        ax.set_xlabel("Sensitivity @ Training", fontsize=12, color='gray')
        ax.set_ylabel("Sensitivity @ Evaluation", fontsize=12, color='gray')
        ax.set_title(f"{site['name']}", fontsize=14)
        ax.set_ylim(0, 1.025)
        ax.set_xlim(0, 1)
        plt.gca().tick_params(axis='y', colors='gray')
        plt.gca().spines['left'].set_color('gray')
        plt.gca().spines['left'].set_linewidth(1)
        plt.gca().tick_params(axis='x', colors='gray', which='both')
        plt.gca().spines['bottom'].set_color('gray')
        sns.despine()
        plt.tight_layout()
        plt.show()

In [None]:
# sens spec
for cohort in COHORTS:
    fig, ax = plt.subplots(figsize=(6, 4))
    print(cohort)
    for site in sites:
        data = site[cohort]
        prev = np.round(np.arange(0.1, 1, 0.1), 1)
        sens_levels = [70, 80, 90]
        spec = data[f'specificity']
        spec_err = np.array(data[f'specificity_ci']).T
        plt.errorbar(
            np.arange(0.1, 1, 0.1),
            spec,
            yerr=spec_err,
            capsize=3,
            marker="o",
            markersize=3,
            linestyle="-",
            linewidth=1,
            color=site['color'],
        )
        ax.set_xlabel("Sensitivity", fontsize=12, color='gray')
        ax.set_ylabel("Specificity", fontsize=12, color='gray')
        # ax.set_title(f"{site['name']}", fontsize=14)
        ax.set_ylim(0, 1.025)
        ax.set_xlim(0, 1)
        plt.gca().tick_params(axis='y', colors='gray')
        plt.gca().spines['left'].set_color('gray')
        plt.gca().spines['left'].set_linewidth(1)
        plt.gca().tick_params(axis='x', colors='gray', which='both')
        plt.gca().spines['bottom'].set_color('gray')
        sns.despine()
        plt.tight_layout()
    plt.show()

In [None]:
# ppv
for cohort in COHORTS:
    print(cohort)
    for site in sites:
        data = site[cohort]
        prev = np.round(np.arange(0.1, 1, 0.1), 1)
        sens_levels = [70, 80, 90]
        fig, ax = plt.subplots(figsize=(6, 4))
        for sens in sens_levels:
            ppv = np.array([data[f'sens_{sens}_prev_{i}_ppv'] for i in prev])
            ppv_err = np.array([data.get(f'sens_{sens}_prev_{i}_ppv_ci') for i in prev])
            plt.errorbar(
                prev,
                ppv,
                yerr=ppv_err.T,
                label=f"Sensitivity {sens}%",
                capsize=3,
                marker="o",
                markersize=3,
                linestyle="-",
                linewidth=1,
                color=site['color'],
            )
            plt.text(
                0.15,
                ppv[0]+0.05,  
                f"{sens}%",
                fontsize=10,
                color=site['color'],
                ha="center",
                va="bottom",
                bbox=dict(boxstyle="round,pad=0.1", edgecolor="none", facecolor="white", alpha=0.7),
            )
        ax.set_xlabel("Prevalence", fontsize=12, color='gray')
        ax.set_ylabel("Positive Predictive Value", fontsize=12, color='gray')
        ax.set_title(f"{site['name']}", fontsize=14)
        ax.set_ylim(0, 1)
        ax.set_xlim(0, 1)
        plt.gca().tick_params(axis='y', colors='gray')
        plt.gca().spines['left'].set_color('gray')
        plt.gca().spines['left'].set_linewidth(1)
        plt.gca().tick_params(axis='x', colors='gray', which='both')
        plt.gca().spines['bottom'].set_color('gray')
        # handles, labels = ax.get_legend_handles_labels()
        # handles = [h[0] for h in handles]
        # ax.legend(handles, labels, markerscale=0.001)
        plt.grid(True)
        sns.despine()
        plt.tight_layout()
        plt.show()

In [None]:
# npv
for cohort in COHORTS:
    print(cohort)
    for site in sites:
        data = site[cohort]
        prev = np.round(np.arange(0.1, 1, 0.1), 1)
        sens_levels = [70, 80, 90]
        fig, ax = plt.subplots(figsize=(6, 4))
        for sens in sens_levels:
            npv = np.array([data[f'sens_{sens}_prev_{i}_npv'] for i in prev])
            npv_err = np.array([data.get(f'sens_{sens}_prev_{i}_npv_ci') for i in prev])
            plt.errorbar(
                prev,
                npv,
                yerr=npv_err.T,
                label=f"Sensitivity {sens}%",
                capsize=3,
                marker="o",
                markersize=3,
                linestyle="-",
                linewidth=1,
                color=site['color'],
            )
            plt.text(
                0.8,
                npv[-2],  
                f"{sens}%",
                fontsize=10,
                color=site['color'],
                ha="center",
                va="bottom",
                bbox=dict(boxstyle="round,pad=0.1", edgecolor="none", facecolor="white", alpha=0.7),
            )
        ax.set_xlabel("Prevalence", fontsize=12, color='gray')
        ax.set_ylabel("Negative Predictive Value", fontsize=12, color='gray')
        ax.set_title(f"{site['name']}", fontsize=14)
        ax.set_ylim(0, 1)
        ax.set_xlim(0, 1)
        plt.gca().tick_params(axis='y', colors='gray')
        plt.gca().spines['left'].set_color('gray')
        plt.gca().spines['left'].set_linewidth(1)
        plt.gca().tick_params(axis='x', colors='gray', which='both')
        plt.gca().spines['bottom'].set_color('gray')
        # handles, labels = ax.get_legend_handles_labels()
        # handles = [h[0] for h in handles]
        # ax.legend(handles, labels, markerscale=0.001)
        plt.grid(True)
        sns.despine()
        plt.tight_layout()
        plt.show()

In [None]:
# print specificity
for cohort in COHORTS:
    print(f"\nCohort: {cohort}")
    print("=" * 40)
    
    sensitivities = [10,20,30,40,50,60,70, 80, 90]
    
    for i, sens in enumerate(sensitivities):
        if sens<65: continue
        print(f"\nSensitivity: {sens}%")
        for site_name, site in zip(["MGH", "BWH", "MIMIC"], sites):
            specificity = site[cohort]['specificity'][i]
            lower_ci, upper_ci = site[cohort]['specificity_ci'][i]
            print(f"  {site_name}: Specificity = {specificity:.3f} (95% CI: [{lower_ci:.3f}, {upper_ci:.3f}])")


In [None]:
# print ppv npv
for cohort in COHORTS:
    print(f"\nCohort: {cohort}")
    print("=" * 40)
    model_names = get_metric('name')  # Retrieve model names

    for prev in [0.1]:
        print("=" * 40)
        print(f"\nPrevalence: {prev*100}%")
        for metric in ['ppv', 'npv']:
            for sens in [70,80,90]: 
                x = get_metric(f'sens_{sens}_prev_{prev}_{metric}', cohort)
                x_err = get_metric_ci(f'sens_{sens}_prev_{prev}_{metric}', cohort)

                print(f"\nSensitivity: {sens}%")
                print(f"{metric} {min(x)*100:.1f}%–{max(x)*100:.1f}%")
                # for name, val, ci_low, ci_high in zip(model_names, x, x_err[0], x_err[1]):
                #     print(f"  {name}: {metric} = {val:.3f} (95% CI: [{val - ci_low:.3f}, {val + ci_high:.3f}])")


In [None]:
# total positives in last N years

from datetime import timedelta

N_YEARS = 2

for cohort in COHORTS:
    print(cohort)
    for site in sites:
        print(site['name'])
        if f"idx_{cohort}" not in site: continue 
        idx = site[f"idx_{cohort}"]
        set_idx = set(list(idx))
        true = site['true'][idx]
        pred = site['pred'][idx]
        sns.set_theme(style="whitegrid")
        fig, ax = plt.subplots(1, 1, figsize=(6,4), sharex=False)

        shades = generate_shades(site['color'], n=3)
    
        for sens, color in zip([.7, .8, .9], reversed(shades)): 
            threshold = train_data[cohort]['sens_to_thres'][sens]
            # decision = pred>(1-threshold)
            decision = pred>threshold
            df = site['pyd'].data.loc[[x in set_idx for x in range(site['pyd'].data.shape[0])],:]
            df.loc[:,'true'] = true
            df.loc[:,'score'] = pred
            df.loc[:,'pred'] = decision
            df = pl.from_pandas(df)

            total_positives = []
            for group in df.group_by("empi"):
                ecg = group[1].sort("ecg_date")
                for idx in range(ecg.height): 
                    start_date = ecg[idx, "ecg_date"]
                    max_date = start_date + timedelta(days=N_YEARS*365)
                    
                    pos_count = (
                        ecg.filter(
                            (pl.col("ecg_date") >= start_date) & 
                            (pl.col("ecg_date") <= max_date) &
                            (pl.col("pred") == True)  # Only count positive predictions
                        )
                        .with_row_count(name='pos_count', offset=1)
                        .select(["pos_count", "empi", "true", "pred"])
                    )
                    total_positives.append(pos_count)

            total_positives = pl.concat(total_positives)

            PPVs = []
            bins = []
            step=1
            for st in range(1, 4, step): 
                end = st + step 
                bins.append((st,end))
            bins.append((3,20))
            for st, end in bins:
                x = total_positives.filter(pl.col('pos_count').is_between(st, end, closed='left'))
                if not x.height: 
                    continue 
                ppv = x.filter(pl.col('true')==1).height/x.height
                if end-st <= 1 : 
                    label = f"{st}"
                else: 
                    label = f"{st}+"
                PPVs.append((label, ppv))
                print(label, ppv)

            ax = sns.barplot(x=[x[0] for x in PPVs], y=[x[1] for x in PPVs], label=sens, color=color, ax=ax)

        ax.set_xlabel("Cumulative positives", fontsize=12, color='gray')
        ax.set_ylabel("Positive Predictive Value", fontsize=12, color='gray')
        ax.set_title(site['name'], fontsize=14)
        ax.set_ylim(0, 1)
        plt.gca().tick_params(axis='y', colors='gray')
        plt.gca().spines['left'].set_color('gray')
        plt.gca().spines['left'].set_linewidth(1)
        plt.gca().tick_params(axis='x', colors='gray', which='both')
        plt.gca().spines['bottom'].set_color('gray')
        ax.grid(axis='y', linestyle='--')
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        # plt.ylim(0.5, 1)
        plt.legend(loc='upper left')
        plt.tight_layout()
        plt.show()




In [None]:
# total negatives in last N years

from datetime import timedelta

N_YEARS = 2

for cohort in COHORTS:
    print(cohort)
    for site in sites:
        print(site['name'])
        if f"idx_{cohort}" not in site: continue 
        idx = site[f"idx_{cohort}"]
        set_idx = set(list(idx))
        true = site['true'][idx]
        pred = site['pred'][idx]
        sns.set_theme(style="whitegrid")
        fig, ax = plt.subplots(1, 1, figsize=(6,4), sharex=False)

        shades = generate_shades(site['color'], n=3)
        for sens, color in zip([.9, .8, .7], shades): 
            threshold = train_data[cohort]['sens_to_thres'][sens]
            # decision = pred>(1-threshold)
            decision = pred>threshold
            df = site['pyd'].data.loc[[x in set_idx for x in range(site['pyd'].data.shape[0])],:]
            df.loc[:,'true'] = true
            df.loc[:,'score'] = pred
            df.loc[:,'pred'] = decision
            df = pl.from_pandas(df)

            total_negatives = []
            for group in df.group_by("empi"):
                ecg = group[1].sort("ecg_date")
                for idx in range(ecg.height): 
                    start_date = ecg[idx, "ecg_date"]
                    max_date = start_date + timedelta(days=N_YEARS * 365)
                    
                    neg_count = (
                        ecg.filter(
                            (pl.col("ecg_date") >= start_date) & 
                            (pl.col("ecg_date") <= max_date) &
                            (pl.col("pred") == False)  # Only count positive predictions
                        )
                        .with_row_count(name='neg_count', offset=1)
                        .select(["neg_count", "empi", "true", "pred"])
                    )
                    total_negatives.append(neg_count)

            total_negatives = pl.concat(total_negatives)

            NPVs = []
            bins = []
            step=1
            for st in range(1, 4, step): 
                end = st + step 
                bins.append((st,end))
            bins.append((3,20))
            for st, end in bins:
                x = total_negatives.filter(pl.col('neg_count').is_between(st, end, closed='left'))
                if not x.height: 
                    continue 
                npv = x.filter(pl.col('true')==0).height/x.height
                if end-st <= 1 : 
                    label = f"{st}"
                else: 
                    label = f"{st}+"
                NPVs.append((label, npv))
                print(label, npv)

            ax = sns.barplot(x=[x[0] for x in NPVs], y=[x[1] for x in NPVs], label=sens, color=color, ax=ax)

        ax.set_xlabel("Cumulative negatives", fontsize=12, color='gray')
        ax.set_ylabel("Negative Predictive Value", fontsize=12, color='gray')
        ax.set_title(site['name'], fontsize=14)
        ax.set_ylim(0, 1)
        plt.gca().tick_params(axis='y', colors='gray')
        plt.gca().spines['left'].set_color('gray')
        plt.gca().spines['left'].set_linewidth(1)
        plt.gca().tick_params(axis='x', colors='gray', which='both')
        plt.gca().spines['bottom'].set_color('gray')
        ax.grid(axis='y', linestyle='--')
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        # plt.ylim(0.5, 1)
        plt.legend(loc='lower right')
        plt.tight_layout()
        plt.show()




In [None]:
# 5 years 

# total positives in last N years 


from datetime import timedelta


for cohort in ['all']:
    print(cohort)
    for site in sites:
        print(site['name'])
        if f"idx_{cohort}" not in site: continue 
        idx = site[f"idx_{cohort}"]
        set_idx = set(list(idx))
        true = site['true'][idx]
        pred = site['pred'][idx]
        sns.set_theme(style="whitegrid")
        fig, ax = plt.subplots(1, 1, figsize=(6,4), sharex=False)

        shades = generate_shades(site['color'], n=3)
    
        for sens, color in zip([.7, .8, .9], reversed(shades)): 
            threshold = train_data[cohort]['sens_to_thres'][sens]
            # decision = pred>(1-threshold)
            decision = pred>threshold
            df = site['pyd'].data.loc[[x in set_idx for x in range(site['pyd'].data.shape[0])],:]
            df.loc[:,'true'] = true
            df.loc[:,'score'] = pred
            df.loc[:,'pred'] = decision
            df = pl.from_pandas(df)

            total_positives = []
            for group in df.group_by("empi"):
                ecg = group[1].sort("ecg_date")
                for idx in range(ecg.height): 
                    start_date = ecg[idx, "ecg_date"]
                    max_date = start_date + timedelta(days=5 * 365)
                    
                    pos_count = (
                        ecg.filter(
                            (pl.col("ecg_date") >= start_date) & 
                            (pl.col("ecg_date") <= max_date) &
                            (pl.col("pred") == True)  # Only count positive predictions
                        )
                        .with_row_count(name='pos_count', offset=1)
                        .select(["pos_count", "empi", "true", "pred"])
                    )
                    total_positives.append(pos_count)

            total_positives = pl.concat(total_positives)

            PPVs = []
            step = 5
            for st in range(0, 50, step): 
                end = st + step 
                if st==0: st=1
                x = total_positives.filter(pl.col('pos_count').is_between(st, end, closed='left'))
                ppv = x.filter(pl.col('true')==1).height/x.height
                PPVs.append((f"{st}-{end}", ppv))

            ax = sns.barplot(x=[x[0] for x in PPVs], y=[x[1] for x in PPVs], label=sens, color=color, ax=ax)

        ax.set_xlabel("Cumulative positives", fontsize=12, color='gray')
        ax.set_ylabel("Positive Predictive Value", fontsize=12, color='gray')
        ax.set_title(site['name'], fontsize=14)
        ax.set_ylim(0, 1)
        plt.gca().tick_params(axis='y', colors='gray')
        plt.gca().spines['left'].set_color('gray')
        plt.gca().spines['left'].set_linewidth(1)
        plt.gca().tick_params(axis='x', colors='gray', which='both')
        plt.gca().spines['bottom'].set_color('gray')
        ax.grid(axis='y', linestyle='--', alpha=0.6)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        # plt.ylim(0.5, 1)
        plt.legend(loc='lower right')
        plt.tight_layout()
        plt.show()




In [None]:
# repeat positives last 1 years 

from datetime import timedelta


for cohort in ['worsen']:
    print(cohort)
    for site in sites:
        print(site['name'])
        if f"idx_{cohort}" not in site: continue 
        idx = site[f"idx_{cohort}"]
        set_idx = set(list(idx))
        true = site['true'][idx]
        pred = site['pred'][idx]
        sns.set_theme(style="whitegrid")
        fig, ax = plt.subplots(1, 1, figsize=(6,4), sharex=False)

        shades = generate_shades(site['color'], n=3)
    
        for sens, color in zip([.7, .8, .9], reversed(shades)): 
            threshold = train_data[cohort]['sens_to_thres'][sens]
            # decision = pred>(1-threshold)
            decision = pred>threshold
            df = site['pyd'].data.loc[[x in set_idx for x in range(site['pyd'].data.shape[0])],:]
            df.loc[:,'true'] = true
            df.loc[:,'score'] = pred
            df.loc[:,'pred'] = decision
            df = pl.from_pandas(df)

            repeat_positives = []
            for group in df.group_by("empi"):
                ecg = group[1].sort("ecg_date")
                for idx in range(ecg.height): 
                    start_date = ecg[idx, "ecg_date"]
                    max_date = start_date + timedelta(days=1 * 365)
                    chunk = ecg.filter(
                        (pl.col("ecg_date") >= start_date) & 
                        (pl.col("ecg_date") <= max_date)
                    ).with_columns(
                        (pl.col('pred').eq(True).cum_sum()).alias("keep")
                    ).filter(
                        pl.col("keep") > 0
                    ).drop(
                        "keep"
                    ).with_columns(
                        (pl.col('pred').eq(False).cum_sum()).alias("keep")
                    ).filter(
                        pl.col("keep") == 0
                    ).drop(
                        "keep"
                    ).with_row_count(
                        name='pos_count',offset=1
                    ).select(
                        ['pos_count','empi','true','pred']
                    )
                    repeat_positives.append(chunk)
            repeat_positives = pl.concat(repeat_positives)

            PPVs = []
            step = 1
            for st in range(0, 10, step): 
                end = st + step 
                if st==0: st=1
                x = total_positives.filter(pl.col('pos_count').is_between(st, end, closed='left'))
                if not x.height: 
                    continue 
                ppv = x.filter(pl.col('true')==1).height/x.height
                if end-st <= 1 : 
                    label = f"{st}"
                else: 
                    label = f"{st}-{end}"
                PPVs.append((label, ppv))

            ax = sns.barplot(x=[x[0] for x in PPVs], y=[x[1] for x in PPVs], label=sens, color=color, ax=ax)

        ax.set_xlabel("Successive positives", fontsize=12, color='gray')
        ax.set_ylabel("Positive Predictive Value", fontsize=12, color='gray')
        ax.set_title(site['name'], fontsize=14)
        ax.set_ylim(0, 1)
        plt.gca().tick_params(axis='y', colors='gray')
        plt.gca().spines['left'].set_color('gray')
        plt.gca().spines['left'].set_linewidth(1)
        plt.gca().tick_params(axis='x', colors='gray', which='both')
        plt.gca().spines['bottom'].set_color('gray')
        ax.grid(axis='y', linestyle='--', alpha=0.6)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        # plt.ylim(0.5, 1)
        plt.legend(loc='lower right')
        plt.tight_layout()
        plt.show()




In [None]:
# repeat positives 
for cohort in ['worsen']:
    print(cohort)
    for site in sites:
        print(site['name'])
        if f"idx_{cohort}" not in site: continue 
        idx = site[f"idx_{cohort}"]
        set_idx = set(list(idx))
        true = site['true'][idx]
        pred = site['pred'][idx]
        sns.set_theme(style="whitegrid")
        fig, ax = plt.subplots(1, 1, figsize=(6,4), sharex=False)

        shades = generate_shades(site['color'], n=3)
    
        for sens, color in zip([.7, .8, .9], reversed(shades)): 
            threshold = train_data[cohort]['sens_to_thres'][sens]
            # decision = pred>(1-threshold)
            decision = pred>threshold
            df = site['pyd'].data.loc[[x in set_idx for x in range(site['pyd'].data.shape[0])],:]
            df.loc[:,'true'] = true
            df.loc[:,'score'] = pred
            df.loc[:,'pred'] = decision
            df = pl.from_pandas(df)

            repeat_positives = []
            for group in df.group_by("empi"): 
                x = group[1].sort(
                    'ecg_date'
                ).with_columns(
                    (pl.col('pred').eq(True).cum_sum()).alias("keep")
                ).filter(
                    pl.col("keep") > 0
                ).drop(
                    "keep"
                ).with_columns(
                    (pl.col('pred').eq(False).cum_sum()).alias("keep")
                ).filter(
                    pl.col("keep") == 0
                ).drop(
                    "keep"
                ).with_row_count(
                    name='pos_count',offset=1
                ).select(
                    ['pos_count','empi','true','pred']
                )
                repeat_positives.append(x)
            repeat_positives = pl.concat(repeat_positives)

            PPVs = []
            step = 5
            for st in range(0, 50, step): 
                end = st + step 
                if st==0: st=1
                x = repeat_positives.filter(pl.col('pos_count').is_between(st, end, closed='left'))
                ppv = x.filter(pl.col('true')==1).height/x.height
                PPVs.append((f"{st}-{end}", ppv))

            ax = sns.barplot(x=[x[0] for x in PPVs], y=[x[1] for x in PPVs], label=sens, color=color, ax=ax)

        ax.set_xlabel("Successive positives", fontsize=12, color='gray')
        ax.set_ylabel("Positive Predictive Value", fontsize=12, color='gray')
        ax.set_title(site['name'], fontsize=14)
        ax.set_ylim(0, 1)
        plt.gca().tick_params(axis='y', colors='gray')
        plt.gca().spines['left'].set_color('gray')
        plt.gca().spines['left'].set_linewidth(1)
        plt.gca().tick_params(axis='x', colors='gray', which='both')
        plt.gca().spines['bottom'].set_color('gray')
        ax.grid(axis='y', linestyle='--', alpha=0.6)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        # plt.ylim(0.5, 1)
        plt.legend(loc='lower right')
        plt.tight_layout()
        plt.show()


In [None]:
# repeat negatives 
for cohort in ['all']:
    print(cohort)
    for site in sites:
        print(site['name'])
        if f"idx_{cohort}" not in site: continue 
        idx = site[f"idx_{cohort}"]
        set_idx = set(list(idx))
        true = site['true'][idx]
        pred = site['pred'][idx]
        sns.set_theme(style="whitegrid")
        fig, ax = plt.subplots(1, 1, figsize=(6,4), sharex=False)
        shades = generate_shades(site['color'], n=3)
        for sens, color in zip([.9, .8, .7], shades): 
            threshold = train_data[cohort]['sens_to_thres'][sens]
            # decision = pred>(1-threshold)
            decision = pred>threshold
            df = site['pyd'].data.loc[[x in set_idx for x in range(site['pyd'].data.shape[0])],:]
            df.loc[:,'true'] = true
            df.loc[:,'score'] = pred
            df.loc[:,'pred'] = decision
            df = pl.from_pandas(df)

            repeat_negatives = []
            for group in df.group_by("empi"): 
                x = group[1].sort(
                    'ecg_date'
                ).with_columns(
                    (pl.col('pred').eq(False).cum_sum()).alias("keep")
                ).filter(
                    pl.col("keep") > 0
                ).drop(
                    "keep"
                ).with_columns(
                    (pl.col('pred').eq(True).cum_sum()).alias("keep")
                ).filter(
                    pl.col("keep") == 0
                ).drop(
                    "keep"
                ).with_row_count(
                    name='neg_count',offset=1
                ).select(
                    ['neg_count','empi','true','pred']
                )
                repeat_negatives.append(x)
            repeat_negatives = pl.concat(repeat_negatives)

            NPVs = []
            step = 5
            for st in range(0, 50, step): 
                end = st + step 
                if st==0: st=1
                x = repeat_negatives.filter(pl.col('neg_count').is_between(st, end, closed='left'))
                npv = x.filter(pl.col('true')==0).height/x.height
                NPVs.append((f"{st}-{end}", npv))

            ax = sns.barplot(x=[x[0] for x in NPVs], y=[x[1] for x in NPVs], label=sens, color=color, ax=ax)
        
        ax.set_xlabel("Successive negatives", fontsize=12, color='gray')
        ax.set_ylabel("Negative Predictive Value", fontsize=12, color='gray')
        ax.set_title(site['name'], fontsize=14)
        ax.set_ylim(0, 1)
        plt.gca().tick_params(axis='y', colors='gray')
        plt.gca().spines['left'].set_color('gray')
        plt.gca().spines['left'].set_linewidth(1)
        plt.gca().tick_params(axis='x', colors='gray', which='both')
        plt.gca().spines['bottom'].set_color('gray')
        ax.grid(axis='y', linestyle='--', alpha=0.6)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        # plt.ylim(0.5, 1)
        plt.legend(loc='lower right')
        plt.tight_layout()
        plt.show()


In [None]:
# repeat false positives 
for cohort in ['all']:
    print(cohort)
    for site in sites:
        print(site['name'])
        if f"idx_{cohort}" not in site: continue 
        idx = site[f"idx_{cohort}"]
        set_idx = set(list(idx))
        true = site['true'][idx]
        pred = site['pred'][idx]
        fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(24, 6), sharex=False)
        for (sens, threshold) in train_data[cohort]['sens_to_thres'].items(): 
            if sens!=.7: continue
            # decision = pred>(1-threshold)
            decision = pred>threshold
            df = site['pyd'].data.loc[[x in set_idx for x in range(site['pyd'].data.shape[0])],:]
            df.loc[:,'true'] = true
            df.loc[:,'score'] = pred
            df.loc[:,'pred'] = decision
            print(get_sens(true, pred, threshold=threshold))
            df = pl.from_pandas(df).with_columns(
                pl.when(
                    (pl.col('pred')!=pl.col('true')) # incorrect predictions 
                    &
                    (pl.col('pred')==1) # predictions are that EF will fall
                ).then(True).otherwise(False).alias('false_positive'),
                pl.when(
                    (pl.col('pred')==pl.col('true')) # correct predictions 
                    &
                    (pl.col('pred')==0) # predictions are that EF stay > 40
                ).then(True).otherwise(False).alias('true_negative'),
            )

            repeat_false_positives = []
            for group in df.group_by("empi"): 
                df = group[1].sort(['empi','ecg_date'])
                fps = 0
                first_ecg_date = None 
                for row in df.to_dicts(): 
                    if row['false_positive']: 
                        if fps == 0: 
                            first_ecg_date = row['ecg_date']
                        fps += 1
                        repeat_false_positives.append({
                            'empi':row['empi'], 
                            'ecg_date':row['ecg_date'],
                            'first_ecg_date':first_ecg_date,
                            'fp_count':fps,
                        })
                    else: 
                        fps = 0

            if site['split'] == 'mimic': 
                lvef = pl.read_csv(
                    '/storage/shared/mimic/raw/lvef.csv'
                ).rename({
                    'subject_id':'empi',
                    'study_datetime':'lvef_date',
                    'result':'lvef'
                }).drop(
                    'measurement'
                ).filter(
                    pl.col('lvef').is_not_null()
                ).with_columns(
                    pl.col('lvef').cast(pl.Int64),
                    pl.col("lvef_date").str.strptime(pl.Datetime, "%Y-%m-%dT%H:%M:%S").dt.date()
                ).filter(
                    (pl.col('lvef')>0) & (pl.col('lvef')<100)
                ).sort('lvef_date')
            else: 
                lvef = pl.read_parquet('/storage2/payal/dropbox/private/data/processed/lvef.parquet').sort('lvef_date') 

            N_values = []
            examples = []
            any_bad_percent = []
            days_to_bad = []
            for N in range(1,20): 
                N_fps = [x for x in repeat_false_positives if x['fp_count']==N]
                if not len(N_fps): continue 
                any_bad_tally = 0
                times_to_first_bad = []
                for x in N_fps: 
                    date = x['first_ecg_date']
                    lvef_after_warning = lvef.filter(pl.col('empi')==x['empi']).filter(pl.col('lvef_date')>=date)
                    any_bad = not lvef_after_warning.filter(pl.col('lvef')<=40).is_empty()
                    any_bad_tally += any_bad
                    time_to_first_bad = None 
                    if any_bad: 
                        time_to_first_bad = lvef_after_warning.filter(pl.col('lvef')<=40).sort('lvef_date').head(1).select('lvef_date').item() - date.date()
                        times_to_first_bad.append(time_to_first_bad)
                percent = round(100*any_bad_tally/len(N_fps))
                days = np.mean(times_to_first_bad).days if times_to_first_bad else 0
                N_values.append(N)
                any_bad_percent.append(percent)
                days_to_bad.append(days)
                examples.append(len(N_fps))
            
            x_max = max(N_values)+1
            ax2.plot(N_values, any_bad_percent, marker='o', linestyle='-', label=sens, zorder=2)
            # ax2.set_ylim(0, 105)
            # ax2.set_xlim(0, x_max)
            # ax2.set_xticks(range(0, x_max, 5))
            # ax2.set_yticks(range(0, 101, 20))
            ax2.set_title("How often is low LVEF observed in the future?", fontsize=14)
            ax2.set_xlabel("Subsequent false positive predictions", fontsize=12)
            ax2.set_ylabel("Percentage", fontsize=12)
            ax2.grid(axis='y', linestyle='--', alpha=0.6)
            ax2.spines['top'].set_visible(False)
            ax2.spines['right'].set_visible(False)

            ax3.plot(N_values, days_to_bad, marker='o', linestyle='-', label=sens, zorder=2)
            ax3.set_xlim(0, 20)
            ax3.set_ylim(0, 1000)
            # ax3.set_xticks(range(0, x_max, 5))
            ax3.set_title(f"{site['name']}", fontsize=14)
            ax3.set_xlabel("Subsequent false positive predictions", fontsize=12)
            ax3.set_ylabel("Days", fontsize=12)
            ax3.grid(axis='y', linestyle='--', alpha=0.6)
            ax3.spines['top'].set_visible(False)
            ax3.spines['right'].set_visible(False)

            ax1.plot(N_values, examples, marker='o', linestyle='-', label=sens, zorder=2)
            # ax1.set_ylim(0, max(examples) + 50)
            # ax1.set_xlim(0, x_max)
            # ax1.set_xticks(range(0, x_max, 5))
            # ax1.set_yticks(range(0, max(examples) + 50, 100))
            ax1.set_title("Number of examples for each N-value", fontsize=14)
            ax1.set_xlabel("Subsequent false positive predictions", fontsize=12)
            ax1.set_ylabel("Number of ECGs", fontsize=12)
            ax1.grid(axis='y', linestyle='--', alpha=0.6)
            ax1.spines['top'].set_visible(False)
            ax1.spines['right'].set_visible(False)

        plt.legend([])
        plt.tight_layout()
        plt.show()


In [None]:
# ORs false positives 
import scipy.stats as stats

cohort = 'all'

for site in sites: 
    print(site['name'])

    idx = site[f"idx_{cohort}"]
    set_idx = set(list(idx))
    true = site['true'][idx]
    pred = site['pred'][idx]
    for (sens, threshold) in train_data[cohort]['sens_to_thres'].items(): 
        if sens==.9: 
            break
    # decision = pred>(1-threshold)
    decision = pred>threshold
    df = site['pyd'].data.loc[[x in set_idx for x in range(site['pyd'].data.shape[0])],:]
    df.loc[:,'true'] = true
    df.loc[:,'score'] = pred
    df.loc[:,'pred'] = decision
    df = pl.from_pandas(df).with_columns(
        pl.when(
            (pl.col('pred')!=pl.col('true')) # incorrect predictions 
            &
            (pl.col('pred')==1) # predictions are that EF will fall
        ).then(True).otherwise(False).alias('false_positive'),
        pl.when(
            (pl.col('pred')==pl.col('true')) # correct predictions 
            &
            (pl.col('pred')==0) # predictions are that EF stay > 40
        ).then(True).otherwise(False).alias('true_negative'),
        pl.when(
            (pl.col('pred')==pl.col('true')) # correct predictions 
            &
            (pl.col('pred')==1) # predictions are that EF stay > 40
        ).then(True).otherwise(False).alias('true_positive'),
    )
    
    pt_w_tp = df.filter(pl.col('true_positive')).select('empi').unique()['empi'].to_list()
    df = df.filter(
        ~ pl.col('empi').is_in(pt_w_tp)
    ).filter(
        pl.col('days_since_diagnosis') < 30
    )

    repeat_false_positives = []
    no_false_positives = []
    for group in df.group_by("empi"): 
        x = group[1].sort(['empi','ecg_date'])
        fps = 0
        first_ecg_date = None 
        for row in x.to_dicts(): 
            if row['false_positive']: 
                if fps == 0: 
                    first_ecg_date = row['ecg_date']
                fps += 1
                repeat_false_positives.append({
                    'empi':row['empi'], 
                    'ecg_date':row['ecg_date'],
                    'first_ecg_date':first_ecg_date,
                    'fp_count':fps,
                })
            else: 
                fps = 0
        if fps == 0: # no false positives ever 
            no_false_positives.append(row['empi'])
            

    if site['split'] == 'mimic': 
        lvef = pl.read_csv(
            '/storage/shared/mimic/raw/lvef.csv'
        ).rename({
            'subject_id':'empi',
            'study_datetime':'lvef_date',
            'result':'lvef'
        }).drop(
            'measurement'
        ).filter(
            pl.col('lvef').is_not_null()
        ).with_columns(
            pl.col('lvef').cast(pl.Int64),
            pl.col("lvef_date").str.strptime(pl.Datetime, "%Y-%m-%dT%H:%M:%S").dt.date()
        ).filter(
            (pl.col('lvef')>0) & (pl.col('lvef')<100)
        ).sort('lvef_date')
    else: 
        lvef = pl.read_parquet('/storage2/payal/dropbox/private/data/processed/lvef.parquet').sort('lvef_date') 

    def tmp(df, DAYS): 
        return df.select(
            ['empi','ecg_date']
        ).unique(
        ).join(
            lvef, 
            on='empi',
            how='left'
        ).filter(
            pl.col('lvef_date') >= pl.col('ecg_date')
        ).with_columns(
            (pl.col('lvef_date') - pl.col('ecg_date')).dt.total_days().alias('ef_diff')
        ).filter(
            pl.col('ef_diff') <= DAYS
        ).sort(
            ['empi', 'lvef']
        ).unique(
            subset=['empi'], keep='first'
        )

    DAYS_LIST  = [30, 90, 180, 270, 365, 540, 720]
    odds_ratios = []
    odds_errors = []
    p_values = []
    for DAYS in DAYS_LIST:

        no = df.filter(
            pl.col('empi').is_in(no_false_positives)
        ).pipe(tmp, DAYS=DAYS)

        rfp = pl.DataFrame(
            repeat_false_positives
        ).filter(
            pl.col('fp_count') > 1
        ).pipe(tmp, DAYS=DAYS)

        a = rfp.filter(pl.col('lvef')<=40).height
        b = rfp.filter(pl.col('lvef')>40).height
        c = no.filter(pl.col('lvef')<=40).height
        d = no.filter(pl.col('lvef')>40).height

        odds_ratio = (a * d) / (b * c)

        se_log_or = np.sqrt(1/a + 1/b + 1/c + 1/d)
        ci_lower = np.exp(np.log(odds_ratio) - 1.96 * se_log_or)
        ci_upper = np.exp(np.log(odds_ratio) + 1.96 * se_log_or)

        odds_ratio, p_value = stats.fisher_exact([[a, b], [c, d]])

        odds_ratios.append(odds_ratio)
        odds_errors.append([odds_ratio-ci_lower, ci_upper-odds_ratio])
        p_values.append(p_value)

        print(f"{DAYS} \t p {p_value:.3f} \t Odds {odds_ratio:.2f} \t CI {ci_lower:.2f}–{ci_upper:.2f}")


    fig = barplot([int(x/30) for x in DAYS_LIST], odds_ratios, yerr=np.array(odds_errors).T, decimal=1)
    ax = fig.axes[0]  
    for bar, p in zip(ax.patches, p_values):
        if p < 0.05: 
            ax.text(bar.get_x() + bar.get_width() / 2, 
                    ax.get_ylim()[1] * .99,
                    '*', 
                    ha='center', va='bottom', fontsize=12, fontweight='bold', color='black')

    plt.xlabel('Months')
    plt.ylabel('Odds Ratio')
    plt.show()



In [22]:
def load_lvef_data(site):
    """Load and preprocess LVEF data based on site source."""
    if site['split'] == 'mimic':
        lvef = pl.read_csv('/storage/shared/mimic/raw/lvef.csv').rename({
            'subject_id': 'empi',
            'study_datetime': 'lvef_date',
            'result': 'lvef'
        }).drop('measurement').filter(
            pl.col('lvef').is_not_null()
        ).with_columns(
            pl.col('lvef').cast(pl.Int64),
            pl.col("lvef_date").str.strptime(pl.Datetime, "%Y-%m-%dT%H:%M:%S").dt.date()
        ).filter((pl.col('lvef') > 0) & (pl.col('lvef') < 100))
    else:
        lvef = pl.read_parquet('/storage2/payal/dropbox/private/data/processed/lvef.parquet')

    return lvef.sort('lvef_date')

def plot_trajectory(pt, site):
    """Plot LVEF trajectory and ECG likelihood with TP/FP/TN/FN labels."""
    threshold = train_data['all']['sens_to_thres'][0.7]
    lvef = load_lvef_data(site)
    idx = site['pyd'].data.query('empi == @pt').sort_values('ecg_date').index.values

    fig, ax1 = plt.subplots(figsize=(10, 6))

    # Configure LVEF Axis (Left Y-Axis)
    ax1.set_ylim(0, 100)
    ax1.set_yticks(range(0, 101, 20))
    ax1.tick_params(axis='y', colors='gray')
    ax1.spines['left'].set_position(('outward', 10))
    ax1.spines['left'].set_color('gray')
    ax1.spines['left'].set_linewidth(1)
    
    # Extract LVEF data for patient
    lvef_data = lvef.filter(pl.col('empi') == pt)
    lvef_dates = pd.to_datetime(lvef_data['lvef_date'].to_numpy().flatten())
    lvef_values = lvef_data['lvef'].to_numpy().flatten()
    lvef_colors = ['silver' if val > 40 else 'firebrick' for val in lvef_values]

    # Plot LVEF trajectory
    ax1.scatter(lvef_dates, lvef_values, c=lvef_colors, marker='o', zorder=2)
    for i in range(len(lvef_dates)-1): 
        ax1.plot(lvef_dates[i:i+2], lvef_values[i:i+2], c=lvef_colors[i+1], zorder=1)  

    xlim_dates = lvef_dates
    ax1.set_xlim(left=xlim_dates.min().replace(month=1, day=1),
                right=xlim_dates.max().replace(month=12, day=31, year=max(xlim_dates).year + 1))
    ax1.xaxis.set_major_locator(mdates.YearLocator(1))
    ax1.xaxis.set_major_formatter(mdates.DateFormatter('%Y'))

    ax1.tick_params(axis='x', colors='gray', which='both')
    ax1.spines['bottom'].set_color('gray')

    ax1.grid(True, axis='x', which='major', linestyle='--', alpha=0.7)
    ax1.grid(True, axis='x', which='minor', linestyle=':', alpha=0.5)

    # Shared x-axis tick styling
    plt.setp(ax1.xaxis.get_majorticklabels(), rotation=45, ha='right', color='gray')
    plt.setp(ax1.xaxis.get_minorticklabels(), rotation=45, ha='right', color='gray', fontsize='small')

    # Axes styling
    ax1.spines['bottom'].set_position(('data', 0))  
    ax1.spines['bottom'].set_linewidth(1)

    ax1.spines['top'].set_visible(False)
    ax1.spines['right'].set_visible(False)
    
    # Configure ECG Likelihood Axis (Right Y-Axis)
    ax2 = ax1.twinx()
    ax2.set_ylim(0, 1.005)
    ax2.set_yticks(np.linspace(0, 1, 11))
    ax2.tick_params(axis='y', colors='gray')
    ax2.spines['right'].set_position(('outward', 10))
    ax2.spines['right'].set_color('gray')
    ax2.spines['right'].set_linewidth(1)

    # Extract ECG data
    dates_ecg = np.array([site['pyd'].data.loc[i, 'ecg_date'].date() for i in idx])
    ecg_preds = site['pred'][idx]
    ecg_true = site['true'][idx]

    # Define labels for TP, FP, TN, FN
    texts, colors, labels = [], [], []
    for pred, true in zip(ecg_preds, ecg_true):
        if pred > threshold and true == 1:
            texts.append(r'$\checkmark$') # TP
            colors.append(MY_NAVY)
            labels.append('True Positive')
        elif pred > threshold and true == 0:
            texts.append(r'$\mathsf{FP}$') # FP
            colors.append(MY_NAVY)
            labels.append('False Positive')
        elif pred <= threshold and true == 0:
            texts.append(r'$\times$' ) # TN
            colors.append('silver')
            labels.append('True Negative')
        else:
            texts.append(r'$\mathsf{FN}$') 
            colors.append('silver')
            labels.append('False Negative')

    # Plot ECG predictions with TP/FP/TN/FN labels
    for date, pred, text, color, label in zip(dates_ecg, ecg_preds, texts, colors, labels):
        ax2.text(date, pred, text, fontsize=10, ha='center', va='bottom', color=color, label=label)
    
    # Configure X-Axis
    ax1.set_xlim(left=lvef_dates.min().replace(month=1, day=1),
                 right=lvef_dates.max().replace(month=12, day=31, year=lvef_dates.max().year ))
    ax1.xaxis.set_major_locator(mdates.YearLocator(1))
    ax1.xaxis.set_major_formatter(mdates.DateFormatter('%Y'))
    ax1.tick_params(axis='x', colors='gray', rotation=45)

    # LVEF threshold line
    ax1.axhline(y=40, color='firebrick', linestyle='--')
    ax1.text(pd.Timestamp(lvef_dates.min().replace(month=3, day=1)), 40+1, "40% LVEF", color='firebrick', fontsize=10, va='bottom', ha='left')

    # ECG threshold line
    ax2.axhline(y=threshold,  color=MY_NAVY, linestyle='--')
    ax2.text(pd.Timestamp(lvef_dates.max().replace(month=9, day=30, year=lvef_dates.max().year -1  )), 
             threshold+0.01, "70% Sensitivity", color=MY_NAVY, fontsize=10, va='bottom', ha='left')

    ax2.spines['bottom'].set_visible(False)
    ax2.spines['top'].set_visible(False)
    ax2.spines['left'].set_visible(False)

    # Align ticks on both axes
    ax1.tick_params(axis='y', direction='in', pad=5)
    ax2.tick_params(axis='y', direction='in', pad=5)

    # Hide unnecessary spines
    for ax in [ax1, ax2]:
        ax.spines['top'].set_visible(False)

    ax1.set_ylabel('LVEF', color='gray', labelpad=-5)
    ax2.set_ylabel('Probability of Worsening LVEF for ECG', color='gray', rotation=270, labelpad=5, va='bottom')


    plt.tight_layout()
    plt.show()

In [None]:
for pt in [100223049,100202452,101091462, 103200474, 105812520, 100358399, 100615898, 101137610, 103401384]: 
    plot_trajectory(pt, sites[0])
    break

In [None]:
# single lead model comparisons
auc_mean = pl.read_csv('/storage2/payal/Dropbox (Partners HealthCare)/private/SILVER/src/auc.csv')
auc_upper = pl.read_csv('/storage2/payal/Dropbox (Partners HealthCare)/private/SILVER/src/auc_upper.csv')
auc_lower = pl.read_csv('/storage2/payal/Dropbox (Partners HealthCare)/private/SILVER/src/auc_lower.csv')

cols = [x for x in auc_mean.columns if 'AUROC' in x]
cols = [x for x in cols if not x.endswith('MIN') and not x.endswith('MAX')]
auc_mean = auc_mean.select(cols)

cols = [x for x in auc_upper.columns if 'AUROC' in x]
cols = [x for x in cols if not x.endswith('MIN') and not x.endswith('MAX')]
auc_upper = auc_upper.select(cols)

cols = [x for x in auc_lower.columns if 'AUROC' in x]
cols = [x for x in cols if not x.endswith('MIN') and not x.endswith('MAX')]
auc_lower = auc_lower.select(cols)

df = pl.concat([auc_mean,auc_lower,auc_upper], how='horizontal')
df = df.with_columns([pl.when(pl.col(col) == "").then(None).otherwise(pl.col(col)).alias(col) for col in df.columns])
df = df.with_columns([pl.col(col).cast(pl.Float64).alias(col) for col in df.columns])
df = df.rename({col: '_'.join([col.split('_')[0],col.split('_')[-1]]) for col in df.columns})
df = df.select([pl.col(col).drop_nulls().first().alias(col) for col in df.columns])
df = df.to_pandas()

lead_names = [col.replace("_mean", "") for col in df.columns if "_mean" in col]
means = df[[col for col in df.columns if "_mean" in col]].values.flatten()
lower = df[[col for col in df.columns if "_lower" in col]].values.flatten()
upper = df[[col for col in df.columns if "_upper" in col]].values.flatten()
errors = np.abs(np.array([*zip((lower-means),(upper-means))]))

plt.figure(figsize=(6,4))
plt.errorbar(
    lead_names,
    means,
    yerr=errors.T,
    label=f"Sensitivity {sens}%",
    capsize=3,
    marker="o",
    markersize=3,
    linestyle="",
    linewidth=1,
    color=MY_NAVY,
)
plt.gca().set_xlabel("Single Lead", fontsize=12, color='gray')
plt.gca().set_ylabel("AUROC", fontsize=12, color='gray')
plt.gca().tick_params(axis='y', colors='gray')
plt.gca().spines['left'].set_color('gray')
plt.gca().spines['left'].set_linewidth(1)
plt.gca().tick_params(axis='x', colors='gray', which='both')
plt.gca().spines['bottom'].set_color('gray')
plt.xticks(rotation=45)
plt.ylim(.75,1)
plt.gca().axhline(y=.926,  color=MY_NAVY, linestyle='--')
plt.grid(True)
sns.despine()
plt.tight_layout()
plt.show()