# 1. Environment Setup
Seed the random generators, import core dependencies, and detect the training device.

In [1]:
!pip install openpyxl



In [2]:
import os
import random
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from IPython.display import display

from pyhealth.datasets import COVID19CXRDataset
from pyhealth.datasets.splitter import split_by_sample
from pyhealth.datasets.utils import get_dataloader
from pyhealth.tasks.covid19_cxr_classification import COVID19CXRClassification

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on device: {device}")

Running on device: cuda


# 2. Load COVID-19 CXR Metadata
Point to the processed COVID-19 Radiography dataset root and trigger metadata preparation if necessary.

In [3]:
dataset = COVID19CXRDataset(
    root="/home/logic/Github/cxr",
)
dataset.stats()

No config path provided, using default config
Initializing covid19_cxr dataset from /home/logic/Github/cxr (dev mode: False)
Scanning table: covid19_cxr from /home/logic/Github/cxr/covid19_cxr-metadata-pyhealth.csv
Collecting global event dataframe...
Collected dataframe with shape: (21165, 6)
Dataset: covid19_cxr
Dev mode: False
Number of patients: 21165
Number of events: 21165


# 3. Prepare PyHealth Dataset
Instantiate the COVID-19 classification task, convert raw samples into PyHealth format, and confirm schema details.

In [4]:
task = COVID19CXRClassification()
sample_dataset = dataset.set_task(task)

print(f"Total task samples: {len(sample_dataset)}")
print(f"Input schema: {sample_dataset.input_schema}")
print(f"Output schema: {sample_dataset.output_schema}")

if len(sample_dataset) == 0:
    raise RuntimeError("The task did not produce any samples. Verify the dataset root or disable dev mode.")

label_processor = sample_dataset.output_processors["disease"]
IDX_TO_LABEL = {index: label for label, index in label_processor.label_vocab.items()}
print(f"Label mapping (index -> name): {IDX_TO_LABEL}")

# Build label histogram to confirm class balance
label_indices = [sample_dataset[i]["disease"].item() for i in range(len(sample_dataset))]
label_distribution = (
    pd.Series(label_indices)
    .map(IDX_TO_LABEL)
    .value_counts()
    .sort_index()
    .to_frame(name="count")
)
label_distribution["proportion"] = label_distribution["count"] / label_distribution["count"].sum()
display(label_distribution)

Setting task COVID19CXRClassification for covid19_cxr base dataset...
Generating samples with 1 worker(s)...


Generating samples for COVID19CXRClassification with 1 worker: 100%|██████████| 21165/21165 [00:05<00:00, 3703.62it/s]

Label disease vocab: {'COVID': 0, 'Lung Opacity': 1, 'Normal': 2, 'Viral Pneumonia': 3}



Processing samples: 100%|██████████| 21165/21165 [00:22<00:00, 945.33it/s]

Generated 21165 samples for task COVID19CXRClassification





Total task samples: 21165
Input schema: {'image': 'image'}
Output schema: {'disease': 'multiclass'}
Label mapping (index -> name): {0: 'COVID', 1: 'Lung Opacity', 2: 'Normal', 3: 'Viral Pneumonia'}


Unnamed: 0,count,proportion
COVID,3616,0.170848
Lung Opacity,6012,0.284054
Normal,10192,0.48155
Viral Pneumonia,1345,0.063548


# 4. Split Dataset
Divide the processed samples into training, validation, and test subsets before building dataloaders.

In [5]:
BATCH_SIZE = 32

train_ds, val_ds, test_ds = split_by_sample(sample_dataset, [0.7, 0.1, 0.2], seed=SEED)
print(f"Train/Val/Test sizes: {len(train_ds)}, {len(val_ds)}, {len(test_ds)}")

train_loader = get_dataloader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = get_dataloader(val_ds, batch_size=BATCH_SIZE) if len(val_ds) else None
test_loader = get_dataloader(test_ds, batch_size=BATCH_SIZE) if len(test_ds) else None

if len(train_loader) == 0:
    raise RuntimeError("The training loader is empty. Increase the dataset size or adjust the split ratios.")

Train/Val/Test sizes: 14815, 2117, 4233


# 5. Inspect Batch Structure
Peek at the first training batch to understand feature shapes and label encodings.

In [6]:
first_batch = next(iter(train_loader))

def describe(value):
    if hasattr(value, "shape"):
        return f"{type(value).__name__}(shape={tuple(value.shape)})"
    if isinstance(value, (list, tuple)):
        return f"{type(value).__name__}(len={len(value)})"
    return type(value).__name__

batch_summary = {key: describe(value) for key, value in first_batch.items()}
print(batch_summary)

disease_targets = first_batch["disease"]
preview_indices = disease_targets[:5].cpu().tolist()
preview_labels = [IDX_TO_LABEL[idx] for idx in preview_indices]
print(f"Sample disease labels: {list(zip(preview_indices, preview_labels))}")

{'image': 'Tensor(shape=(32, 1, 299, 299))', 'disease': 'Tensor(shape=(32,))'}
Sample disease labels: [(2, 'Normal'), (2, 'Normal'), (2, 'Normal'), (2, 'Normal'), (0, 'COVID')]


# 6. Instantiate CNN Model
Create the PyHealth CNN with image embeddings and review its parameter footprint.

In [7]:
from pyhealth.models import CNN

model = CNN(
    dataset=sample_dataset,
    embedding_dim=64,
    hidden_dim=64,
    num_layers=2,
).to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"Feature keys: {model.feature_keys}")
print(f"Label key: {model.label_key}")
print(f"Model mode: {model.mode}")
print(f"Total parameters: {total_params:,}")

  from .autonotebook import tqdm as notebook_tqdm


Feature keys: ['image']
Label key: disease
Model mode: multiclass
Total parameters: 112,452


# 7. Configure Trainer
Wrap the model with the PyHealth Trainer and define optimisation hyperparameters and metrics.

In [8]:
from pyhealth.trainer import Trainer

trainer = Trainer(
    model=model,
    metrics=["accuracy", "f1_macro", "f1_micro"],
    device=str(device),
    enable_logging=False,
 )

training_config = {
    "epochs": 3,
    "optimizer_params": {"lr": 1e-3},
    "max_grad_norm": 5.0,
    "monitor": "accuracy",
    "monitor_criterion": "max",
}

CNN(
  (embedding_model): EmbeddingModel(embedding_layers=ModuleDict())
  (cnn): ModuleDict(
    (image): CNNLayer(
      (cnn): ModuleList(
        (0): CNNBlock(
          (conv1): Sequential(
            (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU()
          )
          (conv2): Sequential(
            (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (downsample): Sequential(
            (0): Conv2d(1, 64, kernel_size=(1, 1), stride=(1, 1))
            (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (relu): ReLU()
        )
        (1): CNNBlock(
          (conv1): Sequential(
            (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), 

# 8. Train the Model
Launch the training loop with optional validation monitoring for early diagnostics.

In [9]:
train_kwargs = dict(training_config)
if val_loader is None:
    train_kwargs.pop("monitor", None)
    train_kwargs.pop("monitor_criterion", None)

trainer.train(
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    **train_kwargs,
 )

Training:
Batch size: 32
Optimizer: <class 'torch.optim.adam.Adam'>
Optimizer params: {'lr': 0.001}
Weight decay: 0.0
Max grad norm: 5.0
Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x7fac7a919a90>
Monitor: accuracy
Monitor criterion: max
Epochs: 3



Epoch 0 / 3: 100%|██████████| 463/463 [02:12<00:00,  3.50it/s]

--- Train epoch-0, step-463 ---
loss: 0.9918



Evaluation: 100%|██████████| 67/67 [00:03<00:00, 18.46it/s]

--- Eval epoch-0, step-463 ---
accuracy: 0.4766
f1_macro: 0.1628
f1_micro: 0.4766
loss: 1.4306
New best accuracy score (0.4766) at epoch-0, step-463




Epoch 1 / 3: 100%|██████████| 463/463 [02:12<00:00,  3.49it/s]

--- Train epoch-1, step-926 ---
loss: 0.8374



Evaluation: 100%|██████████| 67/67 [00:03<00:00, 18.75it/s]

--- Eval epoch-1, step-926 ---
accuracy: 0.7171
f1_macro: 0.6673
f1_micro: 0.7171
loss: 0.7513
New best accuracy score (0.7171) at epoch-1, step-926




Epoch 2 / 3: 100%|██████████| 463/463 [02:12<00:00,  3.49it/s]

--- Train epoch-2, step-1389 ---
loss: 0.7259



Evaluation: 100%|██████████| 67/67 [00:03<00:00, 18.65it/s]

--- Eval epoch-2, step-1389 ---
accuracy: 0.6802
f1_macro: 0.6627
f1_micro: 0.6802
loss: 0.8749





# 9. Evaluate on Validation/Test Splits
Compute accuracy and F1 scores on the held-out loaders to assess generalisation.

In [11]:
evaluation_results = {}
for split_name, loader in {"validation": val_loader, "test": test_loader}.items():
    if loader is None:
        continue
    metrics = trainer.evaluate(loader)
    evaluation_results[split_name] = metrics
    formatted = ", ".join(f"{k}={v:.4f}" for k, v in metrics.items())
    print(f"{split_name.title()} metrics: {formatted}")

Evaluation: 100%|██████████| 67/67 [00:03<00:00, 18.60it/s]



Validation metrics: accuracy=0.6802, f1_macro=0.6627, f1_micro=0.6802, loss=0.8749


Evaluation: 100%|██████████| 133/133 [00:08<00:00, 14.79it/s]

Test metrics: accuracy=0.6674, f1_macro=0.6568, f1_micro=0.6674, loss=0.9159





# 10. Inspect Sample Predictions
Run an inference pass and preview top predictions alongside their probabilities.

In [12]:
target_loader = val_loader if val_loader is not None else train_loader

y_true, y_prob, mean_loss = trainer.inference(target_loader)
top_indices = y_prob.argmax(axis=-1)
preview = []
for i, (true_idx, pred_idx) in enumerate(zip(y_true[:5], top_indices[:5])):
    prob = float(y_prob[i, pred_idx])
    preview.append({
        "true_index": int(true_idx),
        "true_label": IDX_TO_LABEL[int(true_idx)],
        "pred_index": int(pred_idx),
        "pred_label": IDX_TO_LABEL[int(pred_idx)],
        "pred_prob": prob,
    })

print(f"Mean loss: {mean_loss:.4f}")
for sample in preview:
    print(sample)

Evaluation: 100%|██████████| 67/67 [00:03<00:00, 18.60it/s]

Mean loss: 0.8749
{'true_index': 2, 'true_label': 'Normal', 'pred_index': 1, 'pred_label': 'Lung Opacity', 'pred_prob': 0.8432581424713135}
{'true_index': 0, 'true_label': 'COVID', 'pred_index': 1, 'pred_label': 'Lung Opacity', 'pred_prob': 0.7460941672325134}
{'true_index': 2, 'true_label': 'Normal', 'pred_index': 1, 'pred_label': 'Lung Opacity', 'pred_prob': 0.49037978053092957}
{'true_index': 0, 'true_label': 'COVID', 'pred_index': 1, 'pred_label': 'Lung Opacity', 'pred_prob': 0.6506214141845703}
{'true_index': 2, 'true_label': 'Normal', 'pred_index': 1, 'pred_label': 'Lung Opacity', 'pred_prob': 0.9451408386230469}



