## Import

In [1]:
import os
import pandas as pd
from datetime import datetime
from pyhealth.datasets import MIMIC3Dataset

import json
import tqdm as tqdm

## Data

In [2]:
from pyhealth.datasets import MIMIC3Dataset

mimic3_ds = MIMIC3Dataset(
        root="/Users/home/Professor Zijun Yao Lab/GPT experiment/EHR Data Sample/MIMIC-III/mimic-iii-clinical-database-1.4",
        tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
        code_mapping={"NDC": ("ATC", {"target_kwargs": {"level": 1}})},
        # dev=True
)

mimic3_ds.stat()


Statistics of base dataset (dev=False):
	- Dataset: MIMIC3Dataset
	- Number of patients: 46520
	- Number of visits: 58976
	- Number of visits per patient: 1.2678
	- Number of events per visit in DIAGNOSES_ICD: 11.0384
	- Number of events per visit in PROCEDURES_ICD: 4.0711
	- Number of events per visit in PRESCRIPTIONS: 87.1287



'\nStatistics of base dataset (dev=False):\n\t- Dataset: MIMIC3Dataset\n\t- Number of patients: 46520\n\t- Number of visits: 58976\n\t- Number of visits per patient: 1.2678\n\t- Number of events per visit in DIAGNOSES_ICD: 11.0384\n\t- Number of events per visit in PROCEDURES_ICD: 4.0711\n\t- Number of events per visit in PRESCRIPTIONS: 87.1287\n'

In [3]:
mimic3_ds.info()


dataset.patients: patient_id -> <Patient>

<Patient>
    - visits: visit_id -> <Visit> 
    - other patient-level info
    
    <Visit>
        - event_list_dict: table_name -> List[Event]
        - other visit-level info
    
        <Event>
            - code: str
            - other event-level info



In [4]:
from pyhealth.tasks import drug_recommendation_mimic3_fn

mimic3_ds = mimic3_ds.set_task(task_fn=drug_recommendation_mimic3_fn)
# stats info
mimic3_ds.stat()

Generating samples for drug_recommendation_mimic3_fn: 100%|██████████| 46520/46520 [00:21<00:00, 2147.70it/s] 


Statistics of sample dataset:
	- Dataset: MIMIC3Dataset
	- Task: drug_recommendation_mimic3_fn
	- Number of samples: 14141
	- Number of patients: 5449
	- Number of visits: 14141
	- Number of visits per patient: 2.5952
	- conditions:
		- Number of conditions per sample: 31.6520
		- Number of unique conditions: 4491
		- Distribution of conditions (Top-10): [('4280', 11109), ('4019', 10091), ('42731', 8266), ('5849', 7046), ('41401', 6581), ('25000', 5623), ('51881', 5520), ('5990', 4977), ('2724', 4921), ('5856', 4646)]
	- procedures:
		- Number of procedures per sample: 9.4601
		- Number of unique procedures: 1412
		- Distribution of procedures (Top-10): [('3893', 11161), ('966', 5746), ('9604', 5744), ('9904', 5716), ('3995', 5667), ('9671', 5016), ('9672', 3750), ('3891', 2824), ('9915', 2373), ('4513', 2222)]
	- drugs:
		- Number of drugs per sample: 8.7723
		- Number of unique drugs: 14
		- Distribution of drugs (Top-10): [('A', 14040), ('B', 14008), ('N', 13988), ('C', 13509), ('V'

"Statistics of sample dataset:\n\t- Dataset: MIMIC3Dataset\n\t- Task: drug_recommendation_mimic3_fn\n\t- Number of samples: 14141\n\t- Number of patients: 5449\n\t- Number of visits: 14141\n\t- Number of visits per patient: 2.5952\n\t- conditions:\n\t\t- Number of conditions per sample: 31.6520\n\t\t- Number of unique conditions: 4491\n\t\t- Distribution of conditions (Top-10): [('4280', 11109), ('4019', 10091), ('42731', 8266), ('5849', 7046), ('41401', 6581), ('25000', 5623), ('51881', 5520), ('5990', 4977), ('2724', 4921), ('5856', 4646)]\n\t- procedures:\n\t\t- Number of procedures per sample: 9.4601\n\t\t- Number of unique procedures: 1412\n\t\t- Distribution of procedures (Top-10): [('3893', 11161), ('966', 5746), ('9604', 5744), ('9904', 5716), ('3995', 5667), ('9671', 5016), ('9672', 3750), ('3891', 2824), ('9915', 2373), ('4513', 2222)]\n\t- drugs:\n\t\t- Number of drugs per sample: 8.7723\n\t\t- Number of unique drugs: 14\n\t\t- Distribution of drugs (Top-10): [('A', 14040), 

In [5]:
mimic3_ds.samples[0]

{'visit_id': '161106',
 'patient_id': '10004',
 'conditions': [['80503',
   '85200',
   '3453',
   '5180',
   '41401',
   '25000',
   'E8859',
   'V453']],
 'procedures': [['8102', '8051', '8162']],
 'drugs': ['B', 'A', 'N', 'V', 'J', 'C', 'R', 'D'],
 'drugs_hist': [[]]}

## Train Test

In [8]:
from pyhealth.datasets.splitter import split_by_patient
from pyhealth.datasets import split_by_patient, get_dataloader

# data split
train_dataset, val_dataset = split_by_patient(mimic3_ds, [0.8, 0.2])

# create dataloaders (they are <torch.data.DataLoader> object)
train_loader = get_dataloader(train_dataset, batch_size=64, shuffle=True)
val_loader = get_dataloader(val_dataset, batch_size=64, shuffle=False)
# test_loader = get_dataloader(test_dataset, batch_size=64, shuffle=False)

In [9]:
print(len(train_loader), len(val_loader), len(test_loader))

5 1 1


## Model

In [10]:
from pyhealth.models import Transformer

model = Transformer(
    dataset=mimic3_ds,
    feature_keys=["conditions", "procedures"],
    label_key="drugs",
    mode="multilabel",
)


Found Intel OpenMP ('libiomp') and LLVM OpenMP ('libomp') loaded at
the same time. Both libraries are known to be incompatible and this
can cause random crashes or deadlocks on Linux when loaded in the
same Python program.
Using threadpoolctl may cause crashes or deadlocks. For more
information and possible workarounds, please see
    https://github.com/joblib/threadpoolctl/blob/master/multiple_openmp.md



## Model Training

In [11]:
from pyhealth.trainer import Trainer

trainer = Trainer(model=model)
trainer.train(
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    epochs=3,
    monitor="pr_auc_samples",
)

Transformer(
  (embeddings): ModuleDict(
    (conditions): Embedding(917, 128, padding_idx=0)
    (procedures): Embedding(305, 128, padding_idx=0)
  )
  (linear_layers): ModuleDict()
  (transformer): ModuleDict(
    (conditions): TransformerLayer(
      (transformer): ModuleList(
        (0): TransformerBlock(
          (attention): MultiHeadedAttention(
            (linear_layers): ModuleList(
              (0-2): 3 x Linear(in_features=128, out_features=128, bias=False)
            )
            (output_linear): Linear(in_features=128, out_features=128, bias=False)
            (attention): Attention()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (feed_forward): PositionwiseFeedForward(
            (w_1): Linear(in_features=128, out_features=512, bias=True)
            (w_2): Linear(in_features=512, out_features=128, bias=True)
            (dropout): Dropout(p=0.5, inplace=False)
            (activation): GELU(approximate='none')
          )
          (in

  from tqdm.autonotebook import trange
Epoch 0 / 3: 100%|██████████| 5/5 [00:00<00:00, 10.67it/s]

--- Train epoch-0, step-5 ---
loss: 1.0871



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 138.89it/s]

--- Eval epoch-0, step-5 ---
pr_auc_samples: 0.2439
loss: 0.7037
New best pr_auc_samples score (0.2439) at epoch-0, step-5




Epoch 1 / 3: 100%|██████████| 5/5 [00:00<00:00, 28.38it/s]

--- Train epoch-1, step-10 ---
loss: 0.7978



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 139.55it/s]


--- Eval epoch-1, step-10 ---
pr_auc_samples: 0.3720
loss: 0.5377
New best pr_auc_samples score (0.3720) at epoch-1, step-10



Epoch 2 / 3: 100%|██████████| 5/5 [00:00<00:00, 29.49it/s]

--- Train epoch-2, step-15 ---
loss: 0.6209



Evaluation: 100%|██████████| 1/1 [00:00<00:00, 140.18it/s]

--- Eval epoch-2, step-15 ---
pr_auc_samples: 0.5088
loss: 0.4193
New best pr_auc_samples score (0.5088) at epoch-2, step-15





Loaded best model


In [12]:
# option 1: use our built-in evaluation metric
score = trainer.evaluate(test_loader)
print (score)

# option 2: use our pyhealth.metrics to evaluate
from pyhealth.metrics.multilabel import multilabel_metrics_fn
y_true, y_prob, loss = trainer.inference(test_loader)
multilabel_metrics_fn(y_true, y_prob, metrics=["pr_auc_samples"])

Evaluation: 100%|██████████| 1/1 [00:00<00:00, 87.62it/s]


{'pr_auc_samples': 0.5090478799050038, 'loss': 0.43610066175460815}


Evaluation: 100%|██████████| 1/1 [00:00<00:00, 100.64it/s]


{'pr_auc_samples': 0.5090478799050038}