In [15]:
import sys
sys.path.append("/home/chaoqiy2/github/PyHealth-OMOP")

from warnings import simplefilter
# ignore all warnings
simplefilter(action='ignore')

# 1. load general-purpose dataset

In [2]:
from pyhealth.datasets import MIMIC3BaseDataset

base_ds = MIMIC3BaseDataset(
    root="/srv/local/data/physionet.org/files/mimiciii/1.4", 
    flag="dev",
)

# 2. process to task-specific dataset

In [3]:
from pyhealth.tasks import DrugRecDataset
drugrec_ds = DrugRecDataset(base_ds)

----- preparing code mappings -----
source loaded from /home/chaoqiy2/.cache/medcode/ cache
mapping finished: NDC11 -> ATC4
mapping finished: ATC4 -> NDC11
load time: 2.453434705734253s
-----------------------------------------


100%|█████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 8553.57it/s]

1. finish cleaning the dataset for drug recommendation task
2. tokenized the medical codes





In [4]:
from pyhealth.data.split import split_by_pat
train_loader, val_loader, test_loader = split_by_pat(
                                            drugrec_ds,
                                            ratios = [2/3, 1/6, 1/6], 
                                            batch_size = 64, 
                                            seed = 12345,
                                        )

1. finish data splitting
2. generate train / val / test data loaders


# 3. initialize DL models

In [11]:
from pyhealth.models import RETAIN, MLModel, RNN, Transformer, GAMENet, SafeDrug, MICRON

model = RETAIN(
    task = "drug_recommendation",
    voc_size = drugrec_ds.voc_size,
    tokenizers = drugrec_ds.tokenizers,
    emb_dim = 64,
)

# 4. traning

In [12]:
from pyhealth.trainer import Trainer
from pyhealth.evaluator.evaluating_multilabel import evaluate_multilabel

# # ----------- DL model ---------
trainer = Trainer(enable_logging=True, output_path="../output")
trainer.fit(model,
            train_loader=train_loader,
            epochs=5,
            evaluate_fn=evaluate_multilabel,
            eval_loader=val_loader,
            monitor="jaccard")

Iteration:   0%|          | 0/4 [00:00<?, ?it/s]

{'loss': 0.6882487, 'ddi': 0.05085359163521137, 'jaccard': 0.2195066463140662, 'prauc': 0.2922156486015231, 'f1': 0.354803225467975}


Iteration:   0%|          | 0/4 [00:00<?, ?it/s]

{'loss': 0.68483543, 'ddi': 0.0508727051647824, 'jaccard': 0.21958755674078465, 'prauc': 0.31559299570817506, 'f1': 0.3549116226050744}


Iteration:   0%|          | 0/4 [00:00<?, ?it/s]

{'loss': 0.6805311, 'ddi': 0.05085310124692861, 'jaccard': 0.2195577572328803, 'prauc': 0.34413829518510713, 'f1': 0.35487562377228576}


Iteration:   0%|          | 0/4 [00:00<?, ?it/s]

{'loss': 0.6746202, 'ddi': 0.05088227287351201, 'jaccard': 0.2199741811492879, 'prauc': 0.3799587621551031, 'f1': 0.3554049335016782}


Iteration:   0%|          | 0/4 [00:00<?, ?it/s]

{'loss': 0.6662972, 'ddi': 0.05114055010108061, 'jaccard': 0.22125389917842644, 'prauc': 0.42511402781409857, 'f1': 0.35708387991987833}
best_model_path: ../output/1665016652.567665/best.ckpt


# 5. evaluation

In [14]:
from pyhealth.evaluator.evaluating_multilabel import evaluate_multilabel

# load the best model
best_model = trainer.load(model, path="../output/1665016652.567665/best.ckpt") 

evaluate_multilabel(model, val_loader)

{'loss': 0.6662975,
 'ddi': 0.05114055010108061,
 'jaccard': 0.22125389917842644,
 'prauc': 0.42501261449677036,
 'f1': 0.35708387991987833}