<a href="https://colab.research.google.com/github/santule/ERA/blob/main/S13/yolo3_gradcam_gradio.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
import os
drive.mount('/content/drive/')
%cd /content/drive/MyDrive/AI/ERA_course/session13_part3_eval

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).
/content/drive/MyDrive/AI/ERA_course/session13_part3_eval


### LOAD LIGHTNING MODEL

In [2]:
!pip install pytorch-lightning --quiet
!pip install lightning-bolts --quiet

In [3]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from torch.optim.lr_scheduler import OneCycleLR
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.loggers import TensorBoardLogger
import config
import torch
import torch.optim as optim
import matplotlib

from model import YOLOv3
from tqdm import tqdm
from utils_org import (
    mean_average_precision,
    cells_to_bboxes,
    get_evaluation_bboxes,
    save_checkpoint,
    load_checkpoint,
    check_class_accuracy,
    plot_couple_examples,
    accuracy_fn,
    get_loaders,
    non_max_suppression,
    plot_image
)
from loss import YoloLoss
import warnings
warnings.filterwarnings("ignore")

In [4]:
# custom functions for yolo

# loss function for yolov3
loss_fn = YoloLoss()
def criterion(out, y,anchors):
  loss = (  loss_fn(out[0], y[0], anchors[0])
            + loss_fn(out[1], y[1], anchors[1])
            + loss_fn(out[2], y[2], anchors[2])
            )
  return loss

# accuracy function for yolov3
def accuracy_fn(y, out, threshold,
                correct_class, correct_obj,
                correct_noobj, tot_class_preds,
                tot_obj, tot_noobj):

  for i in range(3):

      obj = y[i][..., 0] == 1 # in paper this is Iobj_i
      noobj = y[i][..., 0] == 0  # in paper this is Iobj_i

      correct_class += torch.sum(
          torch.argmax(out[i][..., 5:][obj], dim=-1) == y[i][..., 5][obj]
      )
      tot_class_preds += torch.sum(obj)

      obj_preds = torch.sigmoid(out[i][..., 0]) > threshold
      correct_obj += torch.sum(obj_preds[obj] == y[i][..., 0][obj])
      tot_obj += torch.sum(obj)
      correct_noobj += torch.sum(obj_preds[noobj] == y[i][..., 0][noobj])
      tot_noobj += torch.sum(noobj)

  return((correct_class/(tot_class_preds+1e-16))*100,
         (correct_noobj/(tot_noobj+1e-16))*100,
         (correct_obj/(tot_obj+1e-16))*100)

# pytorch lightning
class LitYolo(LightningModule):
    def __init__(self, num_classes=config.NUM_CLASSES, lr=config.LEARNING_RATE,weight_decay=config.WEIGHT_DECAY,threshold=config.CONF_THRESHOLD,my_dataset=None):
        super().__init__()

        self.save_hyperparameters()
        self.model = YOLOv3(num_classes=self.hparams.num_classes)
        self.my_dataset = my_dataset
        self.criterion = criterion
        self.accuracy_fn = accuracy_fn
        self.tot_class_preds, self.correct_class = 0, 0
        self.tot_noobj, self.correct_noobj = 0, 0
        self.tot_obj, self.correct_obj = 0, 0
        self.scaled_anchors = 0

    def set_scaled_anchor(self, scaled_anchors):
      self.scaled_anchors = scaled_anchors

    def forward(self, x):
      out = self.model(x)
      return out

    def on_train_epoch_start(self):
      # Set a new image size for the dataset at the beginning of each epoch
      size_idx = random.choice(range(len(config.IMAGE_SIZES)))
      self.my_dataset.set_image_size(size_idx=0)
      self.set_scaled_anchor((
          torch.tensor(config.ANCHORS)
          * torch.tensor(config.S[size_idx]).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
      ))

    def on_validation_epoch_start(self):
      self.set_scaled_anchor((
          torch.tensor(config.ANCHORS)
          * torch.tensor(config.S[1]).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
      ))

    def on_test_epoch_start(self):
      self.set_scaled_anchor((
          torch.tensor(config.ANCHORS)
          * torch.tensor(config.S[1]).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
      ))


    def training_step(self, batch, batch_idx):
        x, y = batch
        out  = self(x)
        loss = self.criterion(out,y,self.scaled_anchors)
        acc  = self.accuracy_fn(y,out,self.hparams.threshold,self.correct_class,
                                                                     self.correct_obj,
                                                                     self.correct_noobj,
                                                                     self.tot_class_preds,
                                                                     self.tot_obj,
                                                                     self.tot_noobj)

        self.log('train_loss', loss, prog_bar=True, on_step=False, on_epoch=True)
        self.log_dict({"class_accuracy": acc[0], "no_object_accuracy": acc[1], "object_accuracy":acc[2]},prog_bar=True,on_step=False, on_epoch=True)
        return loss


    def evaluate(self, batch, stage=None):
        x, y = batch
        test_out = self(x)
        loss = self.criterion(test_out,y,self.scaled_anchors)
        acc  = self.accuracy_fn(y,test_out,self.hparams.threshold,self.correct_class,
                                                                     self.correct_obj,
                                                                     self.correct_noobj,
                                                                     self.tot_class_preds,
                                                                     self.tot_obj,
                                                                     self.tot_noobj)

        if stage:
            self.log(f"{stage}_loss", loss, prog_bar=True)
            self.log_dict({"class_accuracy": acc[0], "no_object_accuracy": acc[1], "object_accuracy":acc[2]},prog_bar=True)

    def test_step(self, batch, batch_idx):
        self.evaluate(batch, "test")

    def validation_step(self, batch, batch_idx):
        self.evaluate(batch, "val")

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
        scheduler = OneCycleLR(
                optimizer,
                max_lr= 1E-2,
                pct_start = 5/self.trainer.max_epochs,
                epochs=self.trainer.max_epochs,
                steps_per_epoch=len(train_loader),
                div_factor=100,verbose=True,
                three_phase=False
            )
        return ([optimizer],[scheduler])

### GRADIO APP AND GRADCAM

In [5]:
!pip install gradio --quiet
!pip install albumentations --quiet
!pip install grad-cam --quiet

from pytorch_grad_cam import EigenCAM
from pytorch_grad_cam.utils.model_targets import FasterRCNNBoxScoreTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

import albumentations as Al
from albumentations.pytorch import ToTensorV2
from PIL import Image
import cv2
import gradio as gr
from torchvision import transforms
import albumentations as Al
import utils
import matplotlib.pyplot as plt

In [8]:
with gr.Blocks() as demo:
    #examples = [["/content/drive/MyDrive/AI/ERA_course/session13_old/PASCAL_VOC/images/009948.jpg"],["/content/drive/MyDrive/AI/ERA_course/session13_old/PASCAL_VOC/images/009948.jpg"]]

    # colors for the bboxes
    cmap = plt.get_cmap("tab20b")
    class_labels = config.PASCAL_CLASSES
    colors = [cmap(i) for i in np.linspace(0, 1, len(class_labels))]
    colors_hex = {class_labels[i]:matplotlib.colors.rgb2hex(colors[i]) for i in range(0,len(class_labels))}

    def yolov3_reshape_transform(x): # consolidate the output from the model for gradcam to work
      activations = []
      size = x[0].size()[2:4] # 13 * 13
      for x_item in x:
        x_permute = x_item.permute(0, 1, 4, 2, 3 ) # 1,3,25,13,13
        x_permute = x_permute.reshape((x_permute.shape[0],
                                    x_permute.shape[1]*x_permute.shape[2],
                                    *x_permute.shape[3:])) # 1,75,13,13
        activations.append(torch.nn.functional.interpolate(torch.abs(x_permute), size, mode='bilinear'))
      activations = torch.cat(activations, axis=1) # 1,255,13,13
      return(activations)

    def yolo3_inference(input_img,gradcam=True,gradcam_opa=0.5): # function for yolo inference

      # load model
      yololit = LitYolo()
      inference_model = yololit.load_from_checkpoint("yolo3_model.ckpt")

      # bboxes, gradcam
      anchors  = (torch.tensor(config.ANCHORS) * torch.tensor(config.S[1]).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2))
      bboxes   = [[]]
      sections = [] # to return image and annotations
      nms_boxes_output = []

      # image transformation
      test_transforms = Al.Compose(
        [
            Al.LongestMaxSize(max_size=416),
            Al.PadIfNeeded(
                min_height=416, min_width=416, border_mode=cv2.BORDER_CONSTANT
            ),
            Al.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,),
        ]
      )

      input_img_copy = test_transforms(image=input_img)['image']
      transform = transforms.ToTensor()
      input_img_tensor = transform(input_img_copy).unsqueeze(0)

      # infer the image
      inference_model.eval()
      test_img_out   = inference_model(input_img_tensor)

      # process the outputs to create bounding boxes
      for i in range(3):
          batch_size, A, S, _, _ = test_img_out[i].shape # 1, anchors = 3, scaling = 13/26/52
          anchor = anchors[i]
          boxes_scale_i = utils.cells_to_bboxes(test_img_out[i], anchor, S=S, is_preds=True)
          for idx, (box) in enumerate(boxes_scale_i):
              bboxes[idx] += box
      # nms
      nms_boxes = utils.non_max_suppression(bboxes[0], iou_threshold=0.6, threshold=0.5, box_format="midpoint",)
      nms_boxes_output.append(nms_boxes)

      # use gradio image annotations
      height, width = 416, 416
      for box in nms_boxes:
        class_pred = box[0]
        box = box[2:]
        upper_left_x  = int((box[0] - box[2] / 2) * width)
        upper_left_y  = max(int((box[1] - box[3] / 2) * height),0) # less than 0, box collapses
        lower_right_x = int(upper_left_x + (box[2] * width))
        lower_right_y = int(upper_left_y + (box[3] * height))
        sections.append(((upper_left_x,upper_left_y,lower_right_x,lower_right_y), class_labels[int(class_pred)]))

      # for gradcam
      if gradcam:
        objs = [b[1] for b in nms_boxes_output[0]]
        bbox_coord = [b[2:] for b in nms_boxes_output[0]]
        targets = [FasterRCNNBoxScoreTarget(objs, bbox_coord)]

        target_layers = [inference_model.model]
        cam = EigenCAM(inference_model, target_layers, use_cuda=False,reshape_transform=yolov3_reshape_transform)
        grayscale_cam = cam(input_tensor = input_img_tensor, targets= targets)
        grayscale_cam = grayscale_cam[0, :]
        visualization = show_cam_on_image(input_img_copy, grayscale_cam, use_rgb=True, image_weight=gradcam_opa)

        return (visualization,sections)
      else:
        return (np.array(input_img_tensor.squeeze(0).permute(1,2,0)),sections)

    # app GUI
    with gr.Row():
        img_input  = gr.Image()
        img_output = gr.AnnotatedImage(shape=(100, 100)).style(color_map = colors_hex)
    with gr.Row():
      gradcam_check = gr.Checkbox(label="Gradcam")
      gradcam_opa = gr.Slider(0, 1, value = 0.5, label="Opacity of GradCAM")


    section_btn = gr.Button("Identify Objects")
    section_btn.click(yolo3_inference, inputs=[img_input,gradcam_check,gradcam_opa], outputs=[img_output])
    gr.Markdown("## Some Examples")
    # gr.Examples(examples=examples,
    #                          inputs =[img_input,gradcam_check,gradcam_opa],
    #                          outputs=img_output,
    #                          fn=yolo3_inference, cache_examples=False)

if __name__ == "__main__":
    demo.launch(debug=True)

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
Note: opening Chrome Inspector may crash demo inside Colab notebooks.

To create a public link, set `share=True` in `launch()`.


<IPython.core.display.Javascript object>

Keyboard interruption in main thread... closing server.
