[PyHealth](https://pyhealth.readthedocs.io/en/latest/) is a Python library designed for healthcare data analysis, providing a
range of tools and functionalities that simplify the process of building, training, and
evaluating healthcare models.

- Here, we will use the PyHealth package to handle <u>readmission prediction</u> task on [MIMIC-IV](https://physionet.org/content/mimiciv/0.4/) datasets. The package provides a whole suite of API
modules, covering MIMIC-IV data processing, DNN model initialization, training,
and evaluation. 

- The readmission task is defined by <u>using the procedure, medication, diagnosis codes from the current visit to predict
whether the patient will be readmitted into ICU within 7 days</u>.

Let us start step by step.

## STEP1: dataset processing

The first step is to preprocess the MIMIC-IV data. The MIMIC-IV and other EHR
databases all have heterogeneous and complicated file structures. Each type of clinical
events will maintain a CSV table while these tables are connected by the unique patient
visit ID. 

Before feeding the dataset into deep learning models, researchers usually join
tables together and filter invalid records, etc. Here, we leverage
the diagnosis, procedure, and medication tables as feature entries for the readmission
prediction tasks. Lucikly, PyHealth provides the [pyhealth.datasets.MIMIC4Dataset](https://pyhealth.readthedocs.io/en/latest/api/datasets/pyhealth.datasets.MIMIC4Dataset.html) API to handle the data processing step.

Calling the MIMIC4Dataset API requires the root of MIMIC-IV database and the
data tables that we want to use for features (diagnoses, procedures, and prescriptions).

In [2]:
from pyhealth.datasets import MIMIC4Dataset

mimic4_ds = MIMIC4Dataset(
    root="/srv/local/data/physionet.org/files/mimiciv/2.0/hosp",
    tables=["diagnoses_icd", "procedures_icd", "prescriptions"],
)

mimic4_ds.stat()

FileNotFoundError: [Errno 2] No such file or directory: '/srv/local/data/physionet.org/files/mimiciv/2.0/hosp/patients.csv'

In [4]:
# query the patient level information

patients = mimic4_ds.patients
patient_id, patient_obj = list(patients.items())[2]
patient_id, patient_obj

NameError: name 'mimic4_ds' is not defined

In [3]:
# query the visit level information

visits = patient_obj.visits
visit_id, visit_obj = list(visits.items())[0] # the 0-th visit
visit_id, visit_obj

NameError: name 'patient_obj' is not defined

In [28]:
# query the event level information

diagnoses = visit_obj.get_event_list("diagnoses_icd")
diagnosis = diagnoses[0] # the 0-th diagnosis
diagnosis

Event with ICD10CM code G3183 from table diagnoses_icd

In [29]:
# query the event level information

prescriptions = visit_obj.get_event_list("prescriptions")
prescription = prescriptions[0] # the 0-th prescription
prescription

Event with NDC code 0 from table prescriptions

## STEP2: Task definition

To define the healthcare AI task, the [pyhealth.tasks](https://pyhealth.readthedocs.io/en/latest/api/tasks.html) module provides a different variety of example task functions. Each of them defines a unique healthcare
AI task. 

- Following the readmission task example (below), all task functions input
a pyhealth.data.Patient object and then output a list of samples containing the “X"
and “Y" for supervised learning. If the task is defined as a patient-level prediction task,
then the output is one sample; if it is a visit-level prediction task, then the output is a
list of samples, one for each visit.
Readmission prediction is a visit-level clinical predictive task. Given the above
structured dataset, we need to extract the diagnosis, procedure, and prescription infor-
mation as the features and compare the time gap between the current visit and next
visit to decide whether the patient will be readmitted into hospital within 7 days. If
the answer is Yes, then we set label to be 1; otherwise, the label is 0. It is a binary
classification task.
Luckily, 

PyHealth provides a default readmission prediction task with time window
as an argument (In this task, the threshold is 7 days).

In [31]:
def readmission_prediction_mimic4_fn(patient, time_window=7):
    """Processes a single patient for the readmission prediction task.
    ...
    """
    samples = []

    # we will drop the last visit since we cannot tell its label
    for i in range(len(patient) - 1):
        visit = patient[i]
        next_visit = patient[i + 1]

        # get time difference between current visit and next visit
        time_diff = (next_visit.encounter_time - visit.encounter_time).days
        readmission_label = 1 if time_diff < time_window else 0

        conditions = visit.get_code_list(table="diagnoses_icd")
        procedures = visit.get_code_list(table="procedures_icd")
        drugs = visit.get_code_list(table="prescriptions")
        # exclude: visits without condition, procedure, or drug code
        if len(conditions) * len(procedures) * len(drugs) == 0:
            continue
        samples.append(
            {
                "visit_id": visit.visit_id,
                "patient_id": patient.patient_id,
                "conditions": conditions,
                "procedures": procedures,
                "drugs": drugs,
                "label": readmission_label,
            }
        )

        # use patient or visit level information for cohort selection
        # ...
    return samples

In [32]:
# STEP 2: define the readmission prediction task
readmission_dataset = mimic4_ds.set_task(
    lambda x: readmission_prediction_mimic4_fn(x, time_window=7)
)
readmission_dataset.stat()

Generating samples for <lambda>: 100%|█████████████████████████████████████████████████████████████| 190279/190279 [00:06<00:00, 30338.08it/s]


Statistics of sample dataset:
	- Dataset: MIMIC4Dataset
	- Task: <lambda>
	- Number of samples: 132301
	- Number of patients: 59273
	- Number of visits: 132301
	- Number of visits per patient: 2.2321
	- conditions:
		- Number of conditions per sample: 13.5967
		- Number of unique conditions: 18858
		- Distribution of conditions (Top-10): [('4019', 34015), ('2724', 25635), ('53081', 18716), ('E785', 15686), ('25000', 15530), ('41401', 14924), ('4280', 14814), ('I10', 14065), ('42731', 13960), ('Z87891', 13776)]
	- procedures:
		- Number of procedures per sample: 2.7078
		- Number of unique procedures: 10220
		- Distribution of procedures (Top-10): [('3893', 8320), ('3897', 6165), ('3995', 6035), ('02HV33Z', 5966), ('8856', 5171), ('0040', 4588), ('966', 4479), ('9925', 4310), ('4513', 3395), ('5491', 3313)]
	- drugs:
		- Number of drugs per sample: 29.6466
		- Number of unique drugs: 5458
		- Distribution of drugs (Top-10): [('0', 127393), ('63323026201', 79098), ('00904224461', 70124),

"Statistics of sample dataset:\n\t- Dataset: MIMIC4Dataset\n\t- Task: <lambda>\n\t- Number of samples: 132301\n\t- Number of patients: 59273\n\t- Number of visits: 132301\n\t- Number of visits per patient: 2.2321\n\t- conditions:\n\t\t- Number of conditions per sample: 13.5967\n\t\t- Number of unique conditions: 18858\n\t\t- Distribution of conditions (Top-10): [('4019', 34015), ('2724', 25635), ('53081', 18716), ('E785', 15686), ('25000', 15530), ('41401', 14924), ('4280', 14814), ('I10', 14065), ('42731', 13960), ('Z87891', 13776)]\n\t- procedures:\n\t\t- Number of procedures per sample: 2.7078\n\t\t- Number of unique procedures: 10220\n\t\t- Distribution of procedures (Top-10): [('3893', 8320), ('3897', 6165), ('3995', 6035), ('02HV33Z', 5966), ('8856', 5171), ('0040', 4588), ('966', 4479), ('9925', 4310), ('4513', 3395), ('5491', 3313)]\n\t- drugs:\n\t\t- Number of drugs per sample: 29.6466\n\t\t- Number of unique drugs: 5458\n\t\t- Distribution of drugs (Top-10): [('0', 127393), (

In [5]:
# look at the first sample

readmission_dataset[0]

NameError: name 'readmission_dataset' is not defined

In [40]:
from pyhealth.datasets import split_by_patient, get_dataloader

# split the dataset into train/val/test
train_dataset, val_dataset, test_dataset = split_by_patient(
    readmission_dataset, [0.8, 0.1, 0.1]
)
train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)

## STEP3: initialize ML models

The [pyhealth.models](https://pyhealth.readthedocs.io/en/latest/api/models.html) module provides a large variety of ML models for users to
choose from. These methods can be broadly categorized into: (i) general deep learning
models and (ii) healthcare specific deep learning models. For healthcare specific DL
models, we are constantly implementing new ones and add them into our package.

In this section, we will use the simple three-layer multi-layer perceptron (MLP) in
PyHealth, specifying the number of layers, the number of hidden neurons in each layer,
and the activation functions. More advanced models and examples can be found in
later sections or visit the [PyHealth website](https://pyhealth.readthedocs.io/en/latest/).

In model initialization, we input the keys listed in the sample dictionary. To leverage
three types of clinical event information, we use “conditions", “procedures", “drugs"
as keywords for features, and we input the “label" as the readmission prediction label.

In [41]:
from pyhealth.models import MLP

# STEP 3: define model
model = MLP(
    dataset=readmission_dataset,
    feature_keys=["conditions", "procedures", "drugs"],
    label_key="label",
    mode="binary",
    embedding_dim=128,
    hidden_dim=128,
    n_layers=3,
    activation="relu",
)

## STEP 4: model training

In PyHealth, the [pyhealth.trainer.Trainer](https://pyhealth.readthedocs.io/en/latest/api/trainer.html) module serves as a powerful tool for
managing the training process of all models. 

This trainer simplifies logging during
training, automatically selects available devices, and conveniently saves the best model
to disk. The trainer module offers a range of selectable arguments for customization.

Let’s proceed with the programming steps outlined below.

In [43]:
from pyhealth.trainer import Trainer

# STEP 4: define trainer
trainer = Trainer(model=model)

trainer.train(
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    epochs=5,
    monitor="roc_auc",
) # model is training ...


MLP(
  (embeddings): ModuleDict(
    (conditions): Embedding(18860, 128, padding_idx=0)
    (procedures): Embedding(10222, 128, padding_idx=0)
    (drugs): Embedding(5460, 128, padding_idx=0)
  )
  (linear_layers): ModuleDict()
  (activation): ReLU()
  (mlp): ModuleDict(
    (conditions): Sequential(
      (0): Linear(in_features=128, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): ReLU()
      (4): Linear(in_features=128, out_features=128, bias=True)
    )
    (procedures): Sequential(
      (0): Linear(in_features=128, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): ReLU()
      (4): Linear(in_features=128, out_features=128, bias=True)
    )
    (drugs): Sequential(
      (0): Linear(in_features=128, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): ReLU()
      (4): 

Epoch 0 / 5:   0%|          | 0/3305 [00:00<?, ?it/s]

--- Train epoch-0, step-3305 ---
loss: 0.5563


Evaluation: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 418/418 [00:00<00:00, 472.71it/s]

--- Eval epoch-0, step-3305 ---
pr_auc: 0.7000
roc_auc: 0.7106
f1: 0.6522
loss: 0.6354
New best roc_auc score (0.7106) at epoch-0, step-3305








Epoch 1 / 5:   0%|          | 0/3305 [00:00<?, ?it/s]

--- Train epoch-1, step-6610 ---
loss: 0.5106


Evaluation: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 418/418 [00:00<00:00, 479.10it/s]

--- Eval epoch-1, step-6610 ---
pr_auc: 0.6829
roc_auc: 0.6961
f1: 0.5838
loss: 0.6740






Epoch 2 / 5:   0%|          | 0/3305 [00:00<?, ?it/s]

--- Train epoch-2, step-9915 ---
loss: 0.4470


Evaluation: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 418/418 [00:00<00:00, 480.47it/s]

--- Eval epoch-2, step-9915 ---
pr_auc: 0.6641
roc_auc: 0.6779
f1: 0.5818
loss: 0.7736






Epoch 3 / 5:   0%|          | 0/3305 [00:00<?, ?it/s]

--- Train epoch-3, step-13220 ---
loss: 0.3718


Evaluation: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 418/418 [00:00<00:00, 435.57it/s]

--- Eval epoch-3, step-13220 ---
pr_auc: 0.6505
roc_auc: 0.6643
f1: 0.5828
loss: 0.8855






Epoch 4 / 5:   0%|          | 0/3305 [00:00<?, ?it/s]

--- Train epoch-4, step-16525 ---
loss: 0.2942


Evaluation: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 418/418 [00:00<00:00, 469.68it/s]

--- Eval epoch-4, step-16525 ---
pr_auc: 0.6489
roc_auc: 0.6610
f1: 0.6060
loss: 1.1949
Loaded best model





## STEP5: model evaluation

Following model training, we offer two evaluation methods for the optimized deep
learning model through the [pyhealth.metrics](https://pyhealth.readthedocs.io/en/latest/api/metrics.html). 

PyHealth offers an array
of metrics suitable for assessing binary classification, multi-class classification, and
multi-label classification problems, alongside fairness metrics and advanced model
calibration metrics.

In [45]:
# method 1
result = trainer.evaluate(test_dataloader)
print (result)

# method 2
from pyhealth.metrics.binary import binary_metrics_fn

y_true, y_prob, loss = trainer.inference(test_dataloader)
binary_metrics_fn(
    y_true,
    y_prob,
    metrics=["pr_auc", "roc_auc", "f1"]
)

Evaluation: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 413/413 [00:00<00:00, 477.67it/s]


{'pr_auc': 0.6920530474148421, 'roc_auc': 0.7015410673568111, 'f1': 0.6474005660658425, 'loss': 0.6424725677267835}


Evaluation: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 413/413 [00:00<00:00, 477.06it/s]


{'pr_auc': 0.6920530474148421,
 'roc_auc': 0.7015410673568111,
 'f1': 0.6474005660658425}