In [1]:
%load_ext autoreload
%autoreload 2

In [368]:
from src.yolo.v1.dataset import YOLOv1Dataset
from src.yolo.v1.model import YOLOv1Model
from src.yolo.v1.loss import YOLOv1Loss
from src.yolo.utils import calculate_boxes_iou
from src.utils.utils import DATA_PATH
import albumentations as A
import cv2
import matplotlib.pyplot as plt
import torch
from torch import nn

In [155]:
transform = A.Compose(
    [
        A.LongestMaxSize(max_size=448, interpolation=1),
        A.PadIfNeeded(min_height=448, min_width=448, border_mode=0, value=(0, 0, 0)),
        A.Rotate(limit=10, border_mode=cv2.BORDER_REPLICATE, p=1),
    ],
    bbox_params=A.BboxParams(format="yolo", label_fields=["labels"], min_visibility=0.7),
)

In [156]:
S = 7
C = 10
B = 2

ds = YOLOv1Dataset(S, C, B, DATA_PATH / "yolo_HWD+", transform=transform)
model = YOLOv1Model(S, C, B)

In [182]:
img, annots = ds[0]
img = torch.Tensor(img).permute(2, 0, 1)
out = model(img.unsqueeze(0))

In [183]:
targets = annots.unsqueeze(0)
preds = out

In [321]:
mse = torch.nn.MSELoss(reduction="sum")

In [343]:
boxes_iou = []
for box_idx in range(B):
    box_start_idx = C + 1 + box_idx * 5
    box_end_idx = box_start_idx + 5
    box_iou = calculate_boxes_iou(
        preds[..., box_start_idx:box_end_idx], targets[..., C + 1 : C + 5]
    ).squeeze(-1)
    boxes_iou.append(box_iou.unsqueeze(0))

ious = torch.cat(boxes_iou, dim=0)

iou_maxes, bestbox_idxs = torch.max(ious, dim=0)
obj_mask = targets[..., C] == 1  # Iobj_i

bestbox_idxs = bestbox_idxs[obj_mask]  # only idxs with objects

# Coord loss
obj_preds = preds[obj_mask]  # only preds with objects
obj_targets = targets[obj_mask]  # only targets with objects

noobj_preds = preds[~obj_mask]  # only preds without objects
noobj_targets = targets[~obj_mask]  # only targets without objects

idxs = torch.tensor([range(box_idx * 5 + C + 1, box_idx * 5 + C + 5) for box_idx in bestbox_idxs])
obj_box_preds = obj_preds.gather(1, idxs)
obj_box_targets = obj_targets[:, C + 1 : C + 5]

obj_box_preds[:, 2:] = torch.sign(obj_box_preds[:, 2:]) * torch.sqrt(
    abs(obj_box_preds[:, 2:]) + 1e-10
)
obj_box_targets[:, 2:] = torch.sqrt(obj_box_targets[:, 2:])

coord_loss = mse(obj_box_preds.flatten(), obj_box_targets.flatten())

In [390]:
obj_object_preds = obj_preds[:, C]
obj_object_targets = obj_targets[:, C]
object_loss = mse(obj_object_preds, obj_object_targets)

In [393]:
obj_box_preds

tensor([[-0.2286, -0.0132,  0.2917,  0.4485],
        [ 0.0432, -0.3115,  0.3230, -0.1297],
        [-0.0113, -0.2784,  0.6748,  0.4544],
        [-0.1606,  0.1319,  0.5168,  0.5707],
        [-0.0621, -0.0523,  0.4084, -0.0588]], grad_fn=<CopySlices>)

In [391]:
obj_preds.shape

torch.Size([5, 20])

In [400]:
noobj_object_preds = noobj_preds[:, C]
noobj_object_targets = noobj_targets[:, C]
noobject_loss = mse(noobj_object_preds, noobj_object_targets)

In [354]:
obj_class_preds = obj_preds[:, :C]
obj_class_targets = obj_targets[:, :C]
class_loss = mse(torch.flatten(obj_class_preds), torch.flatten(obj_class_targets))

In [355]:
loss = 5 * coord_loss + object_loss + 0.5 * noobject_loss + class_loss

In [407]:
loss_fn = YOLOv1Loss(C, B, 5, 0.5)

In [413]:
loss_fn(preds, targets)

tensor(14.0815, grad_fn=<MseLossBackward0>) tensor(4.6638, grad_fn=<MseLossBackward0>) tensor(3.5814, grad_fn=<MseLossBackward0>) tensor(6.9436, grad_fn=<MseLossBackward0>)


tensor(83.8056, grad_fn=<AddBackward0>)