In [1]:
from torch.utils.data import DataLoader
import sys
sys.path.append("../")

from pyhealth.datasets import MIMIC3Dataset, eICUDataset, MIMIC4Dataset, OMOPDataset
from pyhealth.models import MLModel
from pyhealth.split import split_by_patient
from pyhealth.tasks import (
    drug_recommendation_mimic3_fn,
    drug_recommendation_eicu_fn,
    drug_recommendation_mimic4_fn,
    drug_recommendation_omop_fn,
    readmission_prediction_mimic3_fn,
)
from pyhealth.utils import collate_fn_dict
from pyhealth.trainer import Trainer
from pyhealth.evaluator import evaluate
from pyhealth.metrics import *
from sklearn.svm import SVC

###############
data = "omop"
################

# STEP 1 & 2: load data and set task

if data == "mimic3":
    mimic3dataset = MIMIC3Dataset(
        root="/srv/local/data/physionet.org/files/mimiciii/1.4",
        tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS", "LABEVENTS"],
        dev=True,
        code_mapping={"PRESCRIPTIONS": "ATC3"},
        refresh_cache=True,
    )
    mimic3dataset.stat()
    mimic3dataset.set_task(drug_recommendation_mimic3_fn)
    mimic3dataset.stat()
    dataset = mimic3dataset

elif data == "eicu":
    eicudataset = eICUDataset(
        root="/srv/local/data/physionet.org/files/eicu-crd/2.0",
        tables=["diagnosis", "medication", "physicalExam"],
        dev=True,
        refresh_cache=False,
    )
    eicudataset.stat()
    eicudataset.set_task(task_fn=drug_recommendation_eicu_fn)
    eicudataset.stat()
    dataset = eicudataset

elif data == "mimic4":
    mimic4dataset = MIMIC4Dataset(
        root="/srv/local/data/physionet.org/files/mimiciv/2.0/hosp",
        tables=["diagnoses_icd", "procedures_icd", "prescriptions"],
        dev=True,
        code_mapping={"prescriptions": "ATC3"},
        refresh_cache=False,
    )
    mimic4dataset.stat()
    mimic4dataset.set_task(task_fn=drug_recommendation_mimic4_fn)
    mimic4dataset.stat()
    dataset = mimic4dataset

elif data == "omop":
    omopdataset = OMOPDataset(
        root="/srv/local/data/zw12/pyhealth/raw_data/synpuf1k_omop_cdm_5.2.2",
        tables=[
            "condition_occurrence",
            "procedure_occurrence",
            "drug_exposure",
            "measurement",
        ],
        dev=True,
        refresh_cache=False,
    )
    omopdataset.stat()
    omopdataset.set_task(task_fn=drug_recommendation_omop_fn)
    omopdataset.stat()
    dataset = omopdataset


  from .autonotebook import tqdm as notebook_tqdm


Loaded OMOP base dataset from /home/pj20/.cache/pyhealth/datasets/5721324acc404b87cb3e2abaa51618c5.pkl

Statistics of OMOP dataset (dev=True):
	- Number of patients: 1000
	- Number of visits: 55261
	- Number of visits per patient: 55.2610
	- Number of condition_occurrence per visit: 2.6635
	- Number of procedure_occurrence per visit: 2.4886
	- Number of drug_exposure per visit: 0.1387
	- Number of measurement per visit: 0.6253



Generating samples for drug_recommendation_omop_fn: 100%|███████████████████| 1000/1000 [00:00<00:00, 3365.90it/s]


Statistics of drug_recommendation_omop_fn task:
	- Dataset: OMOP (dev=True)
	- Number of patients: 633
	- Number of visits: 3342
	- Number of visits per patient: 5.2796
	- Number of visit_id per visit: 1.0000
	- Number of unique visit_id: 3342
	- Number of patient_id per visit: 1.0000
	- Number of unique patient_id: 633
	- Number of conditions per visit: 2.3372
	- Number of unique conditions: 1715
	- Number of procedures per visit: 2.8169
	- Number of unique procedures: 1173
	- Number of drugs per visit: 1.2959
	- Number of unique drugs: 230
	- Label distribution: {'2213440': 606, '2213483': 90, '46275982': 44, '40950844': 127, '787787': 28, '40836918': 164, '40220869': 23, '40018375': 4, '19081294': 20, '40102878': 23, '1305085': 7, '1304850': 32, '19081616': 30, '1310317': 4, '529411': 21, '529303': 21, '40080069': 203, '44817887': 14, '2718906': 3, '40227542': 77, '2718836': 13, '1518606': 162, '40100282': 7, '43532421': 19, '35603172': 10, '1551192': 3, '1560524': 15, '41083320': 




In [2]:
dataset[6]

{'visit_id': '50231',
 'patient_id': '1002',
 'conditions': [['79740', '437798'], ['200051', '439777', '439926']],
 'procedures': [['2314227', '2314209', '2314231', '2414398', '2314229'],
  ['2414398', '2314225', '2314232', '2314209']],
 'drugs': [['40220869', '40018375', '19081294']],
 'label': ['40102878']}

In [3]:
# data split
train_dataset, val_dataset, test_dataset = split_by_patient(dataset, [0.8, 0.1, 0.1])
train_loader = DataLoader(
    train_dataset, batch_size=64, shuffle=True, collate_fn=collate_fn_dict
)
val_loader = DataLoader(
    val_dataset, batch_size=64, shuffle=False, collate_fn=collate_fn_dict
)
test_loader = DataLoader(
    test_dataset, batch_size=64, shuffle=False, collate_fn=collate_fn_dict
)

# STEP 3: define model

model = MLModel(
    dataset=dataset,
    tables=["conditions", "procedures"],
    target="drugs",
    classifier = SVC(gamma='auto', probability=True ,verbose=1),
    mode="multilabel",
    output_path="./ckpt/mlmodel"
)

# STEP 4: define trainer
model.fit(train_loader=train_loader)

(2637, 100) (2637, 228)
[LibSVM]*
optimization finished, #iter = 318
obj = -41.983242, rho = 1.030136
nSV = 131, nBSV = 21
Total nSV = 131
*
optimization finished, #iter = 218
obj = -27.976547, rho = 1.022572
nSV = 97, nBSV = 14
Total nSV = 97
*
optimization finished, #iter = 419
obj = -43.972400, rho = 1.022583
nSV = 133, nBSV = 23
Total nSV = 133
*
optimization finished, #iter = 456
obj = -43.974954, rho = 1.018278
nSV = 151, nBSV = 22
Total nSV = 151
*
optimization finished, #iter = 477
obj = -49.972485, rho = 1.031596
nSV = 157, nBSV = 26
Total nSV = 157
*
optimization finished, #iter = 558
obj = -51.975055, rho = -1.037094
nSV = 192, nBSV = 27
Total nSV = 192
[LibSVM]*
optimization finished, #iter = 10
obj = -5.999798, rho = 1.003417
nSV = 13, nBSV = 4
Total nSV = 13
*
optimization finished, #iter = 19
obj = -5.999687, rho = 1.006933
nSV = 22, nBSV = 3
Total nSV = 22
*
optimization finished, #iter = 9
obj = -3.999861, rho = 1.004008
nSV = 11, nBSV = 3
Total nSV = 11
*
optimization

In [4]:
import warnings
warnings.filterwarnings('ignore')
# STEP 5: evaluate
y_gt, y_prob, y_pred = evaluate(model, test_loader, isMLModel=True)

print(y_gt)
print(y_prob)
print(y_pred)

jaccard = jaccard_multilabel(y_gt, y_pred, average="micro")
accuracy = accuracy_multilabel(y_gt, y_pred)
f1 = f1_multilabel(y_gt, y_pred, average="micro")
prauc = pr_auc_multilabel(y_gt, y_prob)

# print metric name and score
print("jaccard: ", jaccard)
print("accuracy: ", accuracy)
print("f1: ", f1)
print("prauc: ", prauc)

[array([[0.99022114, 0.00977886],
       [0.99246746, 0.00753254],
       [0.987942  , 0.012058  ],
       [0.9900271 , 0.0099729 ],
       [0.99169599, 0.00830401],
       [0.99131254, 0.00868746],
       [0.98762616, 0.01237384],
       [0.99174998, 0.00825002],
       [0.98700871, 0.01299129],
       [0.9912551 , 0.0087449 ],
       [0.992655  , 0.007345  ],
       [0.99116095, 0.00883905],
       [0.99201274, 0.00798726],
       [0.99196664, 0.00803336],
       [0.99189724, 0.00810276],
       [0.99044331, 0.00955669],
       [0.99032672, 0.00967328],
       [0.99161799, 0.00838201],
       [0.99191414, 0.00808586],
       [0.99207763, 0.00792237],
       [0.99254735, 0.00745265],
       [0.99250588, 0.00749412],
       [0.99209877, 0.00790123],
       [0.99201274, 0.00798726],
       [0.988405  , 0.011595  ],
       [0.99229791, 0.00770209],
       [0.9919307 , 0.0080693 ],
       [0.99259933, 0.00740067],
       [0.99205119, 0.00794881],
       [0.99209252, 0.00790748],
       [0