In [1]:
import torchvision.models as models
import torch.nn as nn
import torch

In [2]:
resnet50 = models.resnet50(pretrained=True)

In [3]:
def get_resnet50_backbone(resnet50: nn.Module):
    features = list(resnet50.children())[:-2]
    for feature in features[:-2]:
        for parameter in feature.parameters():
            parameter.requires_grad = False
    return nn.Sequential(*features)


In [4]:
from torchsummary import summary

In [5]:
summary(get_resnet50_backbone(resnet50).cuda(),(3,448,448))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 224, 224]           9,408
       BatchNorm2d-2         [-1, 64, 224, 224]             128
              ReLU-3         [-1, 64, 224, 224]               0
         MaxPool2d-4         [-1, 64, 112, 112]               0
            Conv2d-5         [-1, 64, 112, 112]           4,096
       BatchNorm2d-6         [-1, 64, 112, 112]             128
              ReLU-7         [-1, 64, 112, 112]               0
            Conv2d-8         [-1, 64, 112, 112]          36,864
       BatchNorm2d-9         [-1, 64, 112, 112]             128
             ReLU-10         [-1, 64, 112, 112]               0
           Conv2d-11        [-1, 256, 112, 112]          16,384
      BatchNorm2d-12        [-1, 256, 112, 112]             512
           Conv2d-13        [-1, 256, 112, 112]          16,384
      BatchNorm2d-14        [-1, 256, 1

In [6]:
class Resnet50Yolo(nn.Module):
    def __init__(self, resnet50):
        super(Resnet50Yolo, self).__init__()
        self.resnet50_backbone = get_resnet50_backbone(resnet50=resnet50)
        self.last_conv = nn.Conv2d(2048, 1024, (3, 3), padding=1, stride=2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(7*7*1024, 4096)
        self.fc2 = nn.Linear(4096, 7*7*30)

    def forward(self, x: torch.Tensor):
        x = self.resnet50_backbone(x)
        x = self.last_conv(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.fc2(x)
        return x


In [7]:
resnet50_yolo = Resnet50Yolo(resnet50).cuda()

In [8]:
summary(resnet50_yolo, (3, 448, 448))


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 224, 224]           9,408
       BatchNorm2d-2         [-1, 64, 224, 224]             128
              ReLU-3         [-1, 64, 224, 224]               0
         MaxPool2d-4         [-1, 64, 112, 112]               0
            Conv2d-5         [-1, 64, 112, 112]           4,096
       BatchNorm2d-6         [-1, 64, 112, 112]             128
              ReLU-7         [-1, 64, 112, 112]               0
            Conv2d-8         [-1, 64, 112, 112]          36,864
       BatchNorm2d-9         [-1, 64, 112, 112]             128
             ReLU-10         [-1, 64, 112, 112]               0
           Conv2d-11        [-1, 256, 112, 112]          16,384
      BatchNorm2d-12        [-1, 256, 112, 112]             512
           Conv2d-13        [-1, 256, 112, 112]          16,384
      BatchNorm2d-14        [-1, 256, 1

In [9]:
class YoloLoss(nn.Module):
    def __init__(self, l_coord=5, l_noobj=0.5):
        super(YoloLoss, self).__init__()
        self.l_coord = l_coord
        self.l_noobj = l_noobj

    @staticmethod
    def position_loss(pred: torch.Tensor, truth: torch.Tensor):
        pass

    @staticmethod
    def size_loss(pred: torch.Tensor, truth: torch.Tensor):
        pass

    @staticmethod
    def size_loss(pred: torch.Tensor, truth: torch.Tensor):
        pass

    @staticmethod
    def obj_box_loss(pred: torch.Tensor, truth: torch.Tensor):
        pass

    def noobj_box_loss(pred: torch.Tensor, truth: torch.Tensor):
        pass

    @staticmethod
    def classification_loss(pred: torch.Tensor, truth: torch.Tensor):
        pass

    def forward(self, pred: torch.Tensor, truth: torch.Tensor):
        coord_mask = (truth[:,:,4] > 0).unsqueeze(-1).expand_as(truth)
        noobj_mask = (truth[:,:,4] > 0).unsqueeze(-1).expand_as(truth)
        
        pred = pred.view(7, 7, 30)
        loss = self.l_coord * self.position_loss(pred, truth) \
            + self.l_coord * self.size_loss(pred, truth) \
            + self.obj_box_loss(pred, truth) \
            + self.l_noobj * self.noobj_box_loss(pred, truth) \
            + self.classification_loss(pred, truth)
        return loss
