## STEP1: dataset processing

In [2]:
from pyhealth.datasets import MIMIC4Dataset

mimic4_ds = MIMIC4Dataset(
    root="/srv/local/data/physionet.org/files/mimiciv/2.0/hosp",
    tables=["diagnoses_icd", "procedures_icd", "prescriptions"],
)

mimic4_ds.stat()

INFO: Pandarallel will run on 64 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.
finish basic patient information parsing : 26.628898859024048s
finish parsing diagnoses_icd : 59.32148051261902s
finish parsing procedures_icd : 18.32806658744812s
finish parsing prescriptions : 267.0019724369049s


Mapping codes: 100%|███████████████████████████████████████████████████████████████████████████████| 190279/190279 [00:12<00:00, 15554.53it/s]



Statistics of base dataset (dev=False):
	- Dataset: MIMIC4Dataset
	- Number of patients: 190279
	- Number of visits: 454324
	- Number of visits per patient: 2.3877
	- Number of events per visit in diagnoses_icd: 11.0205
	- Number of events per visit in procedures_icd: 1.5498
	- Number of events per visit in prescriptions: 35.6424



'\nStatistics of base dataset (dev=False):\n\t- Dataset: MIMIC4Dataset\n\t- Number of patients: 190279\n\t- Number of visits: 454324\n\t- Number of visits per patient: 2.3877\n\t- Number of events per visit in diagnoses_icd: 11.0205\n\t- Number of events per visit in procedures_icd: 1.5498\n\t- Number of events per visit in prescriptions: 35.6424\n'

In [6]:
patients = mimic4_ds.patients
patient_id, patient_obj = list(patients.items())[2]
patient_id, patient_obj

('10000084', Patient 10000084 with 2 visits)

In [27]:
visits = patient_obj.visits
visit_id, visit_obj = list(visits.items())[0] # the 0-th visit
visit_id, visit_obj

('23052089',
 Visit 23052089 from patient 10000084 with 19 events from tables ['diagnoses_icd', 'prescriptions'])

In [28]:
diagnoses = visit_obj.get_event_list("diagnoses_icd")
diagnosis = diagnoses[0] # the 0-th diagnosis
diagnosis

Event with ICD10CM code G3183 from table diagnoses_icd

In [29]:
prescriptions = visit_obj.get_event_list("prescriptions")
prescription = prescriptions[0] # the 0-th prescription
prescription

Event with NDC code 0 from table prescriptions

## STEP2: Task definition

In [31]:
def readmission_prediction_mimic4_fn(patient, time_window=7):
    """Processes a single patient for the readmission prediction task.
    ...
    """
    samples = []

    # we will drop the last visit since we cannot tell its label
    for i in range(len(patient) - 1):
        visit = patient[i]
        next_visit = patient[i + 1]

        # get time difference between current visit and next visit
        time_diff = (next_visit.encounter_time - visit.encounter_time).days
        readmission_label = 1 if time_diff < time_window else 0

        conditions = visit.get_code_list(table="diagnoses_icd")
        procedures = visit.get_code_list(table="procedures_icd")
        drugs = visit.get_code_list(table="prescriptions")
        # exclude: visits without condition, procedure, or drug code
        if len(conditions) * len(procedures) * len(drugs) == 0:
            continue
        samples.append(
            {
                "visit_id": visit.visit_id,
                "patient_id": patient.patient_id,
                "conditions": conditions,
                "procedures": procedures,
                "drugs": drugs,
                "label": readmission_label,
            }
        )

        # use patient or visit level information for cohort selection
        # ...
    return samples

In [32]:
# STEP 2: define the readmission prediction task
readmission_dataset = mimic4_ds.set_task(
    lambda x: readmission_prediction_mimic4_fn(x, time_window=7)
)
readmission_dataset.stat()

Generating samples for <lambda>: 100%|█████████████████████████████████████████████████████████████| 190279/190279 [00:06<00:00, 30338.08it/s]


Statistics of sample dataset:
	- Dataset: MIMIC4Dataset
	- Task: <lambda>
	- Number of samples: 132301
	- Number of patients: 59273
	- Number of visits: 132301
	- Number of visits per patient: 2.2321
	- conditions:
		- Number of conditions per sample: 13.5967
		- Number of unique conditions: 18858
		- Distribution of conditions (Top-10): [('4019', 34015), ('2724', 25635), ('53081', 18716), ('E785', 15686), ('25000', 15530), ('41401', 14924), ('4280', 14814), ('I10', 14065), ('42731', 13960), ('Z87891', 13776)]
	- procedures:
		- Number of procedures per sample: 2.7078
		- Number of unique procedures: 10220
		- Distribution of procedures (Top-10): [('3893', 8320), ('3897', 6165), ('3995', 6035), ('02HV33Z', 5966), ('8856', 5171), ('0040', 4588), ('966', 4479), ('9925', 4310), ('4513', 3395), ('5491', 3313)]
	- drugs:
		- Number of drugs per sample: 29.6466
		- Number of unique drugs: 5458
		- Distribution of drugs (Top-10): [('0', 127393), ('63323026201', 79098), ('00904224461', 70124),

"Statistics of sample dataset:\n\t- Dataset: MIMIC4Dataset\n\t- Task: <lambda>\n\t- Number of samples: 132301\n\t- Number of patients: 59273\n\t- Number of visits: 132301\n\t- Number of visits per patient: 2.2321\n\t- conditions:\n\t\t- Number of conditions per sample: 13.5967\n\t\t- Number of unique conditions: 18858\n\t\t- Distribution of conditions (Top-10): [('4019', 34015), ('2724', 25635), ('53081', 18716), ('E785', 15686), ('25000', 15530), ('41401', 14924), ('4280', 14814), ('I10', 14065), ('42731', 13960), ('Z87891', 13776)]\n\t- procedures:\n\t\t- Number of procedures per sample: 2.7078\n\t\t- Number of unique procedures: 10220\n\t\t- Distribution of procedures (Top-10): [('3893', 8320), ('3897', 6165), ('3995', 6035), ('02HV33Z', 5966), ('8856', 5171), ('0040', 4588), ('966', 4479), ('9925', 4310), ('4513', 3395), ('5491', 3313)]\n\t- drugs:\n\t\t- Number of drugs per sample: 29.6466\n\t\t- Number of unique drugs: 5458\n\t\t- Distribution of drugs (Top-10): [('0', 127393), (

In [33]:
readmission_dataset[0]

{'visit_id': '22595853',
 'patient_id': '10000032',
 'conditions': ['5723',
  '78959',
  '5715',
  '07070',
  '496',
  '29680',
  '30981',
  'V1582'],
 'procedures': ['5491'],
 'drugs': ['0',
  '63323026201',
  '19515089452',
  '00245004101',
  '63739054410',
  '51079007220',
  '00904198861',
  '00006022761',
  '00173068224',
  '61958070101',
  '00135019502',
  '00487980125',
  '51079007320'],
 'label': 0}

In [40]:
from pyhealth.datasets import split_by_patient, get_dataloader

# split the dataset into train/val/test
train_dataset, val_dataset, test_dataset = split_by_patient(
    readmission_dataset, [0.8, 0.1, 0.1]
)
train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)

## STEP3: initialize ML models

In [41]:
from pyhealth.models import MLP

# STEP 3: define model
model = MLP(
    dataset=readmission_dataset,
    feature_keys=["conditions", "procedures", "drugs"],
    label_key="label",
    mode="binary",
    embedding_dim=128,
    hidden_dim=128,
    n_layers=3,
    activation="relu",
)

## STEP 4: model training

In [43]:
from pyhealth.trainer import Trainer

# STEP 4: define trainer
trainer = Trainer(model=model)

trainer.train(
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    epochs=5,
    monitor="roc_auc",
) # model is training ...


MLP(
  (embeddings): ModuleDict(
    (conditions): Embedding(18860, 128, padding_idx=0)
    (procedures): Embedding(10222, 128, padding_idx=0)
    (drugs): Embedding(5460, 128, padding_idx=0)
  )
  (linear_layers): ModuleDict()
  (activation): ReLU()
  (mlp): ModuleDict(
    (conditions): Sequential(
      (0): Linear(in_features=128, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): ReLU()
      (4): Linear(in_features=128, out_features=128, bias=True)
    )
    (procedures): Sequential(
      (0): Linear(in_features=128, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): ReLU()
      (4): Linear(in_features=128, out_features=128, bias=True)
    )
    (drugs): Sequential(
      (0): Linear(in_features=128, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): ReLU()
      (4): 

Epoch 0 / 5:   0%|          | 0/3305 [00:00<?, ?it/s]

--- Train epoch-0, step-3305 ---
loss: 0.5563


Evaluation: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 418/418 [00:00<00:00, 472.71it/s]

--- Eval epoch-0, step-3305 ---
pr_auc: 0.7000
roc_auc: 0.7106
f1: 0.6522
loss: 0.6354
New best roc_auc score (0.7106) at epoch-0, step-3305








Epoch 1 / 5:   0%|          | 0/3305 [00:00<?, ?it/s]

--- Train epoch-1, step-6610 ---
loss: 0.5106


Evaluation: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 418/418 [00:00<00:00, 479.10it/s]

--- Eval epoch-1, step-6610 ---
pr_auc: 0.6829
roc_auc: 0.6961
f1: 0.5838
loss: 0.6740






Epoch 2 / 5:   0%|          | 0/3305 [00:00<?, ?it/s]

--- Train epoch-2, step-9915 ---
loss: 0.4470


Evaluation: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 418/418 [00:00<00:00, 480.47it/s]

--- Eval epoch-2, step-9915 ---
pr_auc: 0.6641
roc_auc: 0.6779
f1: 0.5818
loss: 0.7736






Epoch 3 / 5:   0%|          | 0/3305 [00:00<?, ?it/s]

--- Train epoch-3, step-13220 ---
loss: 0.3718


Evaluation: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 418/418 [00:00<00:00, 435.57it/s]

--- Eval epoch-3, step-13220 ---
pr_auc: 0.6505
roc_auc: 0.6643
f1: 0.5828
loss: 0.8855






Epoch 4 / 5:   0%|          | 0/3305 [00:00<?, ?it/s]

--- Train epoch-4, step-16525 ---
loss: 0.2942


Evaluation: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 418/418 [00:00<00:00, 469.68it/s]

--- Eval epoch-4, step-16525 ---
pr_auc: 0.6489
roc_auc: 0.6610
f1: 0.6060
loss: 1.1949
Loaded best model





## STEP5: model evaluation

In [45]:
# method 1
result = trainer.evaluate(test_dataloader)
print (result)

# method 2
from pyhealth.metrics.binary import binary_metrics_fn

y_true, y_prob, loss = trainer.inference(test_dataloader)
binary_metrics_fn(
    y_true,
    y_prob,
    metrics=["pr_auc", "roc_auc", "f1"]
)

Evaluation: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 413/413 [00:00<00:00, 477.67it/s]


{'pr_auc': 0.6920530474148421, 'roc_auc': 0.7015410673568111, 'f1': 0.6474005660658425, 'loss': 0.6424725677267835}


Evaluation: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 413/413 [00:00<00:00, 477.06it/s]


{'pr_auc': 0.6920530474148421,
 'roc_auc': 0.7015410673568111,
 'f1': 0.6474005660658425}