In [1]:
import os
import json
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
from torch import nn
from torchvision.ops import box_convert
from PIL import Image, ImageDraw
from transformers import OwlViTProcessor, OwlViTForObjectDetection

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x1553d46ab850>

In [3]:
from src.dataset import get_dataloaders

In [4]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
device

'cuda:0'

In [5]:
train_dataloader, test_dataloader = get_dataloaders(4)
batch_1 = next(iter(train_dataloader))
batch_1[0].keys()

dict_keys(['input_ids', 'attention_mask', 'pixel_values'])

In [6]:
batch_1[1]

tensor([[ 675.3600,  326.0500, 1007.0600, 1052.7300],
        [ 857.3300, 1010.7800,  185.2100,  285.7000],
        [ 716.8300,  750.2000,  116.9500,   76.9800],
        [1429.7600,  799.2400,  895.1000,  174.8000]])

In [7]:
batch_1[2]

{'width': tensor([2560, 1920, 1920, 2560]),
 'height': tensor([1920, 1440, 1080, 1920]),
 'impath': ['/scratch/hk3820/capstone/data/paco_frames/v1/paco_frames/81cee65a-afe3-4dc2-a31e-3b67b062bf35_007471.jpeg',
  '/scratch/hk3820/capstone/data/paco_frames/v1/paco_frames/69c9d98e-c125-4d24-b180-aea768ef900a_008159.jpeg',
  '/scratch/hk3820/capstone/data/paco_frames/v1/paco_frames/3efc152d-ea0e-4372-b552-7d5e1cf07259_386360.jpeg',
  '/scratch/hk3820/capstone/data/paco_frames/v1/paco_frames/a723d89c-78b7-4325-b18e-b8a4436a27ca_020117.jpeg']}

In [8]:
batch_1[0]['input_ids'] = batch_1[0]['input_ids'].view(-1,16)
batch_1[0]['attention_mask'] = batch_1[0]['attention_mask'].view(-1,16)

In [9]:
processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32") # Image Processor + Text Tokenizer
model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
model = model.to(device)

In [10]:
# class PatchedOwlVit(OwlViTForObjectDetection):
#     def __init__(self):
#         super().__init__()
        
#     def forward():
        

In [11]:
outputs = model(**batch_1[0].to(device))
outputs.keys()

odict_keys(['logits', 'pred_boxes', 'text_embeds', 'image_embeds', 'class_embeds', 'text_model_output', 'vision_model_output'])

In [12]:
logits = outputs["logits"] # (B,N,C)
logits.shape

torch.Size([4, 576, 3])

In [13]:
pred_boxes = outputs["pred_boxes"] # (B,N,4)
pred_boxes.shape # Pred boxes in resized img

torch.Size([4, 576, 4])

In [14]:
# target_sizes = torch.cat([batch_1[2]['height'].view(-1,1), batch_1[2]['width'].view(-1,1)], dim=1)
# target_sizes = target_sizes.to(device)
# results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes, threshold=0.1)
# results

## Focal Loss

In [15]:
def pos_neg_focal_loss(logits, focal_alpha = 0.5, focal_gamma = 2):
    p = nn.functional.sigmoid(logits) # P(y=1)
    pos_class_losses = -focal_alpha * torch.pow(1-p, focal_gamma) * torch.log(p) # Loss if y=1
    neg_class_losses = -(1-focal_alpha) * torch.pow(p, focal_gamma) * torch.log(1-p) # Loss if y=0
    return pos_class_losses, neg_class_losses

In [16]:
pos_class_losses, neg_class_losses =  pos_neg_focal_loss(logits, focal_alpha = 0.5, focal_gamma = 2)
pos_class_losses.shape # (B,N,C)

torch.Size([4, 576, 3])

In [17]:
batch_size = pos_class_losses.shape[0]
num_queries = pos_class_losses.shape[-1]

# Target Labels -- 1st query Pos, rest Neg
target_labels = nn.functional.one_hot(torch.zeros(1).to(torch.int64), num_classes=num_queries).to(device)
target_labels = target_labels.repeat(batch_size,1,1)
target_labels.shape # (B,M,C)

torch.Size([4, 1, 3])

In [18]:
def contrastive_focal_loss(pos_class_losses, neg_class_losses, target_labels):
    pos_class_loss = torch.einsum('bnc,bmc->bnm', pos_class_losses, target_labels.to(torch.float32)) # Sum of losses for pos queries
    pos_class_loss = pos_class_loss/target_labels.sum(dim=-1).unsqueeze(-1) # Scale by num pos queries
    neg_class_loss = torch.einsum('bnc,bmc->bnm', neg_class_losses, (1-target_labels.to(torch.float32))) # Sum of losses for neg queries
    neg_class_loss = neg_class_loss/(1-target_labels).sum(dim=-1).unsqueeze(-1) # Scale by num neg queries
    focal_loss = pos_class_loss + neg_class_loss
    return focal_loss

In [19]:
# def contrastive_focal_loss(pos_class_losses, neg_class_losses, target_labels):
#     pos_class_weights = target_labels/target_labels.sum(dim=-1).unsqueeze(-1) # Indicator for +ve queries scaled by num_pos_queries
#     neg_class_weights = (1-target_labels)/(1-target_labels).sum(dim=-1).unsqueeze(-1) # Indicator for -ve queries scaled by num_neg_queries
#     pos_class_loss = torch.einsum('bnc,bmc->bnm', pos_class_losses, pos_class_weights) # Avg loss for pos queries
#     neg_class_loss = torch.einsum('bnc,bmc->bnm', neg_class_losses, neg_class_weights) # Avg loss for neg queries
#     focal_loss = pos_class_loss + neg_class_loss
#     return focal_loss

In [20]:
focal_loss = contrastive_focal_loss(pos_class_losses, neg_class_losses, target_labels)
focal_loss.shape # (B,N,M)

torch.Size([4, 576, 1])

In [21]:
# l = focal_loss.mean()
# l.backward(retain_graph=True)

## BBox loss

In [22]:
pred_boxes = box_convert(pred_boxes, "cxcywh", "xyxy")

In [23]:
target_boxes = box_convert(batch_1[1], "xywh", "xyxy")
target_boxes = target_boxes[:,None,:] # (B,4) -> (B,M=1,4)
target_boxes = target_boxes.to(device)
target_boxes.shape

torch.Size([4, 1, 4])

In [24]:
coord_dists = torch.abs(pred_boxes[:, :, None] - target_boxes[:, None, :])  # [B, N, M, 4]
bbox_loss = torch.sum(coord_dists, axis=-1)  # [B, N, M]
bbox_loss.shape

torch.Size([4, 576, 1])

In [25]:
# l = bbox_loss.mean()
# l.backward(retain_graph=True)

## GIoU loss

In [26]:
# https://github.com/google-research/scenic/blob/main/scenic/model_lib/base_models/box_utils.py

In [27]:
pred_boxes.shape, target_boxes.shape

(torch.Size([4, 576, 4]), torch.Size([4, 1, 4]))

In [28]:
def box_iou(boxes1, boxes2, eps = 1e-6):
    """Computes IoU between two sets of boxes.

    Boxes are in [x, y, x', y'] format [x, y] is top-left, [x', y'] is bottom right.

    Args:
        boxes1: Predicted bounding-boxes in shape [bs, n, 4].
        boxes2: Target bounding-boxes in shape [bs, m, 4].
        eps: Epsilon for numerical stability.

    Returns:
        Pairwise IoU cost matrix of shape [bs, n, m].
    """
    # First, compute box areas. These will be used later for computing the union.
    wh1 = boxes1[..., 2:] - boxes1[..., :2] # W & H of box1
    area1 = wh1[..., 0] * wh1[..., 1]  # [bs, n]

    wh2 = boxes2[..., 2:] - boxes2[..., :2]
    area2 = wh2[..., 0] * wh2[..., 1]  # [bs, m]

    # Compute pairwise top-left and bottom-right corners of the intersection of the boxes.
    lt = torch.maximum(boxes1[..., :, None, :2], boxes2[..., None, :, :2])  # [bs, n, m, 2].
    rb = torch.minimum(boxes1[..., :, None, 2:], boxes2[..., None, :, 2:])  # [bs, n, m, 2].

    # intersection = area of the box defined by [lt, rb]
    wh = (rb - lt).clip(0.0)  # [bs, n, m, 2]
    intersection = wh[..., 0] * wh[..., 1]  # [bs, n, m]

    # union = sum of areas - intersection
    union = area1[..., :, None] + area2[..., None, :] - intersection

    iou = intersection / (union + eps)
    return iou, union  # pytype: disable=bad-return-type  # jax-ndarray

In [29]:
def generalized_box_iou(boxes1, boxes2, eps = 1e-6):
    """Generalized IoU from https://giou.stanford.edu/.

    The boxes should be in [x, y, x', y'] format specifying top-left and bottom-right corners.

    Args:
        boxes1: Predicted bounding-boxes in shape [..., N, 4].
        boxes2: Target bounding-boxes in shape [..., M, 4].
        eps: Epsilon for numerical stability.

    Returns:
        A [bs, n, m] pairwise matrix, of generalized ious.
    """
    # Degenerate boxes gives inf / nan results, so do an early check.
    assert (boxes1[:, :, 2:] >= boxes1[:, :, :2]).all()
    assert (boxes2[:, :, 2:] >= boxes2[:, :, :2]).all()
    
    iou, union = box_iou(boxes1, boxes2, eps=eps)

    # Generalized IoU has an extra term which takes into account the area of
    # the box containing both of these boxes. The following code is very similar
    # to that for computing intersection but the min and max are flipped.
    lt = torch.minimum(boxes1[..., :, None, :2], boxes2[..., None, :, :2])  # [bs, n, m, 2]
    rb = torch.maximum(boxes1[..., :, None, 2:], boxes2[..., None, :, 2:])  # [bs, n, m, 2]

    # Now, compute the covering box's area.
    wh = (rb - lt).clip(0.0)  # Either [bs, n, 2] or [bs, n, m, 2].
    area = wh[..., 0] * wh[..., 1]  # Either [bs, n] or [bs, n, m].

    # Finally, compute generalized IoU from IoU, union, and area.
    # Somehow the PyTorch implementation does not use eps to avoid 1/0 cases.
    return iou - (area - union) / (area + eps)

In [30]:
iou, union = box_iou(pred_boxes, target_boxes)
iou.shape, union.shape # (B,N,M)

(torch.Size([4, 576, 1]), torch.Size([4, 576, 1]))

In [31]:
giou_loss = generalized_box_iou(pred_boxes, target_boxes)
giou_loss.shape  # (B,N,M)

torch.Size([4, 576, 1])

In [32]:
# l = giou_loss.mean()
# l.backward(retain_graph=True)

## Total loss

In [33]:
focal_loss_coef = 1.0/3
bbox_loss_coef = 1.0/3
giou_loss_coef = 1.0/3

total_loss = focal_loss_coef * focal_loss + bbox_loss_coef * bbox_loss + giou_loss_coef + giou_loss
total_loss.shape

torch.Size([4, 576, 1])

In [34]:
# l = total_loss.mean()
# l.backward(retain_graph=True)

## Hungarian Matching

In [35]:
from src.DETR.matcher import HungarianMatcher

In [36]:
from src.utils import paco_to_owl_box

In [37]:
matcher = HungarianMatcher()

In [38]:
outputs_for_matcher = {
    "pred_logits": outputs.logits.to(device),
    "pred_boxes": outputs.pred_boxes.to(device)
}

In [39]:
targets = [{"labels": torch.tensor([0]).to(device), "boxes":box.to(device)} for box in paco_to_owl_box(batch_1[1][:, None, :], batch_1[2])]


In [40]:
targets

[{'labels': tensor([0], device='cuda:0'),
  'boxes': tensor([[202.6080, 130.4200, 504.7260, 551.5121]], device='cuda:0')},
 {'labels': tensor([0], device='cuda:0'),
  'boxes': tensor([[342.9320, 539.0827, 417.0161, 691.4561]], device='cuda:0')},
 {'labels': tensor([0], device='cuda:0'),
  'boxes': tensor([[286.7320, 533.4755, 333.5121, 588.2169]], device='cuda:0')},
 {'labels': tensor([0], device='cuda:0'),
  'boxes': tensor([[428.9280, 319.6960, 697.4580, 389.6160]], device='cuda:0')}]

In [41]:
matches = matcher(outputs_for_matcher, targets)
matches = torch.tensor(matches)
matches

tensor([[ 34,   0],
        [545,   0],
        [ 45,   0],
        [519,   0]])

In [42]:
total_loss.shape

torch.Size([4, 576, 1])

In [43]:
total_matched_loss = total_loss[torch.arange(batch_size), matches[:,0], matches[:,1]]
mean_loss = total_matched_loss.mean()
mean_loss

tensor(1410.3622, device='cuda:0', grad_fn=<MeanBackward0>)

In [44]:
mean_loss.backward(retain_graph=True)