In [None]:
# demo_inference.ipynb

# ===========================================================
# 1. Imports & Setup
# ===========================================================

import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

# Import your modules (adjust these imports if your paths are different)
from src.dataset import CustomImageDataset
from src.model_resnet50 import MultiscaleFusionClassifier  # <-- choose your model class

# -----------------------------------------------------------
# 2. Set Paths & Parameters
# -----------------------------------------------------------

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

DATA_DIR = "./data"
MODEL_CKPT = "./outputs/mhgf2/best_model_fold1.pth"  # Change to your path/backbone
TEST_META_CSV = os.path.join(DATA_DIR, "padufes20/padufes20-test-metadata.csv")
TEST_IMG_DIR = os.path.join(DATA_DIR, "padufes20/padufes20-test-set")
BATCH_SIZE = 8  # For demo, keep small

# Feature columns (copy from your code)
cols = [
    'age', 'smoke', 'drink', 'pesticide', 'gender', 'skin_cancer_history',
    'cancer_history', 'has_piped_water', 'has_sewage_system',
    'background_father_10', 'background_father_12', 'background_father_2',
    'background_father_4', 'background_father_6', 'background_father_7',
    'background_father_9', 'background_father_Other', 'background_mother_0',
    'background_mother_10', 'background_mother_2', 'background_mother_3',
    'background_mother_4', 'background_mother_7', 'background_mother_8',
    'background_mother_Other', 'region_0', 'region_1', 'region_10',
    'region_11', 'region_12', 'region_13', 'region_2', 'region_3',
    'region_4', 'region_5', 'region_6', 'region_7', 'region_8', 'region_9',
    'itch_1.0', 'grew_1.0', 'hurt_1.0', 'changed_1.0', 'bleed_1.0',
    'elevation_1.0', 'fitspatrick'
]

# -----------------------------------------------------------
# 3. Load Data (Test Set)
# -----------------------------------------------------------

test_df = pd.read_csv(TEST_META_CSV)
test_img_paths = [os.path.join(TEST_IMG_DIR, f"{img_id}.png") for img_id in test_df.img_id]
test_labels = test_df.diagnostic_encoded.values
test_meta = test_df[cols].values

# Define image transforms (same as validation)
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

# Create dataset and loader
demo_ds = CustomImageDataset(test_img_paths, test_meta, test_labels, transform=val_transform)
demo_ld = DataLoader(demo_ds, batch_size=BATCH_SIZE, shuffle=True)  # shuffle=True to get random batch

# -----------------------------------------------------------
# 4. Load Model and Weights
# -----------------------------------------------------------

# Model params (update if needed to match your training)
num_classes = len(np.unique(test_labels))
meta_in = test_meta.shape[1]
meta_out = 768
K_init = [8, 8, 8]
ref_delta = 4
ref_epochs = [50, 100, 150]

model = MultiscaleFusionClassifier(num_classes, meta_in, meta_out, K_init, ref_delta, ref_epochs)
model.load_state_dict(torch.load(MODEL_CKPT, map_location=device))
model.to(device)
model.eval()

print("Model loaded from:", MODEL_CKPT)

# -----------------------------------------------------------
# 5. Inference on a Batch
# -----------------------------------------------------------

with torch.no_grad():
    batch_imgs, batch_meta, batch_labels = next(iter(demo_ld))
    batch_imgs, batch_meta = batch_imgs.to(device), batch_meta.to(device)
    logits = model(batch_imgs, batch_meta)
    probs = torch.softmax(logits, dim=1)
    preds = torch.argmax(probs, dim=1)

print("Ground truth labels:", batch_labels.tolist())
print("Predicted labels:   ", preds.cpu().tolist())

# -----------------------------------------------------------
# 6. Show Images with Predictions
# -----------------------------------------------------------

class_names = {0: "BCC", 1: "SCC", 2: "ACK", 3: "NEV", 4: "MEL", 5: "SEK"}  # update as needed

def plot_images(images, labels, preds, probs, class_names, max_images=8):
    n = min(len(images), max_images)
    plt.figure(figsize=(15, 5))
    for i in range(n):
        img = images[i].permute(1, 2, 0).cpu().numpy()
        img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        img = np.clip(img, 0, 1)
        plt.subplot(1, n, i+1)
        plt.imshow(img)
        plt.axis('off')
        gt = class_names.get(labels[i].item(), str(labels[i].item()))
        pr = class_names.get(preds[i].item(), str(preds[i].item()))
        plt.title(f"GT:{gt}\nPR:{pr}\nConf:{probs[i][preds[i]].item():.2f}", fontsize=9)
    plt.tight_layout()
    plt.show()

plot_images(batch_imgs, batch_labels, preds, probs, class_names)

# -----------------------------------------------------------
# 7. Confusion Matrix on the Whole Test Set (optional)
# -----------------------------------------------------------

# For a quick demo, you can run on a small test subset. For full results, set batch_size larger and iterate.
all_gt, all_pr = [], []
with torch.no_grad():
    for imgs, meta, labels in DataLoader(demo_ds, batch_size=32):
        imgs, meta = imgs.to(device), meta.to(device)
        logits = model(imgs, meta)
        pred = torch.argmax(logits, dim=1).cpu().numpy()
        all_gt.extend(labels.numpy())
        all_pr.extend(pred)

cm = confusion_matrix(all_gt, all_pr)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=[class_names.get(i, str(i)) for i in range(num_classes)])
fig, ax = plt.subplots(figsize=(6,6))
disp.plot(ax=ax, cmap="Blues", values_format="d")
plt.title("Confusion Matrix (Test Set)")
plt.show()