In [7]:
import os
import numpy as np
import pandas as pd
from ehr_ml.clmbr import PatientTimelineDataset
from ehr_ml.clmbr.dataset import DataLoader


In [15]:
# constant variables
cuda_device='cuda:2'
seed = 44
hparams_fpath = "/local-scratch/nigam/projects/jlemmon/cl-clmbr/experiments/main/hyperparams"
extract_path = '/local-scratch/nigam/projects/jlemmon/cl-clmbr/experiments/main/data/extracts/20210723'
model_path = '/local-scratch/nigam/projects/jlemmon/cl-clmbr/experiments/main/artifacts/models/clmbr/cl_ete/models'
cohort_fpath = "/local-scratch/nigam/projects/jlemmon/cl-clmbr/experiments/main/data/cohort"
pt_info_path = '/local-scratch/nigam/projects/jlemmon/cl-clmbr/experiments/main/artifacts/models/clmbr/pretrained/info/info.json'

In [19]:
# functions
def load_data():
    """
    Load datasets from split csv files.
    """

    data_path = "/local-scratch/nigam/projects/jlemmon/cl-clmbr/experiments/main/data/labelled_data/hospital_mortality/pretrained/gru_sz_800_do_0.1_cd_0_dd_0_lr_0.001_l2_0.01"

    train_pids = pd.read_csv(f'{data_path}/ehr_ml_patient_ids_train.csv')[:1000] #load small number of samples to speed up debug process
    val_pids = pd.read_csv(f'{data_path}/ehr_ml_patient_ids_val.csv')[:1000]

    train_days = pd.read_csv(f'{data_path}/day_indices_train.csv')[:1000]
    val_days = pd.read_csv(f'{data_path}/day_indices_val.csv')[:1000]

    train_labels = pd.read_csv(f'{data_path}/labels_train.csv')[:1000]
    val_labels = pd.read_csv(f'{data_path}/labels_val.csv')[:1000]

    train_data = (train_labels.to_numpy().flatten(),train_pids.to_numpy().flatten(),train_days.to_numpy().flatten())
    val_data = (val_labels.to_numpy().flatten(),val_pids.to_numpy().flatten(),val_days.to_numpy().flatten())

    return train_data, val_data

def read_file(filename, columns=None, **kwargs):
    print(filename)
    load_extension = os.path.splitext(filename)[-1]
    if load_extension == ".parquet":
        return pd.read_parquet(filename, columns=columns,**kwargs)
    elif load_extension == ".csv":
        return pd.read_csv(filename, usecols=columns, **kwargs)

In [13]:
train_data, val_data = load_data()
dataset = PatientTimelineDataset(extract_path + '/extract.db', 
                                         extract_path + '/ontology.db', 
                                         pt_info_path, 
                                         train_data, 
                                         val_data )

In [25]:
 with DataLoader(dataset, 9262, is_val=False, batch_size=1, device=cuda_device) as train_loader:
            for batch in train_loader:
                print(batch['rnn'])
                break

(tensor([9262,    1,    0,   49,   36,   55,   90,   83,   63,   64,   51,   44,
          62,   46,   45,   47,   61,   75,   76,   74,   54,   52,  106,  105,
         102,   81,   80,  112,  111,  113,  114,  108,  107,  130,  128,   82,
         126,  125,  127,  123,  117,  134,  132,  142,  135,   98,   94,   93,
          78,   67,   66,   65,   68,   59,   73,   70,   69,   72,   71,   79,
          77,  104,  103,   91,   60,   58,   92,   97,   96,   95,  101,  109,
         110,  115,  116,  118,  119,  120,  153,  152,  151,  203,  165,  162,
         150,  149,  161,  160,  159,  158,  163,  164,  173,  174,  170,  169,
         167,  181,  182,  140,  139,  138,  137,   99,  146,  143,  121,  191,
         189,  176,  175,  185,  186,  184,  177,  187,  183,  180,  179,  178,
         201,  198,  211,  200,  193,   53,    3,  217,  228,  232,  250,  260,
         206,  242,  243,  267,  256,  255,  209,  253,  265,  221,  244,  245,
         231,  291,  518,  124,  122,  

In [22]:
cohort = read_file(
        os.path.join(
            cohort_fpath,
            "cohort.parquet"
        ),
        engine='pyarrow'
    )

/local-scratch/nigam/projects/jlemmon/cl-clmbr/experiments/main/data/cohort/cohort.parquet


In [23]:
print(cohort)

        person_id          admit_date      discharge_date admit_date_midnight  \
0        29936887 2019-12-22 22:44:00 2019-12-26 16:00:00 2019-12-22 23:59:00   
1        29936888 2010-07-12 06:34:00 2010-07-14 10:38:00 2010-07-12 23:59:00   
2        29936900 2014-10-31 15:08:00 2014-11-03 13:25:00 2014-10-31 23:59:00   
3        29936914 2018-06-23 05:20:00 2018-06-25 13:48:00 2018-06-23 23:59:00   
4        29936936 2018-02-08 12:05:00 2018-02-10 16:30:00 2018-02-08 23:59:00   
...           ...                 ...                 ...                 ...   
224132   64773432 2021-05-13 19:06:00 2021-05-15 19:57:00 2021-05-13 23:59:00   
224133   64826084 2020-09-15 08:10:00 2020-09-17 16:24:00 2020-09-15 23:59:00   
224134   68184642 2021-06-21 11:30:00 2021-06-23 17:16:00 2021-06-21 23:59:00   
224135   68930745 2021-05-31 20:30:00 2021-06-08 13:47:00 2021-05-31 23:59:00   
224136   69021508 2021-07-13 05:01:00 2021-07-25 13:45:00 2021-07-13 23:59:00   

       discharge_date_midni