In [1]:
from torch.utils.data import DataLoader
import sys
sys.path.append("../")

from pyhealth.datasets import MIMIC3Dataset, eICUDataset, MIMIC4Dataset, OMOPDataset
from pyhealth.models import MLModel
from pyhealth.split import split_by_patient
from pyhealth.tasks import (
    mortality_prediction_mimic3_fn,
    mortality_prediction_eicu_fn,
    mortality_prediction_mimic4_fn,
    mortality_prediction_omop_fn,
)
from pyhealth.utils import collate_fn_dict
from pyhealth.trainer import Trainer
from pyhealth.evaluator import evaluate
from pyhealth.metrics import *
from sklearn.svm import SVC

###############
data = "omop"
################

# STEP 1 & 2: load data and set task

if data == "mimic3":
    mimic3dataset = MIMIC3Dataset(
        root="/srv/local/data/physionet.org/files/mimiciii/1.4",
        tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS", "LABEVENTS"],
        dev=True,
        code_mapping={"PRESCRIPTIONS": "ATC3"},
        refresh_cache=True,
    )
    mimic3dataset.stat()
    mimic3dataset.set_task(drug_recommendation_mimic3_fn)
    mimic3dataset.stat()
    dataset = mimic3dataset

elif data == "eicu":
    eicudataset = eICUDataset(
        root="/srv/local/data/physionet.org/files/eicu-crd/2.0",
        tables=["diagnosis", "medication", "physicalExam"],
        dev=True,
        refresh_cache=False,
    )
    eicudataset.stat()
    eicudataset.set_task(task_fn=drug_recommendation_eicu_fn)
    eicudataset.stat()
    dataset = eicudataset

elif data == "mimic4":
    mimic4dataset = MIMIC4Dataset(
        root="/srv/local/data/physionet.org/files/mimiciv/2.0/hosp",
        tables=["diagnoses_icd", "procedures_icd", "prescriptions"],
        dev=True,
        code_mapping={"prescriptions": "ATC3"},
        refresh_cache=False,
    )
    mimic4dataset.stat()
    mimic4dataset.set_task(task_fn=drug_recommendation_mimic4_fn)
    mimic4dataset.stat()
    dataset = mimic4dataset

elif data == "omop":
    omopdataset = OMOPDataset(
        root="/srv/local/data/zw12/pyhealth/raw_data/synpuf1k_omop_cdm_5.2.2",
        tables=[
            "condition_occurrence",
            "procedure_occurrence",
            "drug_exposure",
            "measurement",
        ],
        dev=True,
        refresh_cache=False,
    )
    omopdataset.stat()
    omopdataset.set_task(task_fn=mortality_prediction_omop_fn)
    omopdataset.stat()
    dataset = omopdataset


  from .autonotebook import tqdm as notebook_tqdm


Loaded OMOP base dataset from /home/pj20/.cache/pyhealth/datasets/5721324acc404b87cb3e2abaa51618c5.pkl

Statistics of OMOP dataset (dev=True):
	- Number of patients: 1000
	- Number of visits: 55261
	- Number of visits per patient: 55.2610
	- Number of condition_occurrence per visit: 2.6635
	- Number of procedure_occurrence per visit: 2.4886
	- Number of drug_exposure per visit: 0.1387
	- Number of measurement per visit: 0.6253



Generating samples for mortality_prediction_omop_fn: 100%|██████████████████| 1000/1000 [00:00<00:00, 1187.32it/s]



Statistics of mortality_prediction_omop_fn task:
	- Dataset: OMOP (dev=True)
	- Number of patients: 903
	- Number of visits: 54147
	- Number of visits per patient: 59.9635
	- Number of visit_id per visit: 1.0000
	- Number of unique visit_id: 54147
	- Number of patient_id per visit: 1.0000
	- Number of unique patient_id: 903
	- Number of conditions per visit: 2.1424
	- Number of unique conditions: 5067
	- Number of procedures per visit: 2.0873
	- Number of unique procedures: 3149
	- Number of drugs per visit: 0.0974
	- Number of unique drugs: 238
	- Label distribution: {0: 54145, 1: 2}



In [2]:
dataset[6]

{'visit_id': '316',
 'patient_id': '10',
 'conditions': [['194286']],
 'procedures': [['2211511']],
 'drugs': [[]],
 'label': 0}

In [3]:
# data split
train_dataset, val_dataset, test_dataset = split_by_patient(dataset, [0.8, 0.1, 0.1])
train_loader = DataLoader(
    train_dataset, batch_size=64, shuffle=True, collate_fn=collate_fn_dict
)
val_loader = DataLoader(
    val_dataset, batch_size=64, shuffle=False, collate_fn=collate_fn_dict
)
test_loader = DataLoader(
    test_dataset, batch_size=64, shuffle=False, collate_fn=collate_fn_dict
)

# STEP 3: define model

model = MLModel(
    dataset=dataset,
    tables=["conditions", "procedures", "drugs"],
    target="label",
    classifier = SVC(gamma='auto', probability=True ,verbose=1),
    mode="binary",
    output_path="./ckpt/mlmodel"
)

# STEP 4: define trainer
model.fit(train_loader=train_loader)

(44147, 100) (44147,)
[LibSVM]*
optimization finished, #iter = 9
obj = -1.993617, rho = 1.009045
nSV = 8, nBSV = 1
Total nSV = 8
*
optimization finished, #iter = 36
obj = -3.997707, rho = 1.008881
nSV = 23, nBSV = 2
Total nSV = 23
*
optimization finished, #iter = 49
obj = -3.995712, rho = 1.011013
nSV = 32, nBSV = 2
Total nSV = 32
*
optimization finished, #iter = 1
obj = -2.000000, rho = 1.000000
nSV = 2, nBSV = 2
Total nSV = 2
*
optimization finished, #iter = 38
obj = -3.997697, rho = 1.008371
nSV = 24, nBSV = 2
Total nSV = 24
*
optimization finished, #iter = 45
obj = -3.997733, rho = -1.011456
nSV = 35, nBSV = 2
Total nSV = 35
best_model_path: ./ckpt/mlmodel/1665548508.148289/best.ckpt


In [6]:
import warnings
warnings.filterwarnings('ignore')
# STEP 5: evaluate
y_gt, y_prob, y_pred = evaluate(model, test_loader, isMLModel=True)

print(y_gt)
print(y_prob)
print(y_pred)

jaccard = jaccard_score(y_gt, y_pred, average="micro")
accuracy = accuracy_score(y_gt, y_pred)
f1 = f1_score(y_gt, y_pred, average="micro")

# print metric name and score
print("jaccard: ", jaccard)
print("accuracy: ", accuracy)
print("f1: ", f1)

[[9.99947695e-01 5.23047428e-05]
 [9.99946573e-01 5.34271717e-05]
 [9.99950670e-01 4.93303890e-05]
 ...
 [9.99949818e-01 5.01818000e-05]
 [9.99947924e-01 5.20760967e-05]
 [9.99946688e-01 5.33122121e-05]] <class 'numpy.ndarray'>
[[1. 1.]
 [0. 0.]
 [0. 0.]
 ...
 [0. 0.]
 [0. 0.]
 [0. 0.]]
[[9.99947695e-01 5.23047428e-05]
 [9.99946573e-01 5.34271717e-05]
 [9.99950670e-01 4.93303890e-05]
 ...
 [9.99949818e-01 5.01818000e-05]
 [9.99947924e-01 5.20760967e-05]
 [9.99946688e-01 5.33122121e-05]]
[[1 0]
 [1 0]
 [1 0]
 ...
 [1 0]
 [1 0]
 [1 0]]
jaccard:  0.00021853146853146853
accuracy:  0.0
f1:  0.00043696744592527855


In [None]:
y_gt = np.zeros([len(y_true), 2])

In [None]:
for i in range(len(y_gt)):
    y_gt[y_true[i]] = 1
    