In [1]:
import torch
import torch.nn.functional as F
import numpy as np
from scipy.optimize import linear_sum_assignment
from torchvision.ops import box_convert, generalized_box_iou_loss
import math

In [2]:
# --- hyperparams
l1_lambda = 5 # loss. l1 and iou hyperparams from original detr paper
IoU_lambda = 2

def Lmatch(gtbox, gtlabel, pbox, plogits):
    prob = plogits.softmax(dim = -1)[gtlabel]

    pbox = box_convert(boxes = pbox, in_fmt = 'cxcywh', out_fmt = 'xyxy')
    l_iou = generalized_box_iou_loss(gtbox, pbox)
    l_l1 = F.l1_loss(gtbox, pbox)
    l_box = l1_lambda*l_l1 + IoU_lambda*l_iou

    total_loss = -prob + l_box

    return total_loss

In [17]:
N = 100
pboxes = torch.randn((N, 4))
plogits = torch.randn((N, 21))

ngt = 4
gtboxes = torch.randn((ngt, 4))
gtlabels = torch.arange(4)#torch.ones(ngt, dtype = int)

# find optimal matching
Lmatrix = np.zeros((ngt, N))
for i in range(ngt):
    for j in range(N):
        Lmatrix[i, j] = Lmatch(gtboxes[i], gtlabels[i], pboxes[j], plogits[j])

row, cols = linear_sum_assignment(Lmatrix)
row, cols


# step 2: compute loss
targets = torch.ones(N, dtype = gtlabels.dtype)*20
targets[cols] = gtlabels[row]
weights = torch.ones(21)
weights[-1] = 0.1
plogits.shape, targets.shape, weights.shape
nll_loss = F.cross_entropy(plogits, targets, weight = weights)

# bbox loss
perm_boxes = pboxes[cols]
perm_boxes = box_convert(boxes = perm_boxes, in_fmt = 'cxcywh', out_fmt = 'xyxy')
l_iou = generalized_box_iou_loss(gtboxes, perm_boxes, reduction = 'sum')
l_l1 = F.l1_loss(gtboxes, perm_boxes, reduction = 'sum')
l_box = l1_lambda*l_l1 + IoU_lambda*l_iou
l_box /= ngt


total_loss = nll_loss + l_box


tensor(12.3258)

In [13]:
# - log p
-math.log(1/21)

3.044522437723423