In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import matplotlib.pyplot as plt
from torchvision import transforms
from torchvision.datasets import ImageFolder

# =========================================================
#  Load SAME LeafCNN model architecture from your main code
# =========================================================

class ConvBlock(nn.Module):
    def __init__(self, in_c, out_c, pool=True):
        super().__init__()
        self.conv = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1, bias=False)
        self.bn   = nn.BatchNorm2d(out_c)
        self.act  = nn.ReLU(inplace=True)
        self.pool = pool
        if pool:
            self.pool_layer = nn.MaxPool2d(2,2)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.act(x)
        if self.pool:
            x = self.pool_layer(x)
        return x


class LeafCNN(nn.Module):
    def __init__(self, num_classes=52):
        super().__init__()
        ch = [48, 72, 120, 192, 256]
        self.layer1 = ConvBlock(3,   ch[0], pool=True)
        self.layer2 = ConvBlock(ch[0], ch[1], pool=True)
        self.layer3 = ConvBlock(ch[1], ch[2], pool=True)
        self.layer4 = ConvBlock(ch[2], ch[3], pool=True)
        self.layer5 = ConvBlock(ch[3], ch[4], pool=True)

        self.global_pool = nn.AdaptiveAvgPool2d((1,1))
        self.fc1 = nn.Linear(ch[-1], 1536)
        self.dropout = nn.Dropout(0.4)
        self.fc2 = nn.Linear(1536, num_classes)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        x = self.global_pool(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x, inplace=True)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# =========================================================
# SETUP
# =========================================================

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMG_SIZE = 224
MODEL_PATH = r"C:\Users\AL IMRAN\Desktop\CSE498R\Local\best_model.pth"
VAL_DIR = r"C:\Users\AL IMRAN\Desktop\CSE498R\Local\Dataset_Final_V2_Split\val"

# Load model
model = LeafCNN(num_classes=52).to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

# ======================================================================
# REGISTER HOOKS FOR GRAD-CAM (ONLY LAST CONV: layer5.conv)
# ======================================================================

feature_maps = []
gradients = []

def forward_hook(module, inp, out):
    feature_maps.clear()
    feature_maps.append(out)

def backward_hook(module, grad_in, grad_out):
    gradients.clear()
    gradients.append(grad_out[0])

target_layer = model.layer5.conv
target_layer.register_forward_hook(forward_hook)
target_layer.register_backward_hook(backward_hook)

# ======================================================================
# GRAD-CAM FUNCTION (NO GRADCAM++)
# ======================================================================

def generate_gradcam(input_tensor, class_idx=None):
    model.zero_grad()
    feature_maps.clear()
    gradients.clear()

    output = model(input_tensor)
    if class_idx is None:
        class_idx = int(output.argmax(dim=1).item())

    score = output[0, class_idx]
    score.backward(retain_graph=True)

    activations = feature_maps[0].detach()[0]   # (C,H,W)
    grads = gradients[0].detach()[0]            # (C,H,W)

    weights = grads.mean(dim=(1, 2))            # (C,)

    cam = torch.sum(weights[:, None, None] * activations, dim=0)
    cam = torch.relu(cam)

    cam -= cam.min()
    cam /= (cam.max() + 1e-8)

    cam = cam.unsqueeze(0).unsqueeze(0)
    cam = F.interpolate(cam, size=(IMG_SIZE, IMG_SIZE),
                        mode="bilinear", align_corners=False)
    return cam.squeeze().cpu()

# ======================================================================
# LOAD ONE IMAGE PER CLASS
# ======================================================================

val_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
])

val_ds = ImageFolder(VAL_DIR, transform=val_tf)

samples = {}
for img, lbl in val_ds:
    if lbl not in samples:
        samples[lbl] = img
    if len(samples) == len(val_ds.classes):
        break

print("âœ” Collected", len(samples), "classes")

# ======================================================================
# SAVE GRAD-CAM IMAGES
# ======================================================================

SAVE_DIR = os.path.join(os.path.dirname(MODEL_PATH), "gradcam_only")
os.makedirs(SAVE_DIR, exist_ok=True)

for class_idx, img in samples.items():

    img_tensor = img.unsqueeze(0).to(DEVICE)

    heatmap = generate_gradcam(img_tensor, class_idx)

    img_np = img_tensor.squeeze().cpu().permute(1, 2, 0).numpy()

    fname = f"{class_idx:02d}_{val_ds.classes[class_idx]}_gradcam.png"
    save_path = os.path.join(SAVE_DIR, fname.replace(" ", "_"))

    plt.figure(figsize=(5,5))
    plt.imshow(img_np)
    plt.imshow(heatmap.numpy(), cmap='jet', alpha=0.4)
    plt.axis("off")
    plt.savefig(save_path, bbox_inches="tight", pad_inches=0)
    plt.close()

print("\nðŸŽ‰ All Grad-CAM results saved to:", SAVE_DIR)


  model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))


âœ” Collected 51 classes


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)



ðŸŽ‰ All Grad-CAM results saved to: C:\Users\AL IMRAN\Desktop\CSE498R\Local\gradcam_only
