In [1]:
import sys
sys.path.append('../../src')

#######################################
import tortto as tt
import tortto.nn as nn

# import torch as tt
# import torch.nn as nn
###################################


import warnings
warnings.filterwarnings("ignore", category=UserWarning)

def intersection_over_union(boxes, boxes_tgt, S):
    boxes_tgt = boxes_tgt[:, None, :]  # (N, 1, 5)
    area = boxes[..., 3] * boxes[..., 4]  # (N,B)
    area_tgt = boxes_tgt[..., 3] * boxes_tgt[..., 4]  # (N,1)

    norm_xy = boxes[..., 1:3]/S  # /S if relative to box
    norm_xy_tgt = boxes_tgt[..., 1:3]/S  # /S if relative to box
    lt = tt.max(norm_xy-boxes[..., 3:]/2, norm_xy_tgt -
                boxes_tgt[..., 3:]/2)  # left-top corner (N, B, 2)
    # right-bottom corner (N, B, 2)
    rb = tt.min(norm_xy+boxes[..., 3:]/2, norm_xy_tgt+boxes_tgt[..., 3:]/2)

    wh = (rb - lt).clamp(min=0)
    inter = wh[..., 0]*wh[..., 1]  # (N, B)

    union = area+area_tgt-inter

    return inter/union


class YoloLoss(nn.Module):
    def __init__(self, S, B, C):
        super(YoloLoss, self).__init__()
        self.mse = nn.MSELoss(reduction='sum')
        self.S = S
        self.B = B
        self.C = C
        self.lambda_coord = 5
        self.lambda_noobj = 0.5

    def forward(self, predictions, target):
        # extract boxes and labels from predictions
        # (N, S*S*(C+5B))->(N*S*S, C+5B)
        predictions = predictions.view(-1, self.C + 5 * self.B)
        labels = predictions[..., :20]  # (N*S*S, C)
        boxes = predictions[..., 20:].view(-1, self.B, 5)  # (N*S*S, B, 5)

        # boxes_tgt: (N, S*S, 5). labels_tgt: (N, S*S, C). Iobj: (N, S*S)
        boxes_tgt, labels_tgt, Iobj = target
        N = boxes_tgt.shape[0]

        # flatten out batch dimension. select boxes that exists
        Iobj = Iobj.flatten()  # (N*S*S)
        boxes_tgt = boxes_tgt.view(-1, 5)  # (N*S*S, 5)
        labels_tgt = labels_tgt.view(-1, self.C)  # (N*S*S, C)

        # select cells that have bounding boxes. Nobj: number of boxes that exists
        boxes_tgt_obj = boxes_tgt[Iobj]  # (Nobj, 5), could be empty
        Nobj = boxes_tgt_obj.shape[0]
        loss = 0
        if Nobj != 0:  # if boxes exists:
            # find responsible boxes
            boxes_obj = boxes[Iobj]  # (Nobj, B, 5)
            with tt.no_grad():
                ious = intersection_over_union(
                    boxes_obj, boxes_tgt_obj, self.S)  # (Nobj, B)
            best_box = tt.argmax(ious, dim=-1)  # (Nobj,)
            responsible = boxes_obj[tt.arange(Nobj), best_box, :]  # (Nobj, 5)
            # calc loss
            wh = responsible[..., 3:]  # (Nobj, 2). select width and height
            loss += self.lambda_coord*(
                self.mse(responsible[..., 1:3], boxes_tgt_obj[..., 1:3])  # x,y
                + self.mse(tt.sign(wh) * tt.sqrt(tt.abs(wh)+1e-8),
                           tt.sqrt(boxes_tgt_obj[..., 3:]))  # w,h
            )  # coord loss
            # box confidence loss
            loss += self.mse(responsible[..., 0], boxes_tgt_obj[..., 0])
            # class confidence loss
            loss += self.mse(labels[Iobj], labels_tgt[Iobj])

        # select cells that have no bounding boxes.
        boxes_tgt_noobj = boxes_tgt[~Iobj]  # (Nnoobj, 5), could be empty
        if boxes_tgt_noobj.shape[0] != 0:
            boxes_noobj = boxes[~Iobj]  # (Nnoobj, B, 5)
            loss_noobj = self.lambda_noobj * \
                self.mse(boxes_noobj[..., 0], boxes_tgt_noobj[..., 0, None])
            loss += self.lambda_noobj*loss_noobj

        return loss/N

In [2]:
import numpy as np

loss_fn = YoloLoss(S=7, B=2, C=20)

boxes = tt.tensor(np.load('data/boxes.npy'))
labels = tt.tensor(np.load('data/labels.npy'))
Iobj = tt.tensor(np.load('data/Iobj.npy'))
out = tt.tensor(np.load('data/out.npy'), requires_grad=True)

loss = loss_fn(out, (boxes, labels, Iobj))
loss

tensor( 137.1291, grad_fn=<DivBackward>)