In [1]:
import os
import torch
import pandas as pd 
from ehr_ml.clmbr.dataset import DataLoader
from ehr_ml.clmbr.prediction_model import CLMBR
from ehr_ml.clmbr import PatientTimelineDataset
from ehr_ml.clmbr import convert_patient_data
from ehr_ml.clmbr.utils import read_config, read_info
from ehr_ml.clmbr.trainer import Trainer

In [2]:
info_dir = "/local-scratch/nigam/projects/lguo/temp_ds_shift_robustness/clmbr/experiments/clmbr/clmbr_artifacts/infos/2009_2012"
info = read_info(os.path.join(info_dir,'info.json'))

In [3]:
df_cohort = pd.read_parquet(
    os.path.join(
        "/local-scratch/nigam/projects/lguo/temp_ds_shift_robustness/clmbr/cohorts/admissions/cohort",
        "cohort_split.parquet",
    )
)

df_cohort = df_cohort.assign(
    date = pd.to_datetime(df_cohort['admit_date']).dt.date
)

train = df_cohort.query(
    f"fold_id!=['val','test','1'] and admission_year==[2009,2010,2011,2012]"
)

# use same validation set as clmbr pretraining
val = df_cohort.query(
    f"fold_id==['1'] and admission_year==[2009,2010,2011,2012]"
)

# convert patient ids and save to train_splits path
train_person_ids, train_day_ids = convert_patient_data(
    info['extract_dir'], 
    train['person_id'], 
    train['date']
)

val_person_ids, val_day_ids = convert_patient_data(
    info['extract_dir'], 
    val['person_id'], 
    val['date']
)

In [4]:
len(train_person_ids), len(val_person_ids)

(29061, 7267)

In [5]:
dataset = PatientTimelineDataset(
    os.path.join(info["extract_dir"], "extract.db"),
    os.path.join(info["extract_dir"], "ontology.db"),
    os.path.join(info_dir, "info.json"),
    ([0 for _ in train_person_ids], train_person_ids, train_day_ids),
    ([0 for _ in val_person_ids], val_person_ids, val_day_ids),
)

In [6]:
from torch.nn.functional import binary_cross_entropy_with_logits as bce

losses = {}

for encoder in ['gru','transformer']:
    
    model_dir = f"/local-scratch/nigam/projects/lguo/temp_ds_shift_robustness/clmbr/experiments/clmbr/clmbr_artifacts/models/2009_2012/{encoder}/"
    models = os.listdir(model_dir)
    losses[encoder]={}
    
    for model in models:
        print(f"evaluating {encoder} model {model}")
        
        m = CLMBR.from_pretrained(os.path.join(model_dir,model))

        m.config['day_dropout']=0
        m.config['code_dropout']=0

        #trainer = Trainer(m)
        #losses[model] = trainer.evaluate(dataset, num_batches=len(val_person_ids),)

        m.eval()

        batches = DataLoader(
            dataset=dataset,
            threshold=m.config['num_first'],
            is_val=True,
            batch_size=2000,
        )

        num_batches = len(batches)

        with torch.no_grad():
            loss = 0
            for batch in batches:
                out = m.forward(batch)

                (
                    non_text_indices,
                    non_text_expected_output,
                    seen_before,
                    non_text_indices1,
                    non_text_expected_output1,
                    seen_before1,
                ) = batch['task']

                loss += bce(out['values'],non_text_expected_output,reduction='mean').cpu().numpy()
        
        loss/=num_batches
        
        print(f"{encoder} model {model} loss: {loss}")
        losses[encoder][model]=loss

evaluating gru model 0
gru model 0 loss: 0.6733505258505995
evaluating gru model 11
gru model 11 loss: 0.7340359112078493
evaluating gru model 26
gru model 26 loss: 0.7286979989572004
evaluating gru model 12
gru model 12 loss: 0.68930495665832
evaluating gru model 7
gru model 7 loss: 0.6976799504323439
evaluating gru model 6
gru model 6 loss: 0.72289858690717
evaluating gru model 8
gru model 8 loss: 0.7338464646176859
evaluating gru model 17
gru model 17 loss: 0.7290476154197346
evaluating gru model 22
gru model 22 loss: 0.6835146051916209
evaluating gru model 1
gru model 1 loss: 0.6924780017950318
evaluating gru model 18
gru model 18 loss: 0.6726470433852889
evaluating gru model 15
gru model 15 loss: 0.6940031688321721
evaluating gru model 24
gru model 24 loss: 0.7024940489368006
evaluating gru model 23
gru model 23 loss: 0.7359668531201102
evaluating gru model 19
gru model 19 loss: 0.6872853826392781
evaluating gru model 3
gru model 3 loss: 0.6918611932884563
evaluating gru model 20


In [7]:
df = pd.concat((
    (
        pd.DataFrame.from_dict(losses[x],orient='index')
        .reset_index()
        .rename(columns={'index':'model_num',0:'loss'})
        .assign(encoder=x)
    ) for x in losses.keys()
))

In [8]:
df

Unnamed: 0,model_num,loss,encoder
0,0,0.673351,gru
1,11,0.734036,gru
2,26,0.728698,gru
3,12,0.689305,gru
4,7,0.697680,gru
...,...,...,...
43,45,0.768995,transformer
44,14,1.023786,transformer
45,36,0.840886,transformer
46,9,0.848912,transformer


In [9]:
#df.to_csv('tables/clmbr_losses.csv')