# Prepare data for box-level loss estimation

In [1]:
import argparse
import datetime
import json
import random
import time
from pathlib import Path
import os, sys
import numpy as np

import torch
from datasets import build_dataset
from util.slconfig import SLConfig
from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou
import torch.nn.functional as F

In [49]:
def read_one_image_results(path):
    with open(path, "r") as outfile:
        data = json.load(outfile)
    return data

def write_one_results(path, json_data):
    with open(path, "w") as outfile:
        json.dump(json_data, outfile)
        
def transform_tensor_to_list(l):
    return l.cpu().tolist()

def transform_tensors_to_list(l):
    if torch.is_tensor(l):
        return transform_tensor_to_list(l)
    if isinstance(l, list):
        r = []
        for i in l:
            r.append(transform_tensors_to_list(i))
        return r
    if isinstance(l, dict):
        r = {}
        for k,v in l.items():
            r[k] = transform_tensors_to_list(v)
        return r
    return l

def generate_one_image_results(path):
    results = read_one_image_results(path)
    pred_logits = torch.FloatTensor(results['input']['pred_logits']).squeeze(axis=0)
    pred_boxes = torch.FloatTensor(results['input']['pred_boxes']).squeeze(axis=0)
    prob = pred_logits.sigmoid()
    select_mask = prob > score_threshold
    if select_mask.sum() == 0:
        loc = prob.argmax()
        x_loc = torch.div(loc, prob.shape[1], rounding_mode='floor')
        y_loc = loc % prob.shape[1]
        select_mask[x_loc, y_loc] = 1
    score = prob[select_mask]
    selected_index = torch.div(torch.nonzero(select_mask.reshape(-1)),prob.shape[1], rounding_mode='floor').squeeze(axis=1)
    labels = torch.nonzero(select_mask.reshape(-1)) % prob.shape[1]
    out_logits = pred_logits[selected_index]
    out_boxes = pred_boxes[selected_index]
    return selected_index, score, labels, out_logits, out_boxes

def hungarian_matching(out_logits, out_boxes, targets, cost_class = 2.0, cost_bbox = 5.0, cost_giou = 2.0, focal_alpha = 0.25):
    """ Performs the matching
    """
    if targets["boxes"] is None or targets["labels"].shape[0] == 0:
        return None
    
    # We flatten to compute the cost matrices in a batch
    num_queries = out_logits.shape[0]
    out_prob = out_logits.sigmoid()  # [num_queries, num_classes]
    
    tgt_ids = targets["labels"]
    tgt_bbox = targets["boxes"]
    
    # Compute the classification cost.
    alpha = focal_alpha
    gamma = 2.0
    neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log())
    pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
    cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]
    
    # Compute the L1 cost between boxes
    cost_bbox = torch.cdist(out_boxes, tgt_bbox, p=1)
    
    # Compute the giou cost betwen boxes            
    cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_boxes), box_cxcywh_to_xyxy(tgt_bbox))
    
    # Final cost matrix
    C = cost_bbox * cost_bbox + cost_class * cost_class + cost_giou * cost_giou
    C = C.view(num_queries, -1)
    return torch.argmin(C, axis=1)


def sigmoid_focal_loss(inputs, targets, alpha: float = 0.25, gamma: float = 2):
    """
    Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: A float tensor with the same shape as inputs. Stores the binary
                 classification label for each element in inputs
                (0 for the negative class and 1 for the positive class).
        alpha: (optional) Weighting factor in range (0,1) to balance
                positive vs negative examples. Default = -1 (no weighting).
        gamma: Exponent of the modulating factor (1 - p_t) to
               balance easy vs hard examples.
    Returns:
        Loss tensor
    """
    prob = inputs.sigmoid()
    ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
    p_t = prob * targets + (1 - prob) * (1 - targets)
    loss = ce_loss * ((1 - p_t) ** gamma)

    if alpha >= 0:
        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
        loss = alpha_t * loss
    return loss

def compute_loss(out_logits, out_boxes, targets, matched_target_indexes, cls_loss_coef = 1.0, bbox_loss_coef = 5.0, giou_loss_coef = 2.0):
    if matched_target_indexes == None:
        target_classes_onehot = torch.zeros([out_logits.shape[0], out_logits.shape[1]],dtype=out_logits.dtype, layout=out_logits.layout, 
                                        device=out_logits.device)
        loss_ce = sigmoid_focal_loss(out_logits, target_classes_onehot)
        loss_ce = loss_ce.sum(axis=1)
        loss = loss_ce * cls_loss_coef
        return loss, loss_ce, torch.zeros(loss.shape), torch.zeros(loss.shape)
    cls_loss_coef = 1.0
    bbox_loss_coef = 5.0
    giou_loss_coef = 2.0
    target_boxes = targets['boxes'][matched_target_indexes]
    loss_bbox = F.l1_loss(out_boxes, target_boxes, reduction='none') # [num_queries, 4]
    loss_bbox = loss_bbox.mean(axis=1) # [num_queries]
    loss_giou = 1 - torch.diag(generalized_box_iou(box_cxcywh_to_xyxy(out_boxes),
                box_cxcywh_to_xyxy(target_boxes))) # [num_queries]
    target_classes_onehot = torch.zeros([out_logits.shape[0], out_logits.shape[1]],dtype=out_logits.dtype, layout=out_logits.layout, 
                                        device=out_logits.device)
    target_labels = targets['labels'][matched_target_indexes]
    target_classes_onehot.scatter_(1, target_labels.unsqueeze(-1), 1)
    loss_ce = sigmoid_focal_loss(out_logits, target_classes_onehot)
    loss_ce = loss_ce.sum(axis=1)
    loss = loss_ce * cls_loss_coef + loss_bbox * bbox_loss_coef + loss_giou * giou_loss_coef
    return loss, loss_ce, loss_bbox, loss_giou

## parameter

In [51]:
split = "train"
data_path = "./data/5_scale_31/" + split + "/data/"
store_path = "./data/5_scale_31/" + split + "/box_annotation/"
image_nums = 5000
score_threshold = 0.25

In [52]:
model_config_path = "config/DINO/DINO_5scale.py"
args = SLConfig.fromfile(model_config_path) 
args.dataset_file = 'coco'
args.coco_path = "../coco/" # the path of coco
args.fix_size = False

dataset_val = build_dataset(image_set=split, args=args)   

data_aug_params: {
  "scales": [
    480,
    512,
    544,
    576,
    608,
    640,
    672,
    704,
    736,
    768,
    800
  ],
  "max_size": 1333,
  "scales2_resize": [
    400,
    500,
    600
  ],
  "scales2_crop": [
    384,
    600
  ]
}
loading annotations into memory...
Done (t=15.29s)
creating index...
index created!


In [None]:
image_nums = len(dataset_val)
for image_idx in range(image_nums):
    image_path = data_path + str(image_idx) + ".json"
    selected_index, score, labels, out_logits, out_boxes = generate_one_image_results(image_path)
    _, targets = dataset_val[image_idx]
    matched_target_indexes = hungarian_matching(out_logits, out_boxes, targets)
    loss, loss_ce, loss_bbox, loss_giou = compute_loss(out_logits, out_boxes, targets, matched_target_indexes)
    json_object = {'matched_target_indexes': matched_target_indexes, 'out_labels': labels.squeeze(axis=1), 'loss': loss, 
               'loss_ce': loss_ce, 'loss_bbox': loss_bbox, 'loss_giou': loss_giou}
    image_store_path = store_path + str(image_idx) + ".json"
    json_object = transform_tensors_to_list(json_object)
    write_one_results(image_store_path, json_object)
    if image_idx % 1000 == 0:
        print(f"Complete {image_idx+1}/{image_nums}")

Complete 1/118287
Complete 1001/118287
Complete 2001/118287
Complete 3001/118287
Complete 4001/118287
Complete 5001/118287
Complete 6001/118287
Complete 7001/118287
Complete 8001/118287
Complete 9001/118287
Complete 10001/118287
Complete 11001/118287
Complete 12001/118287
Complete 13001/118287
Complete 14001/118287
Complete 15001/118287
Complete 16001/118287
Complete 17001/118287
Complete 18001/118287
Complete 19001/118287
Complete 20001/118287
Complete 21001/118287
Complete 22001/118287


In [54]:
image_nums

118287