In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
from yolov2 import YOLOv2D19 as YOLOv2
from detection_datasets import VOCDatasetV2
import pickle
with open('anchors_VOC0712trainval.pickle', 'rb') as handle:
    anchors = pickle.load(handle)

In [2]:
model = YOLOv2(device=torch.device('cpu'), dtype=torch.float32)

  state_dict = torch.load(state_dict_path, map_location=self.device)


In [3]:
transforms = A.Compose([
    A.Resize(width=416, height=416),
    A.VerticalFlip(p=1.0),
    ToTensorV2()
], bbox_params=A.BboxParams(format='pascal_voc'))
train_set = VOCDatasetV2(devkit_path = '../../datasets/VOCdevkit/', scales=[13], anchors=anchors, transforms=transforms, 
                         dtype=torch.float32, device=torch.device('cpu'))

True ../../datasets/VOCdevkit/VOC2007\ImageSets\Main\trainval.txt
True ../../datasets/VOCdevkit/VOC2012\ImageSets\Main\trainval.txt


In [4]:
class YOLOv2Loss(nn.Module):
    def __init__(self, lambda_noobj=0.5, lambda_coord=5.0, num_classes=20):
        super().__init__()
        self.mse = torch.nn.MSELoss(reduction='sum')
        self.softmax = torch.nn.Softmax(dim=2)
        self.lambda_noobj = lambda_noobj
        self.lambda_coord = lambda_coord
        self.num_classes = num_classes
        
    def forward(self, out, gt_out, anchors):
        # [conf, obj_xc, obj_yc, obj_w, obj_h]
        is_obj = gt_out[:, 0::25, ...] == 1.0
        no_obj = gt_out[:, 0::25, ...] == 0.0

        # CONFIDENCE LOSS ===========
        conf_true = gt_out[:, 0::25, ...]
        conf_pred = out[:, 0::25, ...].sigmoid()

        is_obj_conf_pred = is_obj * conf_pred
        is_obj_conf_true = is_obj * conf_true
        
        no_obj_conf_pred = no_obj * conf_pred
        no_obj_conf_true = no_obj * conf_true

        is_obj_conf_loss = self.mse(is_obj_conf_pred, is_obj_conf_true)
        no_obj_conf_loss = self.mse(no_obj_conf_pred, no_obj_conf_true) 
        # ===========================

        # BOX LOSS ==================
        xc_true = gt_out[:, 1::25, ...]
        yc_true = gt_out[:, 2::25, ...]
        w_true = gt_out[:, 3::25, ...]
        h_true = gt_out[:, 4::25, ...]
        
        xc_pred = out[:, 1::25, ...].sigmoid()
        yc_pred = out[:, 2::25, ...].sigmoid()
        
        scale = gt_out.shape[-1]
        _anchors = torch.tensor(anchors) * scale
        pw = _anchors[:, 0]
        ph = _anchors[:, 1]
        
        w_pred = pw[None, :, None, None] * out[:, 3::25, ...].exp()
        h_pred = ph[None, :, None, None] * out[:, 4::25, ...].exp()

        xc_pred = is_obj * xc_pred
        xc_true = is_obj * xc_true
        yc_pred = is_obj * yc_pred
        yc_true = is_obj * yc_true
        
        w_pred = is_obj * w_pred
        w_true = is_obj * w_true
        h_pred = is_obj * h_pred
        h_true = is_obj * h_true

        xc_loss = self.mse(xc_pred, xc_true)
        yc_loss = self.mse(yc_pred, yc_true)
        w_loss = self.mse(w_pred.sqrt(), w_true.sqrt())
        h_loss = self.mse(h_pred.sqrt(), h_true.sqrt())
        # ===========================

        # CLASS LOSS ================
        class_true = []
        for i in range(len(anchors)):
            first_idx = 5 + i*(5+self.num_classes)
            last_idx = 25 + i*(5+self.num_classes)
            class_true.append(gt_out[:, first_idx:last_idx, ...])
        class_true = torch.stack(class_true, dim=1)

        class_pred = []
        for i in range(len(anchors)):
            first_idx = 5 + i*(5+self.num_classes)
            last_idx = 25 + i*(5+self.num_classes)
            class_pred.append(gt_out[:, first_idx:last_idx, ...])
        class_pred = torch.stack(class_pred, dim=1)

        class_pred = self.softmax(class_pred)
        
        class_pred = is_obj[:, :, None, :, :] * class_pred
        class_true = is_obj[:, :, None, :, :] * class_true

        class_loss = self.mse(class_pred, class_true)
        # ===========================

        loss =  self.lambda_coord * (xc_loss + yc_loss) + \
                self.lambda_coord * (w_loss + h_loss) + \
                is_obj_conf_loss + \
                self.lambda_noobj * no_obj_conf_loss + \
                class_loss

        return loss
        


In [5]:
loss = YOLOv2Loss()

In [6]:
image, gt_out = train_set[2]
gt_out = gt_out.unsqueeze(0)
image = image.unsqueeze(0)
out = model(image)

In [7]:
loss(out, gt_out, anchors)

tensor(116.3473, grad_fn=<AddBackward0>)