# Grad-CAM and Activation Analysis for Satellite Segmentation

In [None]:
import torch
import torch.nn.functional as F
import segmentation_models_pytorch as smp
import numpy as np
import cv2
import matplotlib.pyplot as plt

In [None]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(device)

In [None]:
NUM_CLASSES = 5
ENCODER = "resnet34"

def build_model():
    return smp.Unet(
        encoder_name=ENCODER,
        encoder_weights=None,
        in_channels=3,
        classes=NUM_CLASSES,
    ).to(device)

In [None]:
model = build_model()
model.load_state_dict(torch.load("../models/best_model.pth", map_location=device))
model.eval()

In [None]:
def preprocess_image(path, size=256):
    img_bgr = cv2.imread(path)
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
    img_rgb = cv2.resize(img_rgb, (size, size))
    img = img_rgb.astype(np.float32) / 255.0
    x = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
    return img, x.to(device)

In [None]:
img, x = preprocess_image("path/to/test_image.jpg")
plt.imshow(img)
plt.axis("off")

In [None]:
target_layer = model.encoder.layer4

In [None]:
cache = {"act": None, "grad": None}

def forward_hook(module, inp, out):
    cache["act"] = out
    out.register_hook(lambda g: cache.__setitem__("grad", g))

handle = target_layer.register_forward_hook(forward_hook)

In [None]:
with torch.no_grad():
    logits = model(x)
    pred = torch.argmax(logits, dim=1)

plt.imshow(pred.squeeze().cpu())
plt.title("Predicted classes")
plt.axis("off")

In [None]:
CLASS_ID = 2

In [None]:
model.zero_grad(set_to_none=True)
mask = (pred == CLASS_ID).float()
score = (logits[:, CLASS_ID] * mask).sum()
score.backward()

In [None]:
act = cache["act"]
grad = cache["grad"]

weights = grad.mean(dim=(2, 3), keepdim=True)
cam = (weights * act).sum(dim=1, keepdim=True)
cam = F.relu(cam)

cam = F.interpolate(cam, size=(256, 256), mode="bilinear", align_corners=False)
cam = cam.squeeze().detach().cpu().numpy()
cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)

handle.remove()

In [None]:
plt.figure(figsize=(12,4))

plt.subplot(1,3,1)
plt.imshow(img)
plt.title("Input")
plt.axis("off")

plt.subplot(1,3,2)
plt.imshow(cam)
plt.title("Grad-CAM")
plt.axis("off")

plt.subplot(1,3,3)
plt.imshow(img)
plt.imshow(cam, alpha=0.45)
plt.title("Overlay")
plt.axis("off")

plt.show()

In [None]:
acts = {}

def save_activation(module, inp, out):
    acts["feat"] = out.detach()

h = target_layer.register_forward_hook(save_activation)
_ = model(x)
h.remove()

feat = acts["feat"][0]
print(feat.shape)

In [None]:
k = 8
plt.figure(figsize=(12,6))
for i in range(k):
    plt.subplot(2,4,i+1)
    plt.imshow(feat[i].cpu().numpy())
    plt.title(f"channel {i}")
    plt.axis("off")
plt.show()