# CIS 680 - Final Project - Mask R-CNN

#### Raima Sen
#### Shreyas Ramesh

This project aims to develop a pixel wise framework that combines semantic segmentation and object detection to perform instance segmentation. The primary goal is to extend Faster RCNN which outputs the class label and bounding box of candidate objects. This
extension is brought about with an additional output : an object mask for each candidate object.

References:

[1] Umer Farooq. From R-CNN to mask R-CNN. Feb. 2018. URL: https://medium.com/@umerfarooq_26378/from-r-cnn-to-mask-r-cnn-d6367b196cfd

[2] Ross Girshick. Fast R-CNN. Sept. 2015. URL: https://arxiv.org/abs/1504.08083.

[3] Ross Girshick et al. Rich feature hierarchies for accurate object detection and semantic segmentation. Oct. 2014. URL: https://arxiv.org/abs/1311.2524.

[4] Kaiming He et al. Mask R-CNN. Jan. 2018. URL: https://arxiv.org/abs/1703.06870.

[5] Chuan Liu, Yi Gao, and Jiancheng Lv. “Dynamic Normalization”. In: CoRR abs/2101.06073 (2021). arXiv: 2101.06073. URL: https://arxiv.org/abs/2101.06073.

[6] Shaoqing Ren et al. Faster R-CNN: Towards real-time object detection with region proposal networks. Jan. 2016. URL: https://arxiv.org/abs/1506.01497.



# Part 1 : Initialize 
- Imports \\
- Mount Drive \\
- Set Cuda \\
- Define global variables

In [None]:
import torch
import torchvision
from torchvision import transforms
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.ops import focal_loss
import h5py
import math
import numpy as np
import matplotlib.pyplot as plt
from torchvision.transforms.functional import pad, normalize, resize
import matplotlib.patches as patches
import os
from  matplotlib.patches import Rectangle as rec
from torch import nn, Tensor
from torchvision.ops import box_iou
import random
from copy import deepcopy
from tqdm import tqdm
from torchvision.models.detection.image_list import ImageList
from torchvision.ops import MultiScaleRoIAlign
from sklearn.metrics import auc
import pdb
!pip install pytorch_lightning &> /dev/null
import pytorch_lightning as pl
import pytorch_lightning.loggers as pl_loggers
import pytorch_lightning.callbacks as pl_callbacks

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
print(device)

cuda:0


In [None]:
h_img = 800
w_img = 1088
debug = False

In [None]:
from functools import partial
def MultiApply(func, *args, **kwargs):
    pfunc = partial(func, **kwargs) if kwargs else func
    map_results = map(pfunc, *args)
  
    return tuple(map(list, zip(*map_results)))

# Part 2 : Loading the Data
- class : BuildDataset
- class : BuildDataLoader

In [None]:
class BuildDataset(torch.utils.data.Dataset):
    def __init__(self, path):
        self.images_path = os.path.join(path[0])  
        self.masks_path = os.path.join(path[1])    
        self.labels_path = os.path.join(path[2])
        self.bboxes_path = os.path.join(path[3])
        self.images = np.array(h5py.File(self.images_path, 'r')['data'])    #transformation needed : 0-1, resize, normalize, pad
        self.masks = np.array(h5py.File(self.masks_path, 'r')['data'])      #transformation needed : resize, pad
        self.bboxes = np.load(self.bboxes_path, allow_pickle=True, encoding='latin1')    #transformation needed : scaling and add padding
        self.labels = np.load(self.labels_path, allow_pickle=True, encoding='latin1')
        self.indexes = []
        temp = 0
        for i in range(len(self.images)):
          self.indexes.append(temp)
          temp += self.labels[i].shape[0]
          
        self.indexes = np.array(self.indexes)

    def __getitem__(self, index):
        #example : images[5] has 2 objects in it; so there will be 2 labels, 2 corresponding masks and 2 corresponding bounding boxes
        image = torch.from_numpy(self.images[index].astype('float32')).to(device)                                               #previously was uint32
        labels = torch.from_numpy(self.labels[index]).to(device)                                                                #int64
        n_obj = labels.shape[0]                  #number of objects present in the image
        masks = torch.from_numpy(self.masks[self.indexes[index] : self.indexes[index] + n_obj].astype('float32')).to(device)    #previously was uint32
        bboxes = torch.from_numpy(self.bboxes[index]).to(device)                                                                #float32  
        transed_img, transed_masks, transed_bboxes = self.pre_process_single(image, masks, bboxes)
        return transed_img, labels, transed_masks, transed_bboxes, index

    #the 3 transformations below are defined for single image and its corresponding (sometimes multiple) masks and bounding boxes
    def pre_process_single(self, image, masks, bboxes):   
        #img transformation : 0-1, resize, normalize, pad
        img = pad(normalize(resize(torch.div(image, 255), size=(800,1066)), (0.485,0.456,0.406), (0.229,0.224,0.225)), [11,0])
        #mask transformation : resize, pad
        new_masks = pad(resize(masks, size=(800,1066)), [11,0])
        #bounding boxes transformation : scaling and pad
        old_h = 300
        old_w = 400
        new_h = 800
        new_w = 1066
        padding = 22
        scale_x = new_w/old_w
        scale_y = new_h/old_h
        new_bboxes = torch.zeros_like(bboxes)
        new_bboxes[:,0] = scale_x * bboxes[:,0] - padding/2
        new_bboxes[:,1] = scale_y * bboxes[:,1]
        new_bboxes[:,2] = scale_x * bboxes[:,2] + padding/2
        new_bboxes[:,3] = scale_y * bboxes[:,3]

        return img, new_masks, new_bboxes
       
    def __len__(self):
        return len(self.images)

In [None]:
class BuildDataLoader(torch.utils.data.DataLoader):
    def __init__(self, dataset, batch_size, shuffle, num_workers):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.num_workers = num_workers

    def collect_fn(self, batch):
        images, labels, masks, bounding_boxes, index = list(zip(*batch))
        return torch.stack(images), labels, masks, bounding_boxes, index

    def loader(self):
        return DataLoader(self.dataset,
                          batch_size=self.batch_size,
                          shuffle=self.shuffle,
                          num_workers=self.num_workers,
                          collate_fn=self.collect_fn)

# Part 3 : Instantiate BuildDataset and BuildDataLoader classes
- run only before training to prevent a lot of memory usage

In [None]:
imgs_path = '/content/drive/MyDrive/CIS680/HW4/data/images.h5'
masks_path = '/content/drive/MyDrive/CIS680/HW4/data/masks.h5'
labels_path = '/content/drive/MyDrive/CIS680/HW4/data/labels.npy'
bboxes_path = '/content/drive/MyDrive/CIS680/HW4/data/bboxes.npy'

paths = [imgs_path, masks_path, labels_path, bboxes_path]
dataset = BuildDataset(paths)

subset = []
# change the number of samples as needed. taking 10 now for sanity checking
for i in range(200):
   subset.append(dataset[i])

full_size = len(subset)
train_size = int(full_size * 0.9)
test_size = full_size - train_size

torch.random.manual_seed(1)
train_dataset, test_dataset = torch.utils.data.random_split(subset, [train_size, test_size])

batch_size = 2
print("batch size:", batch_size)
test_build_loader = BuildDataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
test_loader = test_build_loader.loader()
train_build_loader = BuildDataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
train_loader = train_build_loader.loader()
full_build_loader = BuildDataLoader(subset, batch_size=batch_size, shuffle=False, num_workers=0)
full_loader = full_build_loader.loader()

batch size: 2


In [None]:
del train_dataset
del test_dataset
del dataset

# Part 4 : RPN Head
- Backbone is Resnet50 FPN 

## class RPNHead definition

In [None]:
class RPNHead(nn.Module):
    def __init__(self, num_anchors=3, in_channels=256, device='cuda',
                 anchors_param=dict(ratio=[[1, 0.5, 2], [1, 0.5, 2], [1, 0.5, 2], [1, 0.5, 2], [1, 0.5, 2]],
                                    scale=[32, 64, 128, 256, 512],
                                    grid_size=[(200, 272), (100, 136), (50, 68), (25, 34), (13, 17)],
                                    stride=[4, 8, 16, 32, 64])
                 ):
        super(RPNHead, self).__init__()
        self.device=device
        self.num_anchors = num_anchors
        self.in_channels = in_channels
        self.ratio = anchors_param["ratio"]
        self.scale = anchors_param["scale"]
        self.grid_size = anchors_param["grid_size"]
        self.stride = anchors_param["stride"]
        self.h_img = h_img
        self.w_img = w_img
        #Defining resnet50 FPN backbone
        pretrained_model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=False, pretrained_backbone=True)
        self.backbone = pretrained_model.backbone
        self.intermediate = nn.Sequential(
                                          nn.Conv2d(in_channels = self.in_channels, out_channels=256, kernel_size=3, padding = 'same'),
                                          nn.BatchNorm2d(num_features=256),
                                          nn.ReLU()
                                          )
        self.classifier = nn.Sequential(
                                        nn.Conv2d(in_channels = self.in_channels, out_channels = 1*self.num_anchors, kernel_size=1, padding = 'same'),
                                        nn.Sigmoid()
                                        )
        self.Regressor = nn.Sequential(
                                        nn.Conv2d(in_channels = self.in_channels, out_channels = 4*1*self.num_anchors, kernel_size=1, padding = 'same') 
                                        )
        self.anchors = self.create_anchors(self.ratio, self.scale, self.grid_size, self.stride)

    '''
    Forward each level of the FPN output through the intermediate layer and the RPN heads
    Input:
          X: list:len(FPN){(bz,256,grid_size[0],grid_size[1])}
    Ouput:
          logits: list:len(FPN){(bz,1*num_anchors,grid_size[0],grid_size[1])}
          bbox_regs: list:len(FPN){(bz,4*num_anchors, grid_size[0],grid_size[1])}
    '''
    def forward(self, images):
        logits = []
        bbox_regs = []
        feature_pyramid = [v.detach() for v in self.backbone(images).values()]
        for i in range(len(feature_pyramid)):
          logit, bbox_reg = self.forward_single(feature_pyramid[i])

          logits.append(logit)
          bbox_regs.append(bbox_reg)

        return logits, bbox_regs
        
    '''
    Forward a single level of the FPN output through the intermediate layer and the RPN heads
    Input:
          feature: (bz,256,grid_size[0],grid_size[1])}
    Ouput:
          logit: (bz,1*num_anchors,grid_size[0],grid_size[1])
          bbox_regs: (bz,4*num_anchors, grid_size[0],grid_size[1])
    '''
    def forward_single(self, feature):
        X = self.intermediate(feature)
        logit = self.classifier(X)
        bbox_reg = self.Regressor(X)
        return logit, bbox_reg

    '''
    This function creates the anchor boxes for all FPN level
    Input:
          aspect_ratio: list:len(FPN){list:len(number_of_aspect_ratios)}
          scale:        list:len(FPN)
          grid_size:    list:len(FPN){tuple:len(2)}
          stride:        list:len(FPN)
    Output:
          anchors_list: list:len(FPN){(num_anchors,grid_size[0],grid_size[1],4)}
    '''
    def create_anchors(self, aspect_ratio, scale, grid_size, stride):
        anchors_list = []
        for i in range(len(aspect_ratio)):
          anchors = self.create_anchors_single(aspect_ratio[i], scale[i], grid_size[i], stride[i])
          anchors_list.append(anchors)
        return anchors_list

    '''
    This function creates the anchor boxes for one FPN level
    Input:
          aspect_ratio: list:len(number_of_aspect_ratios)
          scale: scalar
          grid_size: tuple:len(2)
          stride: scalar
    Output:
          anchors: (num_anchors,grid_size[0],grid_size[1],4)
    '''
    def create_anchors_single(self, aspect_ratio, scale, grid_size, stride):
        x = torch.arange(start=int(stride/2) , end=int(self.w_img + stride / 2), step=stride)
        y = torch.arange(start=int(stride/2) , end=int(self.h_img + stride / 2), step=stride)
        xx,yy  = torch.meshgrid(x,y)
        anchors = []
        for aspect in aspect_ratio:
          h = scale / math.sqrt(aspect)
          w = scale * math.sqrt(aspect)
          anchor = torch.zeros((grid_size[0], grid_size[1], 4)) # (x, y, w, h) - Sy, Sx, 4
          anchor[:, :, 0] = xx.T
          anchor[:, : ,1] = yy.T
          anchor[:, :, 2] = w
          anchor[:, :, 3] = h
          anchors.append(anchor)
        anchors = torch.stack(anchors)       
        #anchors = anchors.reshape((-1,4))
        return anchors

    def get_anchors(self):
        return self.anchors

    '''
    This function creates the ground truth for a batch of images
    Input:
          bboxes_list: list: len(bz){(number_of_boxes,4)}
    Ouput:
          ground: list:len(FPN){(bz,num_anchors,grid_size[0],grid_size[1])}
          ground_coord: list:len(FPN){(bz,4*num_anchors,grid_size[0],grid_size[1])}
    '''
    #The below note is done
    #abhi ke liye iska shape hai len(bz), len(FPN), ...
    #karna padega as len(FPN), (bz, num_anchors...)
    def create_batch_truth(self, bboxes_list):
        ground_list=[]
        ground_coord_list=[]
        for idx in range(len(bboxes_list)):
          ground_class, ground_crd = self.create_ground_truth(bboxes_list[idx], self.grid_size, self.anchors)       
          ground_list.append(ground_class)
          ground_coord_list.append(ground_crd)    
        ground = []
        ground_coord = []
        for lvl in range(len(ground_list[0])):  #iterate over FPN levels
          temp_list1 = []
          temp_list2 = []
          for i in range(len(ground_list)):     #iterate over batch_size
            temp_list1.append(ground_list[i][lvl])
            temp_list2.append(ground_coord_list[i][lvl])         
          ground.append(torch.stack(temp_list1))
          ground_coord.append(torch.stack(temp_list2))
        return ground, ground_coord

    '''
    This function create the ground truth for ONE IMAGE for all the FPN levels
    Input:
          bboxes:      (n_boxes,4)
          grid_size:   list:len(FPN){tuple:len(2)}
          anchor_list: list:len(FPN){(num_anchors,grid_size[0],grid_size[1],4)}
    Output:
          ground_clas: list:len(FPN){(num_anchors,grid_size[0],grid_size[1])}
          ground_coord: list:len(FPN){(4*num_anchors,grid_size[0],grid_size[1])}
    '''
    def create_ground_truth(self, bboxes, grid_sizes, anchors_list):
        #note : bboxes are the ground truth in x1,y1,x2,y2 format
        #note : prefer doing calc in flattened mode
        bboxes = bboxes.to(device)
        fpn_levels = len(anchors_list)
        ground_clas = []
        ground_coord = []
        for lvl in range(fpn_levels):
          class_ls = []
          coord_ls = []
          anchors = anchors_list[lvl].to(device)
          for a in range(self.num_anchors):
            flat_anchors = anchors[a].permute(2,0,1).flatten(start_dim=1,end_dim=-1).T
            # pdb.set_trace()
            x1=flat_anchors[:,0]-flat_anchors[:,2]/2.0
            y1=flat_anchors[:,1]-flat_anchors[:,3]/2.0
            x2=flat_anchors[:,0]+flat_anchors[:,2]/2.0
            y2=flat_anchors[:,1]+flat_anchors[:,3]/2.0
            # convert flattened anchors from x,y,w,h to x1,y1,x2,y2 format
            formated_flat_anchors = torch.hstack((x1.reshape(-1,1),y1.reshape(-1,1),x2.reshape(-1,1),y2.reshape(-1,1)))

            ground_crd = torch.zeros((self.grid_size[lvl][0]*self.grid_size[lvl][1],4)).to(device)
            ground_cls = -1*torch.ones(self.grid_size[lvl][0]*self.grid_size[lvl][1]).to(device)

            # calculating IOU with all anchors (even the ones that are out of the image boundaries)
            iou = box_iou(formated_flat_anchors,bboxes)

            # convert bboxes from x1,y1,x2,y2 to x,y,w,h format so that we can encode them
            x = (bboxes[:,0]+bboxes[:,2])/2.0
            y = (bboxes[:,1]+bboxes[:,3])/2.0
            w = (bboxes[:,2]-bboxes[:,0])
            h = (bboxes[:,3]-bboxes[:,1])
            formated_bboxes = torch.hstack((x.reshape(-1,1),y.reshape(-1,1),w.reshape(-1,1),h.reshape(-1,1)))

            # FINDING POSITIVE LABELS and corresponding ground truth bboxes
            pos_idx_mask = torch.max(iou,dim=1)[0]>=0.7
            pos_idx = pos_idx_mask.nonzero().flatten()
            ground_cls[pos_idx] = 1
            # finding corresponding bbox and anchor info
            bbox_index = torch.max(iou,dim=1)[1]
            bbox_label = bbox_index[pos_idx]
            anchor_data = flat_anchors[pos_idx]
            xa = anchor_data[:,0]
            ya = anchor_data[:,1]
            wa = anchor_data[:,2]
            ha = anchor_data[:,3]
            # encoding to tx, ty, tw, th
            ground_crd[pos_idx,0]= (formated_bboxes[bbox_label,0]-xa)/wa
            ground_crd[pos_idx,1]= (formated_bboxes[bbox_label,1]-ya)/ha
            ground_crd[pos_idx,2]= torch.log(formated_bboxes[bbox_label,2]/wa)
            ground_crd[pos_idx,3]= torch.log(formated_bboxes[bbox_label,3]/ha)

            # FINDING IN BETWEEN LABELS and corresponding ground truth bboxes
            in_between = torch.logical_and(torch.max(iou,dim=1)[0]<0.7,torch.max(iou,dim=1)[0]>0.4)
            in_between_idx = in_between.nonzero().flatten()
            ground_cls[in_between_idx] = 1
            # finding corresponding bbox and anchor info
            bbox_label = bbox_index[in_between_idx]
            anchor_data = flat_anchors[in_between_idx]
            xa = anchor_data[:,0]
            ya = anchor_data[:,1]
            wa = anchor_data[:,2]
            ha = anchor_data[:,3]
            # encoding to tx, ty, tw, th
            ground_crd[in_between_idx,0]= (formated_bboxes[bbox_label,0]-xa)/wa
            ground_crd[in_between_idx,1]= (formated_bboxes[bbox_label,1]-ya)/ha
            ground_crd[in_between_idx,2]= torch.log(formated_bboxes[bbox_label,2]/wa)
            ground_crd[in_between_idx,3]= torch.log(formated_bboxes[bbox_label,3]/ha)

            # FINDING NEGATIVE LABELS (they don't have any ground truth boxes attached)
            ground_cls[torch.all(iou<=0.3,dim=1)]=0

            #ELIMINATING CROSS BOUNDARY ANCHORS
            cross_bound_l = torch.logical_or(x1<=0,y1<=0)
            cross_bound_h = torch.logical_or(x2>=1088.0,y2>=800)
            ground_cls[cross_bound_l]=-1
            ground_cls[cross_bound_h]=-1

            ground_cls = ground_cls.unsqueeze(dim=0).reshape(1,self.grid_size[lvl][0],self.grid_size[lvl][1])
            ground_crd = ground_crd.T.reshape(4,self.grid_size[lvl][0],self.grid_size[lvl][1])

            class_ls.append(ground_cls)
            coord_ls.append(ground_crd)

          ground_clas.append(torch.cat(class_ls))
          ground_coord.append(torch.cat(coord_ls))
          
        return ground_clas, ground_coord

    '''
    Compute the loss of the classifier for a level but for all images
    '''
    def loss_class(self, pos_class_gt, pos_class_pred, neg_class_gt, neg_class_pred, pos_anchors, neg_anchors):
        skip_flag_pos = False
        skip_flag_neg = False
        if (pos_anchors.shape[0]==0):
          skip_flag_pos = True
        if (neg_anchors.shape[0]==0):
          skip_flag_neg = True
        criterion = torch.nn.BCELoss()
        #seperately calc postiive class and negative class loss:
        neg_class_loss = criterion(neg_class_pred+1e-10,neg_class_gt) if skip_flag_neg == False else torch.tensor(0)
        pos_class_loss = criterion(pos_class_pred+1e-10, pos_class_gt) if skip_flag_pos == False else torch.tensor(0)
        loss_c = neg_class_loss + pos_class_loss
        #print(loss_c)
        return loss_c
   
    '''
    Compute the loss of the regressor for a level but over all images
    '''
    def loss_reg(self, pos_regr_gt, pos_regr_pred):
        criterion = nn.SmoothL1Loss()
        loss_r = sum([criterion(pos_regr_gt[i], pos_regr_pred[i]) for i in range(4)])
        #print(loss_r)
        return loss_r

    '''
    Compute the total loss
    Input:
          clas_out_list: list:len(FPN){(bz,1*num_anchors,grid_size[0],grid_size[1])}
          regr_out_list: list:len(FPN){(bz,4*num_anchors,grid_size[0],grid_size[1])}
          targ_clas_list: list:len(FPN){(bz,1*num_anchors,grid_size[0],grid_size[1])}
          targ_regr_list: list:len(FPN){(bz,4*num_anchors,grid_size[0],grid_size[1])}
          l: weighting lambda between the two losses
          effective_batch: the number of anchors in the effective batch (M in the handout)
    '''
    def compute_loss(self, class_out_list, regr_out_list, targ_class_list, targ_regr_list, l=1, effective_batch=50):
        fpn_levels = len(class_out_list)
        batch_size = len(class_out_list[0])
        loss_c = 0
        loss_r = 0
        for lvl in range(fpn_levels):
          #step 1: flatten the class_out and targ_class
          targ_class = targ_class_list[lvl].reshape(-1)
          class_out = class_out_list[lvl].reshape(-1)
          #step 2: create mini batch
          pos_anchors = (targ_class==1).nonzero()
          neg_anchors = (targ_class==0).nonzero()
          pos_size = int( min(pos_anchors.shape[0], effective_batch/2))
          neg_size = int(effective_batch - pos_size)
          pos_anchors = pos_anchors[torch.randperm(pos_anchors.shape[0]), :]
          pos_anchors = pos_anchors[:pos_size, :]
          neg_anchors = neg_anchors[torch.randperm(neg_anchors.shape[0]), :]
          neg_anchors = neg_anchors[:neg_size, :]
          #step 3: assign positive and negative ground truths and predictions for the classifier
          pos_class_gt = targ_class[pos_anchors]
          pos_class_pred = class_out[pos_anchors]
          neg_class_gt = targ_class[neg_anchors]
          neg_class_pred = class_out[neg_anchors]

          #step 4: calculate classifier loss for both positive and negative ground truth labels and sum over all FPN levels 
          loss_c += self.loss_class(pos_class_gt, pos_class_pred, neg_class_gt, neg_class_pred, pos_anchors, neg_anchors)

          # step 5: calculate regressor loss for only positive ground truth labels
          #reshape the regressor outputs and targets from (bs, 4*num_anchors, Sy, Sx) to (4*num_anchors, bs, Sy, Sx)
          targ_regr = targ_regr_list[lvl].permute(1,0,2,3)
          regr_out = regr_out_list[lvl].permute(1,0,2,3)
          #flatten the regressor output and targets to (4, bs*Sy*Sx)
          targ_regr = targ_regr.reshape(4,-1)
          regr_out = regr_out.reshape(4,-1)
          #get only postive ground truth labels for the regressor
          pos_regr_gt = targ_regr[:, pos_anchors]
          pos_regr_pred = regr_out[:, pos_anchors]
          #extremely kaam chalau method below to eliminate the fpn levels where no bounding boxes are predicted
          if len(pos_regr_pred[0])!=0 and len(pos_regr_gt[0])!=0:
            #call regressor loss and sum over all FPN levels where the list is not empty
            loss_r += self.loss_reg(pos_regr_gt, pos_regr_pred)

        #here i am taking the average over the batch size
        loss_classifier = loss_c/batch_size
        loss_regressor = loss_r/batch_size
        loss = loss_classifier + l * loss_regressor

        return loss, loss_classifier, loss_regressor
    
    '''
    Input  : clas:(top_k_boxes)
             prebox:(top_k_boxes,4)
    Output  :
    nms_clas : (remaining_boxes_after_nms)
    nms_prebox : (remaining_boxes_after_nms, 4)
    '''
    def NMS(self, clas, prebox, thresh):
        bbox_sorted=deepcopy(prebox.permute(1,0).detach().cpu())
        clas_sorted=deepcopy(clas.detach().cpu())
        nms_prebox=[]
        nms_clas=[]
        bbox_sorted=list(bbox_sorted.numpy())
        clas_sorted=list(clas_sorted.numpy())

        while len(bbox_sorted)!=0:
          curr_bbox=bbox_sorted[0]
          curr_conf=clas_sorted[0]
          bbox_sorted.remove(curr_bbox)
          clas_sorted.remove(curr_conf)
          nms_prebox.append(curr_bbox)
          nms_clas.append(curr_conf)
          for id,diff_boxes in enumerate(bbox_sorted):
            if(box_iou(torch.from_numpy(curr_bbox).reshape(1,4),torch.from_numpy(diff_boxes).reshape(1,4))[0]>0.5):
              del bbox_sorted[id]
              del clas_sorted[id]

        return nms_clas, nms_prebox

    '''
    Post process the output for one image across one FPN level
    Input:
       mat_clas: {(1*num_anchors,grid_size[0],grid_size[1])}  (score of the output boxes)
       mat_coord: {(4*num_anchors,grid_size[0],grid_size[1])} (encoded coordinates of the output boxess)
       anchors_list[lvl] (num_anchors,grid_size[0],grid_size[1],4)
    Output:
        nms_clas: (Post_NMS_boxes)
        nms_prebox: (Post_NMS_boxes,4)
    '''
    def postprocessImg(self,mat_clas, mat_coord, anchors, IOU_thresh, keep_num_postNMS):
        #decode mat_coord from tx, ty, tw, th to normal x,y,w,h 
        #preparing coordinates to match anchors
        temp_coords = []
        for a in range(self.num_anchors):
          temp_coords.append(mat_coord[a:a+4])
        reshaped_coords = torch.stack(temp_coords).permute(0,2,3,1).reshape(-1,4)
        anchors = anchors.reshape(-1,4).to(device)   #flattening
        xa = anchors[:,0]
        ya = anchors[:,1]
        wa = anchors[:,2]
        ha = anchors[:,3]
        tx, ty, tw, th = reshaped_coords[:,0],  reshaped_coords[:,1], reshaped_coords[:,2], reshaped_coords[:,3]
        x, y, w, h = tx*wa+xa, ty*ha+ya, torch.exp(tw)*wa, torch.exp(th)*ha
        #convert to x1,y1,x2,y2 format
        x1, y1, x2, y2 = x- w/2, y-h/2, x+w/2, y+h/2
        #remove invalid upper and lower bounds
        cross_bound_l = torch.logical_or(x1<=0,y1<=0)
        cross_bound_h = torch.logical_or(x2>=1088.0,y2>=800)
        mat_clas = mat_clas.flatten()
        mat_clas[cross_bound_l]=0
        mat_clas[cross_bound_h]=0

        #take only the objectness that is >0.5
        thresh_obj = mat_clas[mat_clas>0.5]
        #take corresponding x1,y1,x2,y2
        x1, y1, x2, y2 = x1[mat_clas>0.5], y1[mat_clas>0.5], x2[mat_clas>0.5], y2[mat_clas>0.5]

        num_ele = len(thresh_obj)
        #take only top keep_num_preNMS boxes in sorted form to pass to nms
        id = torch.argsort(thresh_obj, descending=True)[:num_ele]
        prebox = torch.vstack((x1[id], y1[id], x2[id], y2[id]))
        clas = thresh_obj[id]
        # call NMS 
        post_nms_class, post_nms_bboxes = self.NMS(clas, prebox, IOU_thresh)
        post_nms_class = np.array(post_nms_class)
        post_nms_bboxes = np.array(post_nms_bboxes)
        keep = min(keep_num_postNMS, post_nms_class.shape[0])
        nms_clas = post_nms_class[:keep]
        nms_prebox = post_nms_bboxes[:keep]

        prebox = prebox.T.cpu().detach().numpy()
        clas = clas.cpu().detach().numpy()

        return nms_clas, nms_prebox, clas, prebox

    '''
    Post process for the outputs for a batch of images
    Input:
          out_c: list:len(FPN){(bz,1*num_anchors,grid_size[0],grid_size[1])}
          out_r: list:len(FPN){(bz,4*num_anchors,grid_size[0],grid_size[1])}
          IOU_thresh: scalar that is the IOU threshold for the NMS
          keep_num_postNMS: number of masks we will keep from each image after the NMS
    Output:
          nms_clas_list: list:len(bz){(Post_NMS_boxes)} (the score of the boxes that the NMS kept)
          nms_prebox_list: list:len(bz){(Post_NMS_boxes,4)} (the coordinate of the boxes that the NMS kept)
    '''
    def postprocess(self, out_c, out_r, IOU_thresh=0.5, keep_num_postNMS=100):
        nms_class_list = []
        nms_box_list = []
        clas_list = []
        prebox_list = []
        anchors_list = self.anchors
        for img_id in range(len(out_c[0])): #iterate over each image in the batch_size
          lvl_pre_cls = []
          lvl_pre_box = []
          lvl_post_cls = []
          lvl_post_box = []
          for lvl in range(len(out_c)):  #iterate over FPN each level         
            nms_class, nms_box, clas, prebox = self.postprocessImg(out_c[lvl][img_id], out_r[lvl][img_id], anchors_list[lvl], IOU_thresh, keep_num_postNMS)
            if torch.tensor(nms_class).shape[0]!=0:
              lvl_post_cls.append(torch.tensor(nms_class))
            if torch.tensor(nms_box).shape[0]!=0:
              lvl_post_box.append(torch.tensor(nms_box))
            if torch.tensor(clas).shape[0]!=0:
              lvl_pre_cls.append(torch.tensor(clas))
            if torch.tensor(prebox).shape[0]!=0:
              lvl_pre_box.append(torch.tensor(prebox))
          nms_class_list.append(torch.cat(lvl_post_cls))
          nms_box_list.append(torch.cat(lvl_post_box))
          clas_list.append(torch.cat(lvl_pre_cls))
          prebox_list.append(torch.cat(lvl_pre_box))
        return nms_class_list, nms_box_list, clas_list, prebox_list

## Visualization of ground truth for 1 example (only on CPU for now)

In [None]:
#before implementing the loss function lets visualize the ground truth (works only in cpu)
rpn = RPNHead()
anchors_list = rpn.get_anchors()
image, _, _, bboxes, _ = dataset.__getitem__(1)
# transed_img, labels, transed_masks, transed_bboxes, index
g_class, g_coord = rpn.create_ground_truth(bboxes = bboxes, grid_sizes = rpn.grid_size,anchors_list = anchors_list)

#display original image2
for lvl in range(len(g_class)):
  fig = plt.figure()
  ax = fig.add_subplot()
  ax.imshow(image.detach().cpu().numpy().transpose(1,2,0))
  for a in range(3):
    #prepare corresponding anchors 
    flat_anchors = anchors_list[lvl][a].permute(2,0,1).flatten(start_dim=1,end_dim=-1).T
    flat_ground_coord = g_coord[lvl][a:a+4].flatten(start_dim=1,end_dim=-1)
    flat_ground_class = g_class[lvl][a].reshape(-1)
    objects = (flat_ground_class==1).nonzero().flatten()
    #print(objects)

    #decoding object wise
    for elem in objects:
          x_a, y_a, w_a, h_a = flat_anchors[elem, 0], flat_anchors[elem, 1], flat_anchors[elem, 2], flat_anchors[elem, 3]
          tx, ty, tw, th = flat_ground_coord[0, elem], flat_ground_coord[1, elem], flat_ground_coord[2, elem], flat_ground_coord[3, elem]
          w = torch.exp(tw) * w_a
          h = torch.exp(th) * h_a
          x = (tx * w_a) + x_a
          y = (ty * h_a) + y_a
          x = x.to('cpu')
          y = y.to('cpu')
          w = w.to('cpu')
          h = h.to('cpu')
          rect = patches.Rectangle((x-w/2,y-h/2), w, h,fill=False,color='r')
          # rect = rect.to('cpu')
          ax.add_patch(rect)
          rect = patches.Rectangle((x_a-w_a/2, y_a-h_a/2), w_a, h_a, fill=False,color='b')
          ax.add_patch(rect)
          
    plt.show()

## Training

In [None]:
rpn_network_obj = RPNHead().to(device)
optimizer = torch.optim.Adam(rpn_network_obj.parameters(), lr = 0.001)

train_loss = []
train_loss_c = []
train_loss_r = []
val_loss = []
val_loss_c = []
val_loss_r = []
num_epochs = 30

for epoch in tqdm(range(num_epochs)):
  run_loss = 0
  run_loss_c = 0
  run_loss_r = 0
  #train the model
  rpn_network_obj.train()
  for i, (images, labels, masks, bboxes, indexes) in enumerate(train_loader):
    #pass it through the network
    clas, regr = rpn_network_obj.forward(images.to(device))
    #get gt
    gt_clas_batch, gt_coord_batch = rpn_network_obj.create_batch_truth(bboxes)
    #zero out optimizer
    optimizer.zero_grad()
    #calc the loss
    loss, loss_c , loss_r = rpn_network_obj.compute_loss(clas, regr, gt_clas_batch, gt_coord_batch)
    #backward pass
    loss.backward()
    optimizer.step()
    # empty intermediate predcitions 
    del clas, regr, gt_clas_batch, gt_coord_batch
    torch.cuda.empty_cache()

    run_loss += loss.item()
    run_loss_c += loss_c.item()
    run_loss_r += loss_r.item()

  train_loss.append(run_loss/len(full_loader))
  train_loss_c.append(run_loss_c/len(full_loader))
  train_loss_r.append(run_loss_r/len(full_loader))

  # checkpoint
  if(epoch%2==0):
    path = "/content/drive/MyDrive/CIS680/Final/Checkpoint/epoch-" +str(epoch+1)
    torch.save({
        'epoch': epoch+1,
        'model_state_dict': rpn_network_obj.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': run_loss/len(full_loader),
        'clas_loss': run_loss_c/len(full_loader),
        'regr_loss': run_loss_r/len(full_loader)
        }, path)
  
  #validating the model
  rpn_network_obj.eval()
  run_eval_loss = 0
  run_eval_loss_c = 0
  run_eval_loss_r = 0
  for j, (images, labels, masks, bboxes,indexes) in enumerate(test_loader):
    vclas, vregr = rpn_network_obj.forward(images.to(device))
    vgt_clas_batch, vgt_coord_batch=rpn_network_obj.create_batch_truth(bboxes)
    vloss, loss_cls , loss_re = rpn_network_obj.compute_loss(vclas, vregr, vgt_clas_batch, vgt_coord_batch)
    run_eval_loss += vloss.item()
    run_eval_loss_c += loss_cls.item()
    run_eval_loss_r += loss_re.item()
    del vclas, vregr, vgt_clas_batch, vgt_coord_batch
    torch.cuda.empty_cache()

  val_loss.append(run_eval_loss/len(full_loader))
  print("Validation Loss = ", val_loss[-1])
  val_loss_c.append(run_eval_loss_c/len(full_loader))
  val_loss_r.append(run_eval_loss_r/len(full_loader))

In [None]:
plt.plot(val_loss_c, label='classifier val loss')
plt.plot(val_loss_r, label='regressor val loss')
plt.xlabel("epochs")
plt.ylabel("loss")
plt.title("category loss")
plt.legend()
plt.grid()
plt.show()

## Post Processing

In [None]:
path = "drive/MyDrive/CIS680/Final/Checkpoint/epoch-29"
checkpoint = torch.load(path)
test_model = RPNHead().to(device)

test_model.load_state_dict(checkpoint['model_state_dict'])

#in eval mode
test_model.eval()
for idx, (images, labels, masks, bboxes, indexes) in enumerate(test_loader):
  clas, regr = test_model.forward(images) #predict
  post_nms_clas, post_nms_predbox, clas, prebox = test_model.postprocess(clas, regr)  #post process

In [None]:
def visualize_post_process(boxes_list, images):
  for i in range(len(boxes_list)):    #number of examples to show
    fig = plt.figure()
    ax = fig.add_subplot()
    objs = boxes_list[i]
    ax.imshow(images[i].detach().cpu().numpy().transpose((1,2,0)))
    for box in objs:
      x1, y1, x2, y2 = box[0], box[1], box[2], box[3]
      rect = rec((x1, y1), (x2 - x1), (y2 - y1), fill=False, color='r')
      ax.add_patch(rect)
  plt.show()

In [None]:
print("Pre NMS")
visualize_post_process(prebox,images)
print("Post NMS")
visualize_post_process(post_nms_predbox,images)
print(clas)


# Part 5 : Box Head - Old
- Lot of this is derived from HW 4B

## class BoxHead definition

In [None]:
class BoxHead(nn.Module):
    def __init__(self):
        super(BoxHead, self).__init__()
        self.num_classes = 3
        self.P = 7
        self.intermediate = nn.Sequential(#linear 1
                                          nn.Linear(in_features=256*self.P*self.P, out_features=1024),
                                          nn.ReLU(),
                                          #linear 2
                                          nn.Linear(in_features=1024, out_features=1024),
                                          nn.ReLU(),
                                          )
        #maybe softmax is not advised for training in the classifier. verify later
        self.classifier = nn.Sequential(
                                        nn.Linear(in_features=1024, out_features=self.num_classes + 1), 
                                        #nn.Softmax()
                                        )
        self.regressor = nn.Sequential(
                                       nn.Linear(in_features=1024, out_features=int(4*self.num_classes))
                                      )
        
    '''
    input : proposals:  list len batch_size each tensor of dim (top_k_proposals, 4) [format : (x1, y1, x2, y2)]
    output : feature_vectors : (bs*top_k_proposals, 256*P*P) = (total_proposals, 256*P*P)
    function is for a batch
    '''
    def find_feature_vectors(self, batch_proposals, fpn_feat_list):
        #regardless of number of objects in each image, there would only be top k proposal boxes fixed. the top_k value is up to us.
        feature_vectors = []
        #loop over batch
        for i in range(len(batch_proposals)):
          #loop over each proposal
          for p in range(len(batch_proposals[i])):
            pb = batch_proposals[i][p]    #these will be 1 proposal
            #convert proposal box from x1,y1,x2,y2 to x,y,w,h format
            x = (pb[0] + pb[2])/2.0
            y = (pb[1] + pb[3])/2.0
            w = pb[2] - pb[0]
            h = pb[3] - pb[1]
            #choose appropriate fpn level to pool features
            k = torch.clamp(torch.floor(4 + torch.log2(torch.sqrt(w*h)/224)), 2, 5)
            fpn_level = (k - 2.0).int()     #this is a single integer value
            #region of proposal box is given in image coords but we need to change them to feature map coords
            scale_x = fpn_feat_list[fpn_level].shape[3]/1088
            scale_y = fpn_feat_list[fpn_level].shape[2]/800
            #scale the proposal box to feature map coords
            p_scaled = pb.reshape(1,-1).clone()    #(1,4)
            p_scaled[:,0] = p_scaled[:,0] * scale_x
            p_scaled[:,1] = p_scaled[:,1] * scale_y
            p_scaled[:,2] = p_scaled[:,2] * scale_x
            p_scaled[:,3] = p_scaled[:,3] * scale_y
            #do ROI align of feature map and proposal box (both in x1,y1,x2,y2 format)
            fv  = torchvision.ops.roi_align(fpn_feat_list[fpn_level][i].unsqueeze(0), [p_scaled.to(device)] , output_size=self.P, spatial_scale=1,sampling_ratio=-1)
            feature_vectors.append(fv.flatten())  #appended (1 * 256 * P * P) 
          
        feature_vectors = torch.vstack(feature_vectors) #should be (bs * topK, 256*P^2)
        return feature_vectors

    '''
    input : feature_vectors : (total_proposals , 256*P^2) (bs*top_k_proposals = total_proposals)
    outputs : class_prob : (total_proposals, num_classes+1)
              bbox_reg : (total_proposals, num_classes*4)
    '''
    def box_head(self, feature_vectors):
        X = self.intermediate(feature_vectors)
        class_prob = self.classifier(X)
        bbox_reg = self.regressor(X)

        return class_prob, bbox_reg

    '''
    input : feature_vectors : (bs, 3, h_img, w_img)
    outputs : class_prob : (total_proposals, num_classes+1)
              bbox_reg : (total_proposals, num_classes*4)
    '''
    def forward(self, feature_vectors):
        #we do not want to retrain our backbone and rpn so even when eval=False for the rest of the network, it is true for backbone and rpn
        #pass feature vectors to box head and get class prob and bbox_regressionss
        class_prob, bbox_reg = self.box_head(feature_vectors)
        #print(class_prob.shape)
        #print(bbox_reg.shape)
        return class_prob, bbox_reg
    
    '''
    inputs : bboxes :  (n_objects, 4)
             proposals : (top_K proposals, 4) 
             labels :  (n_objects,)
    outputs : gt_class : (top_K proposals, num_classes+1)
              gt_coord : (top_K proposals, num_classes*4)
    note :
    defined for 1 image
    bboxes are in the form : x1,y1,x2,y2
    proposals are in the form : x1,y1,x2,y2
    '''
    def create_ground_truth(self, bboxes, proposals, labels):
        #print(labels)
        #initialize ground truth class and ground truth coords
        
        gt_class = torch.zeros((proposals.shape[0], self.num_classes+1)).to(device)
        gt_coord = torch.zeros((proposals.shape[0], self.num_classes*4)).to(device)
        en_gtcoord = torch.zeros((proposals.shape[0], self.num_classes*4)).to(device)
        # pdb.set_trace()
        #calculte iou
        iou = box_iou(proposals,bboxes) #shape is (topK, n_obj)  #  ORDER IS IMPORTANT!
        max_iou, idx = torch.max(iou, dim=1)        #find max iou and if that iou is >= 0.5 only then take corresponding index. 
                                                    #do this for each object in the image
        chosen_idx = max_iou>=0.5

        #ASSIGNING GROUND TRUTH LABELS
        #for the ground truth class, label=channel will hold 1 (one hot encoding)
        gt_class[chosen_idx, labels[idx[chosen_idx]]] = 1
        #print(gt_class)
        #for background we have to do this. big brain time!
        bg_idx = torch.all(iou<0.5,dim=1)
        gt_class[bg_idx, 0] = 1
        #print(gt_class)

        #ASSIGNING GROUND TRUTH BBOXES
        #for ground truth bboxes we take the following steps
        #convert bboxes from x1,y1,x2,y2 to x,y,w,h
        converted_bboxes = torch.zeros_like(bboxes)
        converted_bboxes[:,0] = (bboxes[:,0] + bboxes[:,2])/2.0
        converted_bboxes[:,1] = (bboxes[:,1] + bboxes[:,3])/2.0
        converted_bboxes[:,2] = bboxes[:,2] - bboxes[:,0]
        converted_bboxes[:,3] = bboxes[:,3] - bboxes[:,1]
        #convert all proposals from x1,y1,x2,y2 to x,y,w,h 
        proposals[:, 0] = (proposals[:, 0] + proposals[:, 2])/2.0
        proposals[:, 1] = (proposals[:, 1] + proposals[:, 3])/2.0
        proposals[:, 2] = (proposals[:, 2] - proposals[:, 0])
        proposals[:, 3] = (proposals[:, 3] - proposals[:, 1])
        #encode to tx,ty,tw,th
        encoded_bboxes = torch.zeros_like(proposals)
        encoded_bboxes[:,0] = (converted_bboxes[idx,0] - proposals[:, 0])/proposals[:, 2]
        encoded_bboxes[:,1] = (converted_bboxes[idx,1] - proposals[:, 1])/proposals[:, 3]
        encoded_bboxes[:,2] = torch.log(converted_bboxes[idx,2]/proposals[:, 2])
        encoded_bboxes[:,3] = torch.log(converted_bboxes[idx,3]/proposals[:, 3])
        #assign to gt_coord at appropriate indices
        lbl = torch.where(gt_class==1)[1]     #get the labels, including when the label is background (0)
        lbl_idx = lbl.nonzero().flatten()     #we want to insert only the non background boxes to the gt_coord matrix. 
                                              #so we take only those indices which have non zero values
        for j in range(len(lbl_idx)):
          gt_coord[lbl_idx[j], (lbl[lbl_idx[j]]-1)*4:lbl[lbl_idx[j]]*4] = proposals[j]
          en_gtcoord[lbl_idx[j], (lbl[lbl_idx[j]]-1)*4:lbl[lbl_idx[j]]*4] = encoded_bboxes[j]
        #print(gt_coord)

        return gt_class, gt_coord, en_gtcoord
    
    '''
    inputs : bboxes :  list(bs)(n_objects, 4)
             proposals : list(bs)(top_K proposals, 4) 
             labels :  list (bs)(n_objects,)
    outputs : gt_class : (total_proposals, num_classes+1)
              gt_coord : (total_proposals, num_classes*4) (bs*top_k_proposals = total_proposals)
    '''
    def create_batch_truth(self, bboxes, proposals, labels):
        gt_class_batch = []
        gt_coord_batch = []
        for b in range(len(proposals)): #iterate over each image in the batch
          gt_class, _, gt_coord = self.create_ground_truth(bboxes[b], proposals[b], labels[b])
          gt_class_batch.append(gt_class)
          gt_coord_batch.append(gt_coord)
        gt_class_batch = torch.vstack(gt_class_batch)
        gt_coord_batch = torch.vstack(gt_coord_batch)

        return gt_class_batch, gt_coord_batch
        
    '''
    inputs : class_prob : (total_proposals, num_classes+1)
             bbox_reg : (total_proposals, num_classes*4)
             gt_class : (total_proposals, num_classes+1)
             gt_coord : (total_proposals, num_classes*4) (bs*top_k_proposals = total_proposals) these are encoded
    outputs : loss, loss_class, loss_reg
    notes : loss is for full batch of bs images
            labels : vehicle, human, animal means positive; background means negative
            make positive and negative stuff only for classifier loss
            use only positive classifier stuff for regressor loss
            effective batch is basically a subset taken from the batch of proposals (total_proposals me se liya hua batch)
    '''
    def compute_loss(self, class_prob, bbox_reg, gt_class, gt_coord, l=1, effective_batch=50):
        #computing classifier loss on both positive and negative labels
        lbl = torch.where(gt_class==1)[1]
        l_p = (lbl!=0).nonzero().flatten()
        l_n = (lbl==0).nonzero().flatten()
        p_size = min(l_p.shape[0], int(effective_batch*3/4))
        n_size = min(l_n.shape[0], int(effective_batch/4))
        p_id = torch.randperm(l_p.shape[0])[:p_size]
        p_id = l_p[p_id]
        n_id = torch.randperm(l_n.shape[0])[:n_size]
        n_id = l_n[n_id]
        p_class_prob = class_prob[p_id,:]   #wherever the column element in zeroth column is NOT 1, take those corresponding rows and use the postive mask (some object is detected)
        n_class_prob = class_prob[n_id,:]   #wherever the column element in zeroth column is 1, take those corresponding rows and use the negative mask (this is the background)
        p_gt_class = gt_class[p_id,:]
        n_gt_class = gt_class[n_id,:]

        criterion_class = torch.nn.CrossEntropyLoss()
        #doing pos and neg wise CE loss and summing
        loss_class_p = criterion_class(p_class_prob, p_gt_class)
        loss_class_n = criterion_class(n_class_prob, n_gt_class)
        loss_class = loss_class_p + loss_class_n
        #print(loss_class)
        
        #computing regressor loss only on the positive bboxes 
        p_bbox_reg = bbox_reg[p_id,:]
        p_gt_coord = gt_coord[p_id,:]
        criterion_reg = torch.nn.SmoothL1Loss()   #default reduction is mean so we do not have to divide by N_reg
        #doing class label wise smooth L1 loss and then summing
        loss_reg_1 = sum([criterion_reg(p_bbox_reg[:,i], p_gt_coord[:,i]) for i in range(4)])     #vehicle
        loss_reg_2 = sum([criterion_reg(p_bbox_reg[:,i], p_gt_coord[:,i]) for i in range(4,8)])   #human
        loss_reg_3 = sum([criterion_reg(p_bbox_reg[:,i], p_gt_coord[:,i]) for i in range(8,12)])  #animal
        loss_reg = loss_reg_1 + loss_reg_2 + loss_reg_3
        #print(loss_reg)

        #computing total loss
        loss = loss_class + l * loss_reg

        return loss, loss_class, loss_reg

    '''
    for 1 image only (with multiple bboxes and corresponding confidence scores and labels)
    inputs : conf len(bs) (keep_num_preNMS)
             bboxes len(bs) (keep_num_preNMS,4)
             lbl len(bs) (keep_num_preNMS)
    outputs : conf len(bs) (keep_num_postNMS)
             bboxes len(bs) (keep_num_postNMS,4)
             lbl len(bs) (keep_num_postNMS)
    '''
    def NMS(self, conf, bbox, lbl, thresh):
        bbox_clone = bbox.clone().detach()
        conf_clone = conf.clone().detach()
        lbl_clone = lbl.clone().detach()
        nms_box = []
        nms_conf = []
        nms_lbl = []
        bbox_clone = list(bbox_clone)
        conf_clone = list(conf_clone)
        lbl_clone = list(lbl_clone)

        while len(bbox_clone)!=0:
          #get the top value of box, conf and labels
          curr_box = bbox_clone[0]
          curr_conf = conf_clone[0]
          curr_lbl = lbl_clone[0]
          #remove them from clone lists
          bbox_clone.remove(curr_box)
          conf_clone.remove(curr_conf)
          lbl_clone.remove(curr_lbl)
          #append them into corresponding nms lists
          nms_box.append(curr_box)
          nms_conf.append(curr_conf)
          nms_lbl.append(curr_lbl)
          for i, other_bboxes in enumerate(bbox_clone):
            if(box_iou((curr_box).reshape(1,4), other_bboxes.reshape(1,4))[0]>thresh):
              del bbox_clone[i]
              del conf_clone[i]
              del lbl_clone[i]

        #---REVISIT AND FIX----
        #for now if the predicted image is not a background then the nms is applied
        #if the predicted image is background then the same pre_nms stuff is returned
        if nms_conf:
          return torch.stack(nms_conf), torch.stack(nms_box), torch.stack(nms_lbl)
        else:
          return conf, bbox, lbl

    '''
    inputs : proposals len(bs) (top_K_proposals,4)
             classprob (total_proposals,C+1)    total_proposals = topK_proposals * bs
             bbox_reg (total_proposals, C*4)
    outputs : bboxes len(bs) (keep_num_postNMS,4)
              scores len(bs) (keep_num_postNMS)
              labels len(bs) (keep_num_postNMS)
              #same things for (keep_num_preNMS)
    '''
    def post_process(self, class_prob, bbox_reg, proposals, conf_thresh = 0.2, keep_num_preNMS = 10, keep_num_postNMS = 5, keep_topK = 200):
        postnms_bboxes = []
        postnms_scores = []
        postnms_labels = []
        prenms_bboxes = []
        prenms_scores = []
        prenms_labels = []
        for img_id in range(len(proposals)):
          c_prob = class_prob[int(img_id*keep_topK) : int((img_id+1)*keep_topK), :] #(top_K_propsals,C+1)
          prop = proposals[img_id]
          b_reg = bbox_reg[int(img_id*keep_topK) : int((img_id+1)*keep_topK), :] #(top_K_propsals,C*4)
          lbl = torch.argmax(c_prob, axis=1)
          conf_max = torch.max(c_prob, axis=1)[0]
          #mask is when confidence is more than the threshold and the label is nonzero (so that no background is coming)
          mask = torch.logical_and(conf_max>conf_thresh, lbl>0)
          #keep only the masked class probabilities, masked labels, masked bbox regressions and masked proposals
          m_conf, m_lbl, m_b_reg, m_prop = conf_max[mask], lbl[mask], b_reg[mask], prop[mask]
          #keep value will be the min of pre decided preNMS number and the number of confidence scores left
          keep = min(keep_num_preNMS, len(m_conf))
          #sort the confidence scores in descending order and keep only as many as keep value
          kept_idx = torch.argsort(m_conf, dim = 0, descending=True)[:keep]
          keep_conf, keep_lbl, keep_b_reg, keep_prop = m_conf[kept_idx], m_lbl[kept_idx], m_b_reg[kept_idx], m_prop[kept_idx]
          #transform the bboxes from (keep, C*4) to (keep, 4) for easiness and also for passing into NMS
          reshaped_keepbreg = torch.zeros((len(keep_b_reg),4)).to(device)
          for b in range(len(keep_b_reg)):
            reshaped_keepbreg[b,:] = keep_b_reg[b, int(keep_lbl[b]-1)*4:(int(keep_lbl[b]-1)*4)+4]
          #decoding reshaped_keepbreg with corresponding keep_prop (convert keep_prop to x,y,w,h) form and then find x*,y*,w*,h* then convert to x1,y1,x2,y2
          decoded = torch.zeros_like(reshaped_keepbreg).to(device)
          xp = (keep_prop[:, 0] + keep_prop[:, 2])/2.0
          yp = (keep_prop[:, 1] + keep_prop[:, 3])/2.0
          wp = (keep_prop[:, 2] - keep_prop[:, 0])
          hp = (keep_prop[:, 3] - keep_prop[:, 1])
          x = (reshaped_keepbreg[:,0] * wp) + xp
          y = (reshaped_keepbreg[:,1] * hp) + yp
          w = torch.exp(reshaped_keepbreg[:,2]) * wp
          h = torch.exp(reshaped_keepbreg[:,3]) * hp
          decoded[:,0] = torch.clip(x-w/2, min=0) #x1
          decoded[:,1] = torch.clip(y-h/2, min=0) #y1
          decoded[:,2] = torch.clip(x+w/2, max=1088) #x2
          decoded[:,3] = torch.clip(y+h/2, max=800) #y2
          #collect all prenms bboxes, scores, labels
          prenms_bboxes.append(decoded)
          prenms_scores.append(keep_conf)
          prenms_labels.append(keep_lbl)
          #do NMS
          after_conf, after_bboxes, after_labels = self.NMS(keep_conf, decoded, keep_lbl, 0.5)
          #collect all postnms bboxes, scores, labels
          after_keep = min(keep_num_postNMS, len(after_conf))
          postnms_bboxes.append(after_bboxes[:after_keep])
          postnms_scores.append(after_conf[:after_keep])
          postnms_labels.append(after_labels[:after_keep])
       
        #return pre and post nms bboxes, scores, labels
        return prenms_bboxes, prenms_scores, prenms_labels, postnms_bboxes, postnms_scores, postnms_labels

## BoxHead Training

In [None]:
def pretrained_models_680(checkpoint_file, eval=True):
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=False)
    if(eval):
        model.eval()
    model.to(device)        
    backbone = model.backbone
    rpn = model.rpn

    if(eval):
        backbone.eval()
        rpn.eval()

    rpn.nms_thresh=0.6
    checkpoint = torch.load(checkpoint_file)
    backbone.load_state_dict(checkpoint['backbone'])
    rpn.load_state_dict(checkpoint['rpn'])

    return backbone, rpn

In [None]:
network_obj = BoxHead().to(device)
optimizer = torch.optim.Adam(network_obj.parameters(), lr = 0.001)

train_loss = []
train_loss_c = []
train_loss_r = []
val_loss = []
val_loss_c = []
val_loss_r = []
num_epochs = 20
keep_topK = 200
path = '/content/drive/MyDrive/CIS680/Final/checkpoint680.pth'
print(path)
backbone, rpn1 = pretrained_models_680(path)
print(backbone)
for epoch in tqdm(range(num_epochs)):
  run_loss = 0
  run_loss_c = 0
  run_loss_r = 0
  #train the model
  network_obj.train()
  for i, (images, labels, masks, bboxes, indexes) in enumerate(train_loader):
    #pass through backbone
    backout = backbone(images.to(device))
    im_lis = ImageList(images, [(800, 1088)]*images.shape[0])
    #pass thorugh rpn, get proposals and fpn_feat_list
    rpnout = rpn1(im_lis, backout)
    proposals=[proposal[0:keep_topK,:] for proposal in rpnout[0]]
    fpn_feat_list= list(backout.values())
    #get gt
    gt_class_batch, gt_coord_batch = network_obj.create_batch_truth(bboxes, proposals, labels)
    #get feature vectors
    pdb.set_trace()
    feature_vectors = network_obj.find_feature_vectors(proposals, fpn_feat_list)
    #pass through forward
    class_prob, bbox_reg = network_obj.forward(feature_vectors.detach())      #TO ENSURE NO RETRAINING OF THE BACKBONE AND RPN
    #zero out optimizer
    optimizer.zero_grad()
    #calc the loss
    #CE loss already does softmax so we dont add it in the classifier network during training
    loss, loss_class ,loss_reg = network_obj.compute_loss(class_prob, bbox_reg, gt_class_batch.to(device), gt_coord_batch.to(device),effective_batch=150)
    #gather loss values
    run_loss += loss.item()
    run_loss_c += loss_class.item()
    run_loss_r += loss_reg.item()
    #backward pass
    loss.backward()
    optimizer.step()
    #empty intermediate predcitions 
    del class_prob, bbox_reg, gt_class_batch, gt_coord_batch
    torch.cuda.empty_cache()

  train_loss.append(run_loss/len(full_loader))
  train_loss_c.append(run_loss_c/len(full_loader))
  train_loss_r.append(run_loss_r/len(full_loader))

  # checkpoint
  if(epoch%2==0):
    path = "/content/drive/MyDrive/CIS680/Final/Box1812/epoch-" +str(epoch+1)
    torch.save({
        'epoch': epoch+1,
        'model_state_dict': network_obj.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': run_loss/len(full_loader),
        'clas_loss': run_loss_c/len(full_loader),
        'regr_loss': run_loss_r/len(full_loader)
        }, path)
  
  #validating the model
  network_obj.eval()
  run_eval_loss = 0
  run_eval_loss_c = 0
  run_eval_loss_r = 0
  for j, (images, labels, masks, bboxes,indexes) in enumerate(test_loader):
    #pass through backbone
    backout = backbone(images.to(device))
    im_lis = ImageList(images, [(800, 1088)]*images.shape[0])
    #pass thorugh rpn, get proposals and fpn_feat_list
    rpnout = rpn1(im_lis, backout)
    proposals=[proposal[0:keep_topK,:] for proposal in rpnout[0]]
    fpn_feat_list= list(backout.values())
    #get gt
    vgt_class_batch, vgt_coord_batch = network_obj.create_batch_truth(bboxes, proposals, labels)
    #get feature vectors
    vfeature_vectors = network_obj.find_feature_vectors(proposals, fpn_feat_list)
    #pass through forward
    vclass_prob, vbbox_reg = network_obj.forward(vfeature_vectors.detach())      #TO ENSURE NO RETRAINING OF THE BACKBONE AND RPN
    #calc the loss
    vloss, vloss_class ,vloss_reg = network_obj.compute_loss(vclass_prob, vbbox_reg, vgt_class_batch.to(device), vgt_coord_batch.to(device),effective_batch=150)
    run_eval_loss += vloss.item()
    run_eval_loss_c += vloss_class.item()
    run_eval_loss_r += vloss_reg.item()
    del vclass_prob, vbbox_reg, vgt_class_batch, vgt_coord_batch
    torch.cuda.empty_cache()

  val_loss.append(run_eval_loss/len(full_loader))
  val_loss_c.append(run_eval_loss_c/len(full_loader))
  val_loss_r.append(run_eval_loss_r/len(full_loader))

In [None]:
train_loss

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10,4))

ax[0].plot(train_loss)
ax[0].set_title('Total Train Loss')
ax[0].set_xlabel('epochs')
ax[0].set_ylabel('loss')
ax[0].grid()

ax[1].plot(val_loss)
ax[1].set_title('Total Validation Loss')
ax[1].set_xlabel('epochs')
ax[1].set_ylabel('loss')
ax[1].grid()

In [None]:
path = "/content/model_trained_boxhead.pth"
checkpoint = torch.load(path)
boxhead_network_obj = BoxHead().to(device)
# odict_keys(['intermediate_layer.0.weight', 'intermediate_layer.0.bias', 'intermediate_layer.2.weight', 'intermediate_layer.2.bias', 'classifier.0.weight', 'classifier.0.bias', 'regressor.0.weight', 'regressor.0.bias'])
for key in list(checkpoint.keys()):
    checkpoint[key.replace('intermediate_layer', 'intermediate')] = checkpoint.pop(key)
for key in list(checkpoint.keys()):
    checkpoint[key.replace('classifier_head', 'classifier')] = checkpoint.pop(key)
for key in list(checkpoint.keys()):
    checkpoint[key.replace('regressor_head', 'regressor')] = checkpoint.pop(key)
print(checkpoint.keys())
boxhead_network_obj.load_state_dict(checkpoint, strict = False)

In [None]:
topK = 200
network_obj.eval()
softmax = nn.Softmax()

with torch.no_grad():
  for i, (images, labels, masks, bboxes, indexes) in enumerate(test_loader):
    backout = backbone(images)
    im_lis = ImageList(images, [(800, 1088)]*images.shape[0])
    rpnout = rpn1(im_lis, backout)
    proposals=[proposal[0:topK,:].to(device) for proposal in rpnout[0]]
    # proposals=[proposal[0:topK,:] for proposal in rpnout[0]]
    fpn_feat_list= list(backout.values())
    #get feature vectors
    feature_vectors = network_obj.find_feature_vectors(proposals, fpn_feat_list)
    class_prob, bbox_reg = network_obj(feature_vectors.detach())      #TO ENSURE NO RETRAINING OF THE BACKBONE AND RPN
    cls_soft = softmax(class_prob)    #do softmax now as we did not do it in forward function
    #getmax confidence score and corresponding index
    print("class_prob = \n", class_prob)
    cls_with_softmax, max_indices = torch.max(cls_soft,dim=1)  
    print("max_indices = ", max_indices)
    for i in range(len(images)):
        #display the image
        fig = plt.figure()
        ax = fig.add_subplot()
        ax.imshow(images[i].detach().cpu().numpy().transpose(1,2,0))
        print("___________")
        cls_img, lbl = cls_with_softmax[i*topK:(i+1)*topK], max_indices[i*topK:(i+1)*topK]
        #iterate over proposals
        print("Proposals = ", len(proposals))
        for j in range(len(proposals)):
          #if label is 0, it is background so ignore
          if(lbl[j]==0): 
            print("Background proposal ", j)
            continue
          #get single proposal and bounding box
          edge_ind = (bbox_reg[j][((lbl[j]-1)*4).int() : ((lbl[j]-1)*4+4).int()]).cpu()
          prop = (proposals[i][j]).cpu()
          #decode the bounding box with corresponding proposal
          tx, ty, tw, th = edge_ind[0], edge_ind[1], edge_ind[2], edge_ind[3]
          xp, yp, wp, hp = (prop[0]+prop[2])/2, (prop[1]+prop[3])/2, prop[2]-prop[0], prop[3]-prop[1]
          x, y, w, h = tx*wp+xp, ty*hp+yp, torch.exp(tw)*wp, torch.exp(th)*hp
          x1, y1 = x-(w/2), y-(h/2)
          print(x,y,w,h)
          print("--------------------")
          #plot the top 20 predicted bbox regressions
          if (lbl[j]==1): #----vehicle
            rect = rec((x1, y1),w,h,fill=False,color='r',linewidth=2)
            ax.add_patch(rect)
          if (lbl[j]==2): #----human
            rect = rec((x1, y1),w,h,fill=False,color='g',linewidth=2)
            ax.add_patch(rect)
          if (lbl[j]==3): #----animal
            rect = rec((x1, y1),w,h,fill=False,color='b',linewidth=2)
            ax.add_patch(rect)
        plt.show() 

# Part 6 : Mask Head - Raima

## class MaskHead definition

In [None]:
class MaskHead(nn.Module):
    def __init__(self):
        super(MaskHead, self).__init__()
        self.num_classes = 3
        self.P = 14
        self.layers1to4 = nn.Sequential( nn.Conv2d(in_channels=256 , out_channels= 256, kernel_size= 3, padding= 'same'),
                                         nn.ReLU(),
                                         nn.Conv2d(in_channels=256 , out_channels= 256, kernel_size= 3, padding= 'same'),
                                         nn.ReLU(),
                                         nn.Conv2d(in_channels=256 , out_channels= 256, kernel_size= 3, padding= 'same'),
                                         nn.ReLU(),
                                         nn.Conv2d(in_channels=256 , out_channels= 256, kernel_size= 3, padding= 'same'),
                                         nn.ReLU(),
                                          )
        self.deconv = nn.Sequential( nn.ConvTranspose2d(in_channels=256 , out_channels= 256, kernel_size= 3, stride=2, padding= 1, output_padding=1),
                                    nn.ReLU())
        self.conv = nn.Sequential(nn.Conv2d(in_channels=256 , out_channels= self.num_classes, kernel_size= 1),
                                  nn.Sigmoid())
        
    '''
    input : proposals:  list len batch_size each tensor of dim (top_k_proposals, 4) [format : (x1, y1, x2, y2)]
    output : feature_vectors : (bs*top_k_proposals, 256*P*P) = (total_proposals, 256*P*P)
    function is for a batch
    '''
    def find_feature_vectors(self, batch_proposals, fpn_feat_list):
        #regardless of number of objects in each image, there would only be top k proposal boxes fixed. the top_k value is up to us.
        feature_vectors = []
        #loop over batch
        for i in range(len(batch_proposals)):
          #loop over each proposal
          for p in range(len(batch_proposals[i])):
            pb = batch_proposals[i][p]    #these will be 1 proposal
            #convert proposal box from x1,y1,x2,y2 to x,y,w,h format
            x = (pb[0] + pb[2])/2.0
            y = (pb[1] + pb[3])/2.0
            w = pb[2] - pb[0]
            h = pb[3] - pb[1]
            #choose appropriate fpn level to pool features
            k = torch.clamp(torch.floor(4 + torch.log2(torch.sqrt(w*h)/224)), 2, 5)
            fpn_level = (k - 2.0).int()     #this is a single integer value
            #region of proposal box is given in image coords but we need to change them to feature map coords
            # pdb.set_trace()
            scale_x = fpn_feat_list[fpn_level].shape[3]/1088
            scale_y = fpn_feat_list[fpn_level].shape[2]/800
            #scale the proposal box to feature map coords
            p_scaled = pb.reshape(1,-1).clone()    #(1,4)
            p_scaled[:,0] = p_scaled[:,0] * scale_x
            p_scaled[:,1] = p_scaled[:,1] * scale_y
            p_scaled[:,2] = p_scaled[:,2] * scale_x
            p_scaled[:,3] = p_scaled[:,3] * scale_y
            #do ROI align of feature map and proposal box (both in x1,y1,x2,y2 format)
            fv  = torchvision.ops.roi_align(fpn_feat_list[fpn_level][i].unsqueeze(0), [p_scaled.to(device)] , output_size=self.P, spatial_scale=1,sampling_ratio=-1)
            # print("------\nIn Mask Head\nfv = ", fv.shape)
            # print("------------")
            feature_vectors.append(fv)  #appended (1 * 256 * P * P) 
          
        feature_vectors = torch.vstack(feature_vectors) #should be (bs * topK, 256*P^2)
        return feature_vectors
      # This function decodes the output that is given in the encoded format (defined in the handout)
      # into box coordinates where it returns the upper left and lower right corner of the proposed box
      # Input:
      #       flatten_out: (total_number_of_anchors*bz,4)
      #       flatten_anchors: (total_number_of_anchors*bz,4)
      # Output:
      #       box: (total_number_of_anchors*bz,4)
    def output_decoding_postprocess(self, flatten_out,flatten_anchors, device=device):
          #######################################
          # TODO decode the output
          
          fin_box = torch.zeros_like(flatten_anchors)
          fin_box[:,0] = (flatten_anchors[:,0] + flatten_anchors[:,2]) / 2
          fin_box[:,1] = (flatten_anchors[:,1] + flatten_anchors[:,3]) / 2
          fin_box[:,2] = (flatten_anchors[:,2] - flatten_anchors[:,0])
          fin_box[:,3] = (flatten_anchors[:,3] - flatten_anchors[:,1])
          flatten_anchors = fin_box

          conv_box = torch.zeros_like(flatten_anchors).to(device)
                   
          conv_box[:,3] = torch.exp(flatten_out[:,3]) * flatten_anchors[:,3]
          conv_box[:,2] = torch.exp(flatten_out[:,2]) * flatten_anchors[:,2]
          conv_box[:,1] = (flatten_out[:,1] * flatten_anchors[:,2]) + flatten_anchors[:,1]
          conv_box[:,0] = (flatten_out[:,0] * flatten_anchors[:,3]) + flatten_anchors[:,0]

          # COnvert from xywh to x1 y1 x2 y2
          box = torch.zeros_like(conv_box)
          box[:,0] = conv_box[:,0] - (conv_box[:,2]/2)
          box[:,1] = conv_box[:,1] - (conv_box[:,3]/2)
          box[:,2] = conv_box[:,0] + (conv_box[:,2]/2)
          box[:,3] = conv_box[:,1] + (conv_box[:,3]/2)
          return box

    def output_decoding(self, flatten_out, flatten_anchors, device='cpu'):
        # This function decodes the output that are given in the [t_x,t_y,t_w,t_h] format
        # into box coordinates where it returns the upper left and lower right corner of the bbox
        # Input:
        #       flatten_out: (total_number_of_anchors*bz,4)
        #       flatten_anchors: (total_number_of_anchors*bz,4)
        # Output:
        #       box: (total_number_of_anchors*bz,4)
        conv_box = torch.zeros_like(flatten_anchors)
        conv_box[:,3] = torch.exp(flatten_out[:,3]) * flatten_anchors[:,3]
        conv_box[:,2] = torch.exp(flatten_out[:,2]) * flatten_anchors[:,2]
        conv_box[:,1] = (flatten_out[:,1] * flatten_anchors[:,3]) + flatten_anchors[:,1]
        conv_box[:,0] = (flatten_out[:,0] * flatten_anchors[:,2]) + flatten_anchors[:,0]

        # box = conv_box_to_corners(conv_box)
        box = torch.zeros_like(conv_box)
        box[:,0] = conv_box[:,0] - (conv_box[:,2]/2)
        box[:,1] = conv_box[:,1] - (conv_box[:,3]/2)
        box[:,2] = conv_box[:,0] + (conv_box[:,2]/2)
        box[:,3] = conv_box[:,1] + (conv_box[:,3]/2)

        return box
 

    def preprocess_ground_truth_creation(self, proposals, class_logits, box_regression, gt_labels,bbox ,masks , IOU_thresh=0.5, keep_num_preNMS=1000, keep_num_postNMS=10):
        '''This function does the pre-prossesing of the proposals created by the Box Head (during the training of the Mask Head)
        and create the ground truth for the Mask Head
        
        Input:
              class_logits: (total_proposals,(C+1))
              box_regression: (total_proposal,4*C)           ([t_x,t_y,t_w,t_h] format)
              proposals: list:len(bz)(per_image_proposals,4) (the proposals are produced from RPN [x1,y1,x2,y2] format)
              conf_thresh: scalar
              keep_num_preNMS: scalar (number of boxes to keep pre NMS)
              keep_num_postNMS: scalar (number of boxes to keep post NMS)
        Output:
              boxes: list:len(bz){(post_NMS_boxes_per_image,4)} ([x1,y1,x2,y2] format)
              scores: list:len(bz){(post_NMS_boxes_per_image)}   ( the score for the top class for the regressed box)
              labels: list:len(bz){(post_NMS_boxes_per_image)}  (top category of each regressed box)
              gt_masks: list:len(bz){(post_NMS_boxes_per_image,C,2*P,2*P)}'''
        
        
        num_proposals = proposals[0].shape[0]
        # print("scores  : ", class_logits)
        boxes = []
        scores = []
        labels = []
        gt_masks = []
        for i, each_proposal in enumerate(proposals):
            
            each_proposal = each_proposal.to(device)
            box_regression = box_regression.to(device)
            class_logits = class_logits.to(device)
            # gt_labels[i] = gt_labels[i] - 1
            one_image_boxes = box_regression[i*num_proposals:(i+1)*num_proposals]          # Shape (num_proposals, 12)
            one_image_logits = class_logits[i*num_proposals:(i+1)*num_proposals]           # Shape (num_proposals, 4)
            one_image_scores, one_image_label = torch.max(one_image_logits, dim=1)
            one_image_label = one_image_label.clone().int() - 1
            non_bg_label_idx = torch.where(one_image_label >= 0)[0]

            if len(non_bg_label_idx) != 0: 
                class_labels = one_image_label[non_bg_label_idx]
                all_class_boxes = one_image_boxes[non_bg_label_idx]

                #Get the boxes corresponding to the labels 

                class_boxes =  torch.stack([all_class_boxes[i, x*4:(x+1)*4] for i, x in enumerate(class_labels)])      # Shape(filtered_labels, 4) ([t_x,t_y,t_w,t_h])
                decoded_boxes = self.output_decoding_postprocess(class_boxes, each_proposal[non_bg_label_idx])                          # (x1,y1,x2,y2)
                
                valid_boxes_idx = torch.where((decoded_boxes[:,0] >= 0) & (decoded_boxes[:,2] < 1088) & (decoded_boxes[:,1] > 0) & (decoded_boxes[:,3] < 800))
                valid_boxes = decoded_boxes[valid_boxes_idx]

                valid_clases = one_image_label[non_bg_label_idx][valid_boxes_idx]
                valid_scores = one_image_scores[non_bg_label_idx][valid_boxes_idx]
                sorted_scores_pre_nms, sorted_idx = torch.sort(valid_scores, descending=True)
                sorted_clases_pre_nms = valid_clases[sorted_idx]

                #Rearrange the boxes from x2 y2 x1 y1

                sorted_boxes_pre_nms = sorted_boxes_pre_nms[:, [2,3,0,1]]

                sorted_boxes_pre_nms = valid_boxes[sorted_idx]

                iou_check = box_iou(sorted_boxes_pre_nms, bbox[i])
                iou_idx = (iou_check > 0.3).nonzero()
                above_thres_idx = iou_idx[:,0]
                above_thres_gt = iou_idx[:,1]


                masks_gt_all = masks[i][above_thres_gt]
                                
                sorted_boxes_pre_nms = sorted_boxes_pre_nms[above_thres_idx]
                sorted_clases_pre_nms = sorted_clases_pre_nms[above_thres_idx]
                sorted_scores_pre_nms = sorted_scores_pre_nms[above_thres_idx]

                if len(sorted_clases_pre_nms) > keep_num_preNMS:
                    clases_pre_nms = sorted_clases_pre_nms[:keep_num_preNMS]
                    boxes_pre_nms = sorted_boxes_pre_nms[:keep_num_preNMS]
                    scores_pre_nms = sorted_scores_pre_nms[:keep_num_preNMS]
                    masks_pre_nms = masks_gt_all[:keep_num_preNMS]
                else:
                    clases_pre_nms = sorted_clases_pre_nms
                    boxes_pre_nms = sorted_boxes_pre_nms
                    scores_pre_nms = sorted_scores_pre_nms
                    masks_pre_nms = masks_gt_all
                pdb.set_trace()
                clases_post_nms, scores_post_nms, boxes_post_nms, masks_post_nms = self.nms_preprocess_gt(clases_pre_nms, boxes_pre_nms, scores_pre_nms, masks_pre_nms, IOU_thres=IOU_thresh, keep_num_postNMS=keep_num_postNMS)
              
                gt_mask_one = torch.zeros(clases_post_nms.shape[0],self.image_size[0], self.image_size[1]).to(device)

                for j in range(clases_post_nms.shape[0]):
                    b0 = boxes_post_nms[j,0].int()
                    b1 = boxes_post_nms[j,1].int()
                    b2 = boxes_post_nms[j,2].int()
                    b3 = boxes_post_nms[j,3].int()

                    gt_mask_one[j , b1:b3 , b0:b2] = 1
                gt_mask_one = gt_mask_one * masks_post_nms
                gt_mask_one = F.interpolate(gt_mask_one.unsqueeze(0), size=(2*self.P,2*self.P),mode='nearest').squeeze(0)
            
            gt_masks.append(gt_mask_one)
            boxes.append(boxes_post_nms)
            scores.append(scores_post_nms)
            labels.append(clases_post_nms)

        return boxes, scores, labels, gt_masks

    def nms_preprocess_gt(self,clases,boxes,scores, masks, IOU_thres=0.5, keep_num_postNMS=100):
        # Input:
        #       clases: (num_preNMS, )
        #       boxes:  (num_preNMS, 4)
        #       scores: (num_preNMS,)
        # Output:
        #       boxes:  (post_NMS_boxes_per_image,4) ([x1,y1,x2,y2] format)
        #       scores: (post_NMS_boxes_per_image)   ( the score for the top class for the regressed box)
        #       labels: (post_NMS_boxes_per_image)  (top category of each regressed box)
        
        # TODO - NMS for the given classes, boxes and labels 

        boxes = boxes.to(device)
        clases = clases.to(device)
        scores = scores.to(device)
        scores_all = [[],[],[]]
        boxes_all = [[],[],[]]
        clas_all = [[],[],[]]
        masks_all = [[],[],[]]

        for i in range(3):
            each_label_idx = torch.where(clases == i)[0]
            if len(each_label_idx) == 0: #Ensures empty list issues are taken care of
              continue
            each_clas_boxes = boxes[each_label_idx]
            each_clas_score = scores[each_label_idx]

            start_x_torched = each_clas_boxes[:, 0]
            start_y_torched = each_clas_boxes[:, 1]
            end_x_torched   = each_clas_boxes[:, 2]
            end_y_torched   = each_clas_boxes[:, 3]

            areas_torched = (end_x_torched - start_x_torched + 1) * (end_y_torched - start_y_torched + 1)

            order_torched = torch.argsort(each_clas_score)

            while len(order_torched) > 0:
                # The index of largest confidence score
                index = order_torched[-1]
                
                # Pick the bounding box with largest confidence score
                boxes_all[i].append(boxes[index].detach())
                scores_all[i].append(each_clas_score[index].detach())
                masks_all[i].append(masks[index].detach())

                if len(boxes_all[i]) == keep_num_postNMS:
                    break

                # Compute ordinates of intersection-over-union(IOU)
                x1 = torch.maximum(start_x_torched[index], start_x_torched[order_torched[:-1]]).to(device)
                x2 = torch.minimum(end_x_torched[index], end_x_torched[order_torched[:-1]]).to(device)
                y1 = torch.maximum(start_y_torched[index], start_y_torched[order_torched[:-1]]).to(device)
                y2 = torch.minimum(end_y_torched[index], end_y_torched[order_torched[:-1]]).to(device)

                # Compute areas of intersection-over-union
                w = torch.maximum(torch.tensor([0]).to(device), x2 - x1 + 1)
                h = torch.maximum(torch.tensor([0]).to(device), y2 - y1 + 1)
                intersection = w * h

                # Compute the ratio between intersection and union
                ratio = intersection / (areas_torched[index] + areas_torched[order_torched[:-1]] - intersection)
                left = torch.where(ratio < IOU_thres)[0]
                order_torched = order_torched[left]
            clas_all[i] = [i]*len(scores_all[i])
       
        pdb.set_trace()
       
        fin_masks = torch.cat([torch.stack(one_mask) for one_mask in masks_all if len(one_mask)!=0]).reshape(-1,800,1088) ## ISSUES WITH EMPTY LIST (01-12)
        fin_scores = torch.cat([torch.tensor(one_score).reshape(-1,1) for one_score in scores_all if len(one_score)!=0],dim=0).reshape(-1,1)
        fin_boxes = torch.cat([torch.stack(one_box) for one_box in boxes_all if len(one_box)!=0]).reshape(-1,4)
        fin_clas = torch.cat([torch.tensor(one_clas) for one_clas in clas_all if len(one_clas)!=0]).reshape(-1,1)
        return fin_clas, fin_scores, fin_boxes, fin_masks


        
    def forward(self, features):
        X = self.layers1to4(features)
        X = self.deconv(X)
        X = self.conv(X)
        return X

    '''
    Input:
          clas: (top_k_boxes) (scores of the top k boxes)
          prebox: (top_k_boxes,4) (coordinate of the top k boxes)
    Output:
          nms_clas: (Post_NMS_boxes)
          nms_prebox: (Post_NMS_boxes,4)
    '''
    def NMS(self,clas,prebox, thresh=0.5):
        method = 'gauss'
        gauss_sigma=0.5
        n = len(clas)
        sorted_boxs = prebox.reshape(n, -1)
        intersection = torch.mm(sorted_boxs, sorted_boxs.T)
        areas = sorted_boxs.sum(dim=1).expand(n, n)
        union = areas + areas.T - intersection
        ious = (intersection / union).triu(diagonal=1)  
        ious_cmax = ious.max(0)[0].expand(n, n).T        
        if method == 'gauss':
            decay = torch.exp(-(ious ** 2 - ious_cmax ** 2) / gauss_sigma)
        else:
            decay = (1 - ious) / (1 - ious_cmax)       
        # move decay to device
        decay = decay.min(dim=0)[0].to('cuda')

        return clas * decay

    '''
    general function that takes the input list of tensors and concatenates them along the first tensor dimension
    Input:
         input_list: list:len(bz){(dim1,?)}
    Output:
         output_tensor: (sum_of_dim1,?)
    '''
    def flatten_inputs(self,input_list):
        output_tensor = torch.cat(input_list, dim=0)
        return output_tensor

    '''
    This function does the post processing for the result of the Mask Head for a batch of images. It project the predicted mask
    back to the original image size
    Use the regressed boxes to distinguish between the images
    Input:
          masks_outputs: (total_boxes,C,2*P,2*P)
          boxes: list:len(bz){(post_NMS_boxes_per_image,4)} ([x1,y1,x2,y2] format) ; bz = 1
          labels: list:len(bz){(post_NMS_boxes_per_image)}  (top category of each regressed box) ; bz = 1
          image_size: tuple:len(2)
    Output:
          projected_masks: list:len(bz){(post_NMS_boxes_per_image,image_size[0],image_size[1]
    '''
    def postprocess_mask(self, masks_outputs, boxes, labels, image_size=(800,1088)):
        # choose masks that correspond to the classes predicted by the Box Head
        boxes_stacked = self.flatten_inputs(boxes)
        labels_stacked = self.flatten_inputs(labels)

        mask_target = []
        for i in range(len(labels_stacked)):
            one_mask_output = masks_outputs[i]
            temp = one_mask_output[int(labels_stacked[i].item()) - 1, :, :]
            temp = torch.nn.functional.interpolate(temp.unsqueeze(dim=0).unsqueeze(dim=0), size=(800,1088), mode='bilinear')
            mask_target.append(temp.squeeze().squeeze())
            
        projected_masks = torch.stack(mask_target)

        projected_masks[projected_masks > 0.5] = 1
        projected_masks[projected_masks < 0.5] = 0

        return projected_masks # This is compared with the gt mask 

    '''
    Compute the total loss of the Mask Head
    Input:
         mask_output: (total_boxes,C,2*P,2*P)
         labels: (total_boxes)
         gt_masks: (total_boxes,2*P,2*P) - needs to be created 
    Output:
         mask_loss
    '''
    def compute_loss(self,mask_output,labels,gt_masks):
        mask_target = []
        # print("\n-----\nComputeLoss")
        # print("Mask first element = \n", mask_output.shape)
        # print("labels = ", len(labels))
        # print("gt_masks = ", len(gt_masks))
        # print("gt_masks[0] = ", gt_masks[0].shape)
        gt_masks_cat = []
        # print("gt_cat_masks = ", gt_masks_cat.shape)
        # pdb.set_trace()
        mask_target = []
        for i in range(len(labels)):
            one_mask_output = mask_output[i]
            mask_target.append(one_mask_output[int(labels[i].item()) - 1, :, :])
        mask_target = torch.stack(mask_target)
        
        criterion = nn.BCELoss()

        mask_loss = criterion(mask_target, gt_masks)
        # for i in range(len(labels)):
        #     one_mask_output = mask_output[i]
        #     print("Mask_output = ", one_mask_output.shape)
        #     print("labels[i].item() = ", labels[i].cpu().detach())
        #     for j in range(labels[i].shape[0]):
        #       mask_target.append(one_mask_output[labels[i][j] - 1, :, :])
        #       temp = gt_masks[i][j-1].unsqueeze(0).unsqueeze(0)
        #       pdb.set_trace()
        #       gt_mask_cat = F.interpolate(temp, (28, 28), mode='bilinear', align_corners=True)
        #       # print("gt_mask_cat = ", gt_mask_cat.shape)
        #       gt_masks_cat.append(gt_mask_cat)

        
        # mask_target = torch.stack(mask_target)
        # # print("Mask_target = ", mask_target.shape)
        # gt_masks_cat = torch.cat(gt_masks_cat).squeeze()
        # # print("GT_Maaks_cat = ",gt_masks_cat.shape)
        # # gt_masks_cat = F.interpolate(gt_masks_cat, (28, 28), mode='bilinear', align_corners=True)
        # criterion = nn.BCELoss()
        # mask_loss = criterion(mask_target, gt_masks_cat)
        # # print("Mask - Loss = ", mask_loss.mean())
        return mask_loss.mean()

        '''
    Compute the total loss of the Mask Head - Focal Loss
    Input:
         mask_output: (total_boxes,C,2*P,2*P)
         labels: (total_boxes)
         gt_masks: (total_boxes,2*P,2*P) - needs to be created 
    Output:
         mask_loss
    '''
    def compute_loss_focal(self,mask_output,labels,gt_masks):
        mask_target = []
        # print("\n-----\nComputeLoss")
        # print("Mask first element = \n", mask_output.shape)
        # print("labels = ", len(labels))
        # print("gt_masks = ", len(gt_masks))
        # print("gt_masks[0] = ", gt_masks[0].shape)
        gt_masks_cat = []
        # print("gt_cat_masks = ", gt_masks_cat.shape)
        # pdb.set_trace()
        mask_target = []
        for i in range(len(labels)):
            one_mask_output = mask_output[i]
            mask_target.append(one_mask_output[int(labels[i].item()) - 1, :, :])
        mask_target = torch.stack(mask_target)
        
        criterion = focal_loss(gamma = 2.0)

        mask_loss = criterion(mask_target, gt_masks)
       
        return mask_loss.mean()

## Training

In [None]:
del boxhead_network_obj
del path
path = "/content/model_trained_boxhead.pth"
checkpoint = torch.load(path)
boxhead_network_obj = BoxHead().to(device)

for key in list(checkpoint.keys()):
    checkpoint[key.replace('intermediate_layer', 'intermediate')] = checkpoint.pop(key)
for key in list(checkpoint.keys()):
    checkpoint[key.replace('classifier_head', 'classifier')] = checkpoint.pop(key)
for key in list(checkpoint.keys()):
    checkpoint[key.replace('regressor_head', 'regressor')] = checkpoint.pop(key)
print(checkpoint.keys())
boxhead_network_obj.load_state_dict(checkpoint, strict = False)

In [None]:
maskhead_network_obj = MaskHead().to(device)
optimizer = torch.optim.Adam(maskhead_network_obj.parameters(), lr = 0.001)
# boxhead_network_obj
train_loss = []
train_loss_c = []
train_loss_r = []
val_loss = []
val_loss_c = []
val_loss_r = []
num_epochs = 20
keep_topK = 200
# del proposals
for epoch in tqdm(range(num_epochs)):
  run_loss = 0
  run_loss_c = 0
  run_loss_r = 0
  #train the model
  maskhead_network_obj.train()
  for i, (images, labels, masks, bboxes, indexes) in enumerate(train_loader):
    proposals = []
    del proposals, rpnout1
    # Take the features from the backbone
    backout = backbone(images)

    # The RPN implementation takes as first argument the following image list
    im_lis = ImageList(images, [(800, 1088)]*images.shape[0])
    # Then we pass the image list and the backbone output through the rpn
    rpnout1 = rpn1(im_lis, backout)
    # pdb.set_trace()
    print("RPNOUT = \n", rpnout)
    #The final output is
    # A list of proposal tensors: list:len(bz){(keep_topK,4)}
    proposals=[proposal[0:keep_topK,:] for proposal in rpnout1[0]]
    print("Proposals = \n", proposals)
    # pdb.set_trace()
    # A list of features produces by the backbone's FPN levels: list:len(FPN){(bz,256,H_feat,W_feat)}
    fpn_feat_list= list(backout.values())

    feature_vectors              = boxhead_network_obj.find_feature_vectors(proposals, fpn_feat_list)      
    
    class_logits, box_pred       = boxhead_network_obj.forward(feature_vectors)
    # pdb.set_trace()
    new_labels, regressor_target  = boxhead_network_obj.create_batch_truth(bboxes, proposals, labels)
    # pdb.set_trace()
    print("Proposals after a while = \n", proposals)

    mask_boxes, mask_scores, mask_final_labels, mask_gt_masks = maskhead_network_obj.preprocess_ground_truth_creation(proposals, class_logits, box_pred, labels, bboxes ,masks ,     IOU_thresh=0.5, keep_num_preNMS=200, keep_num_postNMS=10)

    mask_feature_vectors = maskhead_network_obj.find_feature_vectors(mask_boxes, fpn_feat_list)
    mask_output = maskhead_network_obj.forward(mask_feature_vectors)
    # pdb.set_trace()
    # calculate loss
    mask_labels = maskhead_network_obj.flatten_inputs(mask_final_labels)
    mask_gt_mks = maskhead_network_obj.flatten_inputs(mask_gt_masks)
    # mask_output = maskhead_network_obj.flatten_inputs(mask_output)
    pdb.set_trace()
    loss = maskhead_network_obj.compute_loss(mask_output.to(device), mask_labels.to(device), mask_gt_mks.to(device))
    '''#pass through backbone
    backout = backbone(images.to(device))
    im_lis = ImageList(images, [(800, 1088)]*images.shape[0])
    #pass thorugh rpn, get proposals and fpn_feat_list
    rpnout = rpn1(im_lis, backout)
    proposals=[proposal[0:keep_topK,:] for proposal in rpnout[0]]
    fpn_feat_list = list(backout.values())
    #get gt
    #gt_class_batch, gt_coord_batch = boxhead_network_obj.create_batch_truth(bboxes, proposals, labels)
    #get feature vectors
    feature_vectors = maskhead_network_obj.find_feature_vectors(proposals, fpn_feat_list)
    # print("feature_vectors = ", feature_vectors.detach().shape)
    #pass through forward
    maskout  = maskhead_network_obj.forward(feature_vectors.detach())      #TO ENSURE NO RETRAINING OF THE BACKBONE AND RPN'''
    
    # print("Maskout = ", maskout.shape)
    #zero out optimizer
    optimizer.zero_grad()
    #calc the loss
    #CE loss already does softmax so we dont add it in the classifier network during training
    # mask_boxes, mask_scores, mask_labels, mask_gt_masks = maskhead_network_obj.preprocess_ground_truth_creation(label_preds, regr_preds, proposals, labels, bboxes, masks, IOU_thresh=0.5, keep_num_preNMS=700, keep_num_postNMS=100)
    # if len(mask_boxes[0]) == 0:
    #   continue
    # loss= maskhead_network_obj.compute_loss(maskout, labels, masks)
    #gather loss values
    run_loss += loss.item()
    # run_loss_c += loss_class.item()
    # run_loss_r += loss_reg.item()
    #backward pass
    loss.backward()
    optimizer.step()
    #empty intermediate predcitions 
    # del gt_class_batch, gt_coord_batch
    torch.cuda.empty_cache()

  train_loss.append(run_loss/len(full_loader))
  # train_loss_c.append(run_loss_c/len(full_loader))
  # train_loss_r.append(run_loss_r/len(full_loader))

  # checkpoint
  if(epoch%2==0):
    path = "/content/drive/MyDrive/CIS680/Final/Mask1812/epoch-" +str(epoch+1)
    torch.save({
        'epoch': epoch+1,
        'model_state_dict': maskhead_network_obj.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': run_loss/len(full_loader),
        # 'clas_loss': run_loss_c/len(full_loader),
        # 'regr_loss': run_loss_r/len(full_loader)
        }, path)
  
  #validating the model
  maskhead_network_obj.eval()
  run_eval_loss = 0
  run_eval_loss_c = 0
  run_eval_loss_r = 0
  for j, (images, labels, masks, bboxes,indexes) in enumerate(test_loader):
    #pass through backbone
    backout = backbone(images.to(device))
    im_lis = ImageList(images, [(800, 1088)]*images.shape[0])
    #pass thorugh rpn, get proposals and fpn_feat_list
    rpnout = rpn1(im_lis, backout)
    proposals=[proposal[0:keep_topK,:] for proposal in rpnout[0]]
    fpn_feat_list= list(backout.values())
    #get gt
    #vgt_class_batch, vgt_coord_batch = boxhead_network_obj.create_batch_truth(bboxes, proposals, labels)
    #get feature vectors
    # vfeature_vectors = maskhead_network_obj.find_feature_vectors(proposals, fpn_feat_list)
    vfeature_vectors = maskhead_network_obj.find_feature_vectors(proposals, fpn_feat_list)
    # print("feature_vectors = ", vfeature_vectors.detach().shape)
    #pass through forward
    maskout = maskhead_network_obj.forward(vfeature_vectors.detach())      #TO ENSURE NO RETRAINING OF THE BACKBONE AND RPN
    #calc the loss
    vloss= maskhead_network_obj.compute_loss(maskout, labels, masks)

    # vloss= maskhead_network_obj.compute_loss(vclass_prob, vbbox_reg, vgt_class_batch.to(device), vgt_coord_batch.to(device),effective_batch=150)
    run_eval_loss += vloss.item()
    # run_eval_loss_c += vloss_class.item()
    # run_eval_loss_r += vloss_reg.item()
    # del vclass_prob, vbbox_reg, vgt_class_batch, vgt_coord_batch
    torch.cuda.empty_cache()

  val_loss.append(run_eval_loss/len(full_loader))
  # val_loss_c.append(run_eval_loss_c/len(full_loader))
  # val_loss_r.append(run_eval_loss_r/len(full_loader))

In [None]:

# Focal Loss

maskhead_network_obj = MaskHead().to(device)
optimizer = torch.optim.Adam(maskhead_network_obj.parameters(), lr = 0.001)
# boxhead_network_obj
train_loss = []
train_loss_c = []
train_loss_r = []
val_loss = []
val_loss_c = []
val_loss_r = []
num_epochs = 20
keep_topK = 200
# del proposals
for epoch in tqdm(range(num_epochs)):
  run_loss = 0
  run_loss_c = 0
  run_loss_r = 0
  #train the model
  maskhead_network_obj.train()
  for i, (images, labels, masks, bboxes, indexes) in enumerate(train_loader):
    proposals = []
    del proposals, rpnout1
    # Take the features from the backbone
    backout = backbone(images)

    # The RPN implementation takes as first argument the following image list
    im_lis = ImageList(images, [(800, 1088)]*images.shape[0])
    # Then we pass the image list and the backbone output through the rpn
    rpnout1 = rpn1(im_lis, backout)
    # pdb.set_trace()
    print("RPNOUT = \n", rpnout)
    #The final output is
    # A list of proposal tensors: list:len(bz){(keep_topK,4)}
    proposals=[proposal[0:keep_topK,:] for proposal in rpnout1[0]]
    print("Proposals = \n", proposals)
    # pdb.set_trace()
    # A list of features produces by the backbone's FPN levels: list:len(FPN){(bz,256,H_feat,W_feat)}
    fpn_feat_list= list(backout.values())

    feature_vectors              = boxhead_network_obj.find_feature_vectors(proposals, fpn_feat_list)      
    
    class_logits, box_pred       = boxhead_network_obj.forward(feature_vectors)
    # pdb.set_trace()
    new_labels, regressor_target  = boxhead_network_obj.create_batch_truth(bboxes, proposals, labels)
    # pdb.set_trace()
    print("Proposals after a while = \n", proposals)

    mask_boxes, mask_scores, mask_final_labels, mask_gt_masks = maskhead_network_obj.preprocess_ground_truth_creation(proposals, class_logits, box_pred, labels, bboxes ,masks ,     IOU_thresh=0.5, keep_num_preNMS=200, keep_num_postNMS=10)

    mask_feature_vectors = maskhead_network_obj.find_feature_vectors(mask_boxes, fpn_feat_list)
    mask_output = maskhead_network_obj.forward(mask_feature_vectors)
    # pdb.set_trace()
    # calculate loss
    mask_labels = maskhead_network_obj.flatten_inputs(mask_final_labels)
    mask_gt_mks = maskhead_network_obj.flatten_inputs(mask_gt_masks)
    # mask_output = maskhead_network_obj.flatten_inputs(mask_output)
    pdb.set_trace()
    loss = maskhead_network_obj.compute_loss(mask_output.to(device), mask_labels.to(device), mask_gt_mks.to(device))
    '''#pass through backbone
    backout = backbone(images.to(device))
    im_lis = ImageList(images, [(800, 1088)]*images.shape[0])
    #pass thorugh rpn, get proposals and fpn_feat_list
    rpnout = rpn1(im_lis, backout)
    proposals=[proposal[0:keep_topK,:] for proposal in rpnout[0]]
    fpn_feat_list = list(backout.values())
    #get gt
    #gt_class_batch, gt_coord_batch = boxhead_network_obj.create_batch_truth(bboxes, proposals, labels)
    #get feature vectors
    feature_vectors = maskhead_network_obj.find_feature_vectors(proposals, fpn_feat_list)
    # print("feature_vectors = ", feature_vectors.detach().shape)
    #pass through forward
    maskout  = maskhead_network_obj.forward(feature_vectors.detach())      #TO ENSURE NO RETRAINING OF THE BACKBONE AND RPN'''
    
    # print("Maskout = ", maskout.shape)
    #zero out optimizer
    optimizer.zero_grad()
    #calc the loss
    #CE loss already does softmax so we dont add it in the classifier network during training
    # mask_boxes, mask_scores, mask_labels, mask_gt_masks = maskhead_network_obj.preprocess_ground_truth_creation(label_preds, regr_preds, proposals, labels, bboxes, masks, IOU_thresh=0.5, keep_num_preNMS=700, keep_num_postNMS=100)
    # if len(mask_boxes[0]) == 0:
    #   continue
    # loss= maskhead_network_obj.compute_loss(maskout, labels, masks)
    #gather loss values
    run_loss += loss.item()
    # run_loss_c += loss_class.item()
    # run_loss_r += loss_reg.item()
    #backward pass
    loss.backward()
    optimizer.step()
    #empty intermediate predcitions 
    # del gt_class_batch, gt_coord_batch
    torch.cuda.empty_cache()

  train_loss.append(run_loss/len(full_loader))
  # train_loss_c.append(run_loss_c/len(full_loader))
  # train_loss_r.append(run_loss_r/len(full_loader))

  # checkpoint
  if(epoch%2==0):
    path = "/content/drive/MyDrive/CIS680/Final/Mask1812/epoch-" +str(epoch+1)
    torch.save({
        'epoch': epoch+1,
        'model_state_dict': maskhead_network_obj.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': run_loss/len(full_loader),
        # 'clas_loss': run_loss_c/len(full_loader),
        # 'regr_loss': run_loss_r/len(full_loader)
        }, path)
  
  #validating the model
  maskhead_network_obj.eval()
  run_eval_loss = 0
  run_eval_loss_c = 0
  run_eval_loss_r = 0
  for j, (images, labels, masks, bboxes,indexes) in enumerate(test_loader):
    #pass through backbone
    backout = backbone(images.to(device))
    im_lis = ImageList(images, [(800, 1088)]*images.shape[0])
    #pass thorugh rpn, get proposals and fpn_feat_list
    rpnout = rpn1(im_lis, backout)
    proposals=[proposal[0:keep_topK,:] for proposal in rpnout[0]]
    fpn_feat_list= list(backout.values())
    #get gt
    #vgt_class_batch, vgt_coord_batch = boxhead_network_obj.create_batch_truth(bboxes, proposals, labels)
    #get feature vectors
    # vfeature_vectors = maskhead_network_obj.find_feature_vectors(proposals, fpn_feat_list)
    vfeature_vectors = maskhead_network_obj.find_feature_vectors(proposals, fpn_feat_list)
    # print("feature_vectors = ", vfeature_vectors.detach().shape)
    #pass through forward
    maskout = maskhead_network_obj.forward(vfeature_vectors.detach())      #TO ENSURE NO RETRAINING OF THE BACKBONE AND RPN
    #calc the loss
    vloss= maskhead_network_obj.compute_loss_focal(maskout, labels, masks)

    # vloss= maskhead_network_obj.compute_loss(vclass_prob, vbbox_reg, vgt_class_batch.to(device), vgt_coord_batch.to(device),effective_batch=150)
    run_eval_loss += vloss.item()
    # run_eval_loss_c += vloss_class.item()
    # run_eval_loss_r += vloss_reg.item()
    # del vclass_prob, vbbox_reg, vgt_class_batch, vgt_coord_batch
    torch.cuda.empty_cache()

  val_loss.append(run_eval_loss/len(full_loader))
  # val_loss_c.append(run_eval_loss_c/len(full_loader))
  # val_loss_r.append(run_eval_loss_r/len(full_loader))

### Plots

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10,4))

ax[0].plot(train_loss)
ax[0].set_title('Total Train Loss')
ax[0].set_xlabel('epochs')
ax[0].set_ylabel('loss')
ax[0].grid()

ax[1].plot(val_loss)
ax[1].set_title('Total Validation Loss')
ax[1].set_xlabel('epochs')
ax[1].set_ylabel('loss')
ax[1].grid()

## post processing

In [None]:
topK = 50
boxhead_network_obj.eval()
softmax = nn.Softmax()

with torch.no_grad():
  for i, (images, labels, masks, bboxes, indexes) in enumerate(train_loader):
    backout = backbone(images)
    im_lis = ImageList(images, [(800, 1088)]*images.shape[0])
    rpnout = rpn(im_lis, backout)
    proposals=[proposal[0:topK,:].to(device) for proposal in rpnout[0]]
    # proposals=[proposal[0:topK,:] for proposal in rpnout[0]]
    fpn_feat_list= list(backout.values())
    #get feature vectors
    feature_vectors = boxhead_network_obj.find_feature_vectors(proposals, fpn_feat_list)
    class_prob, bbox_reg = boxhead_network_obj(feature_vectors.detach())      #TO ENSURE NO RETRAINING OF THE BACKBONE AND RPN
    cls_soft = softmax(class_prob)    #do softmax now as we did not do it in forward function
    #getmax confidence score and corresponding index
    print("class_prob = \n", class_prob)
    cls_with_softmax, max_indices = torch.max(cls_soft,dim=1)  
    print("max_indices = ", max_indices)
    for i in range(len(images)):
        #display the image
        fig = plt.figure()
        ax = fig.add_subplot()
        ax.imshow(images[i].detach().cpu().numpy().transpose(1,2,0))
        print("___________")
        cls_img, lbl = cls_with_softmax[i*topK:(i+1)*topK], max_indices[i*topK:(i+1)*topK]
        #iterate over proposals
        print("Proposals = ", len(proposals))
        for j in range(len(proposals)):
          #if label is 0, it is background so ignore
          if(lbl[j]==0): 
            print("Background proposal ", j)
            continue
          #get single proposal and bounding box
          edge_ind = (bbox_reg[j][((lbl[j]-1)*4).int() : ((lbl[j]-1)*4+4).int()]).cpu()
          prop = (proposals[i][j]).cpu()
          #decode the bounding box with corresponding proposal
          tx, ty, tw, th = edge_ind[0], edge_ind[1], edge_ind[2], edge_ind[3]
          xp, yp, wp, hp = (prop[0]+prop[2])/2, (prop[1]+prop[3])/2, prop[2]-prop[0], prop[3]-prop[1]
          x, y, w, h = tx*wp+xp, ty*hp+yp, torch.exp(tw)*wp, torch.exp(th)*hp
          x1, y1 = x-(w/2), y-(h/2)
          print(x,y,w,h)
          print("--------------------")
          #plot the top 20 predicted bbox regressions
          if (lbl[j]==1): #----vehicle
            rect = rec((x1, y1),w,h,fill=False,color='r',linewidth=2)
            ax.add_patch(rect)
          if (lbl[j]==2): #----human
            rect = rec((x1, y1),w,h,fill=False,color='g',linewidth=2)
            ax.add_patch(rect)
          if (lbl[j]==3): #----animal
            rect = rec((x1, y1),w,h,fill=False,color='b',linewidth=2)
            ax.add_patch(rect)
        plt.show() 

### Visualization

In [None]:
def final_visualize(boxes_list, project_masks, clas, images):
  for i in range(len(boxes_list)):    #number of examples to show
    fig = plt.figure()
    ax = fig.add_subplot()
    objs = boxes_list[i]
    print("Objs = ", objs)
    one_image = images[i].detach().cpu().numpy().transpose(1,2,0)
    mask_vis = np.zeros_like(one_image, dtype=np.float64)
    cat = clas[i]
    final_mask = projected_masks[i]
    ax.imshow(images[i].detach().cpu().numpy().transpose((1,2,0)))
    for box in objs:
      x1, y1, x2, y2 = box[0], box[1], box[2], box[3]
      x1 = x1.cpu()
      y1 = y1.cpu()
      x2 = x2.cpu()
      y2 = y2.cpu()
      rect = rec((x1, y1), (x2 - x1), (y2 - y1), fill=False, color='r')
      ax.add_patch(rect)
      print("cat = ", cat.shape)
      if(cat==0): #vehicle
        mask_vis[:,:,0] = final_mask
      if(cat==1): #person
        mask_vis[:,:,1] = final_mask
      if(cat==2): #animal
        mask_vis[:,:,2] = final_mask
      plt.imshow(mask_vis,alpha=0.5)
  plt.show()

In [None]:
path = "/content/mask.ckpt"
checkpoint = torch.load(path)
test_model = MaskHead().to(device)

test_model.load_state_dict(checkpoint['model_state_dict'])

#in eval mode
test_model.eval()
for idx, (images, labels, masks, bboxes, indexes) in enumerate(test_loader):
  #get features from box head
  backout = backbone(images.to(device))
  im_lis = ImageList(images, [(800, 1088)]*images.shape[0])
  #pass thorugh rpn, get proposals and fpn_feat_list
  rpnout = rpn(im_lis, backout)
  proposals=[proposal[0:keep_topK,:] for proposal in rpnout[0]]
  fpn_feat_list = list(backout.values())
  #get gt
  #gt_class_batch, gt_coord_batch = boxhead_network_obj.create_batch_truth(bboxes, proposals, labels)
  #get feature vectors
  feature_vectors = maskhead_network_obj.find_feature_vectors(proposals, fpn_feat_list)
  # print("feature_vectors = ", feature_vectors.detach().shape)
  #pass through forward
  maskout  = maskhead_network_obj.forward(feature_vectors.detach())  
  # clas, regr = test_model.forward(features) #predict
  projected_masks = test_model.postprocess_mask(maskout, bboxes, labels)  #post process
  final_visualize(post_nms_predbox, projected_masks, labels , images)