In [None]:
"""
The following env works
 - torch: 1.9.1
 - torchvision: 0.10.0+cu102
 - torch_sparse: 0.6.12
 
"""

In [1]:
import os
PATH = '/home/chaoqiy2/github/PyHealth'
os.chdir(PATH)

In [2]:
from pyhealth.sampler import NeighborSampler
from pyhealth.models import Graph_TorchvisionModel
from pyhealth.models import GCN
from torchvision import transforms
from pyhealth.datasets import COVID19CXRDataset

  from .autonotebook import tqdm as notebook_tqdm


## Load Dataset

In [3]:
from pyhealth.datasets import COVID19CXRDataset

root = "/srv/local/data/COVID-19_Radiography_Dataset"
base_dataset = COVID19CXRDataset(root)

In [4]:
base_dataset.default_task

COVID19CXRClassification(task_name='COVID19CXRClassification', input_schema={'path': 'image'}, output_schema={'label': 'label'})

In [5]:
sample_dataset = base_dataset.set_task()

Generating samples for COVID19CXRClassification: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 21165/21165 [00:00<00:00, 1282116.21it/s]


In [11]:
from torchvision import transforms


transform = transforms.Compose([
    transforms.Lambda(lambda x: x if x.shape[0] == 3 else x.repeat(3, 1, 1)),
    transforms.Resize((224, 224)),
    transforms.Normalize(mean=[0.5862785803043838], std=[0.27950088968644304])
])


def encode(sample):
    sample["path"] = transform(sample["path"])
    return sample


sample_dataset.set_transform(encode)

In [12]:
from pyhealth.datasets import split_by_sample

# Get Index of train, valid, test set
train_index, val_index, test_index = split_by_sample(
    dataset=sample_dataset,
    ratios=[0.7, 0.1, 0.2],
    get_index = True
)

In [13]:
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

model = Graph_TorchvisionModel(
        dataset=sample_dataset,
        feature_keys=["path"],
        label_key="label",
        mode="multiclass",
        model_name="resnet18",
        model_config={},
        gnn_config={"input_dim": 256, "hidden_dim": 128},
    )

| Uniform Initialization
| Uniform Initialization


In [14]:
# Build graph
# Set random = True will build random graph data
graph = model.build_graph(sample_dataset, random = True)

In [15]:
# Define Sampler as Dataloader
train_dataloader = NeighborSampler(sample_dataset, graph["edge_index"], node_idx=train_index, sizes=[15, 10], batch_size=64, shuffle=True, num_workers=12)

# We sample all edges connected to target node for validation and test (Sizes = [-1, -1])
valid_dataloader = NeighborSampler(sample_dataset, graph["edge_index"], node_idx=val_index, sizes=[-1, -1], batch_size=64, shuffle=False, num_workers=12)
test_dataloader = NeighborSampler(sample_dataset, graph["edge_index"], node_idx=test_index, sizes=[-1, -1], batch_size=64, shuffle=False, num_workers=12)

In [18]:
from pyhealth.trainer import Trainer

resnet_trainer = Trainer(model=model, device="cpu")
resnet_trainer.train(
    train_dataloader=train_dataloader,
    val_dataloader=valid_dataloader,
    epochs=1,
    monitor="accuracy",
)

Graph_TorchvisionModel(
  (model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=Tr

Epoch 0 / 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 232/232 [18:34<00:00,  4.80s/it]

--- Train epoch-0, step-232 ---
loss: 1.3025



Evaluation: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 34/34 [01:19<00:00,  2.34s/it]

--- Eval epoch-0, step-232 ---
accuracy: 0.4872
f1_macro: 0.1643
f1_micro: 0.4872
loss: 1.2535
New best accuracy score (0.4872) at epoch-0, step-232
Loaded best model





In [20]:
print(resnet_trainer.evaluate(test_dataloader))

Evaluation: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [02:41<00:00,  2.42s/it]

{'accuracy': 0.4786590097780537, 'f1_macro': 0.1618557783142647, 'f1_micro': 0.4786590097780537, 'loss': 1.256981566770753}



