In [4]:
import statsmodels.api as sm
import scipy.stats as stats
import numpy as np
from scipy.stats import chi2
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.dates as mdates
import polars as pl
import pandas as pd
import seaborn as sns
import torch
import lightning as L
from sklearn.calibration import calibration_curve
from sklearn.metrics import roc_auc_score, average_precision_score, confusion_matrix, roc_curve
from omegaconf import OmegaConf, DictConfig
import hydra
import wandb
from dataset import SupervisedDataset
from lightning_modules import SupervisedTask
from models.ecg_models import *
from run import interpolate
pl.Config.set_tbl_rows(50)
MY_NAVY = '#001F54'

  torch.utils._pytree._register_pytree_node(


In [5]:
alpha = 0.05/50

In [8]:
def get_duration_bet_echo(lvef, subj_col='empi', date_col='lvef_date'):
    durations = []
    for pt, group in lvef.select([subj_col, date_col]).sort([subj_col, date_col]).group_by(subj_col): 
        if group.height == 1: continue 
        lv_dur = group.select(date_col).with_columns(pl.col(date_col).diff().dt.total_days()).drop_nulls().mean().item()
        durations.append(lv_dur)
    return np.mean(durations) 

def fraction_echoes_within_12mo(lvef, subj_col='empi', date_col='lvef_date'):
    inter_echo_days = []
    for pt, group in lvef.select([subj_col, date_col]).sort([subj_col, date_col]).group_by(subj_col): 
        if group.height < 2:
            continue
        diffs = group.select(date_col).with_columns(pl.col(date_col).diff().dt.total_days()).drop_nulls()
        inter_echo_days.extend(diffs.to_series().to_list())
    # inter_echo_days = np.array(inter_echo_days)
    return np.mean([x <= 365 for x in inter_echo_days])


In [9]:
print('ALL PATIENTS')

df_mgh = pl.read_parquet('/storage2/payal/Dropbox (Partners HealthCare)/private/SILVER/data/data.parquet').filter(pl.col('split')!='external')
df_bwh = pl.read_parquet('/storage2/payal/Dropbox (Partners HealthCare)/private/SILVER/data/data.parquet').filter(pl.col('split')=='external')
df_mim = pl.read_parquet('/storage/shared/mimic/data.parquet')

size_mgh = df_mgh.unique('empi').height
size_bwh = df_bwh.unique('empi').height
size_mim = df_mim.unique('empi').height

print(f"patients \t {size_mgh} \t\t {size_bwh} \t\t {size_mim}")
print(f"num ecg \t {df_mgh.height} \t {df_bwh.height}  \t {df_mim.height}")

for cf in [
    'sex',
    'diabetes_mellitus',
    'hypertension',
    'atheroscler',
    'chronic_obstructive_pulmonary_disease',
    'atrial_fibrillation',
    'angio',
    'betablocker',
    'mra',
    'diuretic',
]: 
    count_mgh = df_mgh.select(['empi',cf]).group_by('empi').sum().filter(pl.col(cf)>0).unique('empi').height
    count_bwh = df_bwh.select(['empi',cf]).group_by('empi').sum().filter(pl.col(cf)>0).unique('empi').height
    count_mim = df_mim.select(['empi',cf]).group_by('empi').sum().filter(pl.col(cf)>0).unique('empi').height

    prop_mgh = round(100 * count_mgh / size_mgh, 1)
    prop_bwh = round(100 * count_bwh / size_bwh, 1)
    prop_mim = round(100 * count_mim / size_mim, 1)

    _, p_bwh = sm.stats.proportions_ztest([count_mgh, count_bwh], [size_mgh, size_bwh], alternative='two-sided')
    _, p_mimic = sm.stats.proportions_ztest([count_mgh, count_mim], [size_mgh, size_mim], alternative='two-sided')

    print(f"{cf[:5]} \t\t {prop_mgh}% \t\t {prop_bwh}% {'*' if p_bwh < alpha else ''} \t {prop_mim}% {'*' if p_mimic < alpha else ''}")

mgh_age = df_mgh.select(['empi','age']).group_by('empi').mean().select('age').to_numpy().reshape(-1)
bwh_age = df_bwh.select(['empi','age']).group_by('empi').mean().select('age').to_numpy().reshape(-1)
mim_age = df_mim.select(['empi','age']).group_by('empi').mean().select('age').to_numpy().reshape(-1)
_, p_bwh = stats.ttest_ind(mgh_age, bwh_age, equal_var=False)
_, p_mim = stats.ttest_ind(mgh_age, mim_age, equal_var=False)
print(f"age \t\t {np.mean(mgh_age):.1f} {np.std(mgh_age):.1f} \t {np.mean(bwh_age):.1f} {np.std(bwh_age):.1f} {'*' if p_bwh < alpha else ''} \t {np.mean(mim_age):.1f} {np.std(mim_age):.1f} {'*' if p_mimic < alpha else ''}")

lvef = pl.read_parquet('/storage2/payal/dropbox/private/data/processed/lvef.parquet'
).filter(pl.col('hospital')=='MGH'
).filter(pl.col('lvef').is_not_null())
print('DURATION BET ECHOS', get_duration_bet_echo(lvef))
print('FRAC 2 ECHOS WITHIN YEAR', fraction_echoes_within_12mo(lvef))
size_mgh = lvef.height
count_mgh = lvef.filter(pl.col('lvef')<=40).height
pt_lvef_40_mgh = lvef.filter(pl.col('lvef')<=40).select('empi').unique().height
prop_mgh = round(100 * count_mgh / size_mgh, 1)

lvef = pl.read_parquet('/storage2/payal/dropbox/private/data/processed/lvef.parquet'
).filter(pl.col('hospital')=='BWH'
).filter(pl.col('lvef').is_not_null())
print('DURATION BET ECHOS', get_duration_bet_echo(lvef))
print('FRAC 2 ECHOS WITHIN YEAR', fraction_echoes_within_12mo(lvef))
size_bwh = lvef.height
count_bwh = lvef.filter(pl.col('lvef')<=40).height
pt_lvef_40_bwh = lvef.filter(pl.col('lvef')<=40).select('empi').unique().height
prop_bwh = round(100 * count_bwh / size_bwh, 1)

lvef = pl.read_csv('/storage/shared/mimic/raw/lvef.csv'
).join(df_mim, right_on='empi', left_on='subject_id',
).with_columns(
    pl.col("study_datetime").str.to_date(format="%Y-%m-%dT%H:%M:%S")  # Convert to date (YYYY-MM-DD)
).select(['subject_id', 'study_datetime', 'measurement', 'result',]
).unique().filter(pl.col('result').is_not_null())
print('DURATION BET ECHOS', get_duration_bet_echo(lvef, subj_col='subject_id', date_col='study_datetime'))
print('FRAC 2 ECHOS WITHIN YEAR', fraction_echoes_within_12mo(lvef, subj_col='subject_id', date_col='study_datetime'))
size_mim = lvef.height
count_mim = lvef.filter(pl.col('result').cast(int)<=40).height
pt_lvef_40_mim = lvef.filter(pl.col('result').cast(int)<=40).select('subject_id').unique().height
prop_mim = round(100 * count_mim / size_mim, 1)

_, p_bwh = sm.stats.proportions_ztest([count_mgh, count_bwh], [size_mgh, size_bwh], alternative='two-sided')
_, p_mimic = sm.stats.proportions_ztest([count_mgh, count_mim], [size_mgh, size_mim], alternative='two-sided')

print(f"lvef \t\t {size_mgh} \t {size_bwh} \t {size_mim}")
print(f"ptlv40 \t\t {pt_lvef_40_mgh} \t\t {pt_lvef_40_bwh} \t\t {pt_lvef_40_mim}")
print(f"lv<40 \t\t {prop_mgh}% \t\t {prop_bwh}% {'*' if p_bwh < alpha else ''} \t {prop_mim}% {'*' if p_mimic < alpha else ''}")


ALL PATIENTS
patients 	 83433 		 26942 		 20850
num ecg 	 947905 	 274897  	 196584
sex 		 43.4% 		 45.2% * 	 47.2% *
diabe 		 41.4% 		 39.0% * 	 41.2% 
hyper 		 83.8% 		 67.4% * 	 61.4% *
ather 		 68.3% 		 60.4% * 	 59.2% *
chron 		 12.4% 		 5.7% * 	 10.8% *
atria 		 49.2% 		 45.8% * 	 51.3% *
angio 		 44.4% 		 40.1% * 	 34.8% *
betab 		 52.9% 		 49.6% * 	 56.0% *
mra 		 13.2% 		 8.3% * 	 7.8% *
diure 		 56.5% 		 60.6% * 	 48.7% *
age 		 70.2 15.6 	 70.8 14.4 * 	 74.2 13.4 *
DURATION BET ECHOS 710.6187356614435
FRAC 2 ECHOS WITHIN YEAR 0.5544447463710875
DURATION BET ECHOS 627.4907507708567
FRAC 2 ECHOS WITHIN YEAR 0.620834531454679
DURATION BET ECHOS 425.97449679955685
FRAC 2 ECHOS WITHIN YEAR 0.6779554437657205
lvef 		 235476 	 192060 	 36667
ptlv40 		 20112 		 21685 		 5723
lv<40 		 19.8% 		 25.8% * 	 31.6% *


In [45]:
pl.read_csv('/storage/shared/mimic/raw/lvef.csv'
)

subject_id,study_datetime,measurement,result
i64,date,str,str
10002155,2128-08-03,"""Left ventricle ejection fracti…",
10002155,2129-08-09,"""Left ventricle ejection fracti…",
10004457,2143-03-09,"""Left ventricle ejection fracti…",
10011668,2131-06-18,"""Left ventricle ejection fracti…",
10013049,2114-06-18,"""Left ventricle ejection fracti…",
10013643,2200-10-02,"""Left ventricle ejection fracti…",
10014547,2146-03-02,"""Left ventricle ejection fracti…",
10014610,2174-06-04,"""Left ventricle ejection fracti…",
10016084,2155-11-28,"""Left ventricle ejection fracti…",
10017531,2159-10-10,"""Left ventricle ejection fracti…",


In [None]:
df = df.sort_values(by=["empi", "lvef_date"])
df["delta"] = df.groupby("empi")["lvef_date"].diff().dt.days

# Computing mean delta for each empi
result = df.groupby("empi")["delta"].mean().reset_index()



In [4]:
print('PATIENTS WITH AN LVEF > 40')

df_mgh = pl.read_parquet('/storage2/payal/Dropbox (Partners HealthCare)/private/SILVER/data/data.parquet').filter(pl.col('split')!='external')
df_bwh = pl.read_parquet('/storage2/payal/Dropbox (Partners HealthCare)/private/SILVER/data/data.parquet').filter(pl.col('split')=='external')
df_mim = pl.read_parquet('/storage/shared/mimic/data.parquet')

mgh_subset = pl.read_parquet('/storage2/payal/dropbox/private/data/processed/lvef.parquet'
).filter(pl.col('hospital')=='MGH'
).filter(pl.col('lvef').is_not_null()
).filter(pl.col('lvef')>40
).select('empi'
).unique().to_numpy().reshape(-1)
df_mgh = df_mgh.filter(pl.col('empi').is_in(mgh_subset))

bwh_subset = pl.read_parquet('/storage2/payal/dropbox/private/data/processed/lvef.parquet'
).filter(pl.col('hospital')=='BWH'
).filter(pl.col('lvef').is_not_null()
).filter(pl.col('lvef')>40
).select('empi'
).unique().to_numpy().reshape(-1)
df_bwh = df_bwh.filter(pl.col('empi').is_in(bwh_subset))

mim_subset = pl.read_csv('/storage/shared/mimic/raw/lvef.csv'
).select(['subject_id', 'measurement', 'result',]
).unique().filter(pl.col('result').is_not_null()
).filter(pl.col('result').cast(int)>40
).select('subject_id'
).unique().to_numpy().reshape(-1)
df_mim = df_mim.filter(pl.col('empi').is_in(mim_subset))

size_mgh = df_mgh.unique('empi').height
size_bwh = df_bwh.unique('empi').height
size_mim = df_mim.unique('empi').height

print(f"patients \t {size_mgh} \t\t {size_bwh} \t\t {size_mim}")
print(f"num ecg \t {df_mgh.height} \t {df_bwh.height}  \t {df_mim.height}")

for cf in [
    'sex',
    'diabetes_mellitus',
    'hypertension',
    'atheroscler',
    'chronic_obstructive_pulmonary_disease',
    'atrial_fibrillation',
    'angio',
    'betablocker',
    'mra',
    'diuretic',
]: 
    count_mgh = df_mgh.select(['empi',cf]).group_by('empi').sum().filter(pl.col(cf)>0).unique('empi').height
    count_bwh = df_bwh.select(['empi',cf]).group_by('empi').sum().filter(pl.col(cf)>0).unique('empi').height
    count_mim = df_mim.select(['empi',cf]).group_by('empi').sum().filter(pl.col(cf)>0).unique('empi').height

    prop_mgh = round(100 * count_mgh / size_mgh, 1)
    prop_bwh = round(100 * count_bwh / size_bwh, 1)
    prop_mim = round(100 * count_mim / size_mim, 1)

    _, p_bwh = sm.stats.proportions_ztest([count_mgh, count_bwh], [size_mgh, size_bwh], alternative='two-sided')
    _, p_mimic = sm.stats.proportions_ztest([count_mgh, count_mim], [size_mgh, size_mim], alternative='two-sided')

    print(f"{cf[:5]} \t\t {prop_mgh}% \t\t {prop_bwh}% {'*' if p_bwh < alpha else ''} \t {prop_mim}% {'*' if p_mimic < alpha else ''}")

mgh_age = df_mgh.select(['empi','age']).group_by('empi').mean().select('age').to_numpy().reshape(-1)
bwh_age = df_bwh.select(['empi','age']).group_by('empi').mean().select('age').to_numpy().reshape(-1)
mim_age = df_mim.select(['empi','age']).group_by('empi').mean().select('age').to_numpy().reshape(-1)
_, p_bwh = stats.ttest_ind(mgh_age, bwh_age, equal_var=False)
_, p_mim = stats.ttest_ind(mgh_age, mim_age, equal_var=False)
print(f"age \t\t {np.mean(mgh_age):.1f} {np.std(mgh_age):.1f} \t {np.mean(bwh_age):.1f} {np.std(bwh_age):.1f} {'*' if p_bwh < alpha else ''} \t {np.mean(mim_age):.1f} {np.std(mim_age):.1f} {'*' if p_mimic < alpha else ''}")

lvef = pl.read_parquet('/storage2/payal/dropbox/private/data/processed/lvef.parquet'
).filter(pl.col('hospital')=='MGH'
).filter(pl.col('lvef').is_not_null()
).filter(pl.col('empi').is_in(mgh_subset))
size_mgh = lvef.height
count_mgh = lvef.filter(pl.col('lvef')<=40).height
pt_lvef_40_mgh = lvef.filter(pl.col('lvef')<=40).select('empi').unique().height
prop_mgh = round(100 * count_mgh / size_mgh, 1)

lvef = pl.read_parquet('/storage2/payal/dropbox/private/data/processed/lvef.parquet'
).filter(pl.col('hospital')=='BWH'
).filter(pl.col('lvef').is_not_null()
).filter(pl.col('empi').is_in(bwh_subset))
size_bwh = lvef.height
count_bwh = lvef.filter(pl.col('lvef')<=40).height
pt_lvef_40_bwh = lvef.filter(pl.col('lvef')<=40).select('empi').unique().height
prop_bwh = round(100 * count_bwh / size_bwh, 1)

lvef = pl.read_csv('/storage/shared/mimic/raw/lvef.csv'
).join(df_mim, right_on='empi', left_on='subject_id',
).select(['subject_id', 'measurement', 'result',]
).unique().filter(pl.col('result').is_not_null()
).filter(pl.col('subject_id').is_in(mim_subset))
size_mim = lvef.height
count_mim = lvef.filter(pl.col('result').cast(int)<=40).height
pt_lvef_40_mim = lvef.filter(pl.col('result').cast(int)<=40).select('subject_id').unique().height
prop_mim = round(100 * count_mim / size_mim, 1)

_, p_bwh = sm.stats.proportions_ztest([count_mgh, count_bwh], [size_mgh, size_bwh], alternative='two-sided')
_, p_mimic = sm.stats.proportions_ztest([count_mgh, count_mim], [size_mgh, size_mim], alternative='two-sided')

print(f"lvef \t\t {size_mgh} \t {size_bwh} \t {size_mim}")
print(f"ptlv40 \t\t {pt_lvef_40_mgh} \t\t {pt_lvef_40_bwh} \t\t {pt_lvef_40_mim}")
print(f"lv<=40 \t\t {prop_mgh}% \t\t {prop_bwh}% {'*' if p_bwh < alpha else ''} \t {prop_mim}% {'*' if p_mimic < alpha else ''}")



PATIENTS WITH AN LVEF > 40
patients 	 54100 		 13675 		 10795
num ecg 	 772033 	 168381  	 129151
sex 		 45.4% 		 49.4% * 	 52.2% *
diabe 		 43.7% 		 40.5% * 	 42.5% 
hyper 		 88.4% 		 74.0% * 	 70.2% *
ather 		 70.5% 		 59.3% * 	 58.0% *
chron 		 14.9% 		 7.3% * 	 10.2% *
atria 		 52.9% 		 49.1% * 	 53.0% 
angio 		 45.8% 		 38.5% * 	 37.0% *
betab 		 53.6% 		 49.5% * 	 58.9% *
mra 		 13.0% 		 7.5% * 	 7.3% *
diure 		 58.7% 		 60.0%  	 52.9% *
age 		 70.2 15.5 	 70.1 14.7  	 74.2 13.5 *
lvef 		 213723 	 165416 	 21870
ptlv40 		 10041 		 9038 		 2115
lv<=40 		 11.6% 		 13.8% * 	 15.3% *


In [10]:
print('ECHO DISTRIBUTION')

lvef
data = []
for pt, group in lvef.select(['empi','lvef_date']).unique().group_by('empi'): 
    if group.height > 1: break
    years_w_lvef = group['lvef_date'].to_list()
    years_in_record = max(years_w_lvef) - min(years_w_lvef) + 1
    percent_w_lvef = 100*len(years_w_lvef)/years_in_record
    data.append((percent_w_lvef, years_in_record))

ECHO DISTRIBUTION


In [12]:
group

data = {
    "empi": [1, 1, 1, 2, 3, 3],
    "lvef_date": pd.to_datetime(["2002-07-29", "2004-11-12", "2007-02-28", "2002-07-29", "2003-11-12", "2005-11-12"])
}

df = pd.DataFrame(data)

# Calculating mean deltas (differences between subsequent dates) per empi
df = df.sort_values(by=["empi", "lvef_date"])
df["delta"] = df.groupby("empi")["lvef_date"].diff().dt.days

# Computing mean delta for each empi
result = df.groupby("empi")["delta"].mean().reset_index()

# Display result
import ace_tools as tools
tools.display_dataframe_to_user(name="Mean Deltas per EMPI", dataframe=result)


empi,lvef_date
i64,date
100240487,2002-07-29
100240487,2004-11-12
100240487,2007-02-28


In [13]:
print('YEARLY DROP PROBABILITY')

df_mgh = pl.read_parquet('/storage2/payal/Dropbox (Partners HealthCare)/private/SILVER/data/data.parquet').filter(pl.col('split')!='external')
df_bwh = pl.read_parquet('/storage2/payal/Dropbox (Partners HealthCare)/private/SILVER/data/data.parquet').filter(pl.col('split')=='external')
df_mim = pl.read_parquet('/storage/shared/mimic/data.parquet')

mgh_subset = pl.read_parquet('/storage2/payal/dropbox/private/data/processed/lvef.parquet'
).filter(pl.col('hospital')=='MGH'
).filter(pl.col('lvef').is_not_null()
).filter(pl.col('lvef')>40
).select('empi'
).unique().to_numpy().reshape(-1)
df_mgh = df_mgh.filter(pl.col('empi').is_in(mgh_subset))

bwh_subset = pl.read_parquet('/storage2/payal/dropbox/private/data/processed/lvef.parquet'
).filter(pl.col('hospital')=='BWH'
).filter(pl.col('lvef').is_not_null()
).filter(pl.col('lvef')>40
).select('empi'
).unique().to_numpy().reshape(-1)
df_bwh = df_bwh.filter(pl.col('empi').is_in(bwh_subset))

mim_subset = pl.read_csv('/storage/shared/mimic/raw/lvef.csv'
).select(['subject_id', 'measurement', 'result',]
).unique().filter(pl.col('result').is_not_null()
).filter(pl.col('result').cast(int)>40
).select('subject_id'
).unique().to_numpy().reshape(-1)
df_mim = df_mim.filter(pl.col('empi').is_in(mim_subset))

def get_prob(lvef, DATE, LVEF, SUBJ):
    data = []
    for pt, group in lvef.group_by(SUBJ): 
        group = group.sort(DATE)
        if group.select(LVEF).head(1).item() <= 40: 
            continue 
        if group.filter(pl.col(LVEF)<=40).is_empty(): 
            data.append((0,None))
        else: 
            drop = group.filter(pl.col(LVEF)<=40).select(DATE).head(1).item()
            last_high_before_drop = group.filter(pl.col(LVEF)>40).filter(pl.col(DATE)<=drop).select(DATE).tail(1).item()
            years = (drop - last_high_before_drop).days / 365.25
            data.append((1, years))
    print('Mean years', np.mean([x[1] for x in data if x[0]]))
    print('Mean years', np.median([x[1] for x in data if x[0]]))
    print('% patients who worsen', np.mean([x[0] for x in data]))
    prob = []
    for drop, years in data: 
        if not drop:
            p=0
        else: 
            if years < 1: 
                p = 1
            else: 
                p = 1/years
        prob.append(p)
    return np.mean(prob)

lvef = pl.read_parquet('/storage2/payal/dropbox/private/data/processed/lvef.parquet'
).filter(pl.col('hospital')=='MGH'
).filter(pl.col('lvef').is_not_null()
).filter(pl.col('empi').is_in(mgh_subset))
print('MGH', get_prob(lvef, 'lvef_date', 'lvef', 'empi'))

lvef = pl.read_parquet('/storage2/payal/dropbox/private/data/processed/lvef.parquet'
).filter(pl.col('hospital')=='BWH'
).filter(pl.col('lvef').is_not_null()
).filter(pl.col('empi').is_in(bwh_subset))
print('BWH', get_prob(lvef, 'lvef_date', 'lvef', 'empi'))

lvef = pl.read_csv('/storage/shared/mimic/raw/lvef.csv'
).join(df_mim, right_on='empi', left_on='subject_id',
).select(['subject_id', 'study_datetime', 'result',]
).unique().filter(pl.col('result').is_not_null()
).filter(pl.col('subject_id').is_in(mim_subset)
).with_columns(
    pl.col('result').cast(int),
    pl.col("study_datetime").str.to_date("%Y-%m-%dT%H:%M:%S").alias("study_date"),
)
print('MIMIC', get_prob(lvef, 'study_date', 'result', 'subject_id'))


YEARLY DROP PROBABILITY


KeyboardInterrupt: 

In [33]:
def has_comorbidity(df, col): 
    return df.filter(df["empi"].is_in(df.filter(pl.col(col)==False).select('empi').unique()))

print('YEARLY DROP PROBABILITY')

df_mgh = pl.read_parquet('/storage2/payal/Dropbox (Partners HealthCare)/private/SILVER/data/data.parquet').filter(pl.col('split')!='external')
df_bwh = pl.read_parquet('/storage2/payal/Dropbox (Partners HealthCare)/private/SILVER/data/data.parquet').filter(pl.col('split')=='external')
df_mim = pl.read_parquet('/storage/shared/mimic/data.parquet')

mgh_subset = pl.read_parquet('/storage2/payal/dropbox/private/data/processed/lvef.parquet'
).filter(pl.col('hospital')=='MGH'
).filter(pl.col('lvef').is_not_null()
).filter(pl.col('lvef')>40
).select('empi'
).unique().to_numpy().reshape(-1)
df_mgh = df_mgh.filter(pl.col('empi').is_in(mgh_subset)).pipe(
    has_comorbidity, 'diabetes_mellitus'
).pipe(
    has_comorbidity, 'hypertension'
).pipe(
    has_comorbidity, 'atheroscler'
).pipe(
    has_comorbidity, 'chronic_obstructive_pulmonary_disease'
).pipe(
    has_comorbidity, 'atrial_fibrillation'
)

bwh_subset = pl.read_parquet('/storage2/payal/dropbox/private/data/processed/lvef.parquet'
).filter(pl.col('hospital')=='BWH'
).filter(pl.col('lvef').is_not_null()
).filter(pl.col('lvef')>40
).select('empi'
).unique().to_numpy().reshape(-1)
df_bwh = df_bwh.filter(pl.col('empi').is_in(bwh_subset)).pipe(
    has_comorbidity, 'diabetes_mellitus'
).pipe(
    has_comorbidity, 'hypertension'
).pipe(
    has_comorbidity, 'atheroscler'
).pipe(
    has_comorbidity, 'chronic_obstructive_pulmonary_disease'
).pipe(
    has_comorbidity, 'atrial_fibrillation'
)

mim_subset = pl.read_csv('/storage/shared/mimic/raw/lvef.csv'
).select(['subject_id', 'measurement', 'result',]
).unique().filter(pl.col('result').is_not_null()
).filter(pl.col('result').cast(int)>40
).select('subject_id'
).unique().to_numpy().reshape(-1)
df_mim = df_mim.filter(pl.col('empi').is_in(mim_subset)).pipe(
    has_comorbidity, 'diabetes_mellitus'
).pipe(
    has_comorbidity, 'hypertension'
).pipe(
    has_comorbidity, 'atheroscler'
).pipe(
    has_comorbidity, 'chronic_obstructive_pulmonary_disease'
).pipe(
    has_comorbidity, 'atrial_fibrillation'
)

def get_prob(lvef, DATE, LVEF, SUBJ):
    data = []
    for pt, group in lvef.group_by(SUBJ): 
        group = group.sort(DATE)
        if group.select(LVEF).head(1).item() <= 40: 
            continue 
        if group.filter(pl.col(LVEF)<=40).is_empty(): 
            data.append((0,None))
        else: 
            drop = group.filter(pl.col(LVEF)<=40).select(DATE).head(1).item()
            last_high_before_drop = group.filter(pl.col(LVEF)>40).filter(pl.col(DATE)<=drop).select(DATE).tail(1).item()
            years = (drop - last_high_before_drop).days / 365.25
            data.append((1, years))
    # print('Mean years', np.mean([x[1] for x in data if x[0]]))
    # print('Mean years', np.median([x[1] for x in data if x[0]]))
    # print('% patients who worsen', np.mean([x[0] for x in data]))
    prob = []
    for drop, years in data: 
        if not drop:
            p=0
        else: 
            if years < 1: 
                p = 1
            else: 
                p = 1/years
        prob.append(p)
    return np.mean(prob)

lvef = pl.read_parquet('/storage2/payal/dropbox/private/data/processed/lvef.parquet'
).filter(pl.col('hospital')=='MGH'
).filter(pl.col('lvef').is_not_null()
).filter(pl.col('empi').is_in(mgh_subset))
print('MGH', round(get_prob(lvef, 'lvef_date', 'lvef', 'empi')*100, 1))

lvef = pl.read_parquet('/storage2/payal/dropbox/private/data/processed/lvef.parquet'
).filter(pl.col('hospital')=='BWH'
).filter(pl.col('lvef').is_not_null()
).filter(pl.col('empi').is_in(bwh_subset))
print('BWH', round(get_prob(lvef, 'lvef_date', 'lvef', 'empi')*100, 1))

lvef = pl.read_csv('/storage/shared/mimic/raw/lvef.csv'
).join(df_mim, right_on='empi', left_on='subject_id',
).select(['subject_id', 'study_datetime', 'result',]
).unique().filter(pl.col('result').is_not_null()
).filter(pl.col('subject_id').is_in(mim_subset)
).with_columns(
    pl.col('result').cast(int),
    pl.col("study_datetime").str.to_date("%Y-%m-%dT%H:%M:%S").alias("study_date"),
)
print('MIMIC', round(get_prob(lvef, 'study_date', 'result', 'subject_id')*100, 1))


YEARLY DROP PROBABILITY
MGH 6.3
BWH 7.4
MIMIC 10.2
