# Multilabel Classification Using the ChestX-ray14 Dataset

## Step 0: Install PyHealth

In [1]:
!git clone https://github.com/EricSchrock/PyHealth.git
!cd PyHealth && git checkout ChestX-ray14 && pip install -e .

Cloning into 'PyHealth'...
remote: Enumerating objects: 8126, done.[K
remote: Counting objects: 100% (1718/1718), done.[K
remote: Compressing objects: 100% (530/530), done.[K
remote: Total 8126 (delta 1503), reused 1199 (delta 1188), pack-reused 6408 (from 2)[K
Receiving objects: 100% (8126/8126), 113.90 MiB | 16.80 MiB/s, done.
Resolving deltas: 100% (5260/5260), done.
Branch 'ChestX-ray14' set up to track remote branch 'ChestX-ray14' from 'origin'.
Switched to a new branch 'ChestX-ray14'
Obtaining file:///content/PyHealth
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: pyhealth
  Building editable for pyhealth (pyproject.toml) ... [?25l[?25hdone
  Created wheel for pyhealth: filen

## Step 1: Load Dataset

In [1]:
from pyhealth.datasets import ChestXray14Dataset

dataset = ChestXray14Dataset(download=True, partial=True)
dataset.stats()

Downloading ./images_01.tar.gz...


INFO:pyhealth.datasets.chestxray14:Downloading ./images_01.tar.gz...


Checking MD5 checksum for ./images_01.tar.gz...


INFO:pyhealth.datasets.chestxray14:Checking MD5 checksum for ./images_01.tar.gz...


Extracting ./images_01.tar.gz...


INFO:pyhealth.datasets.chestxray14:Extracting ./images_01.tar.gz...


Deleting ./images_01.tar.gz...


INFO:pyhealth.datasets.chestxray14:Deleting ./images_01.tar.gz...


Download complete


INFO:pyhealth.datasets.chestxray14:Download complete


Initializing ChestX-ray14 dataset from . (dev mode: False)


INFO:pyhealth.datasets.base_dataset:Initializing ChestX-ray14 dataset from . (dev mode: False)


Scanning table: chestxray14 from /content/chestxray14-metadata-pyhealth.csv


INFO:pyhealth.datasets.base_dataset:Scanning table: chestxray14 from /content/chestxray14-metadata-pyhealth.csv


Collecting global event dataframe...


INFO:pyhealth.datasets.base_dataset:Collecting global event dataframe...


Collected dataframe with shape: (4999, 26)


INFO:pyhealth.datasets.base_dataset:Collected dataframe with shape: (4999, 26)


Dataset: ChestX-ray14
Dev mode: False
Number of patients: 1335
Number of events: 4999


## Step 2: Define Task

In [2]:
samples = dataset.set_task()

Setting task ChestXray14MultilabelClassification for ChestX-ray14 base dataset...


INFO:pyhealth.datasets.base_dataset:Setting task ChestXray14MultilabelClassification for ChestX-ray14 base dataset...


Generating samples with 1 worker(s)...


INFO:pyhealth.datasets.base_dataset:Generating samples with 1 worker(s)...
Generating samples for ChestXray14MultilabelClassification with 1 worker: 100%|██████████| 1335/1335 [00:00<00:00, 1475.55it/s]

Label labels vocab: {'atelectasis': 0, 'cardiomegaly': 1, 'consolidation': 2, 'edema': 3, 'effusion': 4, 'emphysema': 5, 'fibrosis': 6, 'hernia': 7, 'infiltration': 8, 'mass': 9, 'nodule': 10, 'pleural_thickening': 11, 'pneumonia': 12, 'pneumothorax': 13}



INFO:pyhealth.processors.label_processor:Label labels vocab: {'atelectasis': 0, 'cardiomegaly': 1, 'consolidation': 2, 'edema': 3, 'effusion': 4, 'emphysema': 5, 'fibrosis': 6, 'hernia': 7, 'infiltration': 8, 'mass': 9, 'nodule': 10, 'pleural_thickening': 11, 'pneumonia': 12, 'pneumothorax': 13}
Processing samples: 100%|██████████| 4999/4999 [01:18<00:00, 63.31it/s]

Generated 4999 samples for task ChestXray14MultilabelClassification



INFO:pyhealth.datasets.base_dataset:Generated 4999 samples for task ChestXray14MultilabelClassification


In [3]:
from pyhealth.datasets import get_dataloader, split_by_sample

train_dataset, val_dataset, test_dataset = split_by_sample(samples, [0.7, 0.1, 0.2])

train_loader = get_dataloader(train_dataset, batch_size=16, shuffle=True)
val_loader = get_dataloader(val_dataset, batch_size=16, shuffle=False)
test_loader = get_dataloader(test_dataset, batch_size=16, shuffle=False)

## Step 3: Define Model

In [4]:
from pyhealth.models import CNN

model = CNN(dataset=samples)

  accuracy of 1. Thus, the ECE is :math:`\\frac{1}{3} \cdot 0.49 + \\frac{2}{3}\cdot 0.3=0.3633`.




## Step 4: Train Model

In [5]:
from pyhealth.trainer import Trainer

# Only measure accurancy because with the "partial" dataset it is likely that
# there are not positive samples of every label present in the validation and test sets
trainer = Trainer(model=model, metrics=["accuracy"])
trainer.train(train_dataloader=train_loader, val_dataloader=val_loader, epochs=1)

CNN(
  (embedding_model): EmbeddingModel(embedding_layers=ModuleDict())
  (cnn): ModuleDict(
    (image): CNNLayer(
      (cnn): ModuleList(
        (0): CNNBlock(
          (conv1): Sequential(
            (0): Conv2d(1, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU()
          )
          (conv2): Sequential(
            (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (downsample): Sequential(
            (0): Conv2d(1, 128, kernel_size=(1, 1), stride=(1, 1))
            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (relu): ReLU()
        )
      )
      (pooling): AdaptiveAvgPool2d(output_size=1)
    )
  )
  (fc): Linear(in_features=128, out_features=1

INFO:pyhealth.trainer:CNN(
  (embedding_model): EmbeddingModel(embedding_layers=ModuleDict())
  (cnn): ModuleDict(
    (image): CNNLayer(
      (cnn): ModuleList(
        (0): CNNBlock(
          (conv1): Sequential(
            (0): Conv2d(1, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU()
          )
          (conv2): Sequential(
            (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (downsample): Sequential(
            (0): Conv2d(1, 128, kernel_size=(1, 1), stride=(1, 1))
            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (relu): ReLU()
        )
      )
      (pooling): AdaptiveAvgPool2d(output_size=1)
    )
  )
  (fc): Linear(in_featur

Metrics: ['accuracy']


INFO:pyhealth.trainer:Metrics: ['accuracy']


Device: cuda


INFO:pyhealth.trainer:Device: cuda





INFO:pyhealth.trainer:


Training:


INFO:pyhealth.trainer:Training:


Batch size: 16


INFO:pyhealth.trainer:Batch size: 16


Optimizer: <class 'torch.optim.adam.Adam'>


INFO:pyhealth.trainer:Optimizer: <class 'torch.optim.adam.Adam'>


Optimizer params: {'lr': 0.001}


INFO:pyhealth.trainer:Optimizer params: {'lr': 0.001}


Weight decay: 0.0


INFO:pyhealth.trainer:Weight decay: 0.0


Max grad norm: None


INFO:pyhealth.trainer:Max grad norm: None


Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x7b047ae4d700>


INFO:pyhealth.trainer:Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x7b047ae4d700>


Monitor: None


INFO:pyhealth.trainer:Monitor: None


Monitor criterion: max


INFO:pyhealth.trainer:Monitor criterion: max


Epochs: 1


INFO:pyhealth.trainer:Epochs: 1


Patience: None


INFO:pyhealth.trainer:Patience: None





INFO:pyhealth.trainer:


Epoch 0 / 1:   0%|          | 0/219 [00:00<?, ?it/s]

--- Train epoch-0, step-219 ---


INFO:pyhealth.trainer:--- Train epoch-0, step-219 ---


loss: 0.2041


INFO:pyhealth.trainer:loss: 0.2041
Evaluation: 100%|██████████| 32/32 [00:02<00:00, 14.63it/s]

--- Eval epoch-0, step-219 ---



INFO:pyhealth.trainer:--- Eval epoch-0, step-219 ---


accuracy: 0.9553


INFO:pyhealth.trainer:accuracy: 0.9553


loss: 0.1695


INFO:pyhealth.trainer:loss: 0.1695


## Step 5: Evaluate Model

In [6]:
trainer.evaluate(test_loader)

Evaluation: 100%|██████████| 63/63 [00:04<00:00, 14.36it/s]


{'accuracy': 0.9500714285714286, 'loss': 0.17985984730342078}