In [1]:
# import 
import numpy as np
from scipy.stats import chi2
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import polars as pl
import pandas as pd
import seaborn as sns
import torch
import lightning as L
from sklearn.calibration import calibration_curve
from sklearn.metrics import roc_auc_score, confusion_matrix, roc_curve
from omegaconf import OmegaConf, DictConfig
import hydra
import wandb
from dataset import SupervisedDataset
from lightning_modules import SupervisedTask
from models.ecg_models import *
from run import interpolate
pl.Config.set_tbl_rows(50)

MY_NAVY = '#001F54'
COLOR_BELOW = 'firebrick'
COLOR_ABOVE = 'gold'
LINE_BELOW = '--'
LINE_ABOVE = '-.'
LABEL_ALL='Any LVEF'
LABEL_ABOVE='Last LVEF > 40%'
LABEL_BELOW='Last LVEF < 40%'

  torch.utils._pytree._register_pytree_node(


In [None]:
# demographics 
df = pl.read_parquet('/storage2/payal/Dropbox (Partners HealthCare)/private/SILVER/data/data.parquet')
df = df.filter(pl.col('split')!='external')

for c in [
 'diabetes_mellitus',
 'hypertension',
 'atherosclerosis',
 'chronic_obstructive_pulmonary_disease',
 'atrial_fibrillation',
 'slgt2',
 'angio',
 'betablocker',
 'mra',
 'diuretic',
]:
    print(c, round(100* df.select(['empi',c]).group_by('empi').sum().filter(pl.col(c)>0).unique('empi').height / df.unique('empi').height, 1), '%')


dem = pl.read_parquet(
    '/storage2/payal/dropbox/private/data/processed/demographic.parquet'
).join(
    df, on='empi', how='inner'
).with_columns(
    ((pl.col('ecg_date')-pl.col('date_of_birth')).cast(pl.Duration).dt.total_days()/365.25).alias('age')
)
# 
# dem.select(['empi','age']).mean(), dem.select(['empi','age']).std()
dem.unique('empi').group_by('sex').len()

In [2]:
# model setup
device = 'cuda:1'
cfg = OmegaConf.create(wandb.Api().run("payal-collabs/SILVER/f7apd99r").config)
L.seed_everything(cfg.utils.seed)
train_pyd = hydra.utils.instantiate(cfg.dataset, split='train')
cfg = interpolate(cfg, train_pyd)
del train_pyd
trainer = L.Trainer(devices=[int(device[-1])]) 
LM = SupervisedTask.load_from_checkpoint(cfg.best_model_path, map_location=torch.device(device))
model = LM.model
model.to(device)
model.eval()
cfg.optimizer.batch_size = 128
cfg.dataset.config.label = 'future_1_365_any_below_40'

Seed set to 140799
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [3]:
split = 'test' # test, external
if split == 'mimic': 
    cfg.dataset.config.datadir = '/storage/shared/mimic/'
    cfg.dataset.config.ecg.storedir = '/storage/shared/mimic/raw/ecg/'

In [4]:
# predictions 
pyd = hydra.utils.instantiate(cfg.dataset, split=split)
pyd.data = pyd.data.reset_index(drop=1)
assert len(pyd)
loader = torch.utils.data.DataLoader(
    dataset = pyd[:5],
    batch_size = cfg.optimizer.batch_size,
    num_workers = 0, 
    collate_fn = pyd.collate,
    shuffle=False,
    pin_memory=True
)
out = trainer.predict(LM, loader)
pred = np.array(torch.sigmoid(torch.cat([x[0] for x in out])).tolist())
true = np.array(torch.cat([x[1] for x in out]).tolist())
idx_below = pyd.data.query('tag_hfref').index.values
idx_above = pyd.data.query('~tag_hfref').index.values
idx_no_com = pyd.data.query('~tag_hfref').query('hypertension==0').query('diabetes_mellitus==0').query('atherosclerosis==0').query('chronic_obstructive_pulmonary_disease==0').query('atrial_fibrillation==0').index.values

TypeError: expected str, bytes or os.PathLike object, not dict

In [35]:
idx_no_med = pyd.data.query('~tag_hfref').query('angio==0').query('betablocker==0').query('mra==0').query('diuretic==0').index.values
idx_healthy = pyd.data.query('~tag_hfref').query('angio==0').query('betablocker==0').query('mra==0').query('diuretic==0').query('hypertension==0').query('diabetes_mellitus==0').query('atherosclerosis==0').query('chronic_obstructive_pulmonary_disease==0').query('atrial_fibrillation==0').index.values

In [63]:
# shap test 
pyd = hydra.utils.instantiate(cfg.dataset, split=split)
pyd.data = pyd.data.reset_index(drop=1)
loader = torch.utils.data.DataLoader(
    dataset = pyd,
    batch_size = 2,
    num_workers = 0, 
    collate_fn = pyd.collate,
    shuffle=False,
    pin_memory=True,
)
def move_to_device(batch, device):
    if isinstance(batch, dict):  # If batch is a dictionary
        return {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
    elif isinstance(batch, (list, tuple)):  # If batch is a list or tuple
        return type(batch)(v.to(device) if isinstance(v, torch.Tensor) else v for v in batch)
    elif isinstance(batch, torch.Tensor):  # If batch is a tensor
        return batch.to(device)
    else:
        raise TypeError(f"Unsupported batch type: {type(batch)}")

for batch in loader: 
    batch = move_to_device(batch, device)
    break

import shap
class ShapModule(torch.nn.Module): 
    def __init__(self, model, device): 
        super(ShapModule, self).__init__()
        self.model = model 
        self.device = device 

    def forward(self, x): 
        outputs = self.model.forward(x)
        outputs = outputs['loss'].unsqueeze(0).unsqueeze(1)
        return outputs
shap_model = ShapModule(model, device)
explainer = shap.DeepExplainer(model=shap_model, data=[batch])
explainer.shap_values([batch])

In [None]:
# saliency 
def move_to_device(batch, device):
    if isinstance(batch, dict):  # If batch is a dictionary
        return {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
    elif isinstance(batch, (list, tuple)):  # If batch is a list or tuple
        return type(batch)(v.to(device) if isinstance(v, torch.Tensor) else v for v in batch)
    elif isinstance(batch, torch.Tensor):  # If batch is a tensor
        return batch.to(device)
    else:
        raise TypeError(f"Unsupported batch type: {type(batch)}")

def plot_saliency(batch, idx=0):
    batch = {k:v[idx:idx+2] for (k,v) in batch.items()}
    batch['ecg'].requires_grad_()
    output = model(batch)
    model.zero_grad()
    output['loss'].backward()
    saliency = batch['ecg'].grad.abs()

    ecg = batch['ecg'][0].detach().cpu().numpy()  
    saliency = saliency[0].detach().cpu().numpy()  
    saliency = (saliency - saliency.min()) / (saliency.max() - saliency.min())
    num_leads = ecg.shape[0]  
    time_steps = ecg.shape[1]  
    fig, axes = plt.subplots(12, 1, figsize=(15, 20), sharex=True)
    for lead_idx in range(12):
        ax = axes[lead_idx]
        time = np.arange(time_steps)
        for i in range(time_steps - 1):
            ax.plot( time[i:i+2], ecg[lead_idx, i:i+2], color='gray', linewidth=1)
        for val in [0, 0.25, 0.5, 0.75]:
            c_idx = np.argwhere(saliency[lead_idx]>val)
            ax.scatter(time[c_idx], ecg[lead_idx][c_idx], c=saliency[lead_idx][c_idx], cmap='Blues', s=100, alpha=0.2*val)
            ax.scatter(time[c_idx], ecg[lead_idx][c_idx], c=saliency[lead_idx][c_idx], cmap='Blues', s=50, alpha=0.7*val)
        ax.set_title(f"Lead {lead_idx + 1}")
        ax.set_ylabel("Amplitude")
    axes[-1].set_xlabel("Time Steps")
    plt.tight_layout()
    plt.show()

for batch in loader: 
    pos_mask = batch['label']==1
    batch = {k:v[pos_mask] for (k,v) in batch.items()}
    batch = move_to_device(batch, device)
    plot_saliency(batch, np.random.choice(pos_mask.sum().item()))


In [None]:
# calibration 

strategy = 'uniform' # 'quantile'

fig, ax = plt.subplots(figsize=(10, 6))
plt.plot([0, 1], [0, 1], linestyle='-', color='silver', alpha=0.5)

for idx, lst, lbl in [(None, '-', LABEL_ALL), (idx_above, LINE_ABOVE, LABEL_ABOVE), (idx_below, LINE_BELOW, LABEL_BELOW)]:
    
    true_here = true[idx].reshape(-1)
    pred_here = pred[idx].reshape(-1)

    ci_percentile = 95
    bootstrap_results = []
    n_samples = len(true_here)
    for _ in range(1000):
        indices = np.random.choice(np.arange(n_samples), size=n_samples, replace=True)
        true_sample = true_here[indices]
        pred_sample = pred_here[indices]
        prob_true, prob_pred = calibration_curve(true_sample, pred_sample, strategy=strategy, n_bins=10)
        bootstrap_results.append(prob_true)
    lower_bound = np.percentile(bootstrap_results, (100 - ci_percentile) / 2, axis=0)
    upper_bound = np.percentile(bootstrap_results, 100 - (100 - ci_percentile) / 2, axis=0)

    prob_true, prob_pred = calibration_curve(true_here, pred_here, strategy=strategy, n_bins=10)

    # plt.fill_between(prob_pred, lower_bound, upper_bound, color=MY_NAVY, alpha=0.4, linewidth=0.5, interpolate=False)
    # plt.plot(prob_pred, prob_true, marker='o', color=MY_NAVY, linestyle=lst, label='Calibration curve')
    plt.errorbar(
        prob_pred,
        prob_true,
        yerr=np.abs(np.vstack([lower_bound-prob_true, upper_bound-prob_true])),
        capsize=3,
        marker='o',
        markersize=3,
        linestyle=lst,
        linewidth=1,
        color=MY_NAVY,
        label=lbl,
    )
    
plt.xlabel('Mean predicted probability', color='gray')
plt.ylabel('Fraction of positives', color='gray')
sns.despine()
plt.tight_layout()
plt.xlim(0,1)
plt.ylim(0,1)
handles, labels = ax.get_legend_handles_labels()
handles = [h[0] for h in handles]
ax.legend(handles, labels, markerscale=0.001)
ax.tick_params(axis='y', colors='gray')
ax.spines['left'].set_color('gray')
ax.spines['left'].set_linewidth(1)
ax.tick_params(axis='x', colors='gray', which='both')
ax.spines['bottom'].set_color('gray')
plt.show()


In [None]:
# hosmer lemeshow test
def hosmer_lemeshow_test(y_true, y_pred_probs, g=10):
    """
    Perform the Hosmer–Lemeshow test for goodness-of-fit.

    Parameters:
    y_true (array-like): True binary outcomes (0 or 1).
    y_pred_probs (array-like): Predicted probabilities from the model.
    g (int): Number of groups to divide the data into (default is 10).

    Returns:
    hl_stat (float): Hosmer–Lemeshow test statistic.
    p_value (float): Corresponding p-value.
    """
    # Create a DataFrame with actual and predicted values
    data = pd.DataFrame({'y_true': y_true, 'y_pred_probs': y_pred_probs})

    # Create quantile-based bins
    data['bin'] = pd.qcut(data['y_pred_probs'], q=g, duplicates='drop')

    # Group data by bins
    grouped = data.groupby('bin')

    # Calculate observed and expected frequencies
    obs_freq = grouped['y_true'].sum()
    exp_freq = grouped['y_pred_probs'].sum()
    n = grouped.size()
    p = exp_freq / n

    # Avoid division by zero
    temp = p * (1 - p)
    temp[temp == 0] = 1e-10

    # Calculate the Hosmer–Lemeshow statistic
    hl_stat = np.sum(((obs_freq - exp_freq) ** 2) / (n * temp))

    # Degrees of freedom
    dof = g - 2

    # Calculate p-value
    p_value = 1 - chi2.cdf(hl_stat, dof)

    return hl_stat, p_value

hl_stat, p_value = hosmer_lemeshow_test(true, pred)
p_value

In [None]:
roc_auc_score(true, pred), roc_auc_score(true[idx_above], pred[idx_above]), roc_auc_score(true[idx_no_com], pred[idx_no_com]), roc_auc_score(true[idx_no_med], pred[idx_no_med]), roc_auc_score(true[idx_healthy], pred[idx_healthy]),

In [None]:
# auroc 

fpr, tpr, thresholds = roc_curve(true, pred)
sens = tpr
spec = 1-fpr

fpr_below, tpr_below, thresholds_below = roc_curve(true[idx_below], pred[idx_below])
sens_below = tpr_below
fpr_above, tpr_above, thresholds_above = roc_curve(true[idx_above], pred[idx_above])
sens_above = tpr_above

plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color=MY_NAVY, label=LABEL_ALL)
plt.plot(fpr_above, tpr_above, color=MY_NAVY, linestyle=LINE_ABOVE, label=LABEL_ABOVE)
plt.plot(fpr_below, tpr_below, color=MY_NAVY, linestyle=LINE_BELOW, label=LABEL_BELOW)
plt.plot([0, 1], [0, 1], linestyle="--", color="silver")
plt.xlabel("False Positive Rate", color="gray")
plt.ylabel("True Positive Rate", color="gray")
plt.xlim(0, 1)
plt.ylim(0, 1.0005)
plt.gca().tick_params(axis='y', colors='gray')
plt.gca().spines['left'].set_color('gray')
plt.gca().spines['left'].set_linewidth(1)
plt.gca().tick_params(axis='x', colors='gray', which='both')
plt.gca().spines['bottom'].set_color('gray')
plt.legend()
sns.despine()
plt.grid(visible=False)
plt.tight_layout()
plt.show()


In [None]:
# dictionary sensitivity to thresholds 
dct = {}
for i in np.concatenate([np.arange(0.1, 1, 0.1), np.arange(0.8, 1.00001, 0.01)]):
    i = round(i, 2)
    for j in range(6,0,-1):
        idx = np.where(np.round(spec, j) == i)[0]
        if idx.any(): 
            break
    dct[i] = thresholds[idx]
dct

In [None]:
# compute specificity, ppv, npv with bootstrap 

def compute_metrics_bootstrap(true, pred_probs, prevalence, num_bootstrap=1000, confidence=0.95):

    _, sens_here, thresholds_here = roc_curve(true, pred_probs)
    select_thresholds = []
    for i in np.arange(0.1, 1, 0.1): # sensitivity
        i = round(i, 2)
        for j in range(6,0,-1):
            idx = np.where(np.round(sens_here, j) == i)[0]
            if idx.any(): 
                break
        select_thresholds.append(thresholds_here[idx][0])

    prevalence = prevalence / 100
    sensitivity_list = []
    specificity_list = []
    ppv_list = []
    npv_list = []

    specificity_err = []
    ppv_err = []
    npv_err = []

    for threshold in select_thresholds:
        samples_specificity = []
        samples_ppv = []
        samples_npv = []

        for _ in range(num_bootstrap):
            idx = np.random.choice(len(true), len(true), replace=True)
            true_sample = true[idx]
            pred_probs_sample = pred_probs[idx]
            
            pred_labels = (pred_probs_sample >= threshold).astype(int)
            # tn, fp, fn, tp = confusion_matrix(true_sample, pred_labels).ravel()
            tn = np.sum((pred_labels == 0) & (true_sample == 0))
            fp = np.sum((pred_labels == 1) & (true_sample == 0))
            fn = np.sum((pred_labels == 0) & (true_sample == 1))
            tp = np.sum((pred_labels == 1) & (true_sample == 1))

            sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
            specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
            ppv = (sensitivity * prevalence) / (sensitivity * prevalence + (1 - specificity) * (1 - prevalence)) if (sensitivity * prevalence + (1 - specificity) * (1 - prevalence)) > 0 else 0
            npv = (specificity * (1 - prevalence)) / (specificity * (1 - prevalence) + (1 - sensitivity) * prevalence) if (specificity * (1 - prevalence) + (1 - sensitivity) * prevalence) > 0 else 0

            samples_specificity.append(specificity)
            samples_ppv.append(ppv)
            samples_npv.append(npv)

        # Store mean metrics
        sensitivity_list.append(sensitivity)
        specificity_list.append(np.percentile(samples_specificity, 50))
        ppv_list.append(np.percentile(samples_ppv, 50))
        npv_list.append(np.percentile(samples_npv, 50))

        # Store confidence intervals
        specificity_err.append((np.percentile(samples_specificity, 100 * (1 - confidence) / 2), np.percentile(samples_specificity, 100 * (confidence + (1 - confidence) / 2))))
        ppv_err.append((np.percentile(samples_ppv, 100 * (1 - confidence) / 2), np.percentile(samples_ppv, 100 * (confidence + (1 - confidence) / 2))))
        npv_err.append((np.percentile(samples_npv, 100 * (1 - confidence) / 2), np.percentile(samples_npv, 100 * (confidence + (1 - confidence) / 2))))

    return (
        np.array(sensitivity_list),
        np.array(specificity_list),
        np.array(ppv_list),
        np.array(npv_list),
        np.array(specificity_err),
        np.array(ppv_err),
        np.array(npv_err),
    )

prevalences = [10,15,20] # [5, 10, 20, 30, 40, 50]
metrics = {}
metrics_below = {}
metrics_above = {}

for p in prevalences: 
    metrics[p] = compute_metrics_bootstrap(true, pred, p)
    metrics_below[p] = compute_metrics_bootstrap(true[idx_below].reshape(-1), pred[idx_below].reshape(-1), p)
    metrics_above[p] = compute_metrics_bootstrap(1-true[idx_above].reshape(-1), 1-pred[idx_above].reshape(-1), p)
    print(p)
    # sensi, speci, ppv, npv, speci_err, ppv_err, npv_err


In [None]:
# specificity

k = [k for k in metrics.keys()][0]

fig, ax = plt.subplots(figsize=(10, 6))
sensi, speci, ppv, npv, speci_err, ppv_err, npv_err = metrics[k]
plt.errorbar(
    sensi,
    speci,
    yerr=np.abs(np.array(speci_err).T - speci),
    capsize=3,
    marker='o',
    markersize=3,
    linestyle='-',
    linewidth=1,
    color=MY_NAVY,
    label='Any LVEF',
)
sensi, speci, ppv, npv, speci_err, ppv_err, npv_err = metrics_above[k]
plt.errorbar(
    sensi,
    speci,
    yerr=np.abs(np.array(speci_err).T - speci),
    capsize=3,
    marker='o',
    markersize=3,
    linestyle=LINE_ABOVE,
    linewidth=1,
    color=MY_NAVY,
    label='Last LVEF > 40%',
)
sensi, speci, ppv, npv, speci_err, ppv_err, npv_err = metrics_below[k]
plt.errorbar(
    sensi,
    speci,
    yerr=np.abs(np.array(speci_err).T - speci),
    capsize=3,
    marker='o',
    markersize=3,
    linestyle=LINE_BELOW,
    linewidth=1,
    color=MY_NAVY,
    label='Last LVEF < 40%',
)
plt.xlabel('Sensitivity', color='gray')
plt.ylabel('Specificity', color='gray')
plt.xlim(0, 1)
plt.ylim(0, 1.005)
plt.gca().tick_params(axis='y', colors='gray')
plt.gca().spines['left'].set_color('gray')
plt.gca().spines['left'].set_linewidth(1)
plt.gca().tick_params(axis='x', colors='gray', which='both')
plt.gca().spines['bottom'].set_color('gray')
handles, labels = ax.get_legend_handles_labels()
handles = [h[0] for h in handles]
ax.legend(handles, labels, markerscale=0.001)
sns.despine()
plt.tight_layout()
plt.show()

In [None]:
# ppv, npv

for mtr, lst, lbl in [(metrics, '-', LABEL_ALL), (metrics_above, LINE_ABOVE, LABEL_ABOVE), (metrics_below, LINE_BELOW, LABEL_BELOW)]:

    plt.figure(figsize=(10, 6))
    for prevalence in prevalences:
        sensi, speci, ppv, npv, speci_err, ppv_err, npv_err = mtr[prevalence]
        plt.errorbar(
            sensi,
            ppv,
            yerr=np.abs(np.array(ppv_err).T - ppv),
            label=f'{prevalence}% prevalence',
            capsize=3,
            marker='o',
            markersize=3,
            linestyle=lst,
            linewidth=1,
            color=MY_NAVY,
        )
        plt.text(
            0.75, 
            ppv[7], 
            f'{prevalence}%', 
            fontsize=12, 
            color=MY_NAVY, 
            ha='center', 
            va='bottom',
            bbox=dict(boxstyle='round,pad=0.1', edgecolor='none', facecolor='white', alpha=0.8)
        )
    plt.xlabel('Sensitivity', color='gray')
    plt.ylabel('PPV', color='gray')
    plt.xlim(0, 1)
    plt.ylim(0, 1.0005)
    plt.gca().tick_params(axis='y', colors='gray')
    plt.gca().spines['left'].set_color('gray')
    plt.gca().spines['left'].set_linewidth(1)
    plt.gca().tick_params(axis='x', colors='gray', which='both')
    plt.gca().spines['bottom'].set_color('gray')
    sns.despine()
    plt.tight_layout()
    plt.show()

    plt.figure(figsize=(10, 6))
    for prevalence in prevalences:
        sensi, speci, ppv, npv, speci_err, ppv_err, npv_err = mtr[prevalence]
        plt.errorbar(
            sensi,
            npv,
            yerr=np.abs(np.array(npv_err).T - npv),
            capsize=3,
            marker='o',
            markersize=3,
            linestyle=lst,
            linewidth=1,
            color=MY_NAVY,
        )
        plt.text(
            0.25, 
            npv[2], 
            f'{prevalence}%', 
            fontsize=12, 
            color=MY_NAVY, 
            ha='center', 
            va='bottom',
            bbox=dict(boxstyle='round,pad=0.1', edgecolor='none', facecolor='white', alpha=0.8)
        )
    plt.xlabel('Sensitivity', color='gray')
    plt.ylabel('NPV', color='gray')
    plt.xlim(0, 1)
    plt.ylim(0, 1.0005)
    plt.gca().tick_params(axis='y', colors='gray')
    plt.gca().spines['left'].set_color('gray')
    plt.gca().spines['left'].set_linewidth(1)
    plt.gca().tick_params(axis='x', colors='gray', which='both')
    plt.gca().spines['bottom'].set_color('gray')
    sns.despine()
    plt.tight_layout()
    plt.show()


In [None]:
sensi, speci, ppv, npv, speci_err, ppv_err, npv_err = metrics[10]
np.round(ppv,2)

In [None]:
def ppv(true_sample, pred_probs_sample): 
    for prevalence in [true_sample.sum()/len(true_sample)]:
        for threshold in [0.57, 0.4, 0.19]:    
            pred_labels = (pred_probs_sample >= threshold).astype(int)
            # tn, fp, fn, tp = confusion_matrix(true_sample, pred_labels).ravel()
            tn = np.sum((pred_labels == 0) & (true_sample == 0))
            fp = np.sum((pred_labels == 1) & (true_sample == 0))
            fn = np.sum((pred_labels == 0) & (true_sample == 1))
            tp = np.sum((pred_labels == 1) & (true_sample == 1))

            sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
            specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
            ppv = (sensitivity * prevalence) / (sensitivity * prevalence + (1 - specificity) * (1 - prevalence)) if (sensitivity * prevalence + (1 - specificity) * (1 - prevalence)) > 0 else 0
            npv = (specificity * (1 - prevalence)) / (specificity * (1 - prevalence) + (1 - sensitivity) * prevalence) if (specificity * (1 - prevalence) + (1 - sensitivity) * prevalence) > 0 else 0

            print(round(prevalence,2), round(sensitivity,2), round(ppv,2))
        print()

# overall ppv
ppv(true, pred)

# worsening ppv
ppv(true[idx_above], pred[idx_above])

# improving 
# ppv(1-true[idx_below], 1-pred[idx_below])



ppv(true[idx_no_med], pred[idx_no_med])

ppv(true[idx_healthy], pred[idx_healthy])

In [9]:
def eval(pred, true): 
    datapoints = len(true)
    prevalence = sum(true)/len(true)
    if prevalence>0 and prevalence<1: 
        auc = roc_auc_score(true, pred)
        auc_95ci = bootstrap(pred=pred, true=true, metric_fn=roc_auc_score)
    threshold = 0.4
    pred_labels = (pred >= threshold).astype(int)
    tn = np.sum((pred_labels == 0) & (true == 0))
    fp = np.sum((pred_labels == 1) & (true == 0))
    fn = np.sum((pred_labels == 0) & (true == 1))
    tp = np.sum((pred_labels == 1) & (true == 1))
    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    ppv = (sensitivity * prevalence) / (sensitivity * prevalence + (1 - specificity) * (1 - prevalence)) if (sensitivity * prevalence + (1 - specificity) * (1 - prevalence)) > 0 else 0
    npv = (specificity * (1 - prevalence)) / (specificity * (1 - prevalence) + (1 - sensitivity) * prevalence) if (specificity * (1 - prevalence) + (1 - sensitivity) * prevalence) > 0 else 0
    data = {
        'datapoints':datapoints,
        'prevalence':prevalence,
        'auc':auc,
        'auc_95ci':auc_95ci,
        'sensitivity':sensitivity,
        'specificity':specificity,
        'ppv':ppv,
        'npv':npv,
    }
    data = {k:round(v,3) for (k,v) in data.items()}
    return data

def bootstrap(pred, true, metric_fn, confidence=0.95, num_samples=1000):
    lower_idx, upper_idx = int(num_samples*(1-confidence)/2), int(num_samples*(confidence)/2)
    n = len(true)
    samples = []
    for _ in range(num_samples): 
        idx = np.random.choice(n, n, replace=True)
        samples.append( metric_fn(true[idx], pred[idx]) )
    assert len(samples) == num_samples, f'failed to get {num_samples} samples for bootstrapping'
    samples = sorted(samples)
    return samples[upper_idx] - samples[lower_idx]

def eval_queries(queries):
    for q in queries: 
        if not q: idx = pyd.data.index.values
        else: idx = pyd.data.query(q).index.values
        x = eval(pred=1-pred[idx], true=1-true[idx])
        print(f"{q}:  \n\t{x['datapoints']} samples  \n\t{x['prevalence']*100}% prevalence  \n\t{x['auc']} ± {x['auc_95ci']}  \n\t{x['sensitivity']} sens  \n\t{x['specificity']} spec  \n\t{x['ppv']} ppv  \n\t{x['npv']} npv ")

In [None]:
eval_queries([
    '',
    'paced==True',
    'paced==False',
    'transplant==0',
    'transplant==1',
    'hospitalisations==0',
    'hospitalisations>=1',
    'hospitalisations>=3',
    'hospitalisations>=5',
    'hospitalisations>=10',
    'hospitalisations>=20',
])

In [None]:
eval_queries([
    '~tag_hfref',
    '~tag_hfref and diabetes_mellitus==0',
    '~tag_hfref and diabetes_mellitus==1',
    '~tag_hfref and hypertension==0',
    '~tag_hfref and hypertension==1',
    '~tag_hfref and atherosclerosis==0',
    '~tag_hfref and atherosclerosis==1',
    '~tag_hfref and chronic_obstructive_pulmonary_disease==0',
    '~tag_hfref and chronic_obstructive_pulmonary_disease==1',
    '~tag_hfref and atrial_fibrillation==0',
    '~tag_hfref and atrial_fibrillation==1',
])

In [None]:
eval_queries([
    'tag_hfref',
    'tag_hfref and num_meds==0',
    'tag_hfref and num_meds>0',
    'tag_hfref and angio==0',
    'tag_hfref and angio==1',
    'tag_hfref and mra==0',
    'tag_hfref and mra==1',
    'tag_hfref and betablocker==0',
    'tag_hfref and betablocker==1',
    'tag_hfref and diuretic==0',
    'tag_hfref and diuretic==1',
])

In [4]:
lvef = pl.read_parquet('/storage2/payal/dropbox/private/data/processed/lvef.parquet')
lvef = lvef.sort('lvef_date')

In [None]:
threshold = 0.6112226064031894 #dct[0.7][0]
pyd.data.loc[:,'true'] = true
pyd.data.loc[:,'score'] = pred
pyd.data.loc[:,'pred'] = (pred>(1-threshold))

In [None]:
# false positives - early warning sign

PER_PATIENT = False 

pyd_polars = pl.from_pandas(pyd.data).with_columns(
    pl.when(
        (pl.col('pred')!=pl.col('true')) # incorrect predictions 
        &
        (pl.col('pred')==1) # predictions are that EF will fall
    ).then(True).otherwise(False).alias('false_positive'),
    pl.when(
        (pl.col('pred')==pl.col('true')) # correct predictions 
        &
        (pl.col('pred')==0) # predictions are that EF stay > 40
    ).then(True).otherwise(False).alias('true_negative'),
).filter(
    ~pl.col('tag_hfref')
)

repeat_false_positives = []
for group in pyd_polars.group_by("empi"): 
    df = group[1].sort(['empi','ecg_date'])
    fps = 0
    first_ecg_date = None 
    for row in df.to_dicts(): 
        if row['false_positive']: 
            if fps == 0: 
                first_ecg_date = row['ecg_date']
            fps += 1
            repeat_false_positives.append({
                'empi':row['empi'], 
                'ecg_date':row['ecg_date'],
                'first_ecg_date':first_ecg_date,
                'fp_count':fps,
            })
        else: 
            fps = 0

if PER_PATIENT: 
    repeat_false_positives = pl.from_dicts(repeat_false_positives).sort('fp_count', descending=True).unique(subset='empi', keep='first').to_dicts()


N_values = []
examples = []
any_bad_percent = []
days_to_bad = []
for N in range(1,50): 
    N_fps = [x for x in repeat_false_positives if x['fp_count']==N]
    if not len(N_fps): continue 
    any_bad_tally = 0
    times_to_first_bad = []
    for x in N_fps: 
        date = x['first_ecg_date']
        lvef_after_warning = lvef.filter(pl.col('empi')==x['empi']).filter(pl.col('lvef_date')>=date)
        any_bad = not lvef_after_warning.filter(pl.col('lvef')<=40).is_empty()
        any_bad_tally += any_bad
        time_to_first_bad = None 
        if any_bad: 
            time_to_first_bad = lvef_after_warning.filter(pl.col('lvef')<=40).sort('lvef_date').head(1).select('lvef_date').item() - date.date()
            # results.append((x['empi'], time_to_first_bad))
            times_to_first_bad.append(time_to_first_bad)
    
    percent = round(100*any_bad_tally/len(N_fps))
    days = np.mean(times_to_first_bad).days if times_to_first_bad else 0

    # print(f'When we have {N} repeated false positives, there is a low LVEF observed in the future {percent}% of the time.')
    # print(f'In these {percent}% cases, the low LVEF is observed within {days} days on average.')
    # print()

    N_values.append(N)
    any_bad_percent.append(percent)
    days_to_bad.append(days)
    examples.append(len(N_fps))

fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(24, 6), sharex=False)

x_max = max(N_values)+1

ax2.plot(N_values, any_bad_percent, marker='o', linestyle='-', color=MY_NAVY, zorder=2)
ax2.set_ylim(0, 105)
ax2.set_xlim(0, x_max)
ax2.set_xticks(range(0, x_max, 5))
ax2.set_yticks(range(0, 101, 20))
ax2.set_title("How often is low LVEF observed in the future?", fontsize=14)
ax2.set_xlabel("Subsequent false positive predictions", fontsize=12)
ax2.set_ylabel("Percentage", fontsize=12)
ax2.grid(axis='y', linestyle='--', alpha=0.6)
ax2.spines['top'].set_visible(False)
ax2.spines['right'].set_visible(False)

ax3.plot(N_values, days_to_bad, marker='o', linestyle='-', color=MY_NAVY, zorder=2)
ax3.set_xlim(0, x_max)
ax3.set_xticks(range(0, x_max, 5))
ax3.set_title("Average days until low LVEF is observed", fontsize=14)
ax3.set_xlabel("Subsequent false positive predictions", fontsize=12)
ax3.set_ylabel("Days", fontsize=12)
ax3.grid(axis='y', linestyle='--', alpha=0.6)
ax3.spines['top'].set_visible(False)
ax3.spines['right'].set_visible(False)

ax1.plot(N_values, examples, marker='o', linestyle='-', color=MY_NAVY, zorder=2)
ax1.set_ylim(0, max(examples) + 50)
ax1.set_xlim(0, x_max)
ax1.set_xticks(range(0, x_max, 5))
ax1.set_yticks(range(0, max(examples) + 50, 100))
ax1.set_title("Number of examples for each N-value", fontsize=14)
ax1.set_xlabel("Subsequent false positive predictions", fontsize=12)
ax1.set_ylabel("Number of patients" if PER_PATIENT else "Number of ECGs", fontsize=12)
ax1.grid(axis='y', linestyle='--', alpha=0.6)
ax1.spines['top'].set_visible(False)
ax1.spines['right'].set_visible(False)

# Adjust layout
plt.tight_layout()
plt.show()

In [66]:
def plot_trajectory(pt): 
    idx = pyd.data.query('empi==@pt').sort_values('ecg_date').index.values
    
    fig, ax1 = plt.subplots(figsize=(10, 6))

    # Left y-axis: LVEF
    # ax1.set_ylabel('LVEF (%)', color='gray')
    ax1.set_ylim(0, 100)
    ax1.set_yticks(range(0, 101, 20))
    ax1.tick_params(axis='y', colors='gray')
    ax1.spines['left'].set_position(('outward', 10))
    ax1.spines['left'].set_color('gray')
    ax1.spines['left'].set_linewidth(1)

    plt_lvef_date = pd.to_datetime(lvef.filter(pl.col('empi') == pt).select('lvef_date').to_numpy().reshape(-1))
    plt_lvef_val = lvef.filter(pl.col('empi') == pt).select('lvef').to_numpy().reshape(-1)
    plt_lvef_col = ['silver' if x > 40 else 'firebrick' for x in plt_lvef_val]


    for i in range(len(plt_lvef_date)-1): 
        ax1.plot(plt_lvef_date[i:i+2], plt_lvef_val[i:i+2], c=plt_lvef_col[i+1], zorder=1)  

    # ax1.plot(plt_lvef_date, plt_lvef_val, c='silver', zorder=1) 
    ax1.scatter(plt_lvef_date, plt_lvef_val, marker='o', c=plt_lvef_col, zorder=2)

    # Left y-axis threshold
    ax1.axhline(y=40, color='firebrick', linestyle='--')

    # Configure x-axis
    xlim_dates = plt_lvef_date
    ax1.set_xlim(left=xlim_dates.min().replace(month=1, day=1),
                 right=xlim_dates.max().replace(month=12, day=31, year=max(xlim_dates).year + 1))
    ax1.xaxis.set_major_locator(mdates.YearLocator(1))
    ax1.xaxis.set_major_formatter(mdates.DateFormatter('%Y'))
    # ax1.xaxis.set_minor_locator(mdates.MonthLocator(interval=3))
    # ax1.xaxis.set_minor_formatter(mdates.DateFormatter('%b'))

    # Set x-axis label and tick colors to gray
    ax1.tick_params(axis='x', colors='gray', which='both')
    ax1.spines['bottom'].set_color('gray')

    ax1.grid(True, axis='x', which='major', linestyle='--', alpha=0.7)
    ax1.grid(True, axis='x', which='minor', linestyle=':', alpha=0.5)

    # Shared x-axis tick styling
    plt.setp(ax1.xaxis.get_majorticklabels(), rotation=45, ha='right', color='gray')
    plt.setp(ax1.xaxis.get_minorticklabels(), rotation=45, ha='right', color='gray', fontsize='small')

    # Axes styling
    ax1.spines['bottom'].set_position(('data', 0))  
    ax1.spines['bottom'].set_linewidth(1)

    ax1.spines['top'].set_visible(False)
    ax1.spines['right'].set_visible(False)

    # Right y-axis: Likelihood (0-1)
    ax2 = ax1.twinx() 
    # ax2.set_ylabel('Likelihood (0-1)', color='gray')
    ax2.set_ylim(0, 1.005)

    ax2.set_yticks([i / 10 for i in range(11)])
    ax2.tick_params(axis='y', colors='gray')
    ax2.spines['right'].set_position(('outward', 10))
    ax2.spines['right'].set_color('gray')
    ax2.spines['right'].set_linewidth(1)

    dates_ecg = np.array([pyd.data.loc[i, 'ecg_date'].date() for i in idx])
    plt_ecg_pred = 1 - pred[idx]
    plt_ecg_color = [MY_NAVY if x < (1-threshold) else 'silver' for x in plt_ecg_pred]
    ax2.scatter(dates_ecg, plt_ecg_pred, marker='x', zorder=3, color=plt_ecg_color) 

    # Right y-axis threshold
    ax2.axhline(y=1-threshold, color='#001F54', linestyle='--')

    # Axes styling
    ax2.spines['bottom'].set_visible(False)
    ax2.spines['top'].set_visible(False)
    ax2.spines['left'].set_visible(False)

    # Align ticks on both axes
    ax1.tick_params(axis='y', direction='in', pad=5)
    ax2.tick_params(axis='y', direction='in', pad=5)

    plt.tight_layout()
    plt.show()


In [None]:
for pt in [100202452,100223049,101091462, 103200474, 105812520, 100358399, 100615898, 101137610, 103401384]: 
    plot_trajectory(pt)