In [1]:
import sys
sys.path.insert(1, "../")
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
from yolov2 import YOLOv2D19 as YOLOv2
from detection_datasets import VOCDatasetV2
from torch import optim
from loss import YOLOv2Loss
from train import *
import torch.optim.lr_scheduler as lr_scheduler
from data_preprocessing import get_norms
import pickle
with open('anchors_VOC0712trainval.pickle', 'rb') as handle:
    anchors = pickle.load(handle)

In [2]:
if torch.cuda.is_available():
    device = torch.device('cuda:0')
else:
    device = torch.device('cpu')

In [3]:
norms = get_norms('../../datasets/VOCdevkit/trainval_norms.json')
means = norms['means']
stds = norms['stds']

In [70]:
transforms = A.Compose([
    A.Resize(width=416, height=416),
    A.VerticalFlip(p=1.0),
    # A.Normalize(mean=means, std=stds),
    ToTensorV2()
], bbox_params=A.BboxParams(format='pascal_voc'))
train_set = VOCDatasetV2(devkit_path = '../../datasets/VOCdevkit/', scales=[13], anchors=anchors, transforms=transforms, 
                         dtype=torch.float32, device=torch.device('cuda:0'))
val_set = VOCDatasetV2(devkit_path = '../../datasets/VOCdevkit/', 
                       subsets = [('VOC2007', 'test')],
                       scales=[13], anchors=anchors, transforms=transforms, 
                       dtype=torch.float32, device=torch.device('cuda:0'))

True ../../datasets/VOCdevkit/VOC2007\ImageSets\Main\trainval.txt
True ../../datasets/VOCdevkit/VOC2012\ImageSets\Main\trainval.txt
True ../../datasets/VOCdevkit/VOC2007\ImageSets\Main\test.txt


In [5]:
train_loader = DataLoader(train_set, batch_size=64, shuffle=False)
val_loader = DataLoader(val_set, batch_size=64, shuffle=False)

In [6]:
loss_fn = YOLOv2Loss(anchors=anchors)

In [7]:
model = YOLOv2(device=device)

  state_dict = torch.load(state_dict_path, map_location=self.device)


In [8]:
epochs=1
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0001)

In [9]:
scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=1.0)

In [10]:
history, gradient_stats = train(epochs, train_loader, val_loader, model, optimizer, loss_fn, scheduler, outputs_path='../log/YOLOv2/training/')

2024-12-22 10:29:14.818788 Epoch 1 
2024-12-22 10:29:15.464221 After batch load / before to(cuda)
2024-12-22 10:29:15.464221 After to(cuda)
2024-12-22 10:29:15.464221 Batch 1 
2024-12-22 10:29:15.464221 Before inference on 64 batch
2024-12-22 10:29:31.394065 After inference on 64 batch / before loss fn
2024-12-22 10:29:39.074182 After loss fn / before zero grad
2024-12-22 10:29:39.074182 After zero grad / before backward
2024-12-22 10:30:28.028630 After backward / before optimizer.step
2024-12-22 10:30:28.055681 After optimizer.step / before += loss.item()
2024-12-22 10:30:28.196609 After += loss.item() / before the next batch
2024-12-22 10:30:28.895733 After batch load / before to(cuda)
2024-12-22 10:30:28.895733 After to(cuda)
2024-12-22 10:30:28.895733 Batch 2 
2024-12-22 10:30:28.895733 Before inference on 64 batch
2024-12-22 10:30:29.183639 After inference on 64 batch / before loss fn
2024-12-22 10:30:49.256513 After loss fn / before zero grad
2024-12-22 10:30:49.257558 After zero

KeyboardInterrupt: 

In [38]:
class testLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()

    def forward(self, out, gt_out):
        return self.mse(out, gt_out)

In [52]:
class YOLOv2Loss(nn.Module):
    def __init__(self, anchors, lambda_noobj=0.5, lambda_coord=5.0, num_classes=20):
        super().__init__()
        self.mse = torch.nn.MSELoss(reduction='sum')
        self.softmax = torch.nn.Softmax(dim=2)
        self.lambda_noobj = lambda_noobj
        self.lambda_coord = lambda_coord
        self.num_classes = num_classes
        self.anchors = anchors
        
    def forward(self, out, gt_out):
        
        # [conf, obj_xc, obj_yc, obj_w, obj_h]
        is_obj = gt_out[:, 0::25, ...] == 1.0
        no_obj = gt_out[:, 0::25, ...] == 0.0

        # CONFIDENCE LOSS ===========
        conf_true = gt_out[:, 0::25, ...]
        conf_pred = out[:, 0::25, ...].sigmoid()

        is_obj_conf_pred = is_obj * conf_pred
        is_obj_conf_true = is_obj * conf_true
        
        no_obj_conf_pred = no_obj * conf_pred
        no_obj_conf_true = no_obj * conf_true

        is_obj_conf_loss = self.mse(is_obj_conf_pred, is_obj_conf_true)
        no_obj_conf_loss = self.mse(no_obj_conf_pred, no_obj_conf_true) 
        # ===========================

        # BOX LOSS ==================
        xc_true = gt_out[:, 1::25, ...]
        yc_true = gt_out[:, 2::25, ...]
        w_true = gt_out[:, 3::25, ...]
        h_true = gt_out[:, 4::25, ...]
        
        xc_pred = out[:, 1::25, ...].sigmoid()
        yc_pred = out[:, 2::25, ...].sigmoid()
        
        scale = gt_out.shape[-1]
        _anchors = torch.tensor(self.anchors).to(out.device) * scale
        pw = _anchors[:, 0]
        ph = _anchors[:, 1]
        
        w_pred = pw[None, :, None, None] * out[:, 3::25, ...].exp()
        h_pred = ph[None, :, None, None] * out[:, 4::25, ...].exp()

        xc_pred = is_obj * xc_pred
        xc_true = is_obj * xc_true
        yc_pred = is_obj * yc_pred
        yc_true = is_obj * yc_true
        
        w_pred = is_obj * w_pred
        w_true = is_obj * w_true
        h_pred = is_obj * h_pred
        h_true = is_obj * h_true

        xc_loss = self.mse(xc_pred, xc_true)
        yc_loss = self.mse(yc_pred, yc_true)
        w_loss = self.mse(w_pred.sqrt(), w_true.sqrt())
        h_loss = self.mse(h_pred.sqrt(), h_true.sqrt())
        # ===========================

        # CLASS LOSS ================
        class_true = []
        for i in range(len(self.anchors)):
            first_idx = 5 + i*(5+self.num_classes)
            last_idx = 25 + i*(5+self.num_classes)
            class_true.append(gt_out[:, first_idx:last_idx, ...])
        class_true = torch.stack(class_true, dim=1)

        class_pred = []
        for i in range(len(self.anchors)):
            first_idx = 5 + i*(5+self.num_classes)
            last_idx = 25 + i*(5+self.num_classes)
            class_pred.append(gt_out[:, first_idx:last_idx, ...])
        class_pred = torch.stack(class_pred, dim=1)

        class_pred = self.softmax(class_pred)
        
        class_pred = is_obj[:, :, None, :, :] * class_pred
        class_true = is_obj[:, :, None, :, :] * class_true

        class_loss = self.mse(class_pred, class_true)
        # ===========================

        loss =  self.lambda_coord * (xc_loss + yc_loss) + \
                self.lambda_coord * (w_loss + h_loss) + \
                is_obj_conf_loss + \
                self.lambda_noobj * no_obj_conf_loss + \
                class_loss

        return loss

In [72]:
img, label = train_set[0]
label = label.unsqueeze(0)

In [74]:
out = model(img.unsqueeze(0))

In [75]:
loss_fn = YOLOv2Loss(anchors=anchors)

In [76]:
loss = loss_fn(out, label)

In [57]:
_datetime = datetime.datetime.now()
print(f"{_datetime} Before")
loss.backward()
_datetime = datetime.datetime.now()
print(f"{_datetime} After")

2024-12-22 10:49:31.458892 Before
2024-12-22 10:49:31.468289 After


In [77]:
loss

tensor(nan, device='cuda:0', grad_fn=<AddBackward0>)

In [66]:
loss_fn

YOLOv2Loss(
  (mse): MSELoss()
  (softmax): Softmax(dim=2)
)

In [73]:
img

tensor([[[ 50.,  64.,  82.,  ...,  61.,  67.,  69.],
         [ 39.,  40.,  53.,  ...,  63.,  68.,  69.],
         [ 33.,  32.,  30.,  ...,  58.,  61.,  61.],
         ...,
         [ 33.,  19.,  12.,  ..., 162., 162., 162.],
         [  5.,   9.,  10.,  ..., 161., 161., 161.],
         [ 11.,  12.,  10.,  ..., 164., 163., 162.]],

        [[ 17.,  22.,  29.,  ...,  76.,  82.,  84.],
         [  8.,   9.,  18.,  ...,  77.,  83.,  84.],
         [ 11.,  13.,  11.,  ...,  72.,  74.,  74.],
         ...,
         [ 36.,  20.,  12.,  ..., 189., 189., 189.],
         [  5.,   9.,  10.,  ..., 187., 188., 187.],
         [  9.,  12.,  11.,  ..., 187., 186., 185.]],

        [[  5.,  13.,  16.,  ..., 105., 111., 113.],
         [  2.,   3.,   6.,  ..., 106., 111., 113.],
         [  7.,  11.,   3.,  ..., 103., 106., 105.],
         ...,
         [ 39.,  22.,  14.,  ..., 198., 198., 198.],
         [  7.,   9.,  10.,  ..., 194., 194., 194.],
         [ 10.,  12.,  11.,  ..., 193., 192., 191.]]]