In [1]:
# ================================
# Cell: Grad-CAM++ (Standalone) + Auto-Zip
# Uses weights from: /kaggle/input/final-weights
# ================================

import os, random, zipfile
import numpy as np
from PIL import Image
import cv2

import torch
import torch.nn as nn
from torchvision import transforms, datasets


# -----------------------
# Settings
# -----------------------
DATA_ROOT = "/kaggle/input/plant-disease-dataset/Dataset_Final_V2_Split"
WEIGHTS_DIR = "/kaggle/input/final-weights"     # <-- your folder
OUT_DIR = "/kaggle/working/plantanet_gradcampp"
os.makedirs(OUT_DIR, exist_ok=True)

IMG_SIZE = 160
SEED = 42

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)


# -----------------------
# Eval transform (same as training)
# -----------------------
eval_tfm = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

# -----------------------
# Load test dataset
# -----------------------
test_dir = os.path.join(DATA_ROOT, "test")
if not os.path.exists(test_dir):
    raise FileNotFoundError(f"Test folder not found: {test_dir}")

test_ds = datasets.ImageFolder(test_dir, transform=eval_tfm)
class_names = test_ds.classes
NUM_CLASSES = len(class_names)

print(f"Loaded test set: {len(test_ds)} images | Classes={NUM_CLASSES}")


# -----------------------
# Model definition (PlantaNet-ReLU)
# -----------------------
class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1):
        super().__init__()
        self.depth = nn.Conv2d(in_ch, in_ch, 3, stride, 1, groups=in_ch, bias=False)
        self.point = nn.Conv2d(in_ch, out_ch, 1, bias=False)
        self.norm = nn.GroupNorm(8 if out_ch % 8 == 0 else 4, out_ch)
        self.act = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.depth(x)
        x = self.point(x)
        x = self.norm(x)
        x = self.act(x)
        return x


class PlantaNetReLU(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        c1, c2, c3, c4 = 144, 224, 320, 448

        self.stem = nn.Sequential(
            nn.Conv2d(3, c1, 3, 1, 1, bias=False),
            nn.GroupNorm(8 if c1 % 8 == 0 else 4, c1),
            nn.ReLU(inplace=True),
        )

        self.block1 = nn.Sequential(
            DepthwiseSeparableConv(c1, c2, stride=2),
            nn.Dropout(0.15)
        )
        self.block2 = nn.Sequential(
            DepthwiseSeparableConv(c2, c3, stride=2),
            nn.Dropout(0.20)
        )
        self.block3 = nn.Sequential(
            DepthwiseSeparableConv(c3, c4, stride=2),
            nn.Dropout(0.25)
        )

        self.conv_extra = nn.Sequential(
            nn.Conv2d(c4, c4, 3, 1, 1, bias=False),
            nn.GroupNorm(8 if c4 % 8 == 0 else 4, c4),
            nn.ReLU(inplace=True),
            nn.Dropout(0.25),
        )

        self.gap = nn.AdaptiveAvgPool2d(1)

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(c4, 1024, bias=False),
            nn.GroupNorm(8 if 1024 % 8 == 0 else 4, 1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Linear(1024, num_classes),
        )

    def forward(self, x):
        x = self.stem(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.conv_extra(x)
        x = self.gap(x)
        x = self.classifier(x)
        return x


# -----------------------
# Auto-find weights
# -----------------------
if not os.path.exists(WEIGHTS_DIR):
    raise FileNotFoundError(f"Weights folder not found: {WEIGHTS_DIR}")

pth_files = []
for root, _, files in os.walk(WEIGHTS_DIR):
    for f in files:
        if f.endswith(".pth"):
            pth_files.append(os.path.join(root, f))

print("Found .pth files:")
for p in pth_files:
    print(" -", p)

if len(pth_files) == 0:
    raise FileNotFoundError(f"No .pth file found inside {WEIGHTS_DIR}")

WEIGHTS_PATH = pth_files[0]   # if multiple, change index
print("\nUsing weights:", WEIGHTS_PATH)


# -----------------------
# Load model + weights
# -----------------------
model = PlantaNetReLU(NUM_CLASSES).to(DEVICE)
state = torch.load(WEIGHTS_PATH, map_location=DEVICE)

# handle DataParallel keys
if isinstance(state, dict) and any(k.startswith("module.") for k in state.keys()):
    state = {k.replace("module.", ""): v for k, v in state.items()}

model.load_state_dict(state, strict=False)
model.eval()
print("Weights loaded!")


# -----------------------
# Grad-CAM++ implementation
# -----------------------
target_layer = model.conv_extra[0]

class GradCAMPlusPlus:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None

        target_layer.register_forward_hook(self._save_activation)
        target_layer.register_full_backward_hook(self._save_gradient)

    def _save_activation(self, module, inp, out):
        self.activations = out

    def _save_gradient(self, module, grad_input, grad_output):
        self.gradients = grad_output[0]

    def generate(self, input_tensor, class_idx=None):
        self.model.zero_grad()
        outputs = self.model(input_tensor)

        if class_idx is None:
            class_idx = int(outputs.argmax(dim=1).item())

        score = outputs[:, class_idx]
        score.backward(retain_graph=False)

        grads = self.gradients  # [B,C,H,W]
        acts  = self.activations

        # Grad-CAM++ weights
        grads_sq = grads ** 2
        grads_cb = grads ** 3

        sum_acts = torch.sum(acts, dim=(2, 3), keepdim=True)
        eps = 1e-9

        alpha_num = grads_sq
        alpha_den = 2 * grads_sq + sum_acts * grads_cb
        alpha_den = torch.where(alpha_den != 0.0, alpha_den, torch.ones_like(alpha_den) * eps)

        alphas = alpha_num / alpha_den
        positive_grads = torch.relu(grads)

        weights = torch.sum(alphas * positive_grads, dim=(2, 3), keepdim=True)

        cam = torch.sum(weights * acts, dim=1).squeeze(0)
        cam = torch.relu(cam)

        cam_np = cam.detach().cpu().numpy()
        cam_np -= cam_np.min()
        cam_np /= (cam_np.max() + eps)
        return cam_np


gradcampp = GradCAMPlusPlus(model, target_layer)

# denorm helpers
mean = np.array([0.485, 0.456, 0.406]).reshape(3, 1, 1)
std  = np.array([0.229, 0.224, 0.225]).reshape(3, 1, 1)

def denorm_tensor(img_tensor):
    arr = img_tensor.detach().cpu().numpy()
    arr = (arr * std) + mean
    arr = np.clip(arr, 0, 1)
    return np.transpose(arr, (1, 2, 0))


# -----------------------
# Pick 1 correctly classified image per class
# -----------------------
print("\nFinding correctly classified samples...")
correct_samples = {cls: [] for cls in range(NUM_CLASSES)}

with torch.no_grad():
    for path, label in test_ds.samples:
        pil = Image.open(path).convert("RGB")
        tensor = eval_tfm(pil).unsqueeze(0).to(DEVICE)
        out = model(tensor)
        pred = out.argmax(1).item()
        if pred == label:
            correct_samples[label].append(path)

print("Done.")


# -----------------------
# Generate Grad-CAM++ overlays
# -----------------------
gradcampp_dir = os.path.join(OUT_DIR, "gradcampp")
os.makedirs(gradcampp_dir, exist_ok=True)

print("\nGenerating Grad-CAM++ ...")
for cls in range(NUM_CLASSES):
    if len(correct_samples[cls]) == 0:
        print(f"No correct sample for class {class_names[cls]}")
        continue

    img_path = correct_samples[cls][0]

    try:
        pil = Image.open(img_path).convert("RGB")
        tensor = eval_tfm(pil).unsqueeze(0).to(DEVICE)

        cam_map = gradcampp.generate(tensor, class_idx=cls)
        cam_resized = cv2.resize(cam_map, (IMG_SIZE, IMG_SIZE))

        heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)[:, :, ::-1]

        original = denorm_tensor(tensor[0])
        orig_uint8 = np.uint8(original * 255)
        heat_uint8 = np.uint8(heatmap)

        overlay = cv2.addWeighted(orig_uint8, 0.6, heat_uint8, 0.4, 0)

        cls_name = class_names[cls].replace("/", "_").replace(" ", "_")
        cls_folder = os.path.join(gradcampp_dir, cls_name)
        os.makedirs(cls_folder, exist_ok=True)

        save_path = os.path.join(cls_folder, f"gradcampp_{cls_name}.png")
        cv2.imwrite(save_path, overlay[:, :, ::-1])
        print("Saved:", save_path)

    except Exception as e:
        print(f"Grad-CAM++ error for {img_path}: {e}")


print("\n Grad-CAM++ images saved in:", gradcampp_dir)


# -----------------------
# Auto-zip for download
# -----------------------
zip_path = os.path.join(OUT_DIR, "gradcampp_results.zip")

with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
    for root, _, files in os.walk(gradcampp_dir):
        for f in files:
            full_path = os.path.join(root, f)
            rel_path = os.path.relpath(full_path, gradcampp_dir)
            zf.write(full_path, arcname=rel_path)

print("Zipped file created at:", zip_path)
print("You can download it from the Kaggle 'Output' panel.")


Device: cpu
Loaded test set: 20506 images | Classes=51
Found .pth files:
 - /kaggle/input/final-weights/PlantaNet_ReLU_final_weights.pth

Using weights: /kaggle/input/final-weights/PlantaNet_ReLU_final_weights.pth
Weights loaded!

Finding correctly classified samples...
Done.

Generating Grad-CAM++ ...
Saved: /kaggle/working/plantanet_gradcampp/gradcampp/Apple___Apple_scab/gradcampp_Apple___Apple_scab.png
Saved: /kaggle/working/plantanet_gradcampp/gradcampp/Apple___Black_rot/gradcampp_Apple___Black_rot.png
Saved: /kaggle/working/plantanet_gradcampp/gradcampp/Apple___Cedar_apple_rust/gradcampp_Apple___Cedar_apple_rust.png
Saved: /kaggle/working/plantanet_gradcampp/gradcampp/Apple___healthy/gradcampp_Apple___healthy.png
Saved: /kaggle/working/plantanet_gradcampp/gradcampp/Banana___cordana/gradcampp_Banana___cordana.png
Saved: /kaggle/working/plantanet_gradcampp/gradcampp/Banana___healthy/gradcampp_Banana___healthy.png
Saved: /kaggle/working/plantanet_gradcampp/gradcampp/Banana___pestalot