In [1]:
from pyhealth.datasets import MIMIC3Dataset, MIMIC4Dataset
from GraphCare.task_fn import drug_recommendation_fn, drug_recommendation_mimic4_fn, mortality_prediction_mimic3_fn, readmission_prediction_mimic3_fn, length_of_stay_prediction_mimic3_fn, length_of_stay_prediction_mimic4_fn, mortality_prediction_mimic4_fn, readmission_prediction_mimic4_fn

ds = MIMIC4Dataset(
root="/data/physionet.org/files/mimiciv/2.0/hosp/", 
tables=["diagnoses_icd", "procedures_icd", "prescriptions"],      
code_mapping={
    "NDC": ("ATC", {"target_kwargs": {"level": 3}}),
    "ICD9CM": "CCSCM",
    "ICD9PROC": "CCSPROC",
    "ICD10CM": "CCSCM",
    "ICD10PROC": "CCSPROC",
    },
dev=False
)

sample_dataset = ds.set_task(drug_recommendation_mimic4_fn)

  from .autonotebook import tqdm as notebook_tqdm
Generating samples for drug_recommendation_mimic4_fn: 100%|██████████| 190279/190279 [00:11<00:00, 16141.38it/s]


In [2]:
# import pickle

# with open('../../../data/pj20/exp_data/ccscm_ccsproc/sample_dataset_mimic4_drugrec_th015.pkl', 'rb') as f:
#     sample_dataset = pickle.load(f)

In [3]:
from pyhealth.datasets import split_by_patient, get_dataloader

train_dataset, val_dataset, test_dataset = split_by_patient(sample_dataset, [0.8, 0.1, 0.1], seed=528)
train_loader = get_dataloader(train_dataset, batch_size=64, shuffle=True)
val_loader = get_dataloader(val_dataset, batch_size=64, shuffle=False)
test_loader = get_dataloader(test_dataset, batch_size=64, shuffle=False)

In [4]:
from pyhealth.trainer import Trainer
import torch
from pyhealth.models import Transformer, RETAIN, SafeDrug, MICRON, CNN, RNN, GAMENet
from collections import defaultdict

results = defaultdict(list)

for i in range(3):
    for model_ in [
        Transformer, 
        RETAIN,
        SafeDrug,
        MICRON,
        GAMENet
        ]:
        try:
            model = model_(
                dataset=sample_dataset,
                feature_keys=["conditions", "procedures"],
                label_key="drugs",
                mode="multilabel",
            )
        except:
            model = model_(dataset=sample_dataset)

        device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

        ## binary
        # trainer = Trainer(model=model, device=device, metrics=["pr_auc", "roc_auc", "accuracy", "f1", "jaccard"])
        # trainer.train(
        #     train_dataloader=train_loader,
        #     val_dataloader=val_loader,
        #     epochs=5,
        #     monitor="accuracy",
        # )

        ## multi-label
        trainer = Trainer(model=model, device=device, metrics=["pr_auc_samples", "roc_auc_samples", "f1_samples", "jaccard_samples"])
        trainer.train(
            train_dataloader=train_loader,
            val_dataloader=val_loader,
            epochs=5,
            monitor="pr_auc_samples",
        )

        ## multi-class
        # trainer = Trainer(model=model, device=device, metrics=["roc_auc_weighted_ovr", "cohen_kappa", "accuracy", "f1_weighted"])
        # trainer.train(
        #     train_dataloader=train_loader,
        #     val_dataloader=val_loader,
        #     epochs=5,
        #     monitor="roc_auc_weighted_ovr",
        # )

        results[model_.__name__].append(trainer.evaluate(val_loader))

Transformer(
  (embeddings): ModuleDict(
    (conditions): Embedding(280, 128, padding_idx=0)
    (procedures): Embedding(233, 128, padding_idx=0)
  )
  (linear_layers): ModuleDict()
  (transformer): ModuleDict(
    (conditions): TransformerLayer(
      (transformer): ModuleList(
        (0): TransformerBlock(
          (attention): MultiHeadedAttention(
            (linear_layers): ModuleList(
              (0): Linear(in_features=128, out_features=128, bias=False)
              (1): Linear(in_features=128, out_features=128, bias=False)
              (2): Linear(in_features=128, out_features=128, bias=False)
            )
            (output_linear): Linear(in_features=128, out_features=128, bias=False)
            (attention): Attention()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (feed_forward): PositionwiseFeedForward(
            (w_1): Linear(in_features=128, out_features=512, bias=True)
            (w_2): Linear(in_features=512, out_features=128, 

: 

: 

In [None]:
avg_results = defaultdict(dict)

for k, v in results.items():
    for k_, v_ in v[0].items():
        avg_results[k][k_] = sum([vv[k_] for vv in v]) / len(v)

In [None]:
import numpy as np
# calculate standard deviation
variation_results = defaultdict(dict)

for k, v in results.items():
    for k_, v_ in v[0].items():
        variation_results[k][k_] = np.std([vv[k_] for vv in v])

In [None]:
avg_results

defaultdict(dict,
            {'Transformer': {'pr_auc_samples': 0.6783844682267832,
              'roc_auc_samples': 0.9203570394645626,
              'f1_samples': 0.5199858108988109,
              'jaccard_samples': 0.36742235572247056,
              'loss': 0.23415324012438457},
             'RETAIN': {'pr_auc_samples': 0.6996383061429623,
              'roc_auc_samples': 0.9277689975071753,
              'f1_samples': 0.5574803101467737,
              'jaccard_samples': 0.4039973902798657,
              'loss': 0.2264935106039047},
             'SafeDrug': {'pr_auc_samples': 0.5245888773670376,
              'roc_auc_samples': 0.8686946082288995,
              'f1_samples': 0.3472441005948089,
              'jaccard_samples': 0.21539324480813865,
              'loss': 0.27848473568757376},
             'MICRON': {'pr_auc_samples': 0.7085677770407818,
              'roc_auc_samples': 0.9289359130123493,
              'f1_samples': 0.5626078071601712,
              'jaccard_samples'

In [None]:
variation_results

defaultdict(dict,
            {'Transformer': {'pr_auc_samples': 0.0009362267858230031,
              'roc_auc_samples': 0.0006469108050935868,
              'f1_samples': 0.0035958936284566606,
              'jaccard_samples': 0.0031124661318476876,
              'loss': 0.0006033271549201798},
             'RETAIN': {'pr_auc_samples': 0.0012243061850607973,
              'roc_auc_samples': 0.00031909088838670557,
              'f1_samples': 0.0028257346084777733,
              'jaccard_samples': 0.002540716746490498,
              'loss': 0.0007287530945779822},
             'SafeDrug': {'pr_auc_samples': 0.005010126702354339,
              'roc_auc_samples': 0.001191226819088192,
              'f1_samples': 0.009106649051422106,
              'jaccard_samples': 0.00757658232526652,
              'loss': 0.0007432820826128067},
             'MICRON': {'pr_auc_samples': 0.0009596725056481872,
              'roc_auc_samples': 0.0007391186346618564,
              'f1_samples': 0.0069027