In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2

import logging

logging.basicConfig(level=logging.INFO)

import numpy as np
import pandas as pd
import tsdm
import polars as pl
import pyarrow as pa

ARROW_DURATION_TYPES = {
    pd.ArrowDtype(pa.duration(unit)) for unit in ["s", "ms", "us", "ns"]
}

ARROW_TIMESTAMP_TYPES = {
    pd.ArrowDtype(pa.timestamp(unit)) for unit in ["s", "ms", "us", "ns"]
}

ARROW_DATE_TYPES = {pd.ArrowDtype(pa.date32()), pd.ArrowDtype(pa.date64())}


def map_dtypes(df):
    """Converts pyarrow date/timestamp/duration types to numpy equivalents.

    Rationale: pyarrow types are currently bugged and do not support all operations.
    """

    for col, dtype in df.dtypes.items():
        if dtype in ARROW_DURATION_TYPES:
            df[col] = df[col].astype("timedelta64[ms]")
        elif dtype in ARROW_TIMESTAMP_TYPES:
            df[col] = df[col].astype("datetime64[ms]")
        elif dtype in ARROW_DATE_TYPES:
            df[col] = df[col].astype("datetime64[s]")
    return df

In [None]:
ds = tsdm.datasets.MIMIC_III()

In [None]:
# map dtypes in all tables.
for name, table in ds.tables.items():
    ds.tables[name] = map_dtypes(table)

ds

## Processing Metadata

In [None]:
# Preprocessing
admissions = ds.ADMISSIONS
patients = ds.PATIENTS
metadata = pd.merge(admissions, patients, on="SUBJECT_ID")
metadata = metadata.assign(ELAPSED_TIME=metadata.DISCHTIME - metadata.ADMITTIME)
# select patients with unique ID
counts = metadata["SUBJECT_ID"].value_counts()
unique_patients = counts[counts == 1].index
metadata = metadata.loc[metadata["SUBJECT_ID"].isin(unique_patients)].reset_index(
    drop=True
)
# select patients with 2-30 days of data.
# NOTE: Code by GRU-ODE-Bayes used `ELAPSED_TIME.dt.day > 2` but this is incorrect,
#  because it will select patients with at least 72 hours of data.
metadata = metadata.loc[
    (metadata.ELAPSED_TIME >= "2d") & (metadata.ELAPSED_TIME <= "30d")
]
# select patients with age between 15 and 100 years at admission.
YEAR = np.timedelta64(365, "D")
metadata = metadata.assign(AGE=metadata.ADMITTIME - metadata.DOB)
metadata = metadata.loc[(metadata.AGE >= 15 * YEAR) & (metadata.AGE <= 100 * YEAR)]
# select patients with chartevents data.
metadata = metadata.loc[metadata.HAS_CHARTEVENTS_DATA]

# select relevant columns.
metadata = metadata[[
    "SUBJECT_ID",
    "HADM_ID",
    "ADMITTIME",
    "DISCHTIME",
    "AGE",
    "ETHNICITY",
    "GENDER",
    "INSURANCE",
    "MARITAL_STATUS",
    "RELIGION",
]]

## Processing Inputevents (metavision system only)

In [None]:
inputevents = ds.INPUTEVENTS_MV

In [None]:
# select relevant patients.
inputevents = inputevents.loc[inputevents.HADM_ID.isin(metadata.HADM_ID)]
d_items = ds.D_ITEMS
inputs = pd.merge(inputevents, d_items, on="ITEMID")

In [None]:
# drop items without ICUSTAY_ID.
inputs = inputs.loc[inputs.ICUSTAY_ID.notnull()]

In [None]:
retained_list = [
    "Albumin 5%",
    "Calcium Gluconate",
    "D5 1/2NS",
    "Dextrose 5%",
    "Furosemide (Lasix)",
    "GT Flush",
    "Gastric Meds",
    "Heparin Sodium",
    "Hydralazine",
    "Insulin - Glargine",
    "Insulin - Humalog",
    "Insulin - Regular",
    "K Phos",
    "KCL (Bolus)",
    "LR",
    "Lorazepam (Ativan)",
    "Magnesium Sulfate (Bolus)",
    "Magnesium Sulfate",
    "Metoprolol",
    "Midazolam (Versed)",
    "Morphine Sulfate",
    "Nitroglycerin",
    "Norepinephrine",
    "OR Cell Saver Intake",
    "OR Crystalloid Intake",
    "PO Intake",
    "Packed Red Blood Cells",
    "Phenylephrine",
    "Piggyback",
    "Potassium Chloride",
    "Solution",
    "Sterile Water",
]

In [None]:
# select items with more than 5000 measurements.
item_counts = inputs["ITEMID"].value_counts()
top_items = item_counts.loc[item_counts > 5000].index
inputs = inputs.loc[inputs.ITEMID.isin(top_items)]