[PyHealth](https://pyhealth.readthedocs.io/en/latest/) is a Python library designed for healthcare data analysis, providing a
range of tools and functionalities that simplify the process of building, training, and
evaluating healthcare models.

- Here, we will use the PyHealth package to handle <u>readmission prediction</u> task on [MIMIC-IV-demo](https://physionet.org/content/mimic-iv-demo/2.2/) datasets. The package provides a whole suite of API
modules, covering MIMIC-IV data processing, DNN model initialization, training,
and evaluation. 

- The readmission task is defined by <u>using the procedure, medication, diagnosis codes from the current visit to predict
whether the patient will be readmitted into ICU within 7 days</u>.

Let us start step by step.

## STEP1: dataset processing

The first step is to preprocess the MIMIC-IV-demo data. Since we are unable to distribute the original MIMIC-IV dataset, we use the [MIMIC-IV-demo dataset](https://physionet.org/content/mimic-iv-demo/2.2/) as the substitute here.

The MIMIC-IV and other EHR
databases all have heterogeneous and complicated file structures. Each type of clinical
events will maintain a CSV table while these tables are connected by the unique patient
visit ID. 

Before feeding the dataset into deep learning models, researchers usually join
tables together and filter invalid records, etc. Here, we leverage
the diagnosis, procedure, and medication tables as feature entries for the readmission
prediction tasks. Lucikly, PyHealth provides the [pyhealth.datasets.MIMIC4Dataset](https://pyhealth.readthedocs.io/en/latest/api/datasets/pyhealth.datasets.MIMIC4Dataset.html) API to handle the data processing step.

Calling the MIMIC4Dataset API requires the root of MIMIC-IV-demo database and the
data tables that we want to use for features (diagnoses, procedures, and prescriptions).

In [2]:
from pyhealth.datasets import MIMIC4Dataset

mimic4_ds = MIMIC4Dataset(
    # if you have the access to MIMIC-IV dataset, please replace the following root
    root="https://storage.googleapis.com/pyhealth/mimiciv-demo/hosp",
    tables=["diagnoses_icd", "procedures_icd", "prescriptions"],
)

mimic4_ds.stat()

  warn(f"Failed to load image Python extension: {e}")


INFO: Pandarallel will run on 64 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.
finish basic patient information parsing : 1.3815855979919434s
finish parsing diagnoses_icd : 1.0474581718444824s
finish parsing procedures_icd : 1.006455659866333s
finish parsing prescriptions : 1.826772928237915s


Mapping codes: 100%|███████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 5759.82it/s]


Statistics of base dataset (dev=False):
	- Dataset: MIMIC4Dataset
	- Number of patients: 100
	- Number of visits: 275
	- Number of visits per patient: 2.7500
	- Number of events per visit in diagnoses_icd: 16.3855
	- Number of events per visit in procedures_icd: 2.6255
	- Number of events per visit in prescriptions: 65.6945






'\nStatistics of base dataset (dev=False):\n\t- Dataset: MIMIC4Dataset\n\t- Number of patients: 100\n\t- Number of visits: 275\n\t- Number of visits per patient: 2.7500\n\t- Number of events per visit in diagnoses_icd: 16.3855\n\t- Number of events per visit in procedures_icd: 2.6255\n\t- Number of events per visit in prescriptions: 65.6945\n'

In [3]:
# query the patient level information

patients = mimic4_ds.patients
patient_id, patient_obj = list(patients.items())[2]
patient_id, patient_obj

('10001725', Patient 10001725 with 1 visits)

In [4]:
# query the visit level information

visits = patient_obj.visits
visit_id, visit_obj = list(visits.items())[0] # the 0-th visit
visit_id, visit_obj

('25563031',
 Visit 25563031 from patient 10001725 with 91 events from tables ['diagnoses_icd', 'procedures_icd', 'prescriptions'])

In [5]:
# query the event level information

diagnoses = visit_obj.get_event_list("diagnoses_icd")
diagnosis = diagnoses[0] # the 0-th diagnosis
diagnosis

Event with ICD9CM code 78829 from table diagnoses_icd

In [6]:
# query the event level information

prescriptions = visit_obj.get_event_list("prescriptions")
prescription = prescriptions[0] # the 0-th prescription
prescription

Event with NDC code 68084008901 from table prescriptions

## STEP2: Task definition

To define the healthcare AI task, the [pyhealth.tasks](https://pyhealth.readthedocs.io/en/latest/api/tasks.html) module provides a different variety of example task functions. Each of them defines a unique healthcare
AI task. 

- Following the readmission task example (below), all task functions input
a pyhealth.data.Patient object and then output a list of samples containing the “X"
and “Y" for supervised learning. If the task is defined as a patient-level prediction task,
then the output is one sample; if it is a visit-level prediction task, then the output is a
list of samples, one for each visit.
Readmission prediction is a visit-level clinical predictive task. Given the above
structured dataset, we need to extract the diagnosis, procedure, and prescription infor-
mation as the features and compare the time gap between the current visit and next
visit to decide whether the patient will be readmitted into hospital within 7 days. If
the answer is Yes, then we set label to be 1; otherwise, the label is 0. It is a binary
classification task.
Luckily, 

PyHealth provides a default readmission prediction task with time window
as an argument (In this task, the threshold is 7 days).

In [7]:
def readmission_prediction_mimic4_fn(patient, time_window=7):
    """Processes a single patient for the readmission prediction task.
    ...
    """
    samples = []

    # we will drop the last visit since we cannot tell its label
    for i in range(len(patient) - 1):
        visit = patient[i]
        next_visit = patient[i + 1]

        # get time difference between current visit and next visit
        time_diff = (next_visit.encounter_time - visit.encounter_time).days
        readmission_label = 1 if time_diff < time_window else 0

        conditions = visit.get_code_list(table="diagnoses_icd")
        procedures = visit.get_code_list(table="procedures_icd")
        drugs = visit.get_code_list(table="prescriptions")
        # exclude: visits without condition, procedure, or drug code
        if len(conditions) * len(procedures) * len(drugs) == 0:
            continue
        samples.append(
            {
                "visit_id": visit.visit_id,
                "patient_id": patient.patient_id,
                "conditions": conditions,
                "procedures": procedures,
                "drugs": drugs,
                "label": readmission_label,
            }
        )

        # use patient or visit level information for cohort selection
        # ...
    return samples

In [8]:
# STEP 2: define the readmission prediction task
readmission_dataset = mimic4_ds.set_task(
    lambda x: readmission_prediction_mimic4_fn(x, time_window=7)
)
readmission_dataset.stat()

Generating samples for <lambda>: 100%|████████████████████████████████████████████| 100/100 [00:00<00:00, 14952.42it/s]

Statistics of sample dataset:
	- Dataset: MIMIC4Dataset
	- Task: <lambda>
	- Number of samples: 104
	- Number of patients: 42
	- Number of visits: 104
	- Number of visits per patient: 2.4762
	- conditions:
		- Number of conditions per sample: 17.6827
		- Number of unique conditions: 843
		- Distribution of conditions (Top-10): [('4019', 26), ('E039', 21), ('E785', 20), ('2724', 18), ('Z87891', 18), ('V1582', 15), ('F329', 15), ('Z794', 15), ('N179', 14), ('2859', 12)]
	- procedures:
		- Number of procedures per sample: 3.3077
		- Number of unique procedures: 199
		- Distribution of procedures (Top-10): [('3897', 15), ('02HV33Z', 14), ('966', 10), ('5491', 7), ('9671', 7), ('3893', 7), ('5A1D70Z', 7), ('5A1945Z', 6), ('3E0G76Z', 6), ('9604', 5)]
	- drugs:
		- Number of drugs per sample: 38.2500
		- Number of unique drugs: 928
		- Distribution of drugs (Top-10): [('0', 104), ('63323026201', 74), ('00409490234', 56), ('00338004904', 53), ('00904224461', 51), ('00904516561', 49), ('3839605




"Statistics of sample dataset:\n\t- Dataset: MIMIC4Dataset\n\t- Task: <lambda>\n\t- Number of samples: 104\n\t- Number of patients: 42\n\t- Number of visits: 104\n\t- Number of visits per patient: 2.4762\n\t- conditions:\n\t\t- Number of conditions per sample: 17.6827\n\t\t- Number of unique conditions: 843\n\t\t- Distribution of conditions (Top-10): [('4019', 26), ('E039', 21), ('E785', 20), ('2724', 18), ('Z87891', 18), ('V1582', 15), ('F329', 15), ('Z794', 15), ('N179', 14), ('2859', 12)]\n\t- procedures:\n\t\t- Number of procedures per sample: 3.3077\n\t\t- Number of unique procedures: 199\n\t\t- Distribution of procedures (Top-10): [('3897', 15), ('02HV33Z', 14), ('966', 10), ('5491', 7), ('9671', 7), ('3893', 7), ('5A1D70Z', 7), ('5A1945Z', 6), ('3E0G76Z', 6), ('9604', 5)]\n\t- drugs:\n\t\t- Number of drugs per sample: 38.2500\n\t\t- Number of unique drugs: 928\n\t\t- Distribution of drugs (Top-10): [('0', 104), ('63323026201', 74), ('00409490234', 56), ('00338004904', 53), ('009

In [9]:
# look at the first sample

readmission_dataset[0]

{'visit_id': '22595853',
 'patient_id': '10000032',
 'conditions': ['5723',
  '78959',
  '5715',
  '07070',
  '496',
  '29680',
  '30981',
  'V1582'],
 'procedures': ['5491'],
 'drugs': ['63323026201',
  '19515089452',
  '0',
  '00245004101',
  '51079007220',
  '63739054410',
  '00904198861',
  '00135019502',
  '00006022761',
  '61958070101',
  '00173068224',
  '00487980125',
  '51079007320'],
 'label': 0}

In [16]:
from pyhealth.datasets import split_by_patient, get_dataloader

# split the dataset into train/val/test
train_dataset, val_dataset, test_dataset = split_by_patient(
    readmission_dataset, [0.6, 0.2, 0.2]
)
train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)

## STEP3: initialize ML models

The [pyhealth.models](https://pyhealth.readthedocs.io/en/latest/api/models.html) module provides a large variety of ML models for users to
choose from. These methods can be broadly categorized into: (i) general deep learning
models and (ii) healthcare specific deep learning models. For healthcare specific DL
models, we are constantly implementing new ones and add them into our package.

In this section, we will use the simple three-layer multi-layer perceptron (MLP) in
PyHealth, specifying the number of layers, the number of hidden neurons in each layer,
and the activation functions. More advanced models and examples can be found in
later sections or visit the [PyHealth website](https://pyhealth.readthedocs.io/en/latest/).

In model initialization, we input the keys listed in the sample dictionary. To leverage
three types of clinical event information, we use “conditions", “procedures", “drugs"
as keywords for features, and we input the “label" as the readmission prediction label.

In [21]:
from pyhealth.models import MLP

# STEP 3: define model
model = MLP(
    dataset=readmission_dataset,
    feature_keys=["conditions", "procedures", "drugs"],
    label_key="label",
    mode="binary",
    embedding_dim=16,
    hidden_dim=16,
    n_layers=2,
    activation="relu",
)

## STEP 4: model training

In PyHealth, the [pyhealth.trainer.Trainer](https://pyhealth.readthedocs.io/en/latest/api/trainer.html) module serves as a powerful tool for
managing the training process of all models. 

This trainer simplifies logging during
training, automatically selects available devices, and conveniently saves the best model
to disk. The trainer module offers a range of selectable arguments for customization.

Let’s proceed with the programming steps outlined below.

In [22]:
from pyhealth.trainer import Trainer

# STEP 4: define trainer
trainer = Trainer(model=model)

trainer.train(
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    epochs=50,
    monitor="roc_auc",
) # model is training ...


MLP(
  (embeddings): ModuleDict(
    (conditions): Embedding(845, 16, padding_idx=0)
    (procedures): Embedding(201, 16, padding_idx=0)
    (drugs): Embedding(930, 16, padding_idx=0)
  )
  (linear_layers): ModuleDict()
  (activation): ReLU()
  (mlp): ModuleDict(
    (conditions): Sequential(
      (0): Linear(in_features=16, out_features=16, bias=True)
      (1): ReLU()
      (2): Linear(in_features=16, out_features=16, bias=True)
    )
    (procedures): Sequential(
      (0): Linear(in_features=16, out_features=16, bias=True)
      (1): ReLU()
      (2): Linear(in_features=16, out_features=16, bias=True)
    )
    (drugs): Sequential(
      (0): Linear(in_features=16, out_features=16, bias=True)
      (1): ReLU()
      (2): Linear(in_features=16, out_features=16, bias=True)
    )
  )
  (fc): Linear(in_features=48, out_features=1, bias=True)
)
Metrics: None
Device: cuda

Training:
Batch size: 32
Optimizer: <class 'torch.optim.adam.Adam'>
Optimizer params: {'lr': 0.001}
Weight decay: 0

Epoch 0 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-0, step-2 ---
loss: 0.6956


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 217.30it/s]

--- Eval epoch-0, step-2 ---
pr_auc: 0.5418
roc_auc: 0.5714
f1: 0.5333
loss: 0.6931
New best roc_auc score (0.5714) at epoch-0, step-2






Epoch 1 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-1, step-4 ---
loss: 0.6926


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 123.83it/s]

--- Eval epoch-1, step-4 ---
pr_auc: 0.5180
roc_auc: 0.5556
f1: 0.5000
loss: 0.6916






Epoch 2 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-2, step-6 ---
loss: 0.6891


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 118.17it/s]

--- Eval epoch-2, step-6 ---
pr_auc: 0.5180
roc_auc: 0.5556
f1: 0.2000
loss: 0.6904






Epoch 3 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-3, step-8 ---
loss: 0.6867


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 140.61it/s]

--- Eval epoch-3, step-8 ---
pr_auc: 0.5156
roc_auc: 0.5556
f1: 0.0000
loss: 0.6894






Epoch 4 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-4, step-10 ---
loss: 0.6842


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 223.85it/s]

--- Eval epoch-4, step-10 ---
pr_auc: 0.5221
roc_auc: 0.5714
f1: 0.0000
loss: 0.6884






Epoch 5 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-5, step-12 ---
loss: 0.6811


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 221.36it/s]

--- Eval epoch-5, step-12 ---
pr_auc: 0.4982
roc_auc: 0.5556
f1: 0.0000
loss: 0.6877






Epoch 6 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-6, step-14 ---
loss: 0.6789


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 188.06it/s]


--- Eval epoch-6, step-14 ---
pr_auc: 0.4880
roc_auc: 0.5397
f1: 0.0000
loss: 0.6871



Epoch 7 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-7, step-16 ---
loss: 0.6767


Evaluation: 100%|████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 92.04it/s]


--- Eval epoch-7, step-16 ---
pr_auc: 0.4960
roc_auc: 0.5556
f1: 0.0000
loss: 0.6865



Epoch 8 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-8, step-18 ---
loss: 0.6748


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 110.29it/s]

--- Eval epoch-8, step-18 ---
pr_auc: 0.5062
roc_auc: 0.5714
f1: 0.0000
loss: 0.6861






Epoch 9 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-9, step-20 ---
loss: 0.6742


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 230.32it/s]

--- Eval epoch-9, step-20 ---
pr_auc: 0.5062
roc_auc: 0.5714





f1: 0.0000
loss: 0.6857



Epoch 10 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-10, step-22 ---
loss: 0.6709


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 139.70it/s]


--- Eval epoch-10, step-22 ---
pr_auc: 0.5062
roc_auc: 0.5714
f1: 0.0000
loss: 0.6854



Epoch 11 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-11, step-24 ---
loss: 0.6686


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 122.49it/s]

--- Eval epoch-11, step-24 ---
pr_auc: 0.5300
roc_auc: 0.5873
f1: 0.0000
loss: 0.6851
New best roc_auc score (0.5873) at epoch-11, step-24






Epoch 12 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-12, step-26 ---
loss: 0.6656


Evaluation: 100%|████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 84.68it/s]

--- Eval epoch-12, step-26 ---
pr_auc: 0.5300
roc_auc: 0.5873
f1: 0.0000
loss: 0.6849






Epoch 13 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-13, step-28 ---
loss: 0.6633


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 237.18it/s]


--- Eval epoch-13, step-28 ---
pr_auc: 0.5157
roc_auc: 0.5714
f1: 0.0000
loss: 0.6846



Epoch 14 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-14, step-30 ---
loss: 0.6607


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 291.01it/s]


--- Eval epoch-14, step-30 ---
pr_auc: 0.5157
roc_auc: 0.5714
f1: 0.0000
loss: 0.6844



Epoch 15 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-15, step-32 ---
loss: 0.6594


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 155.01it/s]

--- Eval epoch-15, step-32 ---
pr_auc: 0.5157
roc_auc: 0.5714
f1: 0.0000
loss: 0.6841






Epoch 16 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-16, step-34 ---
loss: 0.6579


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 225.13it/s]

--- Eval epoch-16, step-34 ---
pr_auc: 0.5395
roc_auc: 0.5873
f1: 0.0000
loss: 0.6838






Epoch 17 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-17, step-36 ---
loss: 0.6559


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 229.91it/s]

--- Eval epoch-17, step-36 ---
pr_auc: 0.5395
roc_auc: 0.5873
f1: 0.0000
loss: 0.6836






Epoch 18 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-18, step-38 ---
loss: 0.6508


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 137.01it/s]

--- Eval epoch-18, step-38 ---
pr_auc: 0.5395
roc_auc: 0.5873
f1: 0.0000
loss: 0.6833






Epoch 19 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-19, step-40 ---
loss: 0.6500


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 231.93it/s]

--- Eval epoch-19, step-40 ---
pr_auc: 0.5395
roc_auc: 0.5873
f1: 0.0000
loss: 0.6830






Epoch 20 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-20, step-42 ---
loss: 0.6467


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 149.18it/s]

--- Eval epoch-20, step-42 ---
pr_auc: 0.5330
roc_auc: 0.5714
f1: 0.0000
loss: 0.6826






Epoch 21 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-21, step-44 ---
loss: 0.6431


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 321.82it/s]

--- Eval epoch-21, step-44 ---
pr_auc: 0.5330
roc_auc: 0.5714
f1: 0.0000
loss: 0.6822






Epoch 22 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-22, step-46 ---
loss: 0.6403


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 158.04it/s]

--- Eval epoch-22, step-46 ---
pr_auc: 0.5473
roc_auc: 0.5873
f1: 0.0000
loss: 0.6817






Epoch 23 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-23, step-48 ---
loss: 0.6385


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 187.05it/s]

--- Eval epoch-23, step-48 ---
pr_auc: 0.5473
roc_auc: 0.5873
f1: 0.0000
loss: 0.6812






Epoch 24 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-24, step-50 ---
loss: 0.6349


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 144.79it/s]

--- Eval epoch-24, step-50 ---
pr_auc: 0.5371
roc_auc: 0.5714
f1: 0.0000
loss: 0.6808






Epoch 25 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-25, step-52 ---
loss: 0.6283


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 143.58it/s]

--- Eval epoch-25, step-52 ---
pr_auc: 0.5371
roc_auc: 0.5714
f1: 0.0000
loss: 0.6804






Epoch 26 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-26, step-54 ---
loss: 0.6234


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 112.50it/s]

--- Eval epoch-26, step-54 ---
pr_auc: 0.5413
roc_auc: 0.5873
f1: 0.0000
loss: 0.6798






Epoch 27 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-27, step-56 ---
loss: 0.6202


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 137.13it/s]

--- Eval epoch-27, step-56 ---
pr_auc: 0.5413
roc_auc: 0.5873
f1: 0.0000
loss: 0.6793






Epoch 28 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-28, step-58 ---
loss: 0.6140


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 148.93it/s]

--- Eval epoch-28, step-58 ---
pr_auc: 0.5413
roc_auc: 0.5873
f1: 0.0000
loss: 0.6787






Epoch 29 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-29, step-60 ---
loss: 0.6084


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 114.46it/s]

--- Eval epoch-29, step-60 ---
pr_auc: 0.5095
roc_auc: 0.5556
f1: 0.0000
loss: 0.6781






Epoch 30 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-30, step-62 ---
loss: 0.6035


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 147.19it/s]

--- Eval epoch-30, step-62 ---
pr_auc: 0.5063
roc_auc: 0.5556
f1: 0.0000
loss: 0.6775






Epoch 31 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-31, step-64 ---
loss: 0.5978


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 117.04it/s]

--- Eval epoch-31, step-64 ---
pr_auc: 0.5063





roc_auc: 0.5556
f1: 0.0000
loss: 0.6768



Epoch 32 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-32, step-66 ---
loss: 0.5906


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 104.32it/s]


--- Eval epoch-32, step-66 ---
pr_auc: 0.5063
roc_auc: 0.5556
f1: 0.0000
loss: 0.6761



Epoch 33 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-33, step-68 ---
loss: 0.5878


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 455.26it/s]

--- Eval epoch-33, step-68 ---
pr_auc: 0.5302
roc_auc: 0.5714
f1: 0.0000
loss: 0.6755






Epoch 34 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-34, step-70 ---
loss: 0.5796


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 458.90it/s]

--- Eval epoch-34, step-70 ---
pr_auc: 0.5302
roc_auc: 0.5714
f1: 0.0000
loss: 0.6746






Epoch 35 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-35, step-72 ---
loss: 0.5706


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 341.86it/s]


--- Eval epoch-35, step-72 ---
pr_auc: 0.5302
roc_auc: 0.5714
f1: 0.0000
loss: 0.6738



Epoch 36 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-36, step-74 ---
loss: 0.5630


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 238.52it/s]

--- Eval epoch-36, step-74 ---
pr_auc: 0.5302
roc_auc: 0.5714
f1: 0.0000
loss: 0.6730






Epoch 37 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-37, step-76 ---
loss: 0.5550


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 274.42it/s]

--- Eval epoch-37, step-76 ---
pr_auc: 0.5237
roc_auc: 0.5556
f1: 0.4000
loss: 0.6725






Epoch 38 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-38, step-78 ---
loss: 0.5472


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 249.42it/s]

--- Eval epoch-38, step-78 ---
pr_auc: 0.5237
roc_auc: 0.5556
f1: 0.3636
loss: 0.6718






Epoch 39 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-39, step-80 ---
loss: 0.5377


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 113.52it/s]

--- Eval epoch-39, step-80 ---
pr_auc: 0.5237
roc_auc: 0.5556
f1: 0.3636
loss: 0.6713






Epoch 40 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-40, step-82 ---
loss: 0.5259


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 241.14it/s]

--- Eval epoch-40, step-82 ---
pr_auc: 0.5237
roc_auc: 0.5556
f1: 0.5000
loss: 0.6707






Epoch 41 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-41, step-84 ---
loss: 0.5167


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 243.20it/s]

--- Eval epoch-41, step-84 ---
pr_auc: 0.5292
roc_auc: 0.5714
f1: 0.5000
loss: 0.6703






Epoch 42 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-42, step-86 ---
loss: 0.5077


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 347.44it/s]

--- Eval epoch-42, step-86 ---
pr_auc: 0.5292
roc_auc: 0.5714
f1: 0.5000
loss: 0.6701






Epoch 43 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-43, step-88 ---
loss: 0.4962


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 438.00it/s]

--- Eval epoch-43, step-88 ---
pr_auc: 0.5292
roc_auc: 0.5714
f1: 0.5000
loss: 0.6700






Epoch 44 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-44, step-90 ---
loss: 0.4850


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 204.12it/s]


--- Eval epoch-44, step-90 ---
pr_auc: 0.5292
roc_auc: 0.5714
f1: 0.5000
loss: 0.6702



Epoch 45 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-45, step-92 ---
loss: 0.4731


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 305.75it/s]

--- Eval epoch-45, step-92 ---
pr_auc: 0.5054
roc_auc: 0.5556
f1: 0.5000
loss: 0.6705








Epoch 46 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-46, step-94 ---
loss: 0.4581


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 333.41it/s]

--- Eval epoch-46, step-94 ---
pr_auc: 0.5054
roc_auc: 0.5556
f1: 0.5000
loss: 0.6712








Epoch 47 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-47, step-96 ---
loss: 0.4476


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 322.37it/s]

--- Eval epoch-47, step-96 ---
pr_auc: 0.4999
roc_auc: 0.5397
f1: 0.5000
loss: 0.6724






Epoch 48 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-48, step-98 ---
loss: 0.4350


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 158.69it/s]

--- Eval epoch-48, step-98 ---
pr_auc: 0.4999
roc_auc: 0.5397
f1: 0.5000
loss: 0.6736






Epoch 49 / 50:   0%|          | 0/2 [00:00<?, ?it/s]

--- Train epoch-49, step-100 ---
loss: 0.4221


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 233.82it/s]

--- Eval epoch-49, step-100 ---
pr_auc: 0.4999
roc_auc: 0.5397
f1: 0.5000
loss: 0.6751
Loaded best model





## STEP5: model evaluation

Following model training, we offer two evaluation methods for the optimized deep
learning model through the [pyhealth.metrics](https://pyhealth.readthedocs.io/en/latest/api/metrics.html). 

PyHealth offers an array
of metrics suitable for assessing binary classification, multi-class classification, and
multi-label classification problems, alongside fairness metrics and advanced model
calibration metrics.

In [24]:
# method 1
result = trainer.evaluate(test_dataloader)
print (result)

# method 2
from pyhealth.metrics.binary import binary_metrics_fn

y_true, y_prob, loss = trainer.inference(test_dataloader)
binary_metrics_fn(
    y_true,
    y_prob,
    metrics=["pr_auc", "roc_auc", "f1"]
)

Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 205.15it/s]


{'pr_auc': 0.5818931986585775, 'roc_auc': 0.39772727272727276, 'f1': 0.0, 'loss': 0.7247440218925476}


Evaluation: 100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 155.38it/s]


{'pr_auc': 0.5818931986585775, 'roc_auc': 0.39772727272727276, 'f1': 0.0}

- Note that the performance is bad on MIMIC-IV-demo dataset since it only contains 100 patients, which makes the model overfit. You could try to download the original MIMIC-IV dataset and re-run the pipeline.