In [None]:
import os
PATH = '/home/namkyeong/PyHealth'
os.chdir(PATH)

In [None]:
from pyhealth.sampler import NeighborSampler
from pyhealth.models import Graph_TorchvisionModel

from torchvision import transforms
from pyhealth.datasets import COVID19CXRDataset

## Load Dataset

In [None]:
from pyhealth.datasets import COVID19CXRDataset

root = "./data/COVID-19_Radiography_Dataset"
base_dataset = COVID19CXRDataset(root)

In [None]:
base_dataset.default_task

In [None]:
sample_dataset = base_dataset.set_task()

In [None]:
from torchvision import transforms


transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Grayscale(),
    transforms.Normalize(mean=[0.5862785803043838], std=[0.27950088968644304])
])


def encode(sample):
    sample["path"] = transform(sample["path"])
    return sample


sample_dataset.set_transform(encode)

In [None]:
from pyhealth.datasets import split_by_sample

# Get Index of train, valid, test set
train_index, val_index, test_index = split_by_sample(
    dataset=sample_dataset,
    ratios=[0.7, 0.1, 0.2],
    get_index = True
)

In [None]:
model = Graph_TorchvisionModel(
        dataset=sample_dataset,
        feature_keys=["path"],
        label_key="label",
        mode="multiclass",
        model_name="vit_b_16",
        model_config={"weights": "DEFAULT"},
        gnn_config={"input_dim": 256, "hidden_dim": 128},
    )

In [None]:
# Build graph
# Set random = True will build random graph data
graph = model.build_graph(sample_dataset, random = True)

In [None]:
# Define Sampler as Dataloader
train_dataloader = NeighborSampler(sample_dataset, graph["edge_index"], node_idx=train_index, sizes=[15, 10], batch_size=64, shuffle=True, num_workers=12)

# We sample all edges connected to target node for validation and test (Sizes = [-1, -1])
valid_dataloader = NeighborSampler(sample_dataset, graph["edge_index"], node_idx=val_index, sizes=[-1, -1], batch_size=64, shuffle=False, num_workers=12)
test_dataloader = NeighborSampler(sample_dataset, graph["edge_index"], node_idx=test_index, sizes=[-1, -1], batch_size=64, shuffle=False, num_workers=12)

In [None]:
from pyhealth.trainer import Trainer

resnet_trainer = Trainer(model=model)

In [None]:
print(resnet_trainer.evaluate(test_dataloader))

In [None]:
resnet_trainer.train(
    train_dataloader=train_dataloader,
    val_dataloader=valid_dataloader,
    epochs=1,
    monitor="accuracy"
)