In [1]:
from dataloader import CustomImageDataset, DataLoader
from torchvision.transforms import Compose
import torchvision.transforms as transforms

transform = Compose([
    transforms.ToTensor(), # Scales data into [0,1]
    transforms.Lambda(lambda t: (t * 2) - 1) # Scale between [-1, 1]
])

dataset = CustomImageDataset(annotations_dir = '/home/pqbas/dl/detection/MNIST-ObjectDetection/data/mnist_detection/test/labels',
                             img_dir = '/home/pqbas/dl/detection/MNIST-ObjectDetection/data/mnist_detection/test/images',
                             data_transform = transform,
                             size=(448,448))

train_loader = DataLoader(dataset, batch_size=5, shuffle=True)

In [13]:
import torch
import torch.nn as nn
from metrics import intersection_over_union
from model import YOLOv1

model = YOLOv1(B=2, C=11, S=7)
mse = nn.MSELoss(reduction='sum')
from torch.optim import SGD

optimizer = SGD(model.parameters(), lr = 1e-3, momentum=0.5)

device = torch.device('cuda:0')
model = model.to(device)

n_break = 30
for idx,(img, target) in enumerate(train_loader):
    
    img = img.to(device)
    prediction = model(img)
    prediction = prediction.to(device)

    target_box = target['bbox'].to(device)
    target_prob = target['class'].to(device)
    target_obj = target['one_obj'].to(device)
    target_noObj = 1 - target['one_obj']

    # ========================== #
    #        Object Loss         #
    # ========================== #

    pred_box1 = torch.sigmoid(prediction[...,0:4])
    pred_box2 = torch.sigmoid(prediction[...,4:8])

    iou1 = intersection_over_union(pred_box1, target_box)
    iou2 = intersection_over_union(pred_box2, target_box)
    
    ious = torch.cat([iou1,iou2], dim=3)
    max, bestbox = torch.max(ious, dim=3)
    bestbox = bestbox

    pred_box = (bestbox[...,None]*pred_box2 +  (1 - bestbox[...,None])*pred_box1) * target_obj
    pred_box[...,2:4] = torch.sign(pred_box[...,2:4])*torch.sqrt(pred_box[...,2:4] + 1e-6)

    target_box[...,2:4] = torch.sqrt(target_box[...,2:4])

    box_coordinates_loss = mse(pred_box,target_box)

    #print('Box Loss:', box_coordinates_loss)

    # ========================== #
    #        Classf Loss         #
    # ========================== #

    pred_prob = prediction[...,10:22]*target_obj
    class_loss = mse(pred_prob, target_prob)

    #print('Class Loss:', class_loss)

    # ========================== #
    #          Obj Loss          #
    # ========================== #

    pred_obj = torch.sigmoid(prediction[...,8:10])
    target_obj_ = torch.concatenate([target_obj, target_obj], dim=3)
    obj_loss = mse(pred_obj, target_obj_)

    #print('Obj Loss:', obj_loss)
    
    # ========================== #
    #          Nobj Loss         #
    # ========================== #

    # pred_noobj = torch.sigmoid(prediction[...,8:10])
    # target_noobj_ = torch.concatenate([target_obj, target_obj], dim=3)
    # noobj_loss = mse(pred_noobj, target_noobj_)

    # print('No Object Loss:', noobj_loss)

    # # one_obj = target['one_obj']   
    # # bbox = target['bbox']
    # # prob_class = target['class']
    # # print(bbox.shape)

    # ========================== #
    #         Optimization       #
    # ========================== #

    lambda_obj = 1.0
    lambda_class = 1.0
    lambda_box = 3.0
    loss = lambda_obj*obj_loss + lambda_class*class_loss + lambda_box*box_coordinates_loss
    
    print('Total Loss:', loss)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if idx == n_break:
        break

Total Loss: tensor(404.9922, device='cuda:0', grad_fn=<AddBackward0>)
Total Loss: tensor(429.0859, device='cuda:0', grad_fn=<AddBackward0>)
Total Loss: tensor(438.9860, device='cuda:0', grad_fn=<AddBackward0>)
Total Loss: tensor(451.8215, device='cuda:0', grad_fn=<AddBackward0>)
Total Loss: tensor(449.5104, device='cuda:0', grad_fn=<AddBackward0>)
Total Loss: tensor(434.7503, device='cuda:0', grad_fn=<AddBackward0>)
Total Loss: tensor(437.0022, device='cuda:0', grad_fn=<AddBackward0>)
Total Loss: tensor(420.1757, device='cuda:0', grad_fn=<AddBackward0>)
Total Loss: tensor(413.4536, device='cuda:0', grad_fn=<AddBackward0>)
Total Loss: tensor(441.9079, device='cuda:0', grad_fn=<AddBackward0>)
Total Loss: tensor(416.6808, device='cuda:0', grad_fn=<AddBackward0>)
Total Loss: tensor(440.4563, device='cuda:0', grad_fn=<AddBackward0>)
Total Loss: tensor(445.4779, device='cuda:0', grad_fn=<AddBackward0>)
Total Loss: tensor(435.7072, device='cuda:0', grad_fn=<AddBackward0>)
Total Loss: tensor(4

In [5]:
import torch
import torch.nn as nn
from metrics import intersection_over_union
from model import YOLOv1

model = YOLOv1(B=2, C=10, S=7)

mse = nn.MSELoss(reduction='sum')

n_break = 0
for idx,(img, target) in enumerate(train_loader):
    
    one_obj = one_obj[...,0][...,None]
    
    prediction = model(img)

    # ========================== #
    #        Object Loss         #
    # ========================== #

    target_box = target[...,0:4]

    box1 = torch.sigmoid(prediction[...,0:4])
    box2 = torch.sigmoid(prediction[...,4:8])

    iou1 = intersection_over_union(box1, target_box)
    iou2 = intersection_over_union(box2, target_box)
    
    ious = torch.cat([iou1,iou2], dim=3)
    max, bestbox = torch.max(ious, dim=3)
    
    box = (bestbox[...,None]*box2 +  (1 - bestbox[...,None])*box1) * one_obj

    box[...,2:4] = torch.sqrt(box[...,2:4])
    target[...,2:4] = torch.sqrt(target_box[...,2:4])

    box_coordinates_loss = mse(box,target_box)
    
    # ========================== #
    #        Classf Loss         #
    # ========================== #

    pred_prob = prediction[...,10:21]*one_obj
    target_prob = target[...,4:15]*one_obj

    classf_loss = mse(pred_prob, target_prob)


    # ========================= #
    #        Noobj Loss         #
    # ========================= #

    noobj = 1 - one_obj
    noobj_pred_prob = prediction[...,10:21]*noobj
    noobj_target_prob = target[...,4:15]*noobj

    print(noobj_pred_prob.shape)
    print(noobj_target_prob.shape)

    mse(noobj_pred_prob, noobj_target_prob)


    if idx == n_break:
        break

NameError: name 'one_obj' is not defined