In [1]:
from torch.utils.data import DataLoader
import sys
sys.path.append("../")

from pyhealth.datasets import MIMIC3Dataset, eICUDataset, MIMIC4Dataset, OMOPDataset
from pyhealth.models import MLModel
from pyhealth.split import split_by_patient
from pyhealth.tasks import (
    drug_recommendation_mimic3_fn,
    drug_recommendation_eicu_fn,
    drug_recommendation_mimic4_fn,
    drug_recommendation_omop_fn,
    readmission_prediction_mimic3_fn,
)
from pyhealth.utils import collate_fn_dict
from pyhealth.trainer import Trainer
from pyhealth.evaluator import evaluate
from pyhealth.metrics import *
from sklearn.svm import SVC

###############
data = "omop"
################

# STEP 1 & 2: load data and set task

if data == "mimic3":
    mimic3dataset = MIMIC3Dataset(
        root="/srv/local/data/physionet.org/files/mimiciii/1.4",
        tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS", "LABEVENTS"],
        dev=True,
        code_mapping={"PRESCRIPTIONS": "ATC3"},
        refresh_cache=True,
    )
    mimic3dataset.stat()
    mimic3dataset.set_task(drug_recommendation_mimic3_fn)
    mimic3dataset.stat()
    dataset = mimic3dataset

elif data == "eicu":
    eicudataset = eICUDataset(
        root="/srv/local/data/physionet.org/files/eicu-crd/2.0",
        tables=["diagnosis", "medication", "physicalExam"],
        dev=True,
        refresh_cache=False,
    )
    eicudataset.stat()
    eicudataset.set_task(task_fn=drug_recommendation_eicu_fn)
    eicudataset.stat()
    dataset = eicudataset

elif data == "mimic4":
    mimic4dataset = MIMIC4Dataset(
        root="/srv/local/data/physionet.org/files/mimiciv/2.0/hosp",
        tables=["diagnoses_icd", "procedures_icd", "prescriptions"],
        dev=True,
        code_mapping={"prescriptions": "ATC3"},
        refresh_cache=False,
    )
    mimic4dataset.stat()
    mimic4dataset.set_task(task_fn=drug_recommendation_mimic4_fn)
    mimic4dataset.stat()
    dataset = mimic4dataset

elif data == "omop":
    omopdataset = OMOPDataset(
        root="/srv/local/data/zw12/pyhealth/raw_data/synpuf1k_omop_cdm_5.2.2",
        tables=[
            "condition_occurrence",
            "procedure_occurrence",
            "drug_exposure",
            "measurement",
        ],
        dev=True,
        refresh_cache=False,
    )
    omopdataset.stat()
    omopdataset.set_task(task_fn=drug_recommendation_omop_fn)
    omopdataset.stat()
    dataset = omopdataset


  from .autonotebook import tqdm as notebook_tqdm


Loaded OMOP base dataset from /home/pj20/.cache/pyhealth/5721324acc404b87cb3e2abaa51618c5.pkl

Statistics of OMOP dataset (dev=True):
	- Number of patients: 1000
	- Number of visits: 55261
	- Number of visits per patient: 55.2610
	- Number of condition_occurrence per visit: 2.6635
	- Number of procedure_occurrence per visit: 2.4886
	- Number of drug_exposure per visit: 0.1387
	- Number of measurement per visit: 0.6253



Generating samples for drug_recommendation_omop_fn: 100%|███████████████████| 1000/1000 [00:00<00:00, 3398.50it/s]


Statistics of drug_recommendation_omop_fn task:
	- Dataset: OMOP (dev=True)
	- Number of patients: 633
	- Number of visits: 3342
	- Number of visits per patient: 5.2796
	- Number of visit_id per visit: 1.0000
	- Number of unique visit_id: 3342
	- Number of patient_id per visit: 1.0000
	- Number of unique patient_id: 633
	- Number of conditions per visit: 2.3372
	- Number of unique conditions: 1715
	- Number of procedures per visit: 2.8169
	- Number of unique procedures: 1173
	- Number of drugs per visit: 1.2959
	- Number of unique drugs: 230
	- Label distribution: {'2213440': 606, '2213483': 90, '46275982': 44, '40950844': 127, '787787': 28, '40836918': 164, '40220869': 23, '40018375': 4, '19081294': 20, '40102878': 23, '1305085': 7, '1304850': 32, '19081616': 30, '1310317': 4, '529411': 21, '529303': 21, '40080069': 203, '44817887': 14, '2718906': 3, '40227542': 77, '2718836': 13, '1518606': 162, '40100282': 7, '43532421': 19, '35603172': 10, '1551192': 3, '1560524': 15, '41083320': 




In [2]:
dataset[6]

{'visit_id': '50231',
 'patient_id': '1002',
 'conditions': [['79740', '437798'], ['200051', '439777', '439926']],
 'procedures': [['2314227', '2314209', '2314231', '2414398', '2314229'],
  ['2414398', '2314225', '2314232', '2314209']],
 'drugs': [['40220869', '40018375', '19081294']],
 'label': ['40102878']}

In [3]:
# data split
train_dataset, val_dataset, test_dataset = split_by_patient(dataset, [0.8, 0.1, 0.1])
train_loader = DataLoader(
    train_dataset, batch_size=64, shuffle=True, collate_fn=collate_fn_dict
)
val_loader = DataLoader(
    val_dataset, batch_size=64, shuffle=False, collate_fn=collate_fn_dict
)
test_loader = DataLoader(
    test_dataset, batch_size=64, shuffle=False, collate_fn=collate_fn_dict
)

# STEP 3: define model

model = MLModel(
    dataset=dataset,
    tables=["conditions", "procedures"],
    target="drugs",
    classifier = SVC(gamma='auto', probability=True ,verbose=1),
    mode="multilabel",
    output_path="./ckpt/mlmodel"
)

# STEP 4: define trainer
model.fit(train_loader=train_loader)

(2674, 2892) (2674, 228)
[LibSVM]*
optimization finished, #iter = 325
obj = -37.986570, rho = 1.026273
nSV = 139, nBSV = 20
Total nSV = 139
*
optimization finished, #iter = 395
obj = -39.980175, rho = 1.013967
nSV = 145, nBSV = 20
Total nSV = 145
*
optimization finished, #iter = 331
obj = -39.981799, rho = 1.019714
nSV = 135, nBSV = 21
Total nSV = 135
*
optimization finished, #iter = 332
obj = -37.979941, rho = 1.026671
nSV = 120, nBSV = 21
Total nSV = 120
*
optimization finished, #iter = 400
obj = -35.976727, rho = 1.039763
nSV = 124, nBSV = 18
Total nSV = 124
*
optimization finished, #iter = 623
obj = -47.980845, rho = -1.028772
nSV = 201, nBSV = 24
Total nSV = 201
[LibSVM]*
optimization finished, #iter = 10
obj = -5.999875, rho = 1.002625
nSV = 13, nBSV = 5
Total nSV = 13
*
optimization finished, #iter = 17
obj = -3.999689, rho = 1.004839
nSV = 17, nBSV = 2
Total nSV = 17
*
optimization finished, #iter = 8
obj = -3.999878, rho = 1.001699
nSV = 10, nBSV = 2
Total nSV = 10
*
optimizat

In [4]:
import warnings
warnings.filterwarnings('ignore')
# STEP 5: evaluate
y_gt, y_prob, y_pred = evaluate(model, test_loader, isMLModel=True)

print(y_gt)
print(y_prob)
print(y_pred)

jaccard = jaccard_multilabel(y_gt, y_pred, average="micro")
accuracy = accuracy_multilabel(y_gt, y_pred)
f1 = f1_multilabel(y_gt, y_pred, average="micro")
prauc = pr_auc_multilabel(y_gt, y_prob)

# print metric name and score
print("jaccard: ", jaccard)
print("accuracy: ", accuracy)
print("f1: ", f1)
print("prauc: ", prauc)

[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
[[0.00785653 0.0014507  0.00087168 ... 0.00024216 0.0007744  0.00165535]
 [0.00838695 0.00139299 0.00086467 ... 0.00028918 0.00080275 0.00126916]
 [0.00809015 0.00132017 0.00082316 ... 0.00032088 0.0007158  0.00118882]
 ...
 [0.00799931 0.00131159 0.00082727 ... 0.00032521 0.00071673 0.0011999 ]
 [0.00765453 0.00147987 0.00085456 ... 0.00032776 0.00079606 0.00124576]
 [0.00803738 0.00130101 0.00082438 ... 0.00031103 0.00071673 0.00119886]]
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]]
jaccard:  0.9909464756923241
accuracy:  0.9954347346353304
f1:  0.9954347346353304
prauc:  0.233433581664978
