In [None]:
from pyhealth.datasets import CheXphotoDataset
from pyhealth.tasks import chexphoto_multilabel_task
from pyhealth.datasets import split_by_patient, get_dataloader
from pyhealth.models import CNN
from pyhealth.trainer import Trainer

# Load dataset
dataset = CheXphotoDataset(root="data/chexphoto")
samples = chexphoto_multilabel_task(dataset)

# Split data
train_ds, val_ds, test_ds = split_by_patient(samples, [0.8, 0.1, 0.1])
train_loader = get_dataloader(train_ds, batch_size=32, shuffle=True)
val_loader = get_dataloader(val_ds, batch_size=32, shuffle=False)
test_loader = get_dataloader(test_ds, batch_size=32, shuffle=False)

# Model and training
model = CNN(dataset=samples, feature_keys=["input"], label_key="label", mode="multilabel")
trainer = Trainer(model=model)
trainer.train(train_dataloader=train_loader, val_dataloader=val_loader, epochs=50, monitor="roc_auc")
trainer.evaluate(test_loader)
