<a href="https://colab.research.google.com/github/chcomet/CAP-CS4MS/blob/main/Gradio%2BGradCAM%2BResnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Resnet 18

In [None]:
import torch
from torch import nn
from torchvision import models, transforms
import numpy as np
from PIL import Image

In [None]:
# hyper parameter for model
classes = ['cat', 'dog']
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# image transforms
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
tf = transforms.Compose([transforms.Resize((224, 224)),
              transforms.ToTensor(),
              transforms.Normalize(mean=mean, std=std)])
# convert normalized tensor to numpy array
def tensor2np(tensor, mean, std):
  # inverse of normalization
  tensor = tensor.clone()
  mean_tensor = torch.as_tensor(list(mean), dtype=tensor.dtype, device=tensor.device).view(-1,1,1)
  std_tensor = torch.as_tensor(list(std), dtype=tensor.dtype, device=tensor.device).view(-1,1,1)
  tensor.mul_(std_tensor).add_(mean_tensor)
  # convert tensor to numpy format for plt presentation
  npimg = tensor.numpy()
  npimg = np.transpose(npimg,(1,2,0)) # C*H*W => H*W*C
  return npimg

In [None]:
# set up resnet18
model = models.resnet18(pretrained=False)
model.fc = nn.Sequential(
    nn.Linear(512,256),
    nn.ReLU(),
    nn.Dropout(0.2),
    nn.Linear(256,2),
    nn.LogSoftmax(dim=1)
)
state_dict_train = torch.hub.load_state_dict_from_url('https://github.com/CS4MS/CS4MS_W21/raw/main/checkpoints/dogs-vs-cats.pth', map_location=device)
# load the trained weights (state_dict) in our model
model.load_state_dict(state_dict_train["state_dict"])
# put our model in eval mode
model.eval()

Grad CAM

In [None]:
# clone repository
!git clone https://github.com/jacobgil/pytorch-grad-cam.git
# install the related dependencies
!pip install -r /content/pytorch-grad-cam/requirements.txt
# move the core module to working directory /content
!mv /content/pytorch-grad-cam/pytorch_grad_cam /content/pytorch_grad_cam

In [None]:
from pytorch_grad_cam import GradCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.image import show_cam_on_image

In [None]:
# draw Grad-CAM on image 
def image_grad_cam(model, input_tensor, input_float_np, target_layers):
  cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
  grayscale_cam = cam(input_tensor=input_tensor)
  grayscale_cam = grayscale_cam[0, :]
  return show_cam_on_image(input_float_np, grayscale_cam, use_rgb=True)

Gradio

In [None]:
!pip install gradio
import gradio as gr

In [None]:
# config the predict function, input type of image is numpy.nparray
def predict(inp):
  # numpy.nparray -> PIL.Image
  img = Image.fromarray(inp.astype('uint8'), 'RGB')
  # normalize the image to fit the input size of our model
  input_tensor = tf(img)
  # get copyt of input in type of numpy array float32
  input_float_np = tensor2np(input_tensor, mean, std)
  # unsqueeze the input_tensor
  input_tensor = input_tensor.unsqueeze(dim=0)
  # predict
  with torch.no_grad():
    outputs = model(input_tensor)
  outputs = torch.exp(outputs)
  # probabilities of all classes
  pred_softmax = torch.softmax(outputs, dim=1).cpu().numpy()[0]
  # grad_cam image
  target_layers = [model.layer4[-1]]
  output_img = image_grad_cam(model,input_tensor,input_float_np,target_layers)
  # return label dict and suggestion
  return {classes[i]: float(pred_softmax[i]) for i in range(len(classes))},output_img

In [None]:
# download sample images
!wget https://cdn.pixabay.com/photo/2014/11/30/14/11/cat-551554_960_720.jpg
!mv /content/cat-551554_960_720.jpg /content/cat.jpg
!wget https://cdn.pixabay.com/photo/2015/11/17/13/13/bulldog-1047518_960_720.jpg
!mv /content/bulldog-1047518_960_720.jpg /content/bulldog.jpg

In [None]:
# start gradio application
gr.Interface(
        fn=predict, 
        inputs=gr.inputs.Image(), 
        outputs=[gr.outputs.Label(label="Classification Result"),gr.outputs.Image(label="GRADCAM")],
        examples=[['cat.jpg'],['bulldog.jpg']],
        title="Cat and Dog Classification"
      ).launch()