In [None]:
from pyhealth.datasets import MIMIC4Dataset
if __name__ == "__main__":
    dataset = MIMIC4Dataset(
        ehr_root="/srv/local/data/physionet.org/files/mimiciv/2.2/",
        ehr_tables=["diagnoses_icd", "procedures_icd", "labevents", "prescriptions"],
        dev=True,
        cache_dir="../benchmark_cache/mimic4_ehr_with_prescr/"
    )

    from pyhealth.tasks import InHospitalMortalityMIMIC4

    task = InHospitalMortalityMIMIC4()
    samples = dataset.set_task(task, num_workers=4, cache_dir="../benchmark_cache/mimic4_ihm_w_pre2/")

    from pyhealth.datasets import split_by_sample


    train_dataset, val_dataset, test_dataset = split_by_sample(
        dataset=samples, ratios=[0.7, 0.1, 0.2]
    )




In [None]:
len(samples)

In [None]:
dataset.get_patient(dataset.unique_patient_ids[0]).get_events()

In [None]:
samples[0]

In [None]:
from pyhealth.datasets import get_dataloader
train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)

from pyhealth.models import RNN

model = RNN(
    dataset=samples,
)

from pyhealth.trainer import Trainer

trainer = Trainer(model=model, metrics=["roc_auc"])
print(trainer.evaluate(test_dataloader))

trainer.train(
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    epochs=10,
    monitor="roc_auc",  # Monitor roc_auc specifically
    optimizer_params={"lr": 1e-4}  # Using learning rate of 1e-4
)
