In [2]:
import torch
import torch.nn as nn
from utils import intersection_over_union

In [3]:
class YoloLoss(nn.Module):
    def __init__(self, S=7, B=2, C=20) -> None:
        super().__init__()
        self.B = B
        self.C = C
        self.S = S
        self.mse = nn.MSELoss(reduction="sum")
        self.lambda_obj = 5
        self.lambda_nobj = 0.5

    def forward(self, predtictions, targets):
        predtictions = predtictions[-1, self.S, self.S, self.C + 5*self.B]
        iou_b1 = intersection_over_union(predtictions[..., 21:25], targets[..., 21:25])
        iou_b2 = intersection_over_union(predtictions[..., 26:30], targets[..., 21:25])
        ious = torch.cat(iou_b1.unsqueeze(0), iou_b2.unsqueeze(0))
        iou_max, best_box = torch.max(ious, dim=0)
        box_exists = predtictions[..., 20:21]

        #coordinates cost
        box_predictions = box_exists*(
            best_box*(predtictions[..., 26:30]) + (1-best_box)*(predtictions[..., 21:25])
        )
        box_targets = box_exists*(
            targets[..., 21:25]
        )
        box_predictions[..., 2:4] = torch.sign(box_predictions[..., 2:4])*(torch.sqrt(torch.abs(box_predictions[..., 2:4] + 1e-6)))
        box_targets[..., 2:4] = torch.sqrt(box_targets[..., 2:4])

        box_loss = self.mse(
            torch.flatten(box_predictions, end_dim=-2),
            torch.flatten(box_targets, end_dim=-2)
        )

        #object existence cost
        exist_pred = best_box*predtictions[..., 20:21] + (1-best_box)*predtictions[..., 25:26]
        obj_loss = self.mse(
            torch.flatten(box_exists*exist_pred, end_dim=-2),
            torch.flatten(box_exists*targets[..., 20:21], end_dim=-2)
        )
        #no object loss

        no_obj_loss = self.mse(
            torch.flatten(box_exists*predtictions[..., 20:21], end_dim=-2), 
            torch.flatten(box_exists*targets[..., 20:21], end_dim=-2)
        )
        no_obj_loss+=self.mse(
            torch.flatten(box_exists*predtictions[..., 25:26], end_dim=-2),
            torch.flatten(box_exists*targets[..., 20:21], end_dim=-2)
        )
        
        #class loss
        class_loss = self.mse(
            torch.flatten(box_exists*predtictions[..., :20], end_dim=-2),
            torch.flatten(box_exists*targets[..., :20], end_dim=-2)
        )

        loss = self.lambda_obj*box_loss+obj_loss+self.lambda_nobj*no_obj_loss+class_loss
