In [57]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.amp import autocast
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import timm
import wandb
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.metrics import confusion_matrix
import seaborn as sns
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from torchmetrics.classification import MulticlassCalibrationError

In [None]:
CONFIG = {
    "model_path": "C:/cifar-week3/artifacts/day3_resnet18-full-unfreeze.pth",
    "data_root": "./data",
    "artifacts_dir": "artifacts",
    "batch_size": 64,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "num_classes": 10,
    "num_misclassified": 20,
    "num_gradcam": 20,
    "noise_std": 0.1,
    "ece_bins": 15,
}

CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck']

NORMALIZE_MEAN = [0.4914, 0.4822, 0.4465]
NORMALIZE_STD = [0.2470, 0.2430, 0.2610]

print("Imports complete")
print(f"Using device: {CONFIG['device']}")

In [None]:
wandb.login()
wandb.init(
    project="cifar10-week3",
    name="day4-diagnostics-gradcam-calibration",
    group="Day4_Diagnostics_GradCAM_Calibration",
    config={
        "model": "ResNet18 (fine-tuned)",
        "task": "Diagnostics + Attribution + Calibration",
        **CONFIG
    }
)
print("W&B initialized")

In [None]:
device = torch.device(CONFIG["device"])

if not os.path.exists(CONFIG["model_path"]):
    raise FileNotFoundError(f"Model not found at {CONFIG['model_path']}")

model = timm.create_model("resnet18", pretrained=False, num_classes=CONFIG["num_classes"])
model.load_state_dict(torch.load(CONFIG["model_path"], map_location=device))
model.eval().to(device)
print("Model loaded and ready")

In [None]:
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(NORMALIZE_MEAN, NORMALIZE_STD)
])

test_dataset = datasets.CIFAR10(
    root=CONFIG["data_root"],
    train=False,
    download=True,
    transform=test_transform
)

test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False
)

print(f"Test dataset loaded: {len(test_dataset)} images")

In [62]:
def denormalize_image(img_tensor):
    """Denormalize image tensor for visualization."""
    mean = torch.tensor(NORMALIZE_MEAN).view(3, 1, 1)
    std = torch.tensor(NORMALIZE_STD).view(3, 1, 1)
    img = img_tensor * std + mean
    return torch.clamp(img, 0, 1)

In [None]:
print("Finding top misclassified examples...")
misclassified = []
device_type = "cuda" if device.type == "cuda" else "cpu"

with torch.no_grad():
    for i, (img, label) in enumerate(test_loader):
        img, label = img.to(device), label.to(device)
        
        with autocast(device_type):
            output = model(img)
        
        prob = F.softmax(output, dim=1)
        confidence, pred = torch.max(prob, 1)
        
        if pred.item() != label.item():
            misclassified.append({
                "index": i,
                "true": CLASSES[label.item()],
                "pred": CLASSES[pred.item()],
                "confidence": confidence.item(),
                "img": img.cpu().squeeze(0)
            })

# Sort by confidence (descending) and take top 20
misclassified = sorted(misclassified, key=lambda x: -x["confidence"])[:CONFIG["num_misclassified"]]
print(f"Found {len(misclassified)} misclassified examples")

In [None]:
os.makedirs(f"{CONFIG['artifacts_dir']}/misclassified", exist_ok=True)

misclassified_table = wandb.Table(columns=["Image", "True", "Pred", "Confidence"])

for item in misclassified:
    # Denormalize image
    img = denormalize_image(item["img"])
    img_pil = transforms.ToPILImage()(img)
    
    path = f"{CONFIG['artifacts_dir']}/misclassified/{item['index']}_{item['true']}_as_{item['pred']}.png"
    img_pil.save(path)
    
    misclassified_table.add_data(
        wandb.Image(path),
        item["true"],
        item["pred"],
        f"{item['confidence']:.3f}"
    )

wandb.log({"Misclassified Top 20": misclassified_table})
print(f"Logged {len(misclassified)} misclassified examples to W&B")

In [None]:
# Test if ANY gradients flow through your model
model.eval()
for param in model.parameters():
    param.requires_grad = True

test_input = torch.randn(1, 3, 32, 32, requires_grad=True).to(device)
test_output = model(test_input)
test_loss = test_output[0, 3]  # Pick one output
test_loss.backward()

print("=" * 60)
print("GRADIENT FLOW TEST")
print("=" * 60)
print(f"Input has gradient: {test_input.grad is not None}")
if test_input.grad is not None:
    print(f"Input gradient norm: {test_input.grad.norm().item():.6f}")
    print("✓ Gradients ARE flowing through the model")
else:
    print("✗ Gradients are NOT flowing - model architecture issue!")
print("=" * 60)

In [None]:
print("Generating Grad-CAM heatmaps...")
print("=" * 60)
print("DIAGNOSTIC: Checking GradCAM compatibility")
print("=" * 60)

# Ensure model is ready
model.eval()

# Enable gradients on all parameters
for param in model.parameters():
    param.requires_grad = True

# Test different target layer configurations
print("\nTesting target layers...")
target_configs = [
    ("layer4 (entire)", [model.layer4]),
    ("layer4[-1] (last block)", [model.layer4[-1]]),
    ("layer4[1] (second block)", [model.layer4[1]]),
    ("layer3[-1] (layer3 last)", [model.layer3[-1]])
]

working_config = None

for config_name, target_layers in target_configs:
    try:
        print(f"\nTrying: {config_name}")
        test_cam = GradCAM(model=model, target_layers=target_layers)
        
        # Get one test image
        test_img, test_label = next(iter(DataLoader(test_dataset, batch_size=1)))
        test_img = test_img.to(device)
        
        # Forward pass WITHOUT no_grad
        test_output = model(test_img)
        test_pred = test_output.argmax(1).item()
        
        # Generate heatmap
        test_heatmap = test_cam(
            input_tensor=test_img,
            targets=[ClassifierOutputTarget(test_pred)]
        )
        
        hm_min, hm_max, hm_mean = test_heatmap.min(), test_heatmap.max(), test_heatmap.mean()
        print(f"  Heatmap: min={hm_min:.4f}, max={hm_max:.4f}, mean={hm_mean:.4f}")
        
        if hm_max > 0.01:
            working_config = (config_name, target_layers)
            print(f"  ✓✓ SUCCESS! This configuration works!")
            break
        else:
            print(f"  ✗ Still all zeros")
            
    except Exception as e:
        print(f"  ✗ Error: {e}")

if working_config is None:
    print("\n" + "=" * 60)
    print("CRITICAL ERROR: No working GradCAM configuration found!")
    print("=" * 60)
    print("\nThis means your model has an architecture issue.")
    print("Please share:")
    print("1. How you created the model:")
    print("   Example: model = models.resnet18()")
    print("2. How you modified it:")
    print("   Example: model.fc = nn.Linear(512, 10)")
    print("3. How you loaded weights:")
    print("   Example: model.load_state_dict(torch.load(...))")
    
    # Additional diagnostics
    print("\nModel structure check:")
    print(f"  Has layer4: {hasattr(model, 'layer4')}")
    print(f"  Layer4 type: {type(model.layer4) if hasattr(model, 'layer4') else 'N/A'}")
    
    raise RuntimeError("Cannot generate GradCAM - see diagnostics above")

print("\n" + "=" * 60)
print(f"Using working configuration: {working_config[0]}")
print("=" * 60)

# Now generate actual visualizations with the working config
cam = GradCAM(model=model, target_layers=working_config[1])
os.makedirs(f"{CONFIG['artifacts_dir']}/gradcam", exist_ok=True)

gradcam_table = wandb.Table(columns=["GradCAM Overlay", "Original Image", "True Label", "Predicted", "Correct?"])
clean_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

print("\nGenerating GradCAM visualizations...")

for idx, (img, label) in enumerate(clean_loader):
    if idx >= CONFIG["num_gradcam"]:
        break
    
    # Prepare image for display
    img_tensor = img[0].clone()  # (C, H, W)
    img_unnorm = denormalize_image(img_tensor)  # Denormalized for visualization
    img_np = img_unnorm.permute(1, 2, 0).cpu().numpy().astype(np.float32)
    img_np = np.clip(img_np, 0, 1)  # Clip AFTER conversion
    
    pil_original = transforms.ToPILImage()(img_unnorm)
    
    # Run Grad-CAM - NO torch.no_grad() here!
    input_tensor = img.to(device)
    output = model(input_tensor)  # Allow gradients to flow
    pred = output.argmax(1).item()
    
    # Generate heatmap
    grayscale_cam = cam(
        input_tensor=input_tensor,
        targets=[ClassifierOutputTarget(pred)]
    )
    grayscale_cam = grayscale_cam[0, :]
    
    if idx < 3:
        print(f"Sample {idx}: Heatmap min={grayscale_cam.min():.4f}, max={grayscale_cam.max():.4f}, mean={grayscale_cam.mean():.4f}")
        print(f"  Predicted: {CLASSES[pred]}, True: {CLASSES[label.item()]}")
    
    # Create overlay
    visualization = show_cam_on_image(img_np, grayscale_cam, use_rgb=True)
    pil_overlay = Image.fromarray(visualization.astype(np.uint8))
    
    correct = "Yes" if pred == label.item() else "No"
    
    # Save
    overlay_path = f"{CONFIG['artifacts_dir']}/gradcam/overlay_{idx}.png"
    pil_overlay.save(overlay_path)
    
    gradcam_table.add_data(
        wandb.Image(overlay_path),
        wandb.Image(pil_original),
        CLASSES[label.item()],
        CLASSES[pred],
        correct
    )
    
    if (idx + 1) % 5 == 0:
        print(f"  Grad-CAM {idx+1}/{CONFIG['num_gradcam']} done")

wandb.log({"Grad-CAM Gallery (20 examples)": gradcam_table})
print("✓ Grad-CAM complete — all heatmaps uploaded to W&B!")

In [None]:
print("Running adversarial noise check...")
noisy_correct = 0
total = 0

with torch.no_grad():
    for img, label in test_loader:
        img = img.to(device)
        noise = torch.randn_like(img) * CONFIG["noise_std"]
        noisy_img = torch.clamp(img + noise, 0, 1)
        
        with autocast(device_type):
            pred = model(noisy_img).argmax(1)
        
        noisy_correct += (pred == label.to(device)).sum().item()
        total += 1

noisy_acc = noisy_correct / total
wandb.log({f"adversarial_noise_accuracy (σ={CONFIG['noise_std']})": noisy_acc})
print(f"Noisy accuracy (σ={CONFIG['noise_std']}): {noisy_acc:.4f}")

In [None]:
print("Computing ECE (Expected Calibration Error)...")

all_logits = []
all_labels = []

with torch.no_grad():
    for img, label in test_loader:
        img = img.to(device)
        with autocast(device_type):
            logits = model(img)
        all_logits.append(logits.cpu())
        all_labels.append(label)

logits = torch.cat(all_logits)
labels = torch.cat(all_labels)

In [None]:
ece_metric = MulticlassCalibrationError(
    num_classes=CONFIG["num_classes"],
    n_bins=CONFIG["ece_bins"],
    norm='l1'
)
ece_value = ece_metric(logits, labels).item()
print(f"✓ ECE (before calibration) = {ece_value:.4f}")

In [None]:
probs = F.softmax(logits, dim=1)
conf, pred = torch.max(probs, dim=1)

fig, ax = plt.subplots(figsize=(6, 5))
bins = np.linspace(0, 1, CONFIG["ece_bins"] + 1)
accs = []
confs_bin = []

for i in range(CONFIG["ece_bins"]):
    mask = (conf >= bins[i]) & (conf < bins[i+1])
    if mask.sum() > 0:
        accuracy = (pred[mask] == labels[mask]).float().mean().item()
        confidence = conf[mask].mean().item()
    else:
        accuracy = 0
        confidence = (bins[i] + bins[i+1]) / 2
    accs.append(accuracy)
    confs_bin.append(confidence)

ax.bar(confs_bin, accs, width=0.06, alpha=0.8, color="skyblue", edgecolor="black")
ax.plot([0, 1], [0, 1], "--", color="red", linewidth=2, label="Perfect calibration")
ax.set_xlabel("Confidence")
ax.set_ylabel("Accuracy")
ax.set_title(f"Reliability Diagram – ECE = {ece_value:.4f}")
ax.legend()
ax.grid(True, alpha=0.3)

wandb.log({
    "ECE": ece_value,
    "Reliability Diagram": wandb.Image(fig)
})
plt.close(fig)
print("Reliability diagram saved")

In [None]:
print("Applying temperature scaling...")

T = nn.Parameter(torch.ones(1) * 1.5)
optimizer = torch.optim.LBFGS([T], lr=0.01, max_iter=1000)

def closure():
    optimizer.zero_grad()
    loss = F.cross_entropy(logits / T, labels)
    loss.backward()
    return loss

optimizer.step(closure)
print(f"Learned temperature: {T.item():.3f}")

In [None]:
scaled_logits = logits / T.item()
ece_metric_scaled = MulticlassCalibrationError(
    num_classes=CONFIG["num_classes"],
    n_bins=CONFIG["ece_bins"],
    norm='l1'
)
ece_scaled = ece_metric_scaled(scaled_logits, labels).item()

wandb.log({
    "temperature_scaling_T": T.item(),
    "ECE_after_temperature_scaling": ece_scaled
})

print(f"ECE before: {ece_value:.4f} → after T-scaling: {ece_scaled:.4f}")

In [None]:
print("Generating confusion matrix...")

y_true = labels.cpu().numpy()
y_pred = pred.cpu().numpy()

cm = confusion_matrix(y_true, y_pred)

plt.figure(figsize=(10, 8))
sns.heatmap(
    cm,
    annot=True,
    fmt='d',
    cmap="Blues",
    xticklabels=CLASSES,
    yticklabels=CLASSES,
    cbar_kws={'label': 'Count'}
)
plt.xlabel("Predicted", fontsize=14)
plt.ylabel("True", fontsize=14)
plt.title("Confusion Matrix – CIFAR-10 Test Set", fontsize=16)
plt.xticks(rotation=45)
plt.yticks(rotation=0)
plt.tight_layout()

wandb.log({"Confusion Matrix": wandb.Image(plt)})
plt.show()
plt.close()

print("Confusion matrix uploaded to W&B!")

In [None]:
print("\n" + "="*80)
print("SUMMARY OF RESULTS")
print("="*80)
print(f"Total test images: {len(test_dataset)}")
print(f"Misclassified examples found: {len(misclassified)}")
print(f"Grad-CAM visualizations: {CONFIG['num_gradcam']}")
print(f"Adversarial noise accuracy (σ={CONFIG['noise_std']}): {noisy_acc:.4f}")
print(f"ECE (before calibration): {ece_value:.4f}")
print(f"Temperature learned: {T.item():.3f}")
print(f"ECE (after calibration): {ece_scaled:.4f}")
print("="*80)

# Finish W&B logging
wandb.finish()
print("\n✓ All diagnostics complete! W&B run finished.")