# Step1: create dataset

In [27]:
from pyhealth.datasets import MIMIC3BaseDataset, MIMIC4BaseDataset, eICUBaseDataset, OMOPBaseDataset
# base_dataset = MIMIC3BaseDataset(root="/srv/local/data/physionet.org/files/mimiciii/1.4")
# base_dataset = eICUBaseDataset(root="/srv/local/data/physionet.org/files/eicu-crd/2.0")
# base_dataset = MIMIC4BaseDataset(root="/srv/local/data/physionet.org/files/mimiciv/2.0/hosp")
base_dataset = OMOPBaseDataset(root="/srv/local/data/zw12/pyhealth/raw_data/synpuf1k_omop_cdm_5.2.2")
base_dataset.info()


        ----- Output Data Structure -----
        Dataset.patients: [
            {
                patient_id: patient_id, 
                visits: [
                    {
                        visit_id: visit_id, 
                        patient_id: patient_id, 
                        conditions: [List], 
                        procedures: [List],
                        drugs: [List],
                        visit_info: <dict>
                    }
                    ...
                ]                    
            } 
            ...
        ]
        


user can use this module for data processing
- **[researchers from CS]** build their own model on top of it
- **[reserachers from medical area]** use the models in the package

In [28]:
# create task-specific dataset
from pyhealth.tasks import DrugRecDataset
drug_rec_dataset = DrugRecDataset(base_dataset)
drug_rec_dataset.info()


        ----- Output Data Structure -----
        >> drug_rec_dataloader[0]
        >> {
            "conditions": List[tensor],
            "procedures": List[tensor],
            "drugs": List[tensor]
        }
        


In [29]:
print (len(base_dataset), len(drug_rec_dataset))

1000 532


# Step2: healthcare predictive model

In [30]:
voc_size = drug_rec_dataset.voc_size
params = drug_rec_dataset.params

from pyhealth.models import RETAIN
model = RETAIN(voc_size, params)

# Step 3: create train / val / test dataloader

In [31]:
from pyhealth.data import split
from torch.utils.data import DataLoader

drug_rec_trainset, drug_rec_valset, drug_rec_testset = split.random_split(drug_rec_dataset, [0.8, 0.1, 0.1])
drug_rec_train_loader = DataLoader(drug_rec_trainset, batch_size=1, collate_fn=lambda x: x[0])
drug_rec_val_loader = DataLoader(drug_rec_valset, batch_size=1, collate_fn=lambda x: x[0])
drug_rec_test_loader = DataLoader(drug_rec_testset, batch_size=1, collate_fn=lambda x: x[0])

# Step 4: model training and evaluation

In [32]:
# train
from pytorch_lightning import Trainer

trainer = Trainer(
    gpus=1,
    max_epochs=3,
    progress_bar_refresh_rate=5,
)

trainer.fit(
    model=model,
    train_dataloaders=drug_rec_train_loader,
    val_dataloaders=drug_rec_val_loader,
)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True, used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.gpu:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name                | Type       | Params
---------------------------------------------------
0 | condition_embedding | Sequential | 103 K 
1 | procedure_embedding | Sequential | 71.0 K
2 | alpha_gru           | GRU        | 25.0 K
3 | beta_gru            | GRU        | 25.0 K
4 | alpha_li            | Linear     | 65    
5 | beta_li             | Linear     | 4.2 K 
6 | output              | Linear     | 15.0 K
---------------------------------------------------
243 K     Trainable params
0         Non-trainable params
243 K    

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [33]:
# evaluation
from pyhealth.evaluator import DrugRecEvaluator
evaluator = DrugRecEvaluator(model)
evaluator.evaluate(drug_rec_test_loader)


Jaccard: 0.2067,  PRAUC: 0.3496, AVG_PRC: 0.3785, AVG_RECALL: 0.2712, AVG_F1: 0.2854, AVG_MED: 2.53

