# 1. Environment Setup
Seed the random generators, import core dependencies, and detect the training device.

In [None]:
!pip install openpyxl

In [None]:
import os
import random
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from IPython.display import display

from pyhealth.datasets import COVID19CXRDataset
from pyhealth.datasets.splitter import split_by_sample
from pyhealth.datasets.utils import get_dataloader
from pyhealth.tasks.covid19_cxr_classification import COVID19CXRClassification

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on device: {device}")

# 2. Load COVID-19 CXR Metadata
Point to the processed COVID-19 Radiography dataset root and trigger metadata preparation if necessary.

In [None]:
dataset = COVID19CXRDataset(
    root="/home/logic/Github/cxr",
)
dataset.stats()

# 3. Prepare PyHealth Dataset
Instantiate the COVID-19 classification task, convert raw samples into PyHealth format, and confirm schema details.

In [None]:
task = COVID19CXRClassification()
sample_dataset = dataset.set_task(task)

print(f"Total task samples: {len(sample_dataset)}")
print(f"Input schema: {sample_dataset.input_schema}")
print(f"Output schema: {sample_dataset.output_schema}")

if len(sample_dataset) == 0:
    raise RuntimeError("The task did not produce any samples. Verify the dataset root or disable dev mode.")

label_processor = sample_dataset.output_processors["disease"]
IDX_TO_LABEL = {index: label for label, index in label_processor.label_vocab.items()}
print(f"Label mapping (index -> name): {IDX_TO_LABEL}")

# Build label histogram to confirm class balance
label_indices = [sample_dataset[i]["disease"].item() for i in range(len(sample_dataset))]
label_distribution = (
    pd.Series(label_indices)
    .map(IDX_TO_LABEL)
    .value_counts()
    .sort_index()
    .to_frame(name="count")
)
label_distribution["proportion"] = label_distribution["count"] / label_distribution["count"].sum()
display(label_distribution)

# 4. Split Dataset
Divide the processed samples into training, validation, and test subsets before building dataloaders.

In [None]:
BATCH_SIZE = 32

train_ds, val_ds, test_ds = split_by_sample(sample_dataset, [0.7, 0.1, 0.2], seed=SEED)
print(f"Train/Val/Test sizes: {len(train_ds)}, {len(val_ds)}, {len(test_ds)}")

train_loader = get_dataloader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = get_dataloader(val_ds, batch_size=BATCH_SIZE) if len(val_ds) else None
test_loader = get_dataloader(test_ds, batch_size=BATCH_SIZE) if len(test_ds) else None

if len(train_loader) == 0:
    raise RuntimeError("The training loader is empty. Increase the dataset size or adjust the split ratios.")

# 5. Inspect Batch Structure
Peek at the first training batch to understand feature shapes and label encodings.

In [None]:
first_batch = next(iter(train_loader))

def describe(value):
    if hasattr(value, "shape"):
        return f"{type(value).__name__}(shape={tuple(value.shape)})"
    if isinstance(value, (list, tuple)):
        return f"{type(value).__name__}(len={len(value)})"
    return type(value).__name__

batch_summary = {key: describe(value) for key, value in first_batch.items()}
print(batch_summary)

disease_targets = first_batch["disease"]
preview_indices = disease_targets[:5].cpu().tolist()
preview_labels = [IDX_TO_LABEL[idx] for idx in preview_indices]
print(f"Sample disease labels: {list(zip(preview_indices, preview_labels))}")

# 6. Instantiate CNN Model
Create the PyHealth CNN with image embeddings and review its parameter footprint.

In [None]:
from pyhealth.models import CNN

model = CNN(
    dataset=sample_dataset,
    embedding_dim=64,
    hidden_dim=64,
    num_layers=2,
).to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"Feature keys: {model.feature_keys}")
print(f"Label key: {model.label_key}")
print(f"Model mode: {model.mode}")
print(f"Total parameters: {total_params:,}")

# 7. Configure Trainer
Wrap the model with the PyHealth Trainer and define optimisation hyperparameters and metrics.

In [None]:
from pyhealth.trainer import Trainer

trainer = Trainer(
    model=model,
    metrics=["accuracy", "f1_macro", "f1_micro"],
    device=str(device),
    enable_logging=False,
 )

training_config = {
    "epochs": 3,
    "optimizer_params": {"lr": 1e-3},
    "max_grad_norm": 5.0,
    "monitor": "accuracy",
    "monitor_criterion": "max",
}

# 8. Train the Model
Launch the training loop with optional validation monitoring for early diagnostics.

In [None]:
train_kwargs = dict(training_config)
if val_loader is None:
    train_kwargs.pop("monitor", None)
    train_kwargs.pop("monitor_criterion", None)

trainer.train(
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    **train_kwargs,
 )

# 9. Evaluate on Validation/Test Splits
Compute accuracy and F1 scores on the held-out loaders to assess generalisation.

In [None]:
evaluation_results = {}
for split_name, loader in {"validation": val_loader, "test": test_loader}.items():
    if loader is None:
        continue
    metrics = trainer.evaluate(loader)
    evaluation_results[split_name] = metrics
    formatted = ", ".join(f"{k}={v:.4f}" for k, v in metrics.items())
    print(f"{split_name.title()} metrics: {formatted}")

# 10. Inspect Sample Predictions
Run an inference pass and preview top predictions alongside their probabilities.

In [None]:
target_loader = val_loader if val_loader is not None else train_loader

y_true, y_prob, mean_loss = trainer.inference(target_loader)
top_indices = y_prob.argmax(axis=-1)
preview = []
for i, (true_idx, pred_idx) in enumerate(zip(y_true[:5], top_indices[:5])):
    prob = float(y_prob[i, pred_idx])
    preview.append({
        "true_index": int(true_idx),
        "true_label": IDX_TO_LABEL[int(true_idx)],
        "pred_index": int(pred_idx),
        "pred_label": IDX_TO_LABEL[int(pred_idx)],
        "pred_prob": prob,
    })

print(f"Mean loss: {mean_loss:.4f}")
for sample in preview:
    print(sample)