# CheXpert inference with TorchXRayVision DenseNet-121

Αυτό το notebook ρυθμίζει τον φορτωτή του dataset, φορτώνει το προεκπαιδευμένο μοντέλο `torchxrayvision` και δίνει παραδείγματα αξιολόγησης, inference και Grad-CAM για Chest X-ray εικόνες.

In [None]:
import json
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from PIL import Image, ImageOps
from sklearn.metrics import roc_auc_score
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

import torchxrayvision as xrv


In [None]:
# CheXpert dataset configuration
data_root = Path("/data/CheXpert-v1.0-small")  # προσαρμόστε στη διαδρομή του extract
train_csv = data_root / "train.csv"
valid_csv = data_root / "valid.csv"

IMG_SIZE = 224
WEIGHTS_NAME = "densenet121-res224-chex"
CLASS_NAMES = [
    "Atelectasis",
    "Cardiomegaly",
    "Consolidation",
    "Edema",
    "Enlarged Cardiomediastinum",
    "Fracture",
    "Lung Lesion",
    "Lung Opacity",
    "No Finding",
    "Pleural Effusion",
    "Pleural Other",
    "Pneumonia",
    "Pneumothorax",
    "Support Devices",
]


In [None]:
def build_xrv_transform(img_size: int = IMG_SIZE):
    center_crop = xrv.datasets.XRayCenterCrop()
    resizer = xrv.datasets.XRayResizer(img_size)

    def _transform(image: Image.Image) -> torch.Tensor:
        arr = np.array(image.convert("L"))
        arr = center_crop(arr)
        arr = resizer(arr)
        arr = arr.astype(np.float32) / 255.0
        tensor = torch.from_numpy(arr).unsqueeze(0)
        return tensor

    return _transform

class CheXpertDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None, uncertain_policy="zeros"):
        self.root_dir = Path(root_dir)
        self.df = pd.read_csv(csv_file)
        self.transform = transform or build_xrv_transform()

        if uncertain_policy == "zeros":
            processed = self.df[CLASS_NAMES].fillna(0.0).replace(-1, 0.0)
        elif uncertain_policy == "ones":
            processed = self.df[CLASS_NAMES].fillna(1.0).replace(-1, 1.0)
        elif uncertain_policy == "ignore":
            keep = self.df.dropna(subset=CLASS_NAMES)
            processed = keep[CLASS_NAMES]
            self.df = keep.reset_index(drop=True)
        else:
            raise ValueError(f"Unknown uncertain_policy: {uncertain_policy}")

        self.labels = processed.astype(np.float32).values
        self.paths = self.df["Path"].values

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        rel_path = self.paths[idx]
        image_path = self.root_dir / rel_path
        image = Image.open(image_path).convert("RGB")
        tensor = self.transform(image)
        target = torch.from_numpy(self.labels[idx])
        return tensor, target


In [None]:
# build datasets and loaders
train_dataset = CheXpertDataset(train_csv, data_root)
valid_dataset = CheXpertDataset(valid_csv, data_root)

batch_size = 16
num_workers = 4

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)


In [None]:
# load pretrained DenseNet-121 from torchxrayvision
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = xrv.models.DenseNet(weights=WEIGHTS_NAME)
model = model.to(device).eval()
print(f'Loaded {WEIGHTS_NAME} on {device}.')


In [None]:
# evaluate ROC-AUC on the validation split
model.eval()
y_true, y_score = [], []

with torch.no_grad():
    for inputs, targets in tqdm(valid_loader, desc="Validation"):
        inputs = inputs.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        logits = model(inputs)
        probs = torch.sigmoid(logits)

        y_true.append(targets.cpu().numpy())
        y_score.append(probs.cpu().numpy())

y_true = np.concatenate(y_true, axis=0)
y_score = np.concatenate(y_score, axis=0)

roc_aucs = {}
for idx, name in enumerate(CLASS_NAMES):
    try:
        roc_aucs[name] = roc_auc_score(y_true[:, idx], y_score[:, idx])
    except ValueError:
        roc_aucs[name] = float('nan')

roc_aucs


In [None]:
# single-image inference demo
example_idx = 0
example_rel_path = valid_dataset.paths[example_idx]
example_image = Image.open(data_root / example_rel_path).convert('RGB')
input_tensor = valid_dataset.transform(example_image).unsqueeze(0).to(device)

model.eval()
with torch.no_grad():
    logits = model(input_tensor)
    probs = torch.sigmoid(logits)[0].cpu().numpy()

top = sorted(zip(CLASS_NAMES, probs), key=lambda item: item[1], reverse=True)
print(f'Top findings for {example_rel_path}:')
for name, p in top[:5]:
    print(f'  {name:<26}: {p:.3f}')

example_image.resize((IMG_SIZE, IMG_SIZE))


In [None]:
# Grad-CAM utilities (ίδια λογική με το backend)
def _register_backward_hook(module, hook):
    if hasattr(module, 'register_full_backward_hook'):
        return module.register_full_backward_hook(hook)
    return module.register_backward_hook(lambda m, grad_in, grad_out: hook(m, grad_in, grad_out))

def compute_grad_cam(model, input_tensor, target_layer, target_index=None):
    activations, gradients = [], []

    def forward_hook(_module, _inp, output):
        activations.append(output.detach())

    def backward_hook(_module, _grad_input, grad_output):
        gradients.append(grad_output[0].detach())

    handle_fwd = target_layer.register_forward_hook(forward_hook)
    handle_bwd = _register_backward_hook(target_layer, backward_hook)

    input_tensor = input_tensor.requires_grad_(True)
    logits = model(input_tensor)

    if target_index is None:
        probs = torch.sigmoid(logits)[0]
        target_index = int(torch.argmax(probs).item())

    model.zero_grad(set_to_none=True)
    score = logits[:, target_index]
    score.backward()

    handle_fwd.remove()
    handle_bwd.remove()

    if not activations or not gradients:
        return logits.detach(), None

    grad = gradients[0]
    act = activations[0]

    weights = grad.mean(dim=(2, 3), keepdim=True)
    cam = torch.relu((weights * act).sum(dim=1, keepdim=True))
    cam = torch.nn.functional.interpolate(cam, size=(IMG_SIZE, IMG_SIZE), mode='bilinear', align_corners=False)
    cam = cam.squeeze().cpu()
    cam = cam - cam.min()
    cam = cam / (cam.max() + 1e-8)
    return logits.detach(), cam.numpy()

def build_heatmap_overlay(base_image: Image.Image, heatmap, alpha: float = 0.45):
    resized = base_image.convert('RGB').resize((IMG_SIZE, IMG_SIZE))
    heat = Image.fromarray(np.uint8(heatmap * 255), mode='L')
    heat = heat.resize(resized.size, resample=Image.BILINEAR)
    colorized = ImageOps.colorize(heat, black='#0b1f3a', white='#f97316')
    return Image.blend(resized, colorized, alpha)


In [None]:
# Grad-CAM visualization
model.eval()
target_layer = model.features.denseblock4
logits, heatmap = compute_grad_cam(model, input_tensor, target_layer)

if heatmap is not None:
    overlay = build_heatmap_overlay(example_image, heatmap)
    display(overlay)
else:
    print('Grad-CAM heatmap could not be computed.')


In [None]:
# Ενημέρωση metadata για το FastAPI app
metadata = {
    'dataset': 'chexpert',
    'task': 'multi-label-binary',
    'img_size': IMG_SIZE,
    'n_channels': 1,
    'mean': [0.5],
    'std': [0.5],
    'pretrained_source': 'torchxrayvision',
    'weights': WEIGHTS_NAME,
    'class_names': CLASS_NAMES,
}

metadata_path = Path('medmnist_web/models/metadata.json')
metadata_path.parent.mkdir(parents=True, exist_ok=True)
with open(metadata_path, 'w', encoding='utf-8') as f:
    json.dump(metadata, f, indent=2)

metadata


In [None]:
# Εξαγωγή ενδεικτικών εικόνων για το web UI
sample_dir = Path('sample_uploads')
sample_dir.mkdir(exist_ok=True)

raw_valid = CheXpertDataset(valid_csv, data_root, transform=lambda x: x)
for idx in range(5):
    image, target = raw_valid[idx]
    labels = [CLASS_NAMES[i].replace(' ', '_') for i, value in enumerate(target.numpy()) if value > 0.5]
    suffix = '_'.join(labels) if labels else 'No_Finding'
    filename = f'valid_{idx}_{suffix}.png'
    image.save(sample_dir / filename)

print('✅ Saved samples in sample_uploads/')
