In [54]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!mkdir "/content/vos"
!sudo unzip "/content/drive/MyDrive/vos/train.zip" -d "/content/vos"

In [1]:
import torch

In [2]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
batch_size = 1
pos_ratio = 0.5
n_sample = 256
n_pos = pos_ratio * n_sample
device

'cuda:0'

## Visualize Utils

In [3]:
import cv2
import matplotlib.pyplot as plt
from PIL import Image

colors = [(255, 128, 0), (255, 255, 0), (128, 255, 0), (0, 255, 0), (0, 255, 255), (0, 0, 255), (255, 0, 255), (96, 96, 96)]

def visualize(image_path, anchor_boxes, gt_boxes=None):

  image = Image.open(image_path)
  image = image.resize((800, 800))

  image = np.array(image)
  
  if gt_boxes is not None:
    for i_box, box in enumerate(gt_boxes):
      
      start_point = (int(box[0]), int(box[1]))
      end_point = (int(box[2]), int(box[3]))
      
      cv2.rectangle(image, start_point, end_point, (255, 0, 0), 2)

  for i_box, box in enumerate(anchor_boxes):
    
    start_point = (int(box[0]), int(box[1]))
    end_point = (int(box[2]), int(box[3]))
    color_index = i_box % len(colors) 
    cv2.rectangle(image, start_point, end_point, colors[color_index], 2)
  
  plt.figure(figsize=(8, 8))
  plt.imshow(image)
  plt.show()

In [4]:
## UNIT TEST
# valid_index = get_valid_anchor_boxes(anchor_boxes)
# visualize("/content/image.jpg", anchor_boxes[valid_index][1000:1005], get_bounding_box("/content/image.jpg", "/content/label.png"))

In [5]:
def visualize_a_box(image_path, box):
  visualize(image_path, [box])

## Generate Anchor Boxes

In [6]:
import numpy as np

In [7]:
def gen_anchor_box_for_single_feature_map(center_X, center_Y, scales, ratios):
  
  # Anchor Boxes are generated in the format (x1, y1, x2, y2)
  
  k = 0
  boxes = np.zeros((len(scales) * len(ratios), 4))

  for ratio in ratios:
    
    for scale in scales:

      W = scale * subsample * ratio
      H = scale * subsample

      x_left_top = center_X - (1/2) * W
      y_left_top = center_Y - (1/2) * H

      x_right_bottom = center_X + (1/2) * W
      y_right_bottom = center_Y + (1/2) * H 

      boxes[k, 0] = x_left_top
      boxes[k, 1] = y_left_top
      boxes[k, 2] = x_right_bottom
      boxes[k, 3] = y_right_bottom

      k += 1

  return boxes


In [8]:
## UNIT TEST
# gen_anchor_box_for_single_feature_map(400, 400, [8, 16, 32], [0.5, 1, 2]) # output has to be in shape of (3 * 3, 4)

In [9]:
subsample = 16
scales = [8, 16, 32]
ratios = [0.5, 1, 2]

In [10]:
def generate_all_anchor_boxes():
  
  anchor_boxes = np.zeros((50, 50, len(scales) * len(ratios), 4)) 
  
  # center_X, center_Y
  for i, center_X in enumerate(np.arange(8, 16 * (50), 16 )):
    for j, center_Y in enumerate(np.arange(8, 16 * (50), 16)):
      anchor_boxes[i, j] = gen_anchor_box_for_single_feature_map(center_X, center_Y, scales, ratios)

  return anchor_boxes.reshape(-1, 4)

### Valid anchors

In [11]:
def get_valid_anchor_boxes(anchor_boxes):
  # x1 and x2 have to in range (0, 800), same applies to y1, y2
  inside_anchors_index = (anchor_boxes[:, 0] >= 0) & (anchor_boxes[:, 2] < 800) &\
  (anchor_boxes[:, 1] >= 0) & (anchor_boxes[:, 3] < 800)

  # valid_anchor_boxes = anchor_boxes[inside_anchors_index]

  return inside_anchors_index

In [12]:
# UNIT TEST
# the exact number depends on the scales and ratios you used
# (inside_anchors_index == True).sum() > 0

## Compute IoU

In [13]:
def compute_iou(box1, box2):
  # assumeing boxes are in format (x1, y1, x2, y2)
  inter = {}
  inter["x_left_top"] = max(box1[0], box2[0])
  inter["y_left_top"] = max(box1[1], box2[1])

  inter["x_right_bottom"] = min(box1[2], box2[2])
  inter["y_right_bottom"] = min(box1[3], box2[3])

  if inter["x_left_top"] < inter["x_right_bottom"] and inter["y_left_top"] < inter["y_right_bottom"]: # there is a non-zero intersection  
    iou_area = (inter["x_right_bottom"] - inter["x_left_top"]) * (inter["y_right_bottom"] - inter["y_left_top"])
  else:
    iou_area = 0

  box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
  box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])

  iou = iou_area / (box1_area + box2_area - iou_area)

  return iou

In [14]:
## UNIT TEST
# box1 = [0, 0, 5, 5]
# box2 = [5, 5, 10, 10]
# box3 = [2.5, 2.5, 7.5, 7.5]

# compute_iou(box1, box2) # it has to be zero
# compute_iou(box1, box3) # ~ 0.14
# compute_iou(box3, box1) # it's a symetric function so same 0.14

## Get Bouning Box from Mask

In [15]:
def get_number_of_instances(mask_path):
  
  mask = Image.open(mask_path)
  mask = np.array(mask)

  return len(np.unique(mask))

In [16]:
from PIL import Image

def get_bounding_box(mask_path):

  mask = Image.open(mask_path)
  
  mask = mask.resize((800, 800))  # otherwise boxes don't place at right locations 

  mask = np.array(mask)

  # mask is in shape of [H, W] where each pixel has a value of 0, 1, 2, 3, 4
  # we don't care about different categories, we care just about being an object

  boxes = []

  for instance_number in np.unique(mask):
    if instance_number == 0: 
      continue

    Y, X = np.where(mask == instance_number)
    box = [X.min(), Y.min(), X.max(), Y.max()]
    boxes.append(box)
  
  return boxes

In [17]:
# UNIT TEST
# boxes = get_bounding_box("/content/vos/train/JPEGImages/2d8f5e5025/00055.jpg", "/content/vos/train/Annotations/2d8f5e5025/00055.png")
# visualize("/content/vos/train/JPEGImages/10b31f5431/00030.jpg", boxes)

## Compute (50 * 50 * 9, 2)

In [18]:
def get_anchors_iou(image_path, label_path, anchors):
  
  gt_boxes = get_bounding_box(label_path)

  anchors_iou = np.zeros((len(anchors), len(gt_boxes)))

  for anchor_ind, anchor_box in enumerate(anchors):
    for gt_idx, gt_box in enumerate(gt_boxes):
      anchors_iou[anchor_ind, gt_idx] = compute_iou(anchor_box, gt_box)
  
  return anchors_iou    

In [19]:
def iou_to_label(anchors_iou):

  # anchors_iou is the mutual iou of every anchor box and ground-truth box
  # is in shape of (50 * 50 * 9, 5)

  # output is in shape of (50 * 50 * 9, 1)

  """
  a) if iou of anchor box and ground-truth box is greater than 0.7 label one is assigned
  b) anchor box with highest iou with a specific ground-truth box is also labeled as one
  c) iou less than 0.3 is labeled as -1
  """


  anchor_label = np.full((anchors_iou.shape[0], 1), fill_value=-1)


  negative_a_idx = np.where(np.max(anchors_iou, axis=-1) < 0.3) # a rule
  
  positive_b_idx = np.where(np.max(anchors_iou, axis=-1) > 0.7) # b rule
  postive_c_idx =  np.argmax(anchors_iou, axis=0) # c rule
  
  anchor_label[negative_a_idx, :] = 0

  anchor_label[positive_b_idx, :] = 1
  anchor_label[postive_c_idx, :] = 1
  
  
  
  return anchor_label

In [20]:
# UNIT TEST
# anchors_iou = get_anchors_iou("/content/image.jpg", "/content/label.png", anchor_boxes)
# anchor_label = iou_to_label(anchors_iou)
# X1, _ = np.where(anchor_label == 1)
# visualize("/content/image.jpg", anchor_boxes[X1])

In [21]:
# SANITY CHECK
# anchors with label zero should be much higher in number than label 1 and -1
# label 1 reasonable range is (grouth_truth, 3 * grouth_truth)

# (anchor_label == -1).sum(), (anchor_label == 0).sum(), (anchor_label == 1).sum() 

## Utilities

In [22]:
def corner2center(box):

  x1 = box[0]
  x2 = box[2]

  y1 =  box[1]
  y2 = box[3]

  center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2
  width, height = (x2 - x1), (y2 - y1)

  return [center_x, center_y, width, height]

In [23]:
# UNIT TEST
# box1 = corner2center([0, 0, 5, 5]) 
# assert box1 == [2.5, 2.5, 5, 5]
# print("pass!")

In [24]:
def center2corner(box):
  x, y, w, h = box
  x1 = x - (1/2) * w
  y1 = y - (1/2) * h
  x2 = x + (1/2) * w
  y2 = y + (1/2) * h

  return [x1, y1, x2, y2]

In [25]:
# UNIT TEST
# box1 = center2corner([2.5, 2.5, 5, 5]) 
# assert box1 == [0, 0, 5, 5]
# print("pass!")

$$ t_x = \frac{A_x - G_x}{A_w} $$
\
$$ t_w = \log{\frac{G_w}{A_w}} $$

In [26]:
from math import log

def get_parameterized_target(anchor_box, gt_box):
  
  # input is anchor_box, gt_box in center format
  # output is t_X, t_y, t_h, t_w
  A_x, A_y, A_w, A_h = anchor_box
  G_x, G_y, G_w, G_h = gt_box

  t_x = (A_x - G_x) / A_w
  t_y = (A_y - G_y) / A_h
  t_w = log(G_w / A_w)
  t_h = log(G_h / A_h)

  return [t_x, t_y, t_w, t_h] 

In [27]:
# UNIT TEST
# box1 = corner2center([0., 0., 5., 5.])
# box2 = corner2center([2.5, 2.5, 7.5, 7.5])

# assert (get_parameterized_target(box1, box2) == [-1/2, -1/2, 0, 0]) # expected output [-1/2, -1/2, 0, 0]
# print("pass!")

In [28]:
def parametrized_to_corner(t_x, t_y, t_w, t_h, anchor_box):
  # anchor_box in format center
  A_x, A_y, A_w, A_h = anchor_box
  
  pred_x = A_x - t_x * A_w  
  pred_y = A_y - t_y * A_h

  pred_w = np.exp(t_w) * A_w  
  pred_h = np.exp(t_h) * A_h

  return center2corner([pred_x, pred_y, pred_w, pred_h])  

In [29]:
# UNIT TEST
# box1 = corner2center([0., 0., 5., 5.])
# box2 = corner2center([2.5, 2.5, 7.5, 7.5])

# assert (parametrized_to_corner(*[-1/2, -1/2, 0, 0], box1) == center2corner(box2))
# print("pass!")

## RPN Network

In [30]:
import torch
import torch.nn as nn
from torchvision.models import vgg16
import torch.nn.functional as F

In [31]:
feature_extractor = vgg16(pretrained=True).features[:30].to(device)

In [32]:
class RPN(nn.Module):
  
  def __init__(self, n_anchors):
    
    super(RPN, self).__init__()
    
    self.conv1 = nn.Conv2d(512, 512, 3, 1, 1)
    self.conv1.weight.data.normal_(0, 0.01)
    self.conv1.bias.data.zero_()

    self.cls_layer = nn.Conv2d(512, n_anchors * 2, 1, 1, 0)
    self.cls_layer.weight.data.normal_(0, 0.01)
    self.cls_layer.bias.data.zero_()

    self.reg_layer = nn.Conv2d(512, n_anchors * 4, 1, 1, 0)
    self.cls_layer.weight.data.normal_(0, 0.01)
    self.cls_layer.bias.data.zero_()

  def forward(self, x):
    
    x = self.conv1(x)

    pred_cls = self.cls_layer(x)
    pred_loc = self.reg_layer(x)

    return pred_cls, pred_loc

In [101]:
class Faster_RCNN(nn.Module):
  
  def __init__(self):
    
    super(Faster_RCNN, self).__init__()
    n_anchors = 9
    self.rpn_model = RPN(n_anchors).to(device)

    self.adaptive_max_pool = nn.AdaptiveMaxPool2d((7, 7))

    self.roi_head_classifier = nn.Sequential(
      nn.Linear(25088, 4096),
      nn.Linear(4096, 4096)
    )

    self.score = nn.Linear(4096, 21)
    self.cls_loc = nn.Linear(4096, 21 * 4)

  def rpn(self, x):
    return self.rpn_model(x)

  def forward(self, x, indices_and_rois):
    
    output = []

    for idx, roi in enumerate(indices_and_rois):
      
      image_index, x1, y1, x2, y2 = roi

      x1, y1, x2, y2 = x1 // 16, y1 // 16, x2 // 16, y2 // 16
      im = x.narrow(0, 0, 1)[..., y1:y2+1, x1:x2+1]

      output.append(self.adaptive_max_pool(im)) # (512, 7, 7)

    output = torch.cat(output, 0)
    output = output.view(128, -1) # (128, 25088)25088

    print(output.shape)
    
    output = self.roi_head_classifier(output)
    pred_roi_cls_score = self.score(output)
    pred_roi_loc = self.cls_loc(output)

    return pred_roi_cls_score, pred_roi_loc 

In [34]:
# n_anchors = len(scales) * len(ratios)
# n_anchors = 9
# rpn = RPN(n_anchors).to(device)

In [35]:
# UNIT Test
# assert (rpn.feature_extractor(torch.randn(4, 3, 800, 800)).shape) == (4, 512, 50, 50), "something wrong with feature extractor output shape!"

## NMS

In [36]:
import numpy as np
from tqdm.auto import tqdm

def NMS(boxes):
  # boxes in order of the objectness score
  boxes_detected = []
  boxes_remained = np.copy(boxes)

  while len(boxes_remained) > 0:
    
    if len(boxes_detected) >= 2000:
      return boxes_detected[:2000]

    box_detected = boxes_remained[0]
    boxes_detected.append(box_detected)
    
    boxes_remained = boxes_remained[1:]

    mask = np.ones(len(boxes_remained), dtype=bool)
    idx_to_remove = []
    
    for idx, box in enumerate(boxes_remained):
      
      if compute_iou(box_detected, box) > 0.7:
        idx_to_remove.append(idx)
    
    boxes_remained = boxes_remained[mask, ...]
  return boxes_detected[:2000]
  

## Dataset

In [37]:
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from tqdm.auto import tqdm
import glob
import os

In [38]:
anchor_boxes = generate_all_anchor_boxes()

In [39]:
class VOS(Dataset):

  def __init__(self, root):
    self.images = np.array(glob.glob(os.path.join(root, "JPEGImages/*/*.jpg")))[:2000]
    self.root = root

    idx_to_keep = []
    print("Removing images with no GT-box")
    for idx, image_path in enumerate(tqdm(self.images)):
      
      image_name = image_path.split("/")[-1].split(".")[0]
      folder_id = image_path.split("/")[-2]
      label_path = os.path.join(self.root, "Annotations", folder_id, image_name + ".png")
      
      number_of_instances = get_number_of_instances(label_path)
      if number_of_instances > 1:
        idx_to_keep.append(idx)

    
    self.images = self.images[idx_to_keep]


  def __len__(self):
    return len(self.images)


  def __getitem__(self, idx):
    
    image_path = self.images[idx]
    
    image_name = image_path.split("/")[-1].split(".")[0]
    folder_id = image_path.split("/")[-2]
    label_path = os.path.join(self.root, "Annotations", folder_id, image_name + ".png")


    iou = get_anchors_iou(image_path, label_path, anchor_boxes)

    highest_iou = np.argmax(iou, axis=-1) # only GT-box with highest IoU is considered
    gt_boxes = get_bounding_box(label_path)

    # for each anchor box we have (t_x, t_y, t_w, t_h)
    # ground truth box with highest iou is chosen
    loc = np.zeros((len(anchor_boxes), 4)) 
    label = iou_to_label(iou)
    
    # ignore anchor boxes outside the image
    valid_indexes = get_valid_anchor_boxes(anchor_boxes)
    label[~valid_indexes] = -1

    # sampling positive and negative anchor boxes with ratio 1:1 (if possible)
    pos_index = np.where(label == 1)[0]
    disable = np.random.choice(pos_index, size=max(0, len(pos_index) - 128), replace=False)
    label[disable] = -1

    neg_index = np.where(label == 0)[0]
    pos_count = np.sum(label == 1)
    disable = np.random.choice(neg_index, size=len(neg_index) - (n_sample - pos_count), replace=False)
    label[disable] = -1


    for anchor_idx in range((len(anchor_boxes))):
      
      anchor_center = corner2center(anchor_boxes[anchor_idx]) 
      gt_center = corner2center(gt_boxes[highest_iou[anchor_idx]]) # coordination of nearest GT-box

      loc[anchor_idx] = get_parameterized_target(anchor_center, gt_center)

   
    image_file = Image.open(image_path)
    image_file = image_file.resize((800, 800))
    image_tensor = transforms.ToTensor()((np.array(image_file)))

    return image_path, label_path, image_tensor, label, loc

In [40]:
vos = VOS("/content/vos/train")
train_loader = DataLoader(vos, batch_size=batch_size, shuffle=True)
len(train_loader)

Removing images with no GT-box


  0%|          | 0/2000 [00:00<?, ?it/s]

1994

In [41]:
# vos = VOS("/content/vos/train")
# image_file, label, loc = vos[0]

# # Sanity Check
# image_file.shape, label.shape, loc.shape
# print((label == 0).sum())
# print((label == 1).sum())

# # UNIT TEST
# (label == 0).sum() + (label == 1).sum() == n_sample

In [42]:
# image_tensor, label, loc = next(iter(train_loader))
# image_tensor.shape, label.shape, loc.shape

## Post Process

$$ t_x = \frac{A_x - G_x}{A_w} $$
\
$$ t_w = \log{\frac{G_w}{A_w}} $$

In [43]:
import numpy as np

In [44]:
from tqdm.auto import tqdm

def get_pred_loc_corner(pred_loc):

  pred_loc_numpy = pred_loc.detach().cpu().numpy()
  pred_location_convert = np.zeros_like(pred_loc_numpy)


  for idx, pred in enumerate(pred_loc_numpy):
    A_x, A_y, A_w, A_h = corner2center(anchor_boxes[idx])
    t_x, t_y, t_w, t_h = pred

    x1, y1, x2, y2 = parametrized_to_corner(t_x, t_y, t_w, t_h, [A_x, A_y, A_w, A_h])

    pred_location_convert[idx] = [x1, y1, x2, y2]
  

  return pred_location_convert


In [45]:
# UNIT TEST
# box1 = corner2center([0., 0., 5., 5.])
# box2 = corner2center([2.5, 2.5, 7.5, 7.5])
# (get_pred_loc_corner(torch.tensor([[-1/2, -1/2, 0, 0]])))
# assert (get_pred_loc_corner(torch.tensor([[-1/2, -1/2, 0, 0]]))) == [[  8.,   8.,  72., 136.]]
# print("pass!")

## Filter By Size

In [None]:
def filter_by_size(pred_loc):
  # pred loc is numpy array (22500, 4)
  
  filtered_indexes =( (pred_loc[:, 2] - pred_loc[:, 0]) > 16) & ((pred_loc[:, 3] - pred_loc[:, 1]) > 16)
  
  return pred_loc[filtered_indexes]

In [None]:
# UNIT TEST
# filter_by_size(np.random.rand(22500, 4) * 100).shape

## Clip to the Image

In [46]:
def clip_into_image(boxes):
  
  boxes[:, 0] = np.clip(boxes[:, 0], 0 + 5, 800-5)
  boxes[:, 1] = np.clip(boxes[:, 1], 0 + 5, 800-5)
  boxes[:, 2] = np.clip(boxes[:, 2], 0 + 5, 800-5)
  boxes[:, 3] = np.clip(boxes[:, 2], 0 + 5, 800-5)

  return boxes

In [47]:
# UNIT TEST
boxes = torch.randn((2000, 4))
clip_into_image(boxes).shape

torch.Size([2000, 4])

## Sort Scores

In [48]:
def get_high_score(objectness_score, pred_loc):

  objectness_score = objectness_score.detach().cpu()
  order = objectness_score.ravel().argsort(descending=True)
  pred_ordered = pred_loc[order[:12000], :]
  
  return pred_ordered

In [49]:
# UNIT TEST
# pred_loc = torch.randn((22500, 4))
# objectness_score = torch.randn(len(pred_loc))

# get_high_score(objectness_score, pred_loc).shape

## Sampling ROIs

In [102]:
import numpy as np
import torch

In [103]:
def get_roi_sample(image_path, label_path, ROIs):
  
  anchors_iou = get_anchors_iou(image_path, label_path, ROIs) # (2000, 5)
  
  max_iou = np.max(anchors_iou, axis=-1)
  
  gt_index = np.argmax(anchors_iou, axis=-1) # (2000, 1)

  pos_labels = np.where(max_iou > 0.5)[0]

  neg_labels = np.where((max_iou > 0.1) & (max_iou < 0.5))[0]

  positive_labels = np.random.choice(pos_labels, size=min(len(pos_labels), 32), replace=False)
  negative_labels = np.random.choice(neg_labels, size=min(len(neg_labels), 96), replace=False)

  sample_index = np.append(positive_labels, negative_labels)
  labels = np.zeros(128)
  labels[:32] = 1

  return labels, ROIs[sample_index], gt_index[sample_index]

In [104]:
# UNIT TEST 
# get_roi_sample("/content/image.jpg", "/content/label.png", torch.randn(2000, 4))

## Training

In [106]:
faster_rcnn = Faster_RCNN().to(device)

In [107]:
faster_rcnn_optim = torch.optim.Adam(faster_rcnn.parameters(), lr=0.001)

In [None]:
regression_loss_acc = []
classification_loss_acc = []

for iteration, (image_path, mask_path, image, cls, loc)  in enumerate(tqdm(train_loader)):
  # image in shape of (1, 3, 800, 800)
  # cls in shape of (1, 22500, 1)
  # loc in shape of (1, 22500, 4)
  image_path = image_path[0]
  mask_path = mask_path[0]


  cls = cls[0].reshape(-1).to(device)
  loc = loc[0].to(device)
  image = image.to(device)
  output_map = feature_extractor(image).detach()

  faster_rcnn_optim.zero_grad()

  pred_cls, pred_loc = faster_rcnn.rpn(output_map)

  # pre-process for faster_rcnn
  # pred_cls, pred_loc = rpn(output_map) # output should be (1, 9 * 2, 50 , 50), (1, 9 * 4, 50, 50)

  pred_loc = pred_loc.contiguous().view(-1, 4)
  # pred_cls = pred_cls.permute(0, 2, 3, 1).contiguous().view(1, 50, 50, 18)
  objectness_score = pred_cls.view(50 * 50, 9, 2)[:, :, 1].contiguous().view(-1, 1)
  pred_cls = pred_cls.view(-1, 2)

  pred_loc_corner = get_pred_loc_corner(pred_loc) # (22500, 4) corner format
  pred_loc_corner_filtered = filter_by_size(pred_loc_corner)
  pred_loc_corner = clip_into_image(pred_loc_corner_filtered) # clip boxes into the height and width of image
  pred_high_loc_corner = get_high_score(objectness_score, pred_loc_corner) # (12000, 4)
  # pred_nms_loc_corner = NMS(pred_high_loc_corner) # (2000, 4)
  roi_labels, roi_loc, gt_index = get_roi_sample(image_path, mask_path, pred_high_loc_corner) # (128, 1) (128, 4)

  image_index = torch.zeros((len(roi_loc), 1))
  roi_loc = torch.from_numpy(roi_loc)
  indices_and_rois = torch.cat((image_index, roi_loc), dim=1)

  gt_boxes = get_bounding_box(mask_path)
  roi_loc_parametrized = np.zeros_like(roi_loc)

  for idx, roi in enumerate(roi_loc):
    gt_box = gt_boxes[gt_index[idx]]  
    roi_loc_parametrized[idx] = get_parameterized_target(roi, gt_box)

  pred_roi_cls_score, pred_roi_loc = faster_rcnn(output_map, indices_and_rois.long())

  classification_loss = F.cross_entropy(pred_cls, cls, ignore_index=-1)
  
  # classification_loss_acc.append(classification_loss.item())
  print(classification_loss.item())

  pos_index = torch.where(cls == 1)
  x = torch.abs(pred_loc[pos_index] - loc[pos_index])

  regression_loss = (x < 1).float() * (0.5) * (x**2) + (x > 1).float() * (x - 0.5)
  regression_loss = regression_loss.sum()
  regression_loss = regression_loss / (len(pos_index))

  regression_loss_acc.append(regression_loss.item())
  print(regression_loss.item())

  rpn_loss = classification_loss + (10 * regression_loss)

  roi_labels, roi_loc_parametrized = torch.tensor(roi_labels).long().to(device), torch.tensor(roi_loc_parametrized).to(device)

  ### Faster R-CNN
  roi_cls_loss = F.cross_entropy(pred_roi_cls_score, roi_labels, ignore_index=-1)

  pred_roi_loc = pred_roi_loc.view(-1, 21, 4)
  pred_roi_loc = pred_roi_loc[torch.arange(0, 128).long(), 1] # (128, 4)

  # pred_loc = pred_loc[torch.arange(0, 128).long(), 1]

  x = torch.abs(pred_roi_loc - roi_loc_parametrized)
  roi_reg_loss = (x < 1).float() * 0.5 * (x**2) + (x >= 1).float() * (x - 0.5)
  roi_reg_loss = roi_reg_loss.sum() / len(roi_loc_parametrized.shape[0])
  
  faster_rcnn_loss = roi_cls_loss + (10 * roi_reg_loss) 
  
  loss = rpn_loss + faster_rcnn_loss

  loss.backward()

  faster_rcnn_optim.step()

  if iteration % 100 == 0:
    
    # visualize 100 high score box
    obj_score = pred_cls[:, 1] # [22500, 2]
    
    order_index = obj_score.argsort(dim=0)[:12000]
    predicted_boxes = pred_loc[order_index].detach().cpu().numpy()
    for idx, pred_box in enumerate(predicted_boxes):
      anchor_box = anchor_boxes[order_index[idx]]
      predicted_boxes[idx] = parametrized_to_corner(*pred_box, anchor_box)

    predicted_boxes = clip_into_image(predicted_boxes)
    detected_boxes = NMS(predicted_boxes) #[22500, 4]
    print(len(detected_boxes))
    # print(image_path)
    visualize(image_path[0], detected_boxes, get_bounding_box(mask_path[0]))

    print("classification loss: ", np.mean(classification_loss_acc[-100:]))
    print("regression loss:", np.mean(regression_loss_acc[-100:]))