In [2]:
# Grad-CAM on BEST weights

import os, zipfile
import numpy as np
from PIL import Image
import cv2

import torch
import torch.nn as nn
from torchvision import transforms, datasets


# Paths / settings
DATA_ROOT = "/kaggle/input/plant-disease-dataset/Dataset_Final_V2_Split"
OUT_DIR   = "/kaggle/working/plantanet_relu_final"
ZIP_NAME  = "gradcam_best.zip"

IMG_SIZE = 160
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

# Preferred best-weights locations
BEST_WEIGHTS_PATH = os.path.join(OUT_DIR, "PlantaNet_ReLU_best_weights.pth")

# ACTUAL kaggle dataset path
FALLBACK_BEST_1 = "/kaggle/input/best-weights-dataset/PlantaNet_ReLU_best_weights.pth"

# extra fallbacks (only if needed)
FALLBACK_BEST_2 = "/kaggle/input/best-weights/PlantaNet_ReLU_best_weights.pth"
FALLBACK_FINAL  = "/kaggle/input/final-weights/PlantaNet_ReLU_final_weights.pth"


# Transforms
eval_tfm = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])


# Dataset (test only)
test_dir = os.path.join(DATA_ROOT, "test")
test_ds  = datasets.ImageFolder(test_dir, transform=eval_tfm)

class_names = test_ds.classes
NUM_CLASSES = len(class_names)
print("Classes:", NUM_CLASSES, "| Test:", len(test_ds))


# Model definition 
class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1):
        super().__init__()
        self.depth = nn.Conv2d(in_ch, in_ch, 3, stride, 1, groups=in_ch, bias=False)
        self.point = nn.Conv2d(in_ch, out_ch, 1, bias=False)
        self.norm  = nn.GroupNorm(8 if out_ch % 8 == 0 else 4, out_ch)
        self.act   = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.act(self.norm(self.point(self.depth(x))))

class PlantaNetReLU(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        c1, c2, c3, c4 = 144, 224, 320, 448

        self.stem = nn.Sequential(
            nn.Conv2d(3, c1, 3, 1, 1, bias=False),
            nn.GroupNorm(8 if c1 % 8 == 0 else 4, c1),
            nn.ReLU(inplace=True),
        )
        self.block1 = nn.Sequential(DepthwiseSeparableConv(c1, c2, stride=2), nn.Dropout(0.15))
        self.block2 = nn.Sequential(DepthwiseSeparableConv(c2, c3, stride=2), nn.Dropout(0.20))
        self.block3 = nn.Sequential(DepthwiseSeparableConv(c3, c4, stride=2), nn.Dropout(0.25))

        self.conv_extra = nn.Sequential(
            nn.Conv2d(c4, c4, 3, 1, 1, bias=False),
            nn.GroupNorm(8 if c4 % 8 == 0 else 4, c4),
            nn.ReLU(inplace=True),
            nn.Dropout(0.25),
        )

        self.gap = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(c4, 1024, bias=False),
            nn.GroupNorm(8 if 1024 % 8 == 0 else 4, 1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Linear(1024, num_classes),
        )

    def forward(self, x):
        x = self.stem(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.conv_extra(x)
        x = self.gap(x)
        return self.classifier(x)


# Load BEST weights safely
model = PlantaNetReLU(NUM_CLASSES).to(DEVICE)

weights_path = None
for p in [BEST_WEIGHTS_PATH, FALLBACK_BEST_1, FALLBACK_BEST_2, FALLBACK_FINAL]:
    if os.path.exists(p):
        weights_path = p
        break

if weights_path is None:
    raise FileNotFoundError(
        "No weights found.\n"
        f"Tried:\n- {BEST_WEIGHTS_PATH}\n- {FALLBACK_BEST_1}\n- {FALLBACK_BEST_2}\n- {FALLBACK_FINAL}"
    )

print("Loading weights from:", weights_path)
state = torch.load(weights_path, map_location=DEVICE)

# handle DataParallel keys if needed
if any(k.startswith("module.") for k in state.keys()):
    state = {k.replace("module.", ""): v for k, v in state.items()}

model.load_state_dict(state, strict=False)
model.eval()


# Grad-CAM implementation
target_layer = model.conv_extra[0]

class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None

        target_layer.register_forward_hook(self._save_activation)
        target_layer.register_full_backward_hook(self._save_gradient)

    def _save_activation(self, module, inp, out):
        self.activations = out.detach()

    def _save_gradient(self, module, grad_input, grad_output):
        self.gradients = grad_output[0].detach()

    def generate(self, input_tensor, class_idx=None):
        self.model.zero_grad()
        outputs = self.model(input_tensor)
        if class_idx is None:
            class_idx = int(outputs.argmax(dim=1).item())

        score = outputs[:, class_idx]
        score.backward(retain_graph=False)

        grads = self.gradients
        acts  = self.activations

        weights = grads.mean(dim=(2, 3), keepdim=True)  # GAP on gradients
        cam = (weights * acts).sum(dim=1).squeeze(0)
        cam = torch.relu(cam)

        cam_np = cam.cpu().numpy()
        cam_np -= cam_np.min()
        cam_np /= (cam_np.max() + 1e-9)
        return cam_np

gradcam = GradCAM(model, target_layer)

# denorm helper
mean = np.array([0.485, 0.456, 0.406]).reshape(3,1,1)
std  = np.array([0.229, 0.224, 0.225]).reshape(3,1,1)

def denorm_tensor(img_tensor):
    arr = img_tensor.cpu().numpy()
    arr = (arr * std) + mean
    arr = np.clip(arr, 0, 1)
    return np.transpose(arr, (1,2,0))


# Pick one correct sample per class
print("Collecting correct samples...")
correct_samples = {cls: [] for cls in range(NUM_CLASSES)}

with torch.no_grad():
    for path, label in test_ds.samples:
        img = Image.open(path).convert("RGB")
        tensor = eval_tfm(img).unsqueeze(0).to(DEVICE)
        out = model(tensor)
        pred = out.argmax(1).item()
        if pred == label:
            correct_samples[label].append(path)


# Generate and save Grad-CAM
gradcam_dir = os.path.join(OUT_DIR, "gradcam_best")
os.makedirs(gradcam_dir, exist_ok=True)

print("Generating Grad-CAM...")
for cls in range(NUM_CLASSES):
    if len(correct_samples[cls]) == 0:
        print(f"No correct sample for {class_names[cls]}")
        continue

    path = correct_samples[cls][0]
    pil = Image.open(path).convert("RGB")
    tensor = eval_tfm(pil).unsqueeze(0).to(DEVICE)

    cam_map = gradcam.generate(tensor, class_idx=cls)
    cam_resized = cv2.resize(cam_map, (IMG_SIZE, IMG_SIZE))

    heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)[:, :, ::-1]

    original = denorm_tensor(tensor[0])
    orig_uint8 = np.uint8(original * 255)
    heat_uint8 = np.uint8(heatmap)

    overlay = cv2.addWeighted(orig_uint8, 0.6, heat_uint8, 0.4, 0)

    cls_name = class_names[cls].replace("/", "_").replace(" ", "_")
    cls_folder = os.path.join(gradcam_dir, cls_name)
    os.makedirs(cls_folder, exist_ok=True)

    save_path = os.path.join(cls_folder, f"gradcam_{cls_name}.png")
    cv2.imwrite(save_path, overlay[:, :, ::-1])
    print("Saved:", cls_name)

print("Done. Saved under:", gradcam_dir)


# Auto-zip for download
zip_path = os.path.join(OUT_DIR, ZIP_NAME)
print("Zipping to:", zip_path)

with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
    for root, _, files in os.walk(gradcam_dir):
        for f in files:
            full_path = os.path.join(root, f)
            rel_path = os.path.relpath(full_path, gradcam_dir)
            zf.write(full_path, arcname=os.path.join("gradcam_best", rel_path))

print("Zip ready:", zip_path)


Device: cpu
Classes: 51 | Test: 20506
Loading weights from: /kaggle/input/best-weights-dataset/PlantaNet_ReLU_best_weights.pth
Collecting correct samples...
Generating Grad-CAM...
Saved: Apple___Apple_scab
Saved: Apple___Black_rot
Saved: Apple___Cedar_apple_rust
Saved: Apple___healthy
Saved: Banana___cordana
Saved: Banana___healthy
Saved: Banana___pestalotiopsis
Saved: Banana___sigatoka
Saved: Bean___angular_leaf_spot
Saved: Bean___bean_rust
Saved: Bean___healthy
Saved: Blueberry___healthy
Saved: Corn___Cercospora_leaf_spot_Gray_leaf_spot
Saved: Corn___Common_rust_
Saved: Corn___Northern_Leaf_Blight
Saved: Corn___healthy
Saved: Grape___Black_rot
Saved: Grape___Esca_(Black_Measles)
Saved: Grape___Leaf_blight_(Isariopsis_Leaf_Spot)
Saved: Grape___healthy
Saved: Mango___Anthracnose
Saved: Mango___Bacterial_Canker
Saved: Mango___Cutting_Weevil
Saved: Mango___Die_Back
Saved: Mango___Gall_Midge
Saved: Mango___Healthy
Saved: Mango___Powdery_Mildew
Saved: Mango___Sooty_Mould
Saved: Pepper,_bel