In [7]:
import os
import cv2
import numpy as np

import torch
from torch.utils.data import DataLoader

from data.yolo_dataset import YoloDataset, collate_fn

In [None]:
from torch import nn


class YoloV3Loss(nn.Module):
    ## __init__() will serve as the constructor of out loss function.
    def __init__(self):
        super().__init__()
        
        self.mse = nn.MSELoss()
        ## pos_weight => can control recall/precision tradeoffs with each classes
        ## pos_weight > 1 => ++recall
        ## pos_weight < 1 => ++precision
        class_num = dataset_option["DATASET"]["CLASSES_NUM"]
        pos_weight = torch.ones([class_num])
        self.objBCE = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
        self.multiBCE = nn.BCEWithLogitsLoss(pos_weight=pos_weight)


    ## forward() is where we will perfrom the calculations of the loss function.
    ## ## we can also define backward() function, which take control of how the gradient is calculated,
    ## ## but it is recommended to let ""PyTorch's autograd"" calculate the gradients automatically.
    def forward(self, pred, target, cell, anchor):
        dtype = pred.dtype
        device = pred.device
        

        ## Calculate bbox coordinates loss
        ## only for bboxes which are responsible of g.t. object
        pred_bbox = torch.cat([self.sigmoid(pred[..., 0:2]) + cell,
                                  torch.exp(pred[..., 2:4]) * anchor], dim=-1)
        ## How extract max_bboxes from full bboxes??
        coord_loss = self.mse(pred_bbox[...,...], target[..., 0:4])


        ## Calculate bbox confidence-score loss
        confi_loss = self.objBCE(pred[..., 4], target[..., 4])


        ## Calculate bbox classification loss
        ## only for bboxes which are responsible of g.t. object
        ## 
        class_num = pred.shape[1] - 5
        class_loss = torch.tensor(0., dtype=dtype, device=device)

        ## How extract max_bboxes from full bboxes??
        for i in range(class_num):
            class_loss += self.multiBCE(pred[..., i], target[..., i])
        class_loss /= class_num


        loss = coord_loss + confi_loss + class_loss
        return loss


## Do NMS on a feature map
## HOW?
def nms(pred, target, length):

    for i in range(length):
        ## Do NMS image by image
        #
        #
    



    ##################################
    bbox_num = len(pred)

    bbox_iou = iou(pred, target)
    score = torch.sigmoid(iou[..., ...])    


    bboxes = [box for box in pred if box[1] > threshold]
    bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True)
    nms_pred = []

    while bboxes:
        chosen_box = bboxes.pop(0)

        bboxes = [
            box
            for box in bboxes
            if box[0] != chosen_box[0]
            or intersection_over_union(
                torch.tensor(chosen_box[2:]),
                torch.tensor(box[2:]),
                box_format=box_format,
            )
            < 0.5
        ]

        nms_pred.append(chosen_box)

    return nms_pred


def iou(pred, target):
    b1_x1 = pred[..., 0]
    b1_y1 = pred[..., 1]
    b1_x2 = b1_x1 + pred[..., 2]
    b1_y2 = b1_y1 + pred[..., 3]
    b2_x1 = target[..., 0]
    b2_y1 = target[..., 1]
    b2_x2 = b2_x1 + target[..., 2]
    b2_y2 = b2_y1 + target[..., 3]

    x1 = torch.max(b1_x1, b2_x1)
    y1 = torch.max(b1_y1, b2_y1)
    x2 = torch.max(b1_x2, b2_x2)
    y2 = torch.max(b1_y2, b2_y2)
    
    inter = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
    
    b1_area = abs((b1_x2 - b1_x1) * (b1_y2 - b1_y1))
    b2_area = abs((b2_x2 - b2_y1) * (b2_y2 - b2_y1))

    return inter / (b1_area + b2_area - inter + 1e-6)

In [None]:
def train(model, loss_func, batch_input, anchors, dataset_option, epoch):
    batch_img, batch_label, batch_img_path = batch_input
    batch_size = batch_img.size(0)
    
    ###############################
    batch_output = model(batch_img)


    ## Confidence Score thresholding & NMS
    for i, batch in enumerate(batch_output):
        batch_output[i] = nms(batch, batch_label, batch_size)


    ###############################
    loss = ( loss_func(batch_output[0], batch_label, cell[0], anchor=anchors[0])
            + loss_func(batch_output[1], batch_label, cell[1], anchor=anchors[1])
            + loss_func(batch_output[2], batch_label, cell[2], anchor=anchors[2]) )

    loss.backward()
    

In [None]:
dataset_option = {  "DATASET": {
                        "NAME": "yolo-dataset",
                        "ROOT": "../datasets/yolo-dataset",
                        "CLASSES": {
                            #    "선박": 0, "부표": 1, "어망부표": 2,
                            #    "해상풍력": 3, "등대": 4, "기타부유물" : 5
                               "선박": 0, "부표": 1, "어망부표": 1,
                               "해상풍력": 1, "등대": 1, "기타부유물" : 1
                        },
                        "CLASSES_NUM": 2
                     }
                 }

batch_size = 32
epochs = 1


cell = [[],
        [],
        []]
anchors = [[(10, 13), (16, 30), (33, 23)],
           [(30, 61), (62, 45), (59, 119)],
           [(116, 90), (156, 198), (373, 326)]]

In [5]:
model = Net()
loss_function = YoloV3Loss()

In [None]:
valid_dataset = YoloDataset(dataset_option, split="valid")
valid_loader = DataLoader(valid_dataset, batch_size, collate_fn=collate_fn)

In [None]:
for epoch in range(epochs):
    for i, batch_input in enumerate(valid_loader, 0):
        train(  
                model,
                loss_function,
                batch_input,
                anchors,
                dataset_option,
                epoch
             )
