<a href="https://colab.research.google.com/github/venkatasl/AIML_TRAINING_VENKAT/blob/main/DRDO2024_OcclusionSaliency.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Saliency maps using occlusion

In this notebook, we will use the simple concept of occlusion maps to find the salient parts of an input image.

An occlusion map is a map of the confidence of a model when different parts of an image are occluded.

First, let us load the model we want to use

In [None]:
# import everything
import torch
import torchvision.transforms as transforms
from torchvision import models
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import requests
from io import BytesIO
import matplotlib.pyplot as plt
import math

In [None]:
# get imagenet classes
response = requests.get('https://raw.githubusercontent.com/pytorch/hub/refs/heads/master/imagenet_classes.txt')
text = response.text
ImageNetClasses = text.splitlines()

In [None]:
# Load a pre-trained model
model = models.resnet50(pretrained=True)
# model = models.vgg16(pretrained=True)
# go to https://pytorch.org/serve/model_zoo.html to find the names of more pretrained models

# put the model on the gpu
model.cuda()
model.eval()

In [None]:
# Define the image transformation. This is important because this is how the model was trained
preprocess = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    # The below values are based on the mean and st.deviation of the ImageNet dataset
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# inverse of the transform
invert = transforms.Compose([
        transforms.Normalize(
    mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
    std=[1/0.229, 1/0.224, 1/0.255]),
        transforms.ToPILImage()])

In [None]:
# Function to load and preprocess the image
def load_image(image_path, display=False):
    if image_path.startswith('http'):
        response = requests.get(image_path)
        img = Image.open(BytesIO(response.content))
    else:
        img = Image.open(image_path)
    img = preprocess(img)
    if display:
      plt.imshow(invert(img))

    # img = img.unsqueeze(0)  # Add batch dimension
    return img

In [None]:
# Let us load an image
# image_path = 'https://raw.githubusercontent.com/pytorch/serve/refs/heads/master/examples/image_classifier/kitten.jpg'  # Replace with your image path or URL
image_path = 'https://www.pixelstalk.net/wp-content/uploads/2016/03/Animals-baby-cat-dog-HD-wallpaper.jpg'
img = load_image(image_path, display=True)


In [None]:
# Now let us do inference on the image
with torch.no_grad():
  outputs = model(img.unsqueeze(0).cuda())

sorted = torch.argsort(-outputs.squeeze())

print('The top 10 predictions are: ')
for i in range(20):
  print(f'{sorted[i]}: {ImageNetClasses[sorted[i]]}')

_, prediction = outputs.max(1)
print(f'Predicted class: {prediction}: {ImageNetClasses[prediction]}')


In order to use occlusion maps, we need to create images where one patch is blacked out. I am using this Dataset class to do this:

In [None]:
class OcclusionDataset(Dataset):
  def __init__(self, img, window=10, stride = 5):
    self.masterimage = img
    self.window = window
    self.stride = stride
    self.pos = math.floor(224/self.stride)


  def __len__(self):
    return self.pos*self.pos

  def display(self, index):
    img, mask = self[index]
    img = invert(img.cpu())
    fig, axes = plt.subplots(1,2,figsize = (4, 2))
    axes[0].imshow(img)
    axes[0].axis('off')
    axes[1].imshow(mask)
    axes[1].axis('off')

  def __getitem__(self, idx):
    img = self.masterimage.clone()
    mask = torch.ones([224,224]) # a tensor of all ones same size as image
    row = math.floor(idx/self.pos)*self.stride
    col = (idx%self.pos) * self.stride
    # set a window in the mask to zero
    mask[ row:min(row+self.window, 223), col:(min(col+self.window, 223))] = 0
    # multiply r,g,b channels with the mask:
    for i in range(3):
      img[i,:,:] = img[i,:,:]*mask
    return img.cuda(), (mask-1)*(-1)

In [None]:
# let us see how this class works by giving very large window and stride
dataset = OcclusionDataset(img, window = 100, stride=100)
for i in range(len(dataset)):
  dataset.display(i)
  plt.show()

Now let us write a function to get the images from the dataset and update the confidence onto a heatmap

In [None]:
def CalculateOcclusionMap(img, window=10, stride=5, pclass = -1):
  with torch.no_grad():
    outputs = model(img.unsqueeze(0).cuda())
  pvalue, pred = outputs.max(1)
  if pclass==-1:
    pclass = pred # -1 means default. So I will set it to its predicted class

  # create an all-zero array to accumulate the heatmap values
  heatmap = torch.zeros([224,224], dtype=torch.float).cuda()

  dataset =  OcclusionDataset(img, window = window, stride=stride)
  dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
  numdata = 0
  with torch.no_grad():
    for ii, (oimg, mask) in enumerate(dataloader):
      outputs = model(oimg.cuda())
      outputs = torch.squeeze(pvalue - outputs[:,pclass])
      heatmap = heatmap + (outputs*mask.permute(1,2,0).cuda()).sum(dim=2)
      numdata = numdata + outputs.shape[0]
      print(f'\r Done {numdata} of {len(dataset)}', end='     ')
  return heatmap.cpu()

Let us find the saliency map for the default class:

In [None]:
heatmap = CalculateOcclusionMap(img, window=100,stride=10)

Let us now view the saliency map!

In [None]:
plt.imshow(invert(img))
# normalize heatmap for displaying
heatmap = (heatmap - heatmap.min())/(heatmap.max()-heatmap.min())
plt.imshow(heatmap, cmap='jet', alpha=heatmap)

Now let us find the saliency map for class 209: Chesapeake Bay retriever

In [None]:
heatmap = CalculateOcclusionMap(img, window=100,stride=5, pclass=209)
plt.imshow(invert(img))
# normalize heatmap for displaying
heatmap = (heatmap - heatmap.min())/(heatmap.max()-heatmap.min())
plt.imshow(heatmap, cmap='jet', alpha=heatmap)

## Exercises

1. Try with your own images!
2. What happens when you change the window or stride?
3. What happens when the window is too small or too large? What would be the ideal size of the window?
4. In images with multiple objects, try to change the class and calculate the saliency map
5. Try to change the model. Choose another model from the modelzoo. Do you get the same saliency maps?