In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import torch, torchvision
from torchvision import transforms
import numpy as np
import gradio as gr
from PIL import Image
from pytorch_grad_cam import GradCAM, EigenCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
import random
import config
import cv2
from utils import non_max_suppression, cells_to_bboxes, get_bboxes, YoloCAM
from PIL import Image

# from models.resnet import ResNet18
from fixed_model import s13Model, ScalePrediction, CNNBlock, ResidualBlock

import gradio as gr
model = s13Model(num_classes=config.NUM_CLASSES)

# # new_model = model.load_from_checkpoint('s10Model.ckpt')
model.load_state_dict(torch.load("/home/sn/object_detection_demo/s13Model.pth", map_location=torch.device('cpu')), strict=False)
model.eval()
classes = [
    "aeroplane",
    "bicycle",
    "bird",
    "boat",
    "bottle",
    "bus",
    "car",
    "cat",
    "chair",
    "cow",
    "diningtable",
    "dog",
    "horse",
    "motorbike",
    "person",
    "pottedplant",
    "sheep",
    "sofa",
    "train",
    "tvmonitor"
]

# This will help us create a different color for each class
COLORS = np.random.uniform(0, 255, size=(len(classes), 3))
scaled_anchors = (torch.tensor(config.ANCHORS)
    * torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)).to('cpu')

In [29]:
def draw_boxes(image, boxes):
# Create a Rectangle patch
    
    image = np.array(image)
    height, width, _ = image.shape
    for box in boxes:
        assert len(box) == 6, "box should contain class pred, confidence, x, y, width, height"
        class_pred = int(box[0])
        box = box[2:]
        upper_left_x = box[0] - box[2] / 2
        upper_left_y = box[1] - box[3] / 2
        lower_right_x = box[0] + box[2] / 2
        lower_right_y = box[1] + box[3] / 2
        color = COLORS[class_pred]
        cv2.rectangle(
            image,
            (int(upper_left_x*width), int(upper_left_y*height)),
            (int(lower_right_x*width), int(lower_right_y*height)),
            color, 1
        )
        cv2.putText(image, f"{classes[int(class_pred)]} : {box[1]:.2f}", (int(upper_left_x*width), int(upper_left_y*height - 5)),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1,
                    lineType=cv2.LINE_AA)
    return image

def inference(input_img, transparency = 0.5): #, target_layer_number = -1):
    org_img = input_img
    input_img_aug = config.test_transforms(image=input_img, bboxes=[])
    input_img = input_img_aug["image"]
    input_img = input_img
    input_img = input_img.unsqueeze(0)
    outputs = model(input_img)
    obj_bboxes = get_bboxes(out=outputs, anchors=scaled_anchors)
    obj_detected = draw_boxes(org_img, obj_bboxes)

    grayscale_cam = cam.forward(input_tensor=input_img, scaled_anchors=scaled_anchors, targets=None)
    grayscale_cam = grayscale_cam[0, :]
    
    visualization = show_cam_on_image(org_img/255, grayscale_cam, use_rgb=True, image_weight=transparency)

    return obj_detected/255, visualization


In [20]:
#Set up gradcam instance
target_layer_number = -2
target_layers = [model.layers[target_layer_number]]
cam = YoloCAM(model=model, target_layers=target_layers, use_cuda=False)

In [30]:
title = "PASCAL VOC 2007 Dataset trained on Custom Model with GradCAM"
description = "A simple Gradio interface for Object Detection using Custom model, and get GradCAM"
examples = [ #["example_imgs/cat.jpeg", 0.5], 
            ["example_imgs/dog.jpeg", 0.5],
            ['example_imgs/000030.jpg', 0.6],
            ['example_imgs/000050.jpg', 0.5],
            ['example_imgs/dogs.jpeg', 0.6],
            ['example_imgs/train.jpeg', 0.6],
            ['example_imgs/bird1.jpeg', 0.7],
            ['example_imgs/cars1.jpeg', 0.5],
            ['example_imgs/horse1.jpeg', 0.6],
            # ['example_imgs/train2.jpeg'],
            # ['example_imgs/bird2.jpeg'],
            # ['example_imgs/cars2.jpeg'],
            # ['example_imgs/horse2.jpeg'],

           ]
demo = gr.Interface(
    inference, 
    inputs = [gr.Image(shape=(416, 416), label="Input Image"), 
              gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM"),
             ],
    outputs = gr.Gallery(rows=2, columns=1, min_width=416),
    title = title,
    description = description,
    examples = examples,
)
demo.launch()

Running on local URL:  http://127.0.0.1:7870

To create a public link, set `share=True` in `launch()`.


