#### Read Cohort, get train/val splits

In [3]:
import os
import numpy as np
import pandas as pd

cohort_dir = "/local-scratch/nigam/projects/jlemmon/cl-clmbr/experiments/main/data/cohort"

# load cohort
df_cohort = pd.read_parquet(
    os.path.join(
        cohort_dir,
        "cohort_split.parquet",
    ),
    engine='pyarrow'
)

# datetime -> date
df_cohort = df_cohort.assign(date = pd.to_datetime(df_cohort['admit_date']).dt.date)

# get train/val sets
train = df_cohort.query(
    f"hospital_mortality_fold_id!=['val','test','ignore'] and admission_year==[2008,2009,2010,2011,2012,2013,2014,2015,2016]"
)
    
val = df_cohort.query(
    f"hospital_mortality_fold_id==['val'] and admission_year==[2008,2009,2010,2011,2012,2013,2014,2015,2016]"
)

#### Convert and save person IDs for CLMBR info

In [4]:
data_path = '/local-scratch/nigam/projects/jlemmon/cl-clmbr/experiments/main/data/labelled_data/hospital_mortality/pretrained/gru_sz_800_do_0.1_cd_0_dd_0_lr_0.001_l2_0.01'
extracts_dir = "/local-scratch/nigam/projects/jlemmon/cl-clmbr/experiments/main/data/extracts/20210723"
clmbr_model_path = '/local-scratch/nigam/projects/jlemmon/cl-clmbr/experiments/main/artifacts/models/clmbr/pretrained/models/gru_sz_800_do_0.1_cd_0_dd_0_lr_0.001_l2_0.01'
model_debug_path = '/local-scratch/nigam/projects/jlemmon/cl-clmbr/experiments/main/debug/logging'

In [5]:
# load lists of patient info for hospital_mortality task
train_pids = pd.read_csv(f'{data_path}/ehr_ml_patient_ids_train.csv').to_numpy().flatten()
val_pids = pd.read_csv(f'{data_path}/ehr_ml_patient_ids_val.csv').to_numpy().flatten()

train_days = pd.read_csv(f'{data_path}/day_indices_train.csv').to_numpy().flatten()
val_days = pd.read_csv(f'{data_path}/day_indices_val.csv').to_numpy().flatten()

train_labels = pd.read_csv(f'{data_path}/labels_train.csv').to_numpy().flatten()
val_labels = pd.read_csv(f'{data_path}/labels_val.csv').to_numpy().flatten()

train_data = (train_labels,train_pids,train_days)
val_data = (val_labels,val_pids,val_days)

# before creating dataset length of pids matches length of labels
assert(len(train_labels) == len(train_pids)==len(train_days))
assert(len(val_labels) == len(val_pids)==len(val_days))

#### Create Patient Timeline Dataset

In [6]:
from ehr_ml.clmbr import PatientTimelineDataset

# generate dataset
dataset = PatientTimelineDataset(
    os.path.join(extracts_dir, "extract.db"), 
    os.path.join(extracts_dir, "ontology.db"),
    os.path.join(clmbr_model_path, "info.json"),
    train_data,
    val_data
)

#### End-to-end Binary Classification Model with CLMBR as Encoder

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from ehr_ml.clmbr.rnn_model import PatientRNN
from tqdm import tqdm

class BinaryLinearCLMBRClassifier(nn.Module):
            
    def __init__(self, model, device=None):
        super().__init__()
        self.config = model.config
        self.timeline_model = model.timeline_model
        self.linear = nn.Linear(model.config["size"], 1)
        self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.criterion = nn.Sigmoid()
        self=self.to(self.device)

    def forward(self, batch):
        outputs = dict()
        
        embedding = self.timeline_model(batch["rnn"])

        label_indices, label_values = batch["label"]

        flat_embeddings = embedding.view((-1, embedding.shape[-1]))
        
        target_embeddings = F.embedding(label_indices, flat_embeddings) 
        
        logits = self.linear(target_embeddings).flatten()
        
        outputs['pids']=batch['pid']
        outputs['pred_probs'] = self.criterion(logits)
        outputs['labels'] = label_values
        outputs['loss'] = F.binary_cross_entropy_with_logits(
            logits, label_values.float(), reduction="sum"
        )
        
        return outputs
    
    def predict(self, dataloader):
        
        self.eval()
        
        pred_probs = []
        labels = []
        pids = []
        mismatch_pids = []
        pbar = tqdm(total=dataloader.num_batches)
        with torch.no_grad():
            for batch in dataloader:
                if len(batch['pid']) != len(batch['label']):
                    mismatch_pids.extend(batch['pid'])
                outputs = self.forward(batch)
                pred_probs.extend(list(outputs['pred_probs'].cpu().numpy()))
                labels.extend(outputs['labels'].cpu().numpy())
                pids.extend(outputs['pids'])
                pbar.update(1)
                
        return {
            'pid': pids,
            'labels': labels,
            'pred_probs': pred_probs,
            'mismatch_pids': mismatch_pids,
        }
        
    
    def load_weights(self,model_dir):
        
        model_data = torch.load(
            os.path.join(model_dir,"best")
        )
        
        self.load_state_dict(model_data)
        
        return self

#### Construct Model Config and Load CLMBR info

In [8]:
from ehr_ml.clmbr.utils import read_info
import ehr_ml
import json

model_dir = '/local-scratch/nigam/projects/jlemmon/cl-clmbr/experiments/main/artifacts/models/clmbr/pretrained/models/gru_sz_800_do_0.1_cd_0_dd_0_lr_0.001_l2_0.01'

clmbr_model = ehr_ml.clmbr.CLMBR.from_pretrained(clmbr_model_path, 'cuda:0').to('cuda:0')
clmbr_model.config['model_dir'] = '/local-scratch/nigam/projects/jlemmon/cl-clmbr/experiments/main/debug/model'
clmbr_model.config['epochs_per_cycle'] = 5

#### Train model with Trainer class

In [9]:
device = torch.device('cuda:0')

model = BinaryLinearCLMBRClassifier(clmbr_model, device)

In [10]:
 # have trained model saved, can skip this cell
from ehr_ml.clmbr import Trainer
from ehr_ml.utils import set_up_logging
import logging

set_up_logging(os.path.join(model_debug_path,'train.log'))
logging.info("Args: %s", str(model.config))
trainer = Trainer(model)
trainer.train(dataset, use_pbar=False)

2022-04-05 18:09:25,456 Args: {'batch_size': 2000, 'eval_batch_size': 2000, 'num_first': 9262, 'num_second': 10044, 'size': 800, 'lr': 0.001, 'dropout': 0.1, 'encoder_type': 'gru', 'rnn_layers': 1, 'tied_weights': True, 'l2': 0.01, 'b1': 0.9, 'b2': 0.999, 'e': 1e-08, 'epochs_per_cycle': 5, 'warmup_epochs': 2, 'code_dropout': 0.0, 'day_dropout': 0.0, 'model_dir': '/local-scratch/nigam/projects/jlemmon/cl-clmbr/experiments/main/debug/model'}
2022-04-05 18:09:25,460 Batches per epoch = 1124
2022-04-05 18:09:25,461 Total batches = 5620
2022-04-05 18:09:25,462 Start training
2022-04-05 18:09:25,462 About to start epoch 0
	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  ../torch/csrc/utils/python_arg_parser.cpp:1055.)
  exp_avg.mul_(beta1).add_(1 - beta1, grad)
2022-04-05 18:09:25,751 Seen batch 0
2022-04-05 18:10:48,601 Epoch 0 is complete
2022-04-05 18:11:39,404 Train loss: 2.46325217233

KeyboardInterrupt: 

In [11]:
model = model.load_weights(model.config['model_dir'])

#### Construct dataloader for validation set to get model predictions

In [12]:
from ehr_ml.clmbr.dataset import DataLoader

dataloader = DataLoader(dataset, threshold = model.config['num_first'], is_val=True)

In [13]:
outputs = model.predict(dataloader)
assert(len(outputs['pid'])==len(outputs['labels'])==len(outputs['pred_probs']))

241it [00:08, 29.12it/s]


In [14]:
# different lengths of pids vs. labels & pred_probs
# seems to only happen to 31 pids
len(outputs['pid']), len(outputs['labels']), len(outputs['pred_probs'])

(18867, 18867, 18867)

In [15]:
from ehr_ml import timeline

timelines = timeline.TimelineReader(os.path.join(extracts_dir, "extract.db"))

# In my main script I've tried printing the pid list for a batch if a mismatch was found
# example pids from last batch of validation, one of these pids has no prediction/label
# can't tell which one is the problem pid, all the pids seem to be within my validation range (2008-01-01 to 2016-12-31)
# not sure why the dataloader/dataset is not loading the timeline/labels for these pids
for pid in [2574937, 2163575, 1216845, 162573, 1985692, 1389474, 1023077]:
    patient = timelines.get_patient(pid)
    print(
        "date for day ID:",
        patient.days[
            val_days[
                np.where(val_pids==pid)[0][0]
            ]
        ].date,
        "patient's admission date:", 
        val.iloc[np.where(val_pids==pid)[0][0]]['admit_date']
    )

date for day ID: 2011-03-14 patient's admission date: 2011-03-14 23:00:00
date for day ID: 2016-03-13 patient's admission date: 2016-03-13 16:42:00
date for day ID: 2016-08-03 patient's admission date: 2016-08-03 07:52:00
date for day ID: 2014-11-24 patient's admission date: 2014-11-24 13:38:00
date for day ID: 2016-07-28 patient's admission date: 2016-07-28 21:24:00
date for day ID: 2013-04-04 patient's admission date: 2013-04-04 23:52:00
date for day ID: 2008-01-22 patient's admission date: 2008-01-22 12:20:00


In [16]:
from sklearn.metrics import roc_auc_score as auc
auc(outputs['labels'], outputs['pred_probs'])

0.960986464846224