# 03 â€” Evaluation

Full evaluation of the trained model on the test set:
- Classification report (precision, recall, F1)
- Confusion matrix
- ROC curves and AUC
- Per-class sensitivity & specificity
- Grad-CAM visual explanations

In [None]:
import sys, os
sys.path.insert(0, os.path.abspath('..'))

import torch
import numpy as np
import matplotlib.pyplot as plt

from src.config import CONFIG, CLASS_NAMES, MODELS_DIR, IDX_TO_LABEL, get_device
from src.utils.seed import set_seed
from src.data.dataset import get_dataloaders
from src.models.model import build_model, get_grad_cam
from src.models.loss import get_criterion
from src.train.evaluate import evaluate
from src.utils.metrics import compute_metrics, compute_binary_auc, sensitivity_specificity, get_classification_report
from src.utils.visualization import plot_confusion_matrix, plot_roc_curve_binary, plot_grad_cam
from src.data.transforms import denormalize

%matplotlib inline

DEVICE = get_device()
set_seed(CONFIG['seed'])
print(f'Device: {DEVICE}')

## 1. Load best model

In [None]:
model = build_model(pretrained=False)
ckpt = torch.load(MODELS_DIR / 'best_model.pth', map_location=DEVICE)
model.load_state_dict(ckpt['model_state_dict'])
model = model.to(DEVICE)
model.eval()

print(f"Loaded checkpoint from epoch {ckpt.get('epoch', '?')} with val_acc={ckpt.get('val_acc', '?')}")

## 2. Run evaluation on test set

In [None]:
loaders = get_dataloaders()
test_loader = loaders['test']

criterion = get_criterion(device=DEVICE)
loss, acc, preds, labels, probs = evaluate(model, test_loader, criterion, DEVICE)

print(f'Test Loss:     {loss:.4f}')
print(f'Test Accuracy: {acc:.4f}')

## 3. Classification report

In [None]:
metrics = compute_metrics(labels, preds)
print('Aggregate metrics:', metrics)
print()
print(get_classification_report(labels, preds, class_names=CLASS_NAMES))

## 4. Confusion matrix

In [None]:
plot_confusion_matrix(labels, preds, class_names=CLASS_NAMES)
print()
plot_confusion_matrix(labels, preds, class_names=CLASS_NAMES, normalize=True)

## 5. ROC curves & AUC

In [None]:
auc_score = compute_binary_auc(labels, probs)
print(f'Binary AUC: {auc_score:.4f}')

plot_roc_curve_binary(labels, probs)

## 6. Per-class sensitivity & specificity

In [None]:
ss = sensitivity_specificity(labels, preds)

import pandas as pd
ss_df = pd.DataFrame({
    'class': CLASS_NAMES,
    'sensitivity': [f'{v:.4f}' for v in ss['sensitivity']],
    'specificity': [f'{v:.4f}' for v in ss['specificity']],
})
ss_df

## 7. Grad-CAM visualisation

In [None]:
cam = get_grad_cam(model)

# Grab a few test images
test_iter = iter(test_loader)
images, true_labels = next(test_iter)

for i in range(min(5, len(images))):
    img_tensor = images[i].unsqueeze(0).to(DEVICE)
    heatmap = cam(img_tensor)
    orig_np = denormalize(img_tensor)

    with torch.no_grad():
        logit = model(img_tensor).squeeze()
        prob_mal = torch.sigmoid(logit).item()
        pred_idx = int(prob_mal >= 0.5)

    true_name = IDX_TO_LABEL[true_labels[i].item()]
    pred_name = IDX_TO_LABEL[pred_idx]
    confidence = prob_mal if pred_idx == 1 else 1.0 - prob_mal
    plot_grad_cam(
        orig_np, heatmap,
        predicted_label=pred_name,
        true_label=true_name,
        confidence=confidence,
    )