#### Set up google colab and unzip train data¶


In [None]:
# Set up google drive in google colab

from google.colab import drive
drive.mount('/content/drive')

In [2]:

# Unzip training data from drive

!unzip -q 'drive/My Drive/VOCdevkit.zip'

####Import Libraries


In [3]:
import os
import math
import xml.etree.ElementTree as ET
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image,ImageFilter, ImageEnhance
import cv2
import torch.optim as optim
from torchvision.ops import nms
from tqdm import tqdm
import time
import random

#### Define Model Hyper Parameters


In [4]:
#Classes to train on
select_classes = {'person', 'bird', 'cat', 'cow', 'dog', 'horse', 'sheep', 'aeroplane', 'bicycle', 'boat', 'bus', 'car', 'motorbike', 'train', 'bottle', 'chair', 'diningtable', 'pottedplant', 'sofa', 'tvmonitor'}

input_image_height = 448
input_image_width = 448
anchors_per_box = 2  # terminology anchors used here is incorrect - it should be boxes_per_grid instead

grid_ht = int(input_image_height/7)
grid_wt = int(input_image_width/7)


prob_threshold = 0.1
conf_threshold = 0.1
nms_threshold = 0.5

batch_num = 8

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


####Visualisation Code

In [5]:
# Given input image, draw rectangles as specified by gt_box and pred_box and display
def visualize_tensor(img, gt_box, pred_box, pred_labels = '', pred_score = ''):
    plt.figure(figsize=(5,5))
    transform_img = inv_normalize(img[0]).permute(1,2,0).to('cpu').numpy()
    transform_img = transform_img.copy()
    for box in gt_box:
        x0, x1, y0, y1 = box
        cv2.rectangle(transform_img, (int(x0),int(y0)), (int(x1),int(y1)), color=(0, 255, 255), thickness=2)
    
    if pred_labels == '':
      for box in pred_box:
          x0, x1, y0, y1 = box
          cv2.rectangle(transform_img, (int(x0), int(y0)), (int(x1), int(y1)), color=(255, 0, 0), thickness=2)
    else:
      for idx,box in enumerate(pred_box):
          x0, x1, y0, y1 = box
          cv2.rectangle(transform_img, (int(x0), int(y0)), (int(x1), int(y1)), color=(255, 0, 0), thickness=2)
          print( pred_score[idx].item)
          cv2.putText(transform_img, pred_labels[idx], (x0, y0+15), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
          cv2.putText(transform_img, str(pred_score[idx].item())[:4], (x0, y0+40), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)

    plt.imshow(transform_img)

####Handle Training Data


#####Get all classes and create label encoding


In [6]:
# Getting all the classes and creating label encoding

all_labels = []
max_gt_count = -1   # A single image in training set has max these many classes

for out in sorted(os.listdir('VOCdevkit/VOC2007/Annotations/')):
    tree = ET.parse('VOCdevkit/VOC2007/Annotations/' + out)
    current_gt_count = 0
    for obj in tree.findall('object'):
        lab = (obj.find('name').text)
        # all_labels.append(lab)
        if (lab in (select_classes)):
          all_labels.append(lab)
          current_gt_count += 1
    max_gt_count = max(current_gt_count, max_gt_count)

     

distict_labels = list(set(all_labels))
distict_labels = sorted(distict_labels)

lab_to_val = {j:i for i,j in enumerate(distict_labels)}
val_to_lab = {i:j for i,j in enumerate(distict_labels)}

num_classes = len(distict_labels)

print("All Labels -- ",np.unique(all_labels, return_counts=True)[0])
print("Label Counts -- ",np.unique(all_labels, return_counts=True)[1]) 
print("Num Classes -- ",num_classes)
print("Max Detection in an image -- ",max_gt_count)

All Labels --  ['aeroplane' 'bicycle' 'bird' 'boat' 'bottle' 'bus' 'car' 'cat' 'chair'
 'cow' 'diningtable' 'dog' 'horse' 'motorbike' 'person' 'pottedplant'
 'sheep' 'sofa' 'train' 'tvmonitor']
Label Counts --  [ 331  418  599  398  634  272 1644  389 1432  356  310  538  406  390
 5447  625  353  425  328  367]
Num Classes --  20
Max Detection in an image --  42


#####Create lists with train and test images 

In [7]:
file = open('drive/My Drive/val.txt', "r")
valid_images = file.read().split('\n')
valid_images = valid_images[:-1]


train_images = []
for img in os.listdir('VOCdevkit/VOC2007/JPEGImages/'):
  if img[:-4] in valid_images:
    continue
  train_images.append(img[:-4])


##### Data Augmentation

In [8]:
# Random blur on training image
def random_blur(img):
  if random.random() < 0.5:
    return img
  
  rad = random.choice([1,2])
  img = img.filter(ImageFilter.BoxBlur(radius=rad))
  return img


# Random brightness, contrast, satutration and hue
def random_color(img):
  if random.random() < 0.1:
    return img

  img = transforms.ColorJitter(brightness=(0.5,2.0), contrast=(0.5,2.0), saturation=(0.5,2.0), hue=(-0.25,0.25))(img)
  return img

# Random horizontal flip
def random_flip(img, gt_box):
  if random.random() < 0.5:
    return img,gt_box

  img = transforms.RandomHorizontalFlip(p=1)(img)
  temp = (gt_box[:,1]).copy()
  gt_box[:,1] = img.size[0] - gt_box[:,0] #x1
  gt_box[:,0] = img.size[0] - temp #x2

  return img, gt_box


# Random crop on image
def random_crop(img, gt_box, labels):
  if random.random() < 0.5:
    return img,gt_box,labels
  width, height = img.size
  select_w = random.uniform(0.6*width, width)
  select_h = random.uniform(0.6*height, height)

  start_x = random.uniform(0,width - select_w)
  start_y = random.uniform(0,height - select_h)

  left = start_x
  upper = start_y
  right = start_x + select_w
  bottom = start_y + select_h

  gt_box_copy = gt_box.copy()

  gt_box_copy[gt_box_copy[:,0] < left, 0] = left
  gt_box_copy[gt_box_copy[:,1] > right, 1] = right
  gt_box_copy[gt_box_copy[:,2] < upper, 2] = upper
  gt_box_copy[gt_box_copy[:,3] > bottom, 3] = bottom

  final_gt_box = []
  final_labels = []

  for i in range((gt_box_copy.shape[0])):
    if (((gt_box_copy[i,1] - gt_box_copy[i,0])/(gt_box[i,1]-gt_box[i,0])) < 0.5):
      continue
    if (((gt_box_copy[i,3] - gt_box_copy[i,2])/(gt_box[i,3] - gt_box[i,2])) < 0.5):
      continue
    final_gt_box.append(gt_box_copy[i])
    final_labels.append(labels[i])

  if len(final_gt_box) == 0:
    return img,gt_box,labels

  final_gt_box = np.array(final_gt_box)
  final_gt_box[:,0] = final_gt_box[:,0] - left
  final_gt_box[:,1] = final_gt_box[:,1] - left
  final_gt_box[:,2] = final_gt_box[:,2] - upper
  final_gt_box[:,3] = final_gt_box[:,3] - upper

  return img.crop((left, upper, right, bottom)), final_gt_box, final_labels

##### Creating pytorch dataset

In [9]:
class pascal_voc_data(Dataset):
    def __init__(self, img_dir,desc_dir, type_list,  isTrain=True,transform = None):
        super().__init__()
        self.img_dir = img_dir
        self.desc_dir = desc_dir
        self.type_list = type_list
        self.isTrain = isTrain
        self.transform = transform


        self.img_names = []
        self.img_descs = []
        for img in sorted(os.listdir(img_dir)):
          if img[:-4] in self.type_list:
            self.img_names.append(img)
        
        for desc in sorted(os.listdir(desc_dir)):
          if desc[:-4] in  self.type_list:
            self.img_descs.append(desc)
       
        self.img_names = [os.path.join(img_dir, img_name) for img_name in self.img_names]
        self.img_descs = [os.path.join(desc_dir, img_desc) for img_desc in self.img_descs]
                
        
        self.loc_gts = []
        self.loc_labels = []
        self.final_img_names = []
        for img_idx,img_desc in enumerate(self.img_descs):
            tree = ET.parse(img_desc)
            gt = []
            loc_lab = []
            for obj in tree.findall('object'):
              if ((obj.find('name').text) not in (select_classes)):
                continue

              lab = lab_to_val[(obj.find('name').text)]
              
              loc1 = int(obj.find('bndbox').find('xmin').text)
              loc2 = int(obj.find('bndbox').find('xmax').text)
              loc3 = int(obj.find('bndbox').find('ymin').text)
              loc4 = int(obj.find('bndbox').find('ymax').text)

              # if ht or width is less than 10, ignore the gt box
              if ((loc2 - loc1) < 10 ) or ((loc4 - loc3) < 10):
                continue

              gt.append([int(loc1),int(loc2),int(loc3),int(loc4)])
              loc_lab.append(lab)
            if (len(gt) == 0):
              continue
            self.loc_gts.append(gt)
            self.loc_labels.append(loc_lab)
            self.final_img_names.append(self.img_names[img_idx])

        self.img_names = self.final_img_names
             
    def __len__(self):
        return len(self.img_names)
    
    def __getitem__(self,idx):
        img_name = self.img_names[idx]
        img = Image.open(img_name)
        arr_loc_gts = np.array(self.loc_gts[idx])
        label = self.loc_labels[idx]

        if self.isTrain:
          img = random_blur(img)
          img = random_color(img)
          img,arr_loc_gts = random_flip(img,arr_loc_gts)
          img,arr_loc_gts,label = random_crop(img,arr_loc_gts,label)

        img_h_pre = img.size[1]
        img_w_pre = img.size[0]

        
        if self.transform:
            img = self.transform(img)
            
        img_h_post = img.shape[1]
        img_w_post = img.shape[2]
        
        height_ratio = img_h_post/img_h_pre
        width_ratio = img_w_post/img_w_pre
      
        
        
        arr_loc_gts[:,0] = arr_loc_gts[:,0]*width_ratio
        arr_loc_gts[:,1] = arr_loc_gts[:,1]*width_ratio
        arr_loc_gts[:,2] = arr_loc_gts[:,2]*height_ratio
        arr_loc_gts[:,3] = arr_loc_gts[:,3]*height_ratio
                        
        gts = (arr_loc_gts).tolist()
        

        num_gt = len(gts)
        num_label = len(label)

        if (num_gt != num_label):
            raise Exception("invalid data - num_gt and num_labels do not match")

        fin_gt = torch.ones((max_gt_count,4))*(-1)
        fin_lab = torch.ones((max_gt_count))*(-1)
        count_gt = num_gt

        gts = torch.FloatTensor(gts)
        label = torch.FloatTensor(label)
        fin_gt[:count_gt] = gts
        fin_lab[:count_gt] = label
        
        return img, fin_gt,fin_lab.int(), count_gt


##### Define Transforms on train images

In [10]:
'''
While using pretrained models - 
Pytorch torchvision documentation - https://pytorch.org/docs/master/torchvision/models.html
The images have to be loaded in to a range of [0, 1] and then 
normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225]
'''

transform = transforms.Compose(
    [transforms.Resize((input_image_height,input_image_width)),
     transforms.ToTensor(),
     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])


inv_normalize = transforms.Normalize(
   mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
   std=[1/0.229, 1/0.224, 1/0.225]
)

##### Load dataset - train and valid

In [11]:

train_dataset = pascal_voc_data('VOCdevkit/VOC2007/JPEGImages/', 'VOCdevkit/VOC2007/Annotations/',train_images ,True,transform)
valid_dataset = pascal_voc_data('VOCdevkit/VOC2007/JPEGImages/', 'VOCdevkit/VOC2007/Annotations/',valid_images ,False,transform)

In [12]:
print(len(train_dataset), len(valid_dataset))

4683 328


#### YOLO Helper Functions

##### In a 7 * 7 grid, check which grid the center of object lies in


In [13]:
# Input is single box in format (x1,x2,y1,y2)
# returns grid where the center lies for this box
# Verified Correctness
def get_grid_for_gt_box(gt_box):
    gt_box_tensor = torch.tensor(gt_box)
    gt_ctr_x = (gt_box_tensor[1] + gt_box_tensor[0])//2
    gt_ctr_y = (gt_box_tensor[2] + gt_box_tensor[3])//2

    gt_grid_x = int(gt_ctr_x//grid_wt)
    gt_grid_y = int(gt_ctr_y//grid_ht)
    
    return (gt_grid_x, gt_grid_y)

##### Conversion between ground truth and prediction format



*   Ground Truth - x1,x2,y1,y2
*   Predicted - x,y,w,h (normalised wrt grid)



In [14]:
#grid_x and grid_y are start co ordinates for the grid
# gt_box is the gt in format - x1,x2,y1,y2
# converts gt_box in format - x,y,w,h - normalized

def convert_gt_box_to_pred_format(gt_box, grid_x, grid_y):
    out = torch.zeros(4)

    gt_wt = gt_box[1] - gt_box[0]
    gt_ht = gt_box[3] - gt_box[2]

    out[0] = (gt_box[0] + (gt_wt/2) - grid_x)/(grid_wt*1.0) #x
    out[1] = (gt_box[2] + (gt_ht/2) - grid_y)/(grid_ht*1.0) #y
    out[2] = (gt_wt)/(input_image_width*1.0)  #w
    out[3] = (gt_ht)/(input_image_height*1.0) #h
    
    return out




#grid_x_start and grid_y_start are start co ordinates for the grid
# pred_vector is vector corresponding to one of the box in grid - [x1,y1,w1,h1] - normalized
# return output in format - x1,x2,y1,y2

def convert_pred_vector_to_gt_format(pred_vector, grid_x_start, grid_y_start):

    grid_box = torch.zeros(4)
    out = torch.zeros(4)

    
    grid_box[0] = grid_x_start + grid_wt* pred_vector[0] #x
    grid_box[1] = grid_y_start + grid_ht* pred_vector[1] #y
    grid_box[2] = pred_vector[2]*input_image_width #w
    grid_box[3] = pred_vector[3]*input_image_height #h

    out[0] = grid_box[0] - (grid_box[2]/2)
    out[1] = grid_box[0] + (grid_box[2]/2)
    out[2] = grid_box[1] - (grid_box[3]/2)
    out[3] = grid_box[1] + (grid_box[3]/2)

    return out.int()


##### Given predicted Tensor as input, returns pred objects, scores and bounding box 

In [15]:
# input is prediction for an image (7*7*18)

def decode_pred_data(preds):

    #Get all the predicted locations and corresponding confidence -  98
    pred_loc = preds[:,:,:10].reshape(-1,5)[:,:4]  
    pred_conf = preds[:,:,:10].reshape(-1,5)[:,4]

    #Corresponding to each grid, get the predicted object class score -  p(class/objec)
    pred_class_score, pred_class = torch.max(preds[:,:,-1*num_classes:].reshape(-1,num_classes), dim=1)
    pred_class_score = (((torch.stack((pred_class_score,pred_class_score))).T.reshape(-1,1))).T[0]     # Same score for both the boxes
    pred_class = (((torch.stack((pred_class,pred_class))).T.reshape(-1,1))).T[0]     # Same class for both the boxes


    
    # P(class) = confidence * p(class/objec)
    pred_prob = pred_conf*pred_class_score


    # Only consider boxes with condition on prob_threshold and conf_threshold
    keep_idx = []
    keep_idx = (pred_prob > prob_threshold) & (pred_conf > conf_threshold)
    pred_loc = pred_loc[keep_idx]
    pred_conf = pred_conf[keep_idx]
    pred_class = pred_class[keep_idx]
    pred_class_score = pred_class_score[keep_idx]
    pred_prob = pred_prob[keep_idx]

    #start grid for the pred
    pred_conf_grid = [((i//14),((i%14)//2)) for i in range(len(keep_idx)) if keep_idx[i]==True]


    # convert to x1,x2,y1,y2 format
    converted_preds = torch.zeros((len(pred_loc),4)) 
    for idx in range(len(pred_loc)):
        converted_preds[idx] = convert_pred_vector_to_gt_format(pred_loc[idx], pred_conf_grid[idx][0]*grid_wt, pred_conf_grid[idx][1]*grid_ht)
    converted_preds[converted_preds[:,0] < 0,0] = 0
    converted_preds[converted_preds[:,1] > input_image_width,1] = input_image_width
    converted_preds[converted_preds[:,2] < 0,2] = 0
    converted_preds[converted_preds[:,3] > input_image_height,3] = input_image_height

    fin_class = [val_to_lab[i.item()] for i in pred_class]
    fin_score = pred_class_score
    fin_conf = pred_conf
    fin_prob = pred_prob

    fin_class = np.array(fin_class)
    fin_score = (fin_score.detach().to('cpu'))
    fin_conf = (fin_conf.detach().to('cpu'))
    fin_prob = (fin_prob.detach().to('cpu'))

    final_conv_preds = []
    final_conv_score = []
    final_conv_class = []
    
    for cls in list(set(fin_class)):
        keep_anchors = []
        cur_class = fin_class[fin_class == cls]
        cur_score = fin_score[fin_class == cls]
        cur_conf = fin_conf[fin_class == cls]
        cur_prob = fin_prob[fin_class == cls]
        cur_preds = converted_preds[fin_class == cls]

        sorted_class_scores = torch.argsort(cur_conf, descending=True) # Sorting wrt confidence
        

        while len(sorted_class_scores) > 1:
            current = sorted_class_scores[0]
            keep_anchors.append(current)
            iou_matrix = get_iou_matrix(cur_preds[sorted_class_scores[1:]],cur_preds[current].reshape(1,-1,4)[0])
            sorted_class_scores = sorted_class_scores[np.where(iou_matrix < nms_threshold)[0] + 1]
        
        if (len(sorted_class_scores) == 1):
            keep_anchors.append(sorted_class_scores[0])

        for k in keep_anchors:
            final_conv_preds.append(cur_preds[k])
            final_conv_score.append(cur_score[k])
            final_conv_class.append(cur_class[k])

    return final_conv_preds, final_conv_score, final_conv_class


    

##### Get IOU between 2 set of bounding boxes

In [16]:
# Get Intersection over Union between all the boxes in pred_boxes and gt_boxes
# Taken from faster_rcnn repo - https://github.com/pranayKD/faster_rcnn_colab_pytorch/blob/master/Faster_RCNN.ipynb

def get_iou_matrix(pred_boxes, gt_boxes):

    iou_matrix = torch.zeros((len(pred_boxes), len(gt_boxes)))
    for idx,box in enumerate(gt_boxes):
        if isinstance(box,torch.Tensor):
          gt = torch.cat([box]*len(pred_boxes)).view(1,-1,4)[0]
        else:
          gt = torch.FloatTensor([box]*len(pred_boxes))
        max_x = torch.max(gt[:,0],pred_boxes[:,0])
        min_x = torch.min(gt[:,1],pred_boxes[:,1])
        max_y = torch.max(gt[:,2],pred_boxes[:,2])
        min_y = torch.min(gt[:,3],pred_boxes[:,3])
                
        invalid_roi_idx = (min_x < max_x) | (min_y < max_y)
        roi_area = (min_x - max_x)*(min_y - max_y)
        roi_area[invalid_roi_idx] = 0
        
        total_area = (gt[:,1] - gt[:,0])*(gt[:,3] - gt[:,2]) + \
                    (pred_boxes[:,1] - pred_boxes[:,0])*(pred_boxes[:,3]-pred_boxes[:,2]) - \
                     roi_area
                    
        iou = roi_area/(total_area + 1e-6)
        
        iou_matrix[:,idx] = iou
    return iou_matrix

##### For given predictions, returns the mask where object is present. Also, returns the ground truth tensor which is used in calculating loss

In [17]:
# takes input predicted matrix (7*7*18) , gt_locs and gt_labels
# returns target matrix
# returns mask for grids where object is present
# for mask cannot simply take target matrix non zero condition, because there can be a 0 even at object loc

def convert_format_pred_targets(preds, gts, labels):

    final_targets = torch.zeros_like(preds)
    obj_mask = torch.zeros_like(preds)

    for gt_box_idx in range(len(gts)):
        gt_box = gts[gt_box_idx]    
        label = labels[gt_box_idx]

        gt_grid_x, gt_grid_y = get_grid_for_gt_box(gt_box)  # Get grid location of the gt_box
        start_x = grid_wt*gt_grid_x     # Get starting x co-ordinate of the box
        start_y = grid_ht*gt_grid_y     # Get starting y co-ordinate of the box

        pred_loc = preds[gt_grid_x,gt_grid_y][:5*anchors_per_box].contiguous().view(-1,5)[:,:4]     # Corresponding to the grid - get preds

        temp_pred = torch.zeros((anchors_per_box,4))
        for i in range(anchors_per_box):
            temp_pred[i] = convert_pred_vector_to_gt_format(pred_loc[i], start_x, start_y)

        temp_pred[temp_pred[:,0] < 0,0] = 0
        temp_pred[temp_pred[:,2] < 0,2] = 0
        temp_pred[temp_pred[:,1] > input_image_width,1] = input_image_width
        temp_pred[temp_pred[:,3] > input_image_height,3] = input_image_height


        iou = get_iou_matrix(temp_pred.to(device), [gt_box.to(device)]).to(device)   # get iou between preds and gt_box
        iou_val = iou.max()
        iou_idx = iou.argmax()

        # visualize_tensor(a,[gt_box],temp_pred)

        obj_mask[gt_grid_x,gt_grid_y,(5*iou_idx):(5*iou_idx+5)] = 1.0     # only consider box with max iou
        obj_mask[gt_grid_x,gt_grid_y,(-1*num_classes):] = 1.0             # Consider classes for this box

        final_targets[gt_grid_x,gt_grid_y,(5*iou_idx):(5*iou_idx+4)] = convert_gt_box_to_pred_format(gt_box, start_x, start_y)    # target locs for box is gt
        final_targets[gt_grid_x,gt_grid_y,(5*iou_idx+4):(5*iou_idx+5)] = (iou_val)      # target conf for box is max iou value
        final_targets[gt_grid_x,gt_grid_y,(anchors_per_box*5+label)] = 1.0  # assign target label value


    return final_targets.to(device), obj_mask.to(device)



#### Loss Function

##### Get loss for a single image

In [18]:
def get_loss_one_image(preds, targets,obj_mask):


    # Loss 1 - loss for predicted classes (if obj is present in grid)

    obj_class_preds = (preds*obj_mask)[:,:,(-1*num_classes):]
    obj_class_targets = (targets*obj_mask)[:,:,(-1*num_classes):]

    loss1 = ((obj_class_preds - obj_class_targets)*(obj_class_preds - obj_class_targets)).sum()


    ### TODO - for locs consider square root in w and h
    # loss 2 - loss for box locations  (if obj is present in grid)
    obj_locs_preds = (preds*obj_mask)[:,:,:(anchors_per_box*5)].contiguous().view(-1,5)[:,:4]
    obj_locs_targets = (targets*obj_mask)[:,:,:(anchors_per_box*5)].contiguous().view(-1,5)[:,:4]

    loss2 = ((obj_locs_preds - obj_locs_targets)*(obj_locs_preds - obj_locs_targets)).sum()


    # loss 3 - confidence for box locations (if obj is present in grid)
    obj_conf_preds = (preds*obj_mask)[:,:,:(anchors_per_box*5)].contiguous().view(-1,5)[:,4]
    obj_conf_targets = (targets*obj_mask)[:,:,:(anchors_per_box*5)].contiguous().view(-1,5)[:,4]

    loss3 = ((obj_conf_preds - obj_conf_targets)*(obj_conf_preds - obj_conf_targets)).sum()


    # TODO - check this loss , even if there is a gt obj for grid do we need to consider loss for box with less IOU ? 
    # loss 3 - confidence for backgrounds (if obj is not present)

    no_obj_mask = (1-obj_mask)
    no_obj_conf_preds = (preds*no_obj_mask)[:,:,:(anchors_per_box*5)].contiguous().view(-1,5)[:,4]
    no_obj_conf_targets = (targets*no_obj_mask)[:,:,:(anchors_per_box*5)].contiguous().view(-1,5)[:,4]
   
    loss4 = ((no_obj_conf_preds - no_obj_conf_targets)*(no_obj_conf_preds - no_obj_conf_targets)).sum()

    
    loss = loss1 + 5*loss2 + loss3 + 0.5*loss4

    return loss
    



##### Get loss for a batch of images

In [19]:
def get_loss_batch(batch_preds, batch_gt_boxes_inp, batch_labels_inp, batch_count):


    batch_size = batch_preds.shape[0]
    loss = 0

    for img in range(batch_size):
        pred = batch_preds[img].clone().detach()
        batch_gt_boxes = batch_gt_boxes_inp[img][:batch_count[img]]
        batch_labels = batch_labels_inp[img][:batch_count[img]]
        targets,obj_mask = convert_format_pred_targets(pred, batch_gt_boxes, batch_labels)

        loss = loss + get_loss_one_image(batch_preds[img], targets,obj_mask)

    return loss



    


#### Define YOLO Model Architecture

##### Model base - vgg16 pretrained

In [20]:

def base_model_vgg16(num_freeze_top): 
    vgg16 = models.vgg16(pretrained=True)
    vgg_feature_extracter  = vgg16.features[:-1]
    
    # Freeze learning of top few conv layers
    for layer in vgg_feature_extracter[:num_freeze_top]:
        for param in layer.parameters():
            param.requires_grad = False
    
    return vgg_feature_extracter.to(device)

##### Add a few convolution layers, fully connected layer and reshape the preds 

In [21]:
class YOLONetwork(nn.Module):
    def __init__(self, extractor):
        super().__init__()
        self.extractor = extractor
        self.conv1 = nn.Conv2d(512, 1024,3,1,1)
        self.pool1 = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(1024, 1024,3,1,1)
        self.pool2 = nn.MaxPool2d(2,2)
        self.lin1 = nn.Flatten()
        self.drop1 = nn.Dropout(p=0.5)
        self.lin2 = nn.Linear(7*7*1024, 7*7*(num_classes + anchors_per_box*5))


    def forward(self,x):
        out = self.extractor(x)
        out = self.pool1(F.relu(self.conv1(out)))
        out = self.pool2(F.relu(self.conv2(out)))
        out = self.drop1(F.relu(self.lin1(out)))
        out = torch.sigmoid(self.lin2(out))

        num = out.shape[0]
        return out.contiguous().view(num,7,7,-1)

In [22]:
print('Training Data - %d & Valid Data %d '%(len(train_dataset), len(valid_dataset)))

Training Data - 4683 & Valid Data 328 


#### Training Pipeline

##### Load a model from checkpoint or initiate a new model

In [23]:
opt = 'Adam'
load_model = ''

extractor = base_model_vgg16(10)
net = YOLONetwork(extractor).to(device)
loss_hist = []
valid_hist = []
best_valid_loss = 100000
if opt =='SGD':
    optimizer = optim.SGD(net.parameters(), lr=0.001)
if opt == 'Adam':
    optimizer = optim.Adam(net.parameters(), lr=0.00001)

epoch_start = 0

if load_model != '':
    print('loading model ... ')
    checkpoint = torch.load(load_model, map_location=device)
    net.load_state_dict(checkpoint['net_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    loss_hist = checkpoint['loss_hist']
    valid_hist = checkpoint['valid_hist']
    best_valid_loss = checkpoint['best_valid_loss']
    epoch_start =checkpoint['epoch_start']
    
    net.train()
    print('model loaded ...' )




Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/checkpoints/vgg16-397923af.pth


HBox(children=(FloatProgress(value=0.0, max=553433881.0), HTML(value='')))




##### Model Training and saving after each epoch

In [None]:
img_dis_step = 100
train_loss_dis_step = 10

for epoch in range(epoch_start,epoch_start+150):
    train_loader = DataLoader(train_dataset, batch_size=batch_num, shuffle=True, drop_last=True)
    net.train()
    for train_idx, train_data in enumerate(train_loader,0):
        img, gt, label,count = train_data

        img = img.to(device)
        gt = gt.to(device)
        label = label.to(device)
        count = count.to(device)

        preds = net(img)

        # zero the parameter gradients
        optimizer.zero_grad()

        # backward + optimize
        loss = get_loss_batch(preds, gt,label, count)/batch_num
        loss.backward()
        optimizer.step()

        loss_hist.append(loss)

        if (train_idx%train_loss_dis_step == 0):
            print('epoch %d - iteration %d ---- loss %.6f' %(epoch, train_idx, loss))

        
        if ((train_idx)%img_dis_step == 0):
            plt.plot(loss_hist)
            plt.title("Training Loss")            
            decoded_preds,fin_class, fin_score = decode_pred_data(preds[0])
            visualize_tensor([img[0]], decoded_preds, gt[0][:count[0]])
            plt.show()
            print(fin_class)
            print(fin_score)

    
    print("----- validation step --------")
    valid_loader = DataLoader(valid_dataset, batch_size=batch_num, shuffle=True, drop_last=True)
    net.eval()
    valid_loss = 0
    for valid_idx, valid_data in enumerate(valid_loader):
        img, gt, label,count = valid_data
        img = img.to(device)
        gt = gt.to(device)
        label = label.to(device)
        count = count.to(device)

        with torch.no_grad():
            preds = net(img)
        valid_loss = valid_loss + (get_loss_batch(preds, gt,label, count)/batch_num)

        if (valid_idx == 0):
            decoded_preds,fin_class, fin_score = decode_pred_data(preds[0])
            visualize_tensor([img[0]], decoded_preds, gt[0][:count[0]])
            plt.show()
            print(fin_class)
            print(fin_score)

    valid_loss = valid_loss/len(valid_dataset)
    valid_hist.append(valid_loss)
    print(valid_loss)
    plt.plot(valid_hist)
    plt.title("Valid Loss")      
    plt.show()

    PATH = 'drive/My Drive/saved_models/current_'  + opt + '.pt'
    net.train()
    torch.save({
        'net_state_dict': net.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss_hist':loss_hist,
        'valid_hist':valid_hist,
        'best_valid_loss':best_valid_loss,
        'epoch_start':epoch
      }, PATH)


    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        net.train()
        print("Found new best :) ")
        PATH = 'drive/My Drive/saved_models/best_'  + opt + '.pt'
        torch.save({
            'net_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss_hist':loss_hist,
            'valid_hist':valid_hist,
            'best_valid_loss':best_valid_loss,
            'epoch_start':epoch
        }, PATH)
        
        


    print("-------------------------------")
            



#### Model testing on a single image

###### Load the model best saved

In [None]:
opt = 'Adam'
load_model = 'drive/My Drive/saved_models/best_'  + opt + '.pt'

extractor = base_model_vgg16(10)
net = YOLONetwork(extractor).to(device)
loss_hist = []
valid_hist = []
best_valid_loss = 100000
if opt =='SGD':
    optimizer = optim.SGD(net.parameters(), lr=0.001)
if opt == 'Adam':
    optimizer = optim.Adam(net.parameters(), lr=0.00001)

epoch_start = 0

if load_model != '':
    print('loading model ... ')
    checkpoint = torch.load(load_model, map_location=device)
    net.load_state_dict(checkpoint['net_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    loss_hist = checkpoint['loss_hist']
    valid_hist = checkpoint['valid_hist']
    best_valid_loss = checkpoint['best_valid_loss']
    epoch_start =checkpoint['epoch_start']
    
    
    print('model loaded ...' )

net.eval()
print('epoch', epoch_start)

###### Run the model on image on eval mode

In [None]:
img = Image.open('drive/My Drive/saved_models/plant.jpg')
img = transform(img)
img = img.unsqueeze(0)

with torch.no_grad():
    preds = net(img.to(device))

decoded_preds, fin_score,fin_class = decode_pred_data(preds[0])
visualize_tensor([img[0]], '',decoded_preds, fin_class, fin_score)
plt.show()
print(fin_class)
print(fin_score)
