In [72]:
import torch
import torch.nn as nn

def iou(pred, target):
    
    b1_x1 = pred[...,0:1]
    b1_y1 = pred[...,1:2]
    b1_x2 = pred[...,2:3]
    b1_y2 = pred[...,3:4]
    
    b2_x1 = pred[...,0:1]
    b2_y1 = pred[...,1:2]
    b2_x2 = pred[...,2:3]
    b2_y2 = pred[...,3:4]
    
    x1 = torch.max(b1_x1, b2_x2)
    y1 = torch.max(b1_y1, b2_y1)
    x2 = torch.min(b1_x2, b2_x2)
    y2 = torch.min(b1_y2, b2_y2)
    
    inter = (x2 - x2).clamp(0) * (y2 - y1).clamp(0)
    
    b1 = torch.abs((b1_x2 - b1_x1) * (b1_y2 - b1_y1))
    b2 = torch.abs((b2_x2 - b2_x1) * (b2_y2 - b2_y1))
    
    iou = inter/(b1 + b2 - inter + 1e-6)
    return iou



class Loss(nn.Module):
    
    def __init__(self):
        super(Loss, self).__init__()
        self.mse = nn.MSELoss() 
        self.bce = nn.BCEWithLogitsLoss()
        self.ce =  nn.CrossEntropyLoss()
        self.sigmoid = nn.Sigmoid()
        
        self.l_class = 1
        self.l_noobj = 10
        self.l_obj = 1
        self.l_box = 10
        
    def forward(self, pred, target, anchors):
        ## pred shape = s,s,5+number of class, target s,s,6
        obj_present = target[..., 0] > 0.5
        obj_absent = target[..., 0] <= 0.5

        print(pred.shape, target.shape, obj_present.shape, obj_absent.shape)
        ## no object loss
        no_object_loss = self.bce(pred[..., 0:1][obj_absent], target[..., 0:1][obj_absent])
        
        ## class loss
        class_loss = self.ce(pred[...,5:][obj_present], target[...,5][obj_present].long())
        
        ## object_loss
        anchors = anchors.reshape(1,3,1,1,2)
        box_pred = torch.cat([self.sigmoid(pred[...,1:3]), (torch.exp(pred[...,3:5]) *anchors)], dim=-1)
        ious = iou(box_pred[obj_present], target[...,1:5])
        object_loss = self.mse(self.sigmoid(pred[...,0:1][obj_present]), ious*target[...,0:1][obj_present])
            
            
        ## coord_loss
        pred[...,1:3] = self.sigmoid(pred[...,1:3])
        target[...,3:5] = torch.log(1e-10  + target[...,3:5]/anchors)
        coord_loss = self.mse(pred[...,1:5][obj_present], target[...,1:5][obj_present])
        
        
        loss = self.l_noobj * no_object_loss + self.l_class * class_loss + self.l_obj * object_loss + self.l_box * coord_loss
        
        return loss
        

## implement yolov3 loss

In [73]:
loss = Loss()

In [74]:
pred = torch.rand((4,3,5,5,10))
target = torch.rand((4,3,5,5,6))
anchor = torch.rand((3,2))

In [75]:
l = loss(pred, target, anchor)

torch.Size([4, 3, 5, 5, 10]) torch.Size([4, 3, 5, 5, 6]) torch.Size([4, 3, 5, 5]) torch.Size([4, 3, 5, 5])


In [76]:
l.item()

21.098360061645508