In [1]:
import torch
from torchvision import __version__ as tv_version

print("CUDA available:", torch.cuda.is_available())
print("PyTorch version:", torch.__version__)
print("TorchVision version:", tv_version)
print("PyTorch CUDA build:", torch.version.cuda)
print("cuDNN:", torch.backends.cudnn.version())
print("Device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")


CUDA available: True
PyTorch version: 2.9.0+cu128
TorchVision version: 0.24.0+cu128
PyTorch CUDA build: 12.8
cuDNN: 91002
Device: NVIDIA GeForce RTX 4070 Ti


In [2]:
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.models import densenet121, DenseNet121_Weights
from pathlib import Path
from PIL import Image, ImageOps
from IPython.display import display
import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.metrics import roc_auc_score


In [None]:
# CheXpert dataset configuration
data_root = Path("/data/CheXpert-v1.0-small")  # change to your CheXpert root
train_csv = data_root / "train.csv"
valid_csv = data_root / "valid.csv"

IMG_SIZE = 320
CHEXPERT_MEAN = [0.5330, 0.5330, 0.5330]
CHEXPERT_STD = [0.0349, 0.0349, 0.0349]
CLASS_NAMES = [
    "Atelectasis",
    "Cardiomegaly",
    "Consolidation",
    "Edema",
    "Enlarged Cardiomediastinum",
    "Fracture",
    "Lung Lesion",
    "Lung Opacity",
    "No Finding",
    "Pleural Effusion",
    "Pleural Other",
    "Pneumonia",
    "Pneumothorax",
    "Support Devices",
]

print(f"Train CSV: {train_csv}")
print(f"Validation CSV: {valid_csv}")
print(f"Number of labels: {len(CLASS_NAMES)}")


In [None]:
class CheXpertDataset(Dataset):
    """Minimal CheXpert CSV loader that returns multi-label targets."""

    def __init__(self, csv_file, root_dir, transform=None, uncertain_policy="zeros"):
        self.root_dir = Path(root_dir)
        self.transform = transform
        self.df = pd.read_csv(csv_file)

        if uncertain_policy == "zeros":
            processed = self.df[CLASS_NAMES].fillna(0.0).replace(-1, 0.0)
        elif uncertain_policy == "ones":
            processed = self.df[CLASS_NAMES].fillna(1.0).replace(-1, 1.0)
        elif uncertain_policy == "ignore":
            keep = self.df.dropna(subset=CLASS_NAMES)
            processed = keep[CLASS_NAMES]
            self.df = keep.reset_index(drop=True)
        else:
            raise ValueError(f"Unknown uncertain_policy: {uncertain_policy}")

        self.labels = processed.astype(np.float32).values
        self.paths = self.df["Path"].tolist()

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        rel_path = self.paths[idx]
        image_path = self.root_dir / rel_path
        image = Image.open(image_path).convert("RGB")
        target = torch.from_numpy(self.labels[idx])
        if self.transform is not None:
            image = self.transform(image)
        return image, target

train_transform = transforms.Compose([
    transforms.RandomResizedCrop((IMG_SIZE, IMG_SIZE), scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=CHEXPERT_MEAN, std=CHEXPERT_STD),
])

valid_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=CHEXPERT_MEAN, std=CHEXPERT_STD),
])

train_dataset = CheXpertDataset(train_csv, data_root, transform=train_transform)
valid_dataset = CheXpertDataset(valid_csv, data_root, transform=valid_transform)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(valid_dataset)}")


In [None]:
# create data loaders
batch_size = 16
num_workers = 4

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True,
)

valid_loader = DataLoader(
    dataset=valid_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True,
)

print(f"Train batches per epoch: {len(train_loader)}")
print(f"Validation batches: {len(valid_loader)}")


In [None]:
# define DenseNet-121 backbone adapted for CheXpert
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
weights = DenseNet121_Weights.IMAGENET1K_V1
model = densenet121(weights=weights)

# replace classifier for 14-way multi-label prediction
in_features = model.classifier.in_features
model.classifier = nn.Sequential(
    nn.Dropout(p=0.2),
    nn.Linear(in_features, len(CLASS_NAMES))
)

model = model.to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)

print(model.classifier)


In [None]:
# training loop
num_epochs = 3
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, targets in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        inputs = inputs.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)
    print(f"Epoch {epoch+1}: train BCE loss = {epoch_loss:.4f}")


In [None]:
# evaluation on the validation set
model.eval()
y_true = []
y_score = []

with torch.no_grad():
    for inputs, targets in tqdm(valid_loader, desc="Validation"):
        inputs = inputs.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        logits = model(inputs)
        probs = torch.sigmoid(logits)

        y_true.append(targets.cpu().numpy())
        y_score.append(probs.cpu().numpy())

y_true = np.concatenate(y_true, axis=0)
y_score = np.concatenate(y_score, axis=0)

roc_aucs = {}
for idx, name in enumerate(CLASS_NAMES):
    try:
        roc_aucs[name] = roc_auc_score(y_true[:, idx], y_score[:, idx])
    except ValueError:
        roc_aucs[name] = float("nan")

macro_auc = np.nanmean(list(roc_aucs.values()))
print(f"Macro-average ROC AUC: {macro_auc:.4f}")
roc_aucs


In [None]:
# turn probabilities into binary predictions at 0.5 threshold
threshold = 0.5
y_pred = (y_score >= threshold).astype(int)
positive_rates = y_pred.mean(axis=0)
print("Positive prediction rate per label:")
for name, rate in zip(CLASS_NAMES, positive_rates):
    print(f"  {name:<26}: {rate:.3f}")


In [None]:
# run inference on a single study image
example_idx = 0
example_path = data_root / valid_dataset.paths[example_idx]
example_image = Image.open(example_path).convert("RGB")
input_tensor = valid_transform(example_image).unsqueeze(0).to(device)

model.eval()
with torch.no_grad():
    logits = model(input_tensor)
    probs = torch.sigmoid(logits)[0].cpu().numpy()

prob_table = sorted(zip(CLASS_NAMES, probs), key=lambda item: item[1], reverse=True)
print(f"Top findings for {valid_dataset.paths[example_idx]}:")
for name, p in prob_table[:5]:
    print(f"  {name:<26}: {p:.3f}")

display(example_image.resize((IMG_SIZE, IMG_SIZE)))


In [None]:
# Grad-CAM utilities for DenseNet-121
def _register_backward_hook(module, hook):
    if hasattr(module, "register_full_backward_hook"):
        return module.register_full_backward_hook(hook)
    return module.register_backward_hook(lambda m, grad_in, grad_out: hook(m, grad_in, grad_out))


def compute_grad_cam(model, input_tensor, target_layer, target_index=None):
    activations = []
    gradients = []

    def forward_hook(_module, _inputs, output):
        activations.append(output.detach())

    def backward_hook(_module, _grad_input, grad_output):
        gradients.append(grad_output[0].detach())

    handle_fwd = target_layer.register_forward_hook(forward_hook)
    handle_bwd = _register_backward_hook(target_layer, backward_hook)

    input_tensor = input_tensor.requires_grad_(True)
    logits = model(input_tensor)

    if target_index is None:
        probs = torch.sigmoid(logits)[0]
        target_index = int(torch.argmax(probs).item())

    score = logits[:, target_index]
    model.zero_grad(set_to_none=True)
    score.backward()

    handle_fwd.remove()
    handle_bwd.remove()

    if not activations or not gradients:
        return logits.detach(), None

    grad = gradients[0]
    act = activations[0]

    weights = grad.mean(dim=(2, 3), keepdim=True)
    cam = torch.relu((weights * act).sum(dim=1, keepdim=True))
    cam = torch.nn.functional.interpolate(cam, size=(IMG_SIZE, IMG_SIZE), mode="bilinear", align_corners=False)
    cam = cam.squeeze().cpu()
    cam = cam - cam.min()
    cam = cam / (cam.max() + 1e-8)
    return logits.detach(), cam.numpy()


def build_heatmap_overlay(base_image, heatmap, alpha=0.45):
    base = base_image.convert("RGB").resize((IMG_SIZE, IMG_SIZE))
    heat = Image.fromarray((heatmap * 255).astype(np.uint8), mode="L")
    heat = heat.resize(base.size, resample=Image.BILINEAR)
    colored = ImageOps.colorize(heat, black="#0b1f3a", white="#f97316")
    return Image.blend(base, colored, alpha)


In [None]:
# generate a Grad-CAM heatmap for the example image
model.eval()
target_layer = model.features.denseblock4

logits, heatmap = compute_grad_cam(model, input_tensor, target_layer)
if heatmap is not None:
    overlay = build_heatmap_overlay(example_image, heatmap)
    display(overlay)
else:
    print("Grad-CAM heatmap could not be computed.")


In [None]:
# export TorchScript model, weights, and metadata compatible with the FastAPI app
export_dir = Path("models")
export_dir.mkdir(exist_ok=True)

model_cpu = model.to("cpu").eval()
state_path = export_dir / "chexpert_densenet121_state.pt"
torch.save(model_cpu.state_dict(), state_path)

example = torch.randn(1, 3, IMG_SIZE, IMG_SIZE)
traced = torch.jit.trace(model_cpu, example)
traced.save(str(export_dir / "chexpert_densenet121.ts"))

metadata = {
    "dataset": "chexpert",
    "task": "multi-label-binary",
    "img_size": IMG_SIZE,
    "n_channels": 3,
    "mean": CHEXPERT_MEAN,
    "std": CHEXPERT_STD,
    "class_names": CLASS_NAMES,
}

with open(export_dir / "metadata.json", "w", encoding="utf-8") as f:
    json.dump(metadata, f, ensure_ascii=False, indent=2)

print("✅ Saved TorchScript model and metadata in models/")

# move model back to its original device for any further experiments
model.to(device)


In [None]:
# dump a few CheXpert validation images into sample_uploads/ for the web UI
sample_dir = Path("sample_uploads")
sample_dir.mkdir(exist_ok=True)

raw_valid = CheXpertDataset(valid_csv, data_root, transform=None)
for idx in range(5):
    image, target = raw_valid[idx]
    labels = [CLASS_NAMES[i].replace(' ', '_') for i, value in enumerate(target.numpy()) if value > 0.5]
    label_suffix = "_".join(labels) if labels else "No_Finding"
    filename = f"valid_{idx}_{label_suffix}.png"
    image.save(sample_dir / filename)

print("✅ Saved samples in sample_uploads/")
