# Preparing Data For Med-ROAR
This notebook demonstrates how to prepare mixtures of tabular and time-series EHR data for training Med-ROAR. ROAR architectures require that all input tokens consist of a Type, Position, Value triplit to enable random order mixed type modeling.
* Type = The type of time series record or tabular variable.
* Position (timestamp) = The position of that record in a sequence. This is a continuous (float) number. Tabular variables have no relative position so their position is always 0.
* Value = The discrete value of the token. i.e. the value being predicted.

This notebook converts MIMIC-VI ICU EHR patient data into this format.

In [1]:
import pandas as pd
import numpy as np
from zipfile import ZipFile
from sklearn.preprocessing import LabelEncoder
from tqdm.notebook import tqdm

## Load MIMIC-IV From Its Zip

In [2]:
z = ZipFile("mimic-iv-1.0.zip")

## Define How Many Bins We Want to Break Continuous Data Into

In [3]:
continuous_data_levels = 32

## Define The Minimum Number of Records a Record Type Needs to be Included

In [4]:
min_record_count = 1000

## Define the Max Number of Diagnoses
Many diagnoses are rare, we will use the top N most common

In [5]:
max_diagnoses = 256

## Define a function to get time-series information from MIMIC-IV tables
Discards types of records with too few occurrences. Discretizes numeric records into N quantiles on a per-record-type basis.

In [6]:
def get_info(data_type_name, filename, timestampcol, itemtypecol, valuecol, value_is_numeric=True, min_count=1000, num_bins=8, nrows=10000000000):
    if valuecol is not None:
        d = pd.read_csv(z.open(filename), compression="gzip", nrows=nrows, usecols=["hadm_id", timestampcol, itemtypecol, valuecol])
    else:
        d = pd.read_csv(z.open(filename), compression="gzip", nrows=nrows, usecols=["hadm_id", timestampcol, itemtypecol])
        d[valuecol] = 1
    d[timestampcol] = pd.to_datetime(d[timestampcol], errors="coerce")
    d[itemtypecol] = data_type_name+"_"+ d[itemtypecol].astype(str)
    d.columns = ["hadm_id", "timestamp", "itemtype", "value"]
    d = d.set_index("itemtype").dropna()
    d = d.sort_index()
    d["value"] = d["value"].astype("object")
    for item_type in tqdm(d.index.unique()):
        if d.loc[[item_type], "value"].shape[0] < min_count:
            d.loc[[item_type], "value"] = np.nan
        elif value_is_numeric:
            if d.loc[[item_type], "value"].unique().shape[0] == 1:
                d.loc[[item_type], "value"] = 0
            
            d.loc[[item_type], "value"] = pd.qcut(d.loc[[item_type], "value"], num_bins, duplicates="drop")
        
        replacement = {}
        for i, val in enumerate(pd.Series(d.loc[[item_type], "value"].unique()).sort_values()):
            replacement[val] = i
        
        d.loc[[item_type], "value"] = d.loc[[item_type], "value"].astype("category").cat.rename_categories(replacement)
                
    d = d.dropna().reset_index().set_index("hadm_id")
    return d

## Get Data From Time-Series Tables In MIMIC-IV

In [7]:
chart_data = get_info("chart", "mimic-iv-1.0/icu/chartevents.csv.gz", "charttime", "itemid", "valuenum", 
                value_is_numeric=True, min_count=min_record_count, num_bins=continuous_data_levels)
chart_data

  0%|          | 0/939 [00:00<?, ?it/s]

Unnamed: 0_level_0,itemtype,timestamp,value
hadm_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
27477323,chart_220045,2139-04-19 17:00:00,2
24056235,chart_220045,2173-01-30 20:00:00,13
22606270,chart_220045,2158-04-28 12:15:00,20
27537527,chart_220045,2168-12-03 14:00:00,5
28579091,chart_220045,2179-06-10 14:00:00,29
...,...,...,...
22363461,chart_229882,2145-07-16 10:02:00,0
22363461,chart_229882,2145-07-16 20:14:00,0
22363461,chart_229882,2145-07-17 00:04:00,0
22363461,chart_229882,2145-07-14 20:00:00,0


In [8]:
procedure_data = get_info("procedure", "mimic-iv-1.0/icu/procedureevents.csv.gz", "starttime", "itemid", "value", 
                value_is_numeric=True, min_count=min_record_count, num_bins=continuous_data_levels)
procedure_data

  0%|          | 0/157 [00:00<?, ?it/s]

Unnamed: 0_level_0,itemtype,timestamp,value
hadm_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
24731829,procedure_224263,2139-09-15 00:00:00,14
20048588,procedure_224263,2164-03-04 17:11:00,29
28358146,procedure_224263,2190-08-23 09:20:00,11
25884420,procedure_224263,2146-01-05 12:19:00,16
25095997,procedure_224263,2125-02-25 09:20:00,1
...,...,...,...
24943226,procedure_229526,2166-02-05 15:30:00,15
27900642,procedure_229526,2189-10-29 10:50:00,19
25410797,procedure_229526,2129-03-19 13:18:00,12
28773698,procedure_229526,2183-02-18 13:05:00,11


In [9]:
input_data_start = get_info("input", "mimic-iv-1.0/icu/inputevents.csv.gz", "starttime", "itemid", "amount", 
                value_is_numeric=True, min_count=min_record_count, num_bins=continuous_data_levels)
input_data_start

  0%|          | 0/325 [00:00<?, ?it/s]

Unnamed: 0_level_0,itemtype,timestamp,value
hadm_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
21509458,input_220862,2141-02-11 15:15:00,1
28113079,input_220862,2122-03-29 18:00:00,1
21552777,input_220862,2129-08-22 20:00:00,1
24572136,input_220862,2178-05-11 23:03:00,1
24572136,input_220862,2178-05-10 02:00:00,1
...,...,...,...
25563870,input_229654,2176-05-26 19:26:00,31
26074264,input_229654,2146-05-06 10:55:00,16
24012717,input_229654,2121-03-10 03:02:00,5
22634029,input_229654,2129-02-18 03:02:00,14


In [10]:
input_data_stop = get_info("input", "mimic-iv-1.0/icu/inputevents.csv.gz", "endtime", "itemid", "amount", 
                value_is_numeric=True, min_count=min_record_count, num_bins=continuous_data_levels)
input_data_stop

  0%|          | 0/325 [00:00<?, ?it/s]

Unnamed: 0_level_0,itemtype,timestamp,value
hadm_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
21509458,input_220862,2141-02-11 15:45:00,1
28113079,input_220862,2122-03-29 18:30:00,1
21552777,input_220862,2129-08-22 20:30:00,1
24572136,input_220862,2178-05-11 23:33:00,1
24572136,input_220862,2178-05-10 02:30:00,1
...,...,...,...
25563870,input_229654,2176-05-27 09:43:00,31
26074264,input_229654,2146-05-06 22:54:00,16
24012717,input_229654,2121-03-10 05:53:00,5
22634029,input_229654,2129-02-18 08:40:00,14


In [11]:
output_data = get_info("input", "mimic-iv-1.0/icu/outputevents.csv.gz", "charttime", "itemid", "value", 
                value_is_numeric=True, min_count=min_record_count, num_bins=continuous_data_levels)
output_data

  0%|          | 0/71 [00:00<?, ?it/s]

Unnamed: 0_level_0,itemtype,timestamp,value
hadm_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
24954351,input_226559,2152-06-21 10:00:00,15
24954351,input_226559,2152-06-24 15:00:00,23
24954351,input_226559,2152-06-24 13:00:00,23
25535556,input_226559,2125-12-29 08:00:00,24
24954351,input_226559,2152-06-24 22:00:00,7
...,...,...,...
22217710,input_229413,2150-08-02 06:00:00,0
23936898,input_229413,2185-11-01 04:00:00,5
22217710,input_229413,2150-08-02 07:00:00,5
22217710,input_229413,2150-08-02 08:00:00,3


In [12]:
datetime_data = get_info("input", "mimic-iv-1.0/icu/datetimeevents.csv.gz", "charttime", "itemid", None, 
                value_is_numeric=False, min_count=min_record_count, num_bins=continuous_data_levels)
datetime_data

  0%|          | 0/170 [00:00<?, ?it/s]

Unnamed: 0_level_0,itemtype,timestamp,value
hadm_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
26413041,input_224183,2159-07-31 18:04:00,0
20594201,input_224183,2196-03-11 04:18:00,0
23535512,input_224183,2121-03-22 03:57:00,0
21533475,input_224183,2140-10-09 01:00:00,0
27898558,input_224183,2138-04-23 20:39:00,0
...,...,...,...
26846190,input_229739,2151-05-11 20:02:00,0
23604702,input_229739,2160-01-30 15:35:00,0
28765467,input_229739,2191-10-16 09:34:00,0
26846190,input_229739,2151-05-11 13:41:00,0


## Filter Data to Only ICU Patients

In [13]:
procedure_data = procedure_data[procedure_data.index.isin(chart_data.index)]
input_data_start = input_data_start[input_data_start.index.isin(chart_data.index)]
input_data_stop = input_data_stop[input_data_stop.index.isin(chart_data.index)]
output_data = output_data[output_data.index.isin(chart_data.index)]
datetime_data = datetime_data[datetime_data.index.isin(chart_data.index)]

## Combine Tables

In [14]:
data = pd.concat([chart_data, procedure_data, input_data_start, input_data_stop, datetime_data, output_data])

## Count Each Record Type

In [15]:
counts_by_type = data.reset_index(drop=True)[["itemtype", "value"]].drop_duplicates().groupby("itemtype").count().reset_index()
counts_by_type["source"] = [x.split("_")[0] for x in counts_by_type.itemtype]
counts_by_type["id"] = [x.split("_")[1] for x in counts_by_type.itemtype]
counts_by_type = counts_by_type.drop("itemtype", axis=1).set_index(["source", "id"])

## Filter and Prepare Data

Remove patients with more than 2048 records in an encounter and define the max stay length as 2 weeks.

Normalize timestamps based on the max stay length.

Convert record type strings into int

In [16]:
#discard patients with extreme stay lengths
max_stay = "14 days"

counts = data.groupby("hadm_id").count()["value"]
data = data.loc[counts < 2048].copy()

#swap names for ints in a reversible way
le = LabelEncoder()
le.fit(data.itemtype)

#record original record names
key = pd.DataFrame({"source":[x.split("_")[0] for x in data.itemtype], "id":[x.split("_")[1] for x in data.itemtype], "label":le.transform(data.itemtype)}).drop_duplicates()

#convert record type string into int
data["itemtype"] = le.transform(data.itemtype)

#timestamps are reference to the first one and normalized within the max stay length
data["min_timestamp"] =  data.groupby("hadm_id").timestamp.min()
data["timestamp"] = data.timestamp - data.min_timestamp
data["timestamp"] = data.timestamp/pd.to_timedelta(max_stay)
data = data.drop("min_timestamp", axis=1)

#drop patients with stays longer than 14 days
data = data.drop(data[data.timestamp >= 1].index)

#sort values for convenience
data = data.reset_index().drop_duplicates().sort_values(["hadm_id", "timestamp"]).set_index("hadm_id")

#again remove records which show up less than the min record threshold after all the data processing
vc = data.itemtype.value_counts()
data = data[data.itemtype.isin(vc[vc > min_record_count].index)].copy()

## Get Tabular Demographic Information

In [17]:
patients = pd.read_csv(z.open("mimic-iv-1.0/core/patients.csv.gz"), compression="gzip").set_index("subject_id")
admissions = pd.read_csv(z.open("mimic-iv-1.0/core/admissions.csv.gz"), compression="gzip").set_index("subject_id")
patient_data = admissions.merge(patients, left_index=True, right_index=True).reset_index().set_index("hadm_id")
outcomes = patient_data[["hospital_expire_flag"]]
patient_data = patient_data[["admission_type", "admission_location", "insurance", "language", "marital_status", "ethnicity", "gender", "anchor_age"]]
patient_data["anchor_age"] = pd.qcut(patient_data["anchor_age"], continuous_data_levels, duplicates="drop")
patient_data

Unnamed: 0_level_0,admission_type,admission_location,insurance,language,marital_status,ethnicity,gender,anchor_age
hadm_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
21038362,ELECTIVE,,Other,ENGLISH,SINGLE,UNKNOWN,F,"(-0.001, 19.0]"
24941086,ELECTIVE,,Other,ENGLISH,,WHITE,F,"(-0.001, 19.0]"
21965160,ELECTIVE,,Other,ENGLISH,,UNKNOWN,M,"(-0.001, 19.0]"
24709883,ELECTIVE,,Other,ENGLISH,,OTHER,F,"(-0.001, 19.0]"
23272159,ELECTIVE,,Other,ENGLISH,,BLACK/AFRICAN AMERICAN,M,"(-0.001, 19.0]"
...,...,...,...,...,...,...,...,...
20786062,SURGICAL SAME DAY ADMISSION,PHYSICIAN REFERRAL,Medicare,ENGLISH,SINGLE,WHITE,M,"(30.0, 33.0]"
20943099,EW EMER.,TRANSFER FROM HOSPITAL,Other,ENGLISH,DIVORCED,HISPANIC/LATINO,F,"(44.0, 46.0]"
23176714,SURGICAL SAME DAY ADMISSION,PHYSICIAN REFERRAL,Other,ENGLISH,MARRIED,WHITE,M,"(74.0, 76.0]"
22347500,SURGICAL SAME DAY ADMISSION,PHYSICIAN REFERRAL,Other,ENGLISH,MARRIED,WHITE,F,"(46.0, 49.0]"


## Hold Out Patient Mortality Separately as It Happens Later

In [18]:
outcomes

Unnamed: 0_level_0,hospital_expire_flag
hadm_id,Unnamed: 1_level_1
21038362,0
24941086,0
21965160,0
24709883,0
23272159,0
...,...
20786062,0
20943099,0
23176714,0
22347500,0


## Convert Tabular Demographic Variable Strings Into Ints

In [19]:
na_mask = patient_data.isna()

for col in patient_data:
    le = LabelEncoder()
    le.fit(patient_data[col])
    patient_data[col] = le.transform(patient_data[col])

patient_data[na_mask] = np.nan
patient_data

Unnamed: 0_level_0,admission_type,admission_location,insurance,language,marital_status,ethnicity,gender,anchor_age
hadm_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
21038362,3,,2,1,2.0,6,0,0
24941086,3,,2,1,,7,0,0
21965160,3,,2,1,,6,1,0
24709883,3,,2,1,,4,0,0
23272159,3,,2,1,,2,1,0
...,...,...,...,...,...,...,...,...
20786062,7,6.0,1,1,2.0,7,1,4
20943099,5,8.0,2,1,0.0,3,0,8
23176714,7,6.0,2,1,1.0,7,1,23
22347500,7,6.0,2,1,1.0,7,0,9


## Get Patient Diagnoses

These do not have a timestamp and need to be handled differently that time-series data. Patients will have between zero and some maximum number of them per-encounter.

In [20]:
icd_codes = pd.read_csv(z.open('mimic-iv-1.0/hosp/diagnoses_icd.csv.gz'), compression="gzip").set_index(["icd_code", "icd_version"])
icd_info = pd.read_csv(z.open('mimic-iv-1.0/hosp/d_icd_diagnoses.csv.gz'), compression="gzip").set_index(["icd_code", "icd_version"])
icd_codes["long_title"] = icd_info.long_title

## Save Patient Diagnoses

In [24]:
icd_codes.reset_index()[["hadm_id","icd_code", "icd_version", "long_title"]].to_csv("mimic_patient_diagnoses.csv.gz", index=False)

## Limit ICD Codes to Ones Which Are Not Rare

Discard codes which are barely ever seen

In [27]:
icd_codes = icd_codes[icd_codes.long_title.isin(icd_codes.long_title.value_counts()[:max_diagnoses].index)].copy().reset_index().set_index("hadm_id")
icd_codes

Unnamed: 0_level_0,icd_code,icd_version,subject_id,seq_num,long_title
hadm_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
20475282,V270,9,15734973,5,"Outcome of delivery, single liveborn"
20475282,64891,9,15734973,1,Other current conditions classifiable elsewher...
21518990,V270,9,11442057,6,"Outcome of delivery, single liveborn"
20817034,7840,9,10072949,4,Headache
20817034,V270,9,10072949,5,"Outcome of delivery, single liveborn"
...,...,...,...,...,...
25594844,Z8673,10,13747041,11,Personal history of transient ischemic attack ...
25594844,N189,10,13747041,10,"Chronic kidney disease, unspecified"
25594844,R0902,10,13747041,8,Hypoxemia
25594844,J189,10,13747041,2,"Pneumonia, unspecified organism"


## Transform ICD Type Identifiers Into Int Labels

In [28]:
le = LabelEncoder()
le.fit(icd_codes["long_title"])
icd_codes["long_title"] = le.transform(icd_codes["long_title"])
icd_codes = icd_codes[icd_codes.seq_num < 32].copy()
icd_codes

Unnamed: 0_level_0,icd_code,icd_version,subject_id,seq_num,long_title
hadm_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
20475282,V270,9,15734973,5,177
20475282,64891,9,15734973,1,166
21518990,V270,9,11442057,6,177
20817034,7840,9,10072949,4,96
20817034,V270,9,10072949,5,177
...,...,...,...,...,...
25594844,Z8673,10,13747041,11,195
25594844,N189,10,13747041,10,49
25594844,R0902,10,13747041,8,119
25594844,J189,10,13747041,2,199


## Convert to Type, Position, Value format

In [29]:
icd_codes["itemtype"] = icd_codes["long_title"]
icd_codes["timestamp"] = 0
icd_codes["value"] = 0
icd_codes = icd_codes[["itemtype", "timestamp", "value"]].copy()
icd_codes

Unnamed: 0_level_0,itemtype,timestamp,value
hadm_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
20475282,177,0,0
20475282,166,0,0
21518990,177,0,0
20817034,96,0,0
20817034,177,0,0
...,...,...,...
25594844,195,0,0
25594844,49,0,0
25594844,119,0,0
25594844,199,0,0


## Put The Final Dataset Together

Output is:
* A list of patient IDs
* Time-Series Records: 3 2D matrices corresponding to Type, Position, and Value. Each row encodes the time series information for 1 patient encounter.
* ICD Records: 3 2D matrices corresponding to Type, Position, and Value. Each row encodes the diagnosis (ICD codes) information for 1 patient encounter. In MIMIC-IV diagnoses do not have timestamps so all positions are 0. Value is determined by presence of an ICD code. Every code that is present for patient has a value of 1. Dummy codes which the patient does not have are randomly sampled from the dataset and given a value of 0.
* Demographics: 1 2D matrix containing the tabular patient demographics in a normal table.
* A list of patient outcomes (death vs survived).

In [30]:
def make_datasets(data, diagnoses, patient_data, patient_outcomes, max_diagnoses=32, max_len=2048):
    data = data.sort_index()
    patient_data = patient_data.sort_index()
    diagnoses = diagnoses.sort_index()

    patient_ids = []
    timestamps = []
    itemtypes = []
    values = []

    timestamps_diag = []
    itemtypes_diag = []
    values_diag = []
    
    patient_infos = []
    outcomes = []
    all_diagnoses = pd.Series(diagnoses.itemtype.unique())
    for h in tqdm(data.index.unique()):

        d = data.loc[h]
        ts, it, va = np.full((1,max_len), np.nan), np.full((1,max_len), np.nan), np.full((1,max_len), np.nan)
        ts[0,0:len(d)] = d.timestamp*max_len
        it[0,0:len(d)] = d.itemtype
        va[0,0:len(d)] = d.value
        timestamps.append(ts)
        itemtypes.append(it)
        values.append(va)
        
        
        tsd, itd, vad = np.full((1,max_diagnoses*2), np.nan), np.full((1,max_diagnoses*2), np.nan), np.full((1,max_diagnoses*2), np.nan)
        if h in diagnoses.index:
            diag = diagnoses.loc[[h]]
            non_diagnoses = all_diagnoses[~all_diagnoses.isin(diag.itemtype)].sample(n=len(diag)).unique()
            itd[0,0:len(diag)] = diag.itemtype
            vad[0,0:len(diag)] = 1
            itd[0,max_diagnoses:max_diagnoses+len(non_diagnoses)] = non_diagnoses
            vad[0,max_diagnoses:max_diagnoses+len(non_diagnoses)] = 0
        timestamps_diag.append(tsd)
        itemtypes_diag.append(itd)
        values_diag.append(vad)
        
        patient_infos.append(np.array(patient_data.loc[h]).reshape(1,-1))
        outcomes.append(np.array(patient_outcomes.loc[h]).reshape(1,-1))

        patient_ids.append(np.array(h).reshape(1,-1))

    return np.concatenate(patient_ids), \
        np.concatenate(timestamps), np.concatenate(itemtypes),  np.concatenate(values), \
        np.concatenate(timestamps_diag), np.concatenate(itemtypes_diag),  np.concatenate(values_diag), \
        np.concatenate(patient_infos), \
        np.concatenate(outcomes)

## Make the Dataset

In [31]:
id, ts, it, va, tsd, itd, vad, pi, o = make_datasets(data, icd_codes, patient_data, outcomes)
key["adjusted_label"] = key.label

key = key.set_index(["source", "id"])
key["num_levels"] = counts_by_type.value

  0%|          | 0/50474 [00:00<?, ?it/s]

## Save The Original Labels

In [32]:
key.to_csv("label_key.csv")

## Save the Dataset In a Compact Format

In [33]:
np.savez_compressed("medical_data.npz", 
                    patient_ids=id,
                    timestamps=ts, 
                    itemtypes=it, 
                    values=va,
                    timestamps_diagnoses=tsd,
                    itemtypes_diagnoses=itd,
                    values_diagnoses=vad,
                    patient_info=pi,
                    patient_outcomes=o)