In [1]:
import sys
sys.path.insert(1, '../')

In [2]:
import torch
import torch.nn as nn
from Models.block import *
from torch.autograd.profiler import record_function
from torch.nn.modules.upsampling import Upsample

In [21]:
from evaluation import count_parameters

In [3]:
class SPPF(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=5, 
                 device=None, dtype=None):
        super().__init__()
        self.hidden_channels = out_channels//2
        self.conv1 = Conv(in_channels, out_channels=self.hidden_channels, kernel_size=(1,1), stride=(1,1),
                          padding=(0,0), bias=False,
                          device=device, dtype=dtype)
        self.conv2 = Conv(self.hidden_channels * 4, out_channels=out_channels, kernel_size=(1,1), stride=(1,1), 
                          padding=(0,0), bias=False,
                          device=device, dtype=dtype)
        self.max_pool = nn.MaxPool2d(kernel_size=kernel_size, stride=(1,1), padding=kernel_size // 2)

    def forward(self, x):
        out = self.conv1(x)
        
        _l = [out]
        for i in range(3):
            out = self.max_pool(out)
            _l.append(out)
        out = torch.cat(_l, 1)
        
        out = self.conv2(out)
        return out

# Can I decrease the number of tensors created?

In [5]:
class FPN(nn.Module):
    def __init__(self, residual_connection=False, 
                 CSP=True, add_hidden=True, variant='n',
                 device=None, dtype=None):
        super().__init__()

        if variant not in Model.variants.keys():
            raise Exception("Invalid variant.")
            
        self.variant = variant
        self.mc = Model.variants[self.variant]['mc']
        self.w = Model.variants[self.variant]['w']
        self.d = Model.variants[self.variant]['d']
        
        self.upsample = Upsample(scale_factor=2.0, mode='nearest')

        self.c2f_16 = C2f(self._ch(512)+self._ch(1024), out_channels=self._ch(512), n=self._d(3), residual_connection=residual_connection, 
                        CSP=CSP, add_hidden=add_hidden, bottleneck=1.0,
                        device=device, dtype=dtype)

        self.c2f_8 = C2f(self._ch(256)+self._ch(512), out_channels=self._ch(256), n=self._d(3), residual_connection=residual_connection, 
                        CSP=CSP, add_hidden=add_hidden, bottleneck=1.0,
                        device=device, dtype=dtype)
        
        
    def forward(self, out_8, out_16, out_32):
        out = self.upsample(out_32)

        out = torch.cat([out, out_16], 1)
        out_16 = self.c2f_16(out)

        out = self.upsample(out_16)
        out = torch.cat([out, out_8], 1)
        out = self.c2f_8(out)

        return out, out_16, out_32

    def _ch(self, ch):
        return int(min(ch, self.mc)*self.w)

    def _d(self, d):
        return int(d * self.d)

In [6]:
class PANet(nn.Module):
    def __init__(self, residual_connection=False, 
                 CSP=True, add_hidden=True, variant='n',
                 device=None, dtype=None):
        super().__init__()

        if variant not in Model.variants.keys():
            raise Exception("Invalid variant.")
            
        self.variant = variant
        self.mc = Model.variants[self.variant]['mc']
        self.w = Model.variants[self.variant]['w']
        self.d = Model.variants[self.variant]['d']
        
        self.conv8_16 = Conv(self._ch(256), out_channels=self._ch(256), kernel_size=(3,3), stride=(2,2),
                          padding=(1,1), bias=False,
                          device=device, dtype=dtype)

        self.c2f_16 = C2f(self._ch(256)+self._ch(512), out_channels=self._ch(512), n=self._d(3), residual_connection=residual_connection, 
                        CSP=CSP, add_hidden=add_hidden, bottleneck=1.0,
                        device=device, dtype=dtype)

        self.conv16_32 = Conv(self._ch(512), out_channels=self._ch(512), kernel_size=(3,3), stride=(2,2),
                          padding=(1,1), bias=False,
                          device=device, dtype=dtype)

        self.c2f_32 = C2f(self._ch(1024)+self._ch(512), out_channels=self._ch(1024), n=self._d(3), residual_connection=residual_connection, 
                        CSP=CSP, add_hidden=add_hidden, bottleneck=1.0,
                        device=device, dtype=dtype)

    def _ch(self, ch):
        return int(min(ch, self.mc)*self.w)

    def _d(self, d):
        return int(d * self.d)

    def forward(self, out_8, out_16, out_32):
        out = self.conv8_16(out_8)

        out = torch.cat([out, out_16], 1)
        out_16 = self.c2f_16(out)

        out = self.conv16_32(out_16)
        out = torch.cat([out, out_32], 1)
        out = self.c2f_32(out)
        
        return out_8, out_16, out

In [140]:
class Detect(nn.Module):
    def __init__(self, in_channels, decoupled=True,
                 num_classes=80, num_boxes=16, variant='n',
                 device=None, dtype=None):
        super().__init__()
        self.decoupled = decoupled
        
        if variant not in Model.variants.keys():
            raise Exception("Invalid variant.")
        self.variant = variant
        self.mc = Model.variants[self.variant]['mc']
        self.w = Model.variants[self.variant]['w']
        self.d = Model.variants[self.variant]['d']

        ch_0 = int(self.w*256)
        bbox_hidden, cls_hidden = max((16, ch_0 // 4, num_boxes * 4)), max(ch_0, min(num_classes, 100))

        if decoupled:
            self.bbox_conv1 = Conv(in_channels, out_channels=bbox_hidden, kernel_size=(3,3), stride=(1,1),
                              padding=(1,1), bias=False,
                              device=device, dtype=dtype)
    
            self.bbox_conv2 = Conv(bbox_hidden, out_channels=bbox_hidden, kernel_size=(3,3), stride=(1,1),
                              padding=(1,1), bias=False,
                              device=device, dtype=dtype)
            
            self.bbox_conv3 = nn.Conv2d(bbox_hidden, out_channels=num_boxes * 4, kernel_size=(1,1), stride=(1,1), 
                                        padding=(0,0), bias=True, 
                                        device=device, dtype=dtype)

            self.cls_conv1 = Conv(in_channels, out_channels=cls_hidden, kernel_size=(3,3), stride=(1,1),
                              padding=(1,1), bias=False,
                              device=device, dtype=dtype)
    
            self.cls_conv2 = Conv(cls_hidden, out_channels=cls_hidden, kernel_size=(3,3), stride=(1,1),
                              padding=(1,1), bias=False,
                              device=device, dtype=dtype)

            self.cls_conv3 = nn.Conv2d(cls_hidden, out_channels=num_classes, kernel_size=(1,1), stride=(1,1), 
                                        padding=(0,0), bias=True, 
                                        device=device, dtype=dtype)

    def forward(self, x):
        if self.decoupled:
            out_bb = self.bbox_conv1(x)
            out_bb = self.bbox_conv2(out_bb)
            out_bb = self.bbox_conv3(out_bb)

            out_cls = self.cls_conv1(x)
            out_cls = self.cls_conv2(out_cls)
            out_cls = self.cls_conv3(out_cls)

            return out_bb, out_cls
        else:
            pass

In [141]:
class DFL(nn.Module):
    """
    Integral module of Distribution Focal Loss (DFL).

    Proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391
    """

    def __init__(self, c1=16):
        """Initialize a convolutional layer with a given number of input channels."""
        super().__init__()
        self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
        x = torch.arange(c1, dtype=torch.float)
        self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))
        self.c1 = c1

    def forward(self, x):
        """Applies a transformer layer on input tensor 'x' and returns a tensor."""
        b, _, a = x.shape  # batch, channels, anchors
        return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)
        # return self.conv(x.view(b, self.c1, 4, a).softmax(1)).view(b, 4, a)

In [142]:
class Model(nn.Module):
    variants = {'n': {'d': 0.34, 'w': 0.25, 'mc': 1024},
                's': {'d': 0.34, 'w': 0.50, 'mc': 1024},
                'm': {'d': 0.67, 'w': 0.75, 'mc': 768},
                'l': {'d': 1.00, 'w': 1.00, 'mc': 512},
                'x': {'d': 1.00, 'w': 1.25, 'mc': 512}}

    def __init__(self, three_heads=True, decoupled=True,
                 _FPN=True, _PANet=True, _SPPF=True, v8_loss=True,
                 num_classes=80, num_boxes=16, variant='n', 
                 device=None, dtype=None):
        super().__init__()
        self.three_heads = three_heads
        self._FPN = _FPN
        self._PANet = _PANet
        self._SPPF = _SPPF
        self.v8_loss = v8_loss

        # Backbone model parameters
        residual_connection = True
        CSP = True
        add_hidden = True
        bottleneck = 1.0
        
        if variant not in Model.variants.keys():
            raise Exception("Invalid variant.")
        self.variant = variant
        self.mc = Model.variants[self.variant]['mc']
        self.w = Model.variants[self.variant]['w']
        self.d = Model.variants[self.variant]['d']
        

        self.conv1 = Conv(3, out_channels=self._ch(64), kernel_size=(3, 3), stride=(2, 2), 
                         padding=(1, 1), bias=False, 
                         device=device, dtype=dtype)
        self.conv2 = Conv(self._ch(64), out_channels=self._ch(128), kernel_size=(3, 3), stride=(2, 2), 
                          padding=(1, 1), bias=False,
                          device=device, dtype=dtype)
        self.c2f1 = C2f(self._ch(128), out_channels=self._ch(128), n=self._d(3), residual_connection=residual_connection, 
                        CSP=CSP, add_hidden=add_hidden, bottleneck=1.0,
                        device=device, dtype=dtype)
        self.conv3 = Conv(self._ch(128), out_channels=self._ch(256), kernel_size=(3, 3), stride=(2, 2), 
                         padding=(1, 1), bias=False, 
                         device=device, dtype=dtype)
        self.c2f2 = C2f(self._ch(256), out_channels=self._ch(256), n=self._d(6), residual_connection=residual_connection, 
                        CSP=CSP, add_hidden=add_hidden, bottleneck=1.0,
                        device=device, dtype=dtype)
        # c2f2 --- stride 8 ---->
        
        self.conv4 = Conv(self._ch(256), out_channels=self._ch(512), kernel_size=(3, 3), stride=(2, 2), 
                         padding=(1, 1), bias=False, 
                         device=device, dtype=dtype)
        self.c2f3 = C2f(self._ch(512), out_channels=self._ch(512), n=self._d(6), residual_connection=residual_connection, 
                        CSP=CSP, add_hidden=add_hidden, bottleneck=1.0,
                        device=device, dtype=dtype)
        # c2f3 --- stride 16 ---->
        
        self.conv5 = Conv(self._ch(512), out_channels=self._ch(1024), kernel_size=(3, 3), stride=(2, 2), 
                         padding=(1, 1), bias=False, 
                         device=device, dtype=dtype)
        self.c2f4 = C2f(self._ch(1024), out_channels=self._ch(1024), n=self._d(3), residual_connection=residual_connection, 
                        CSP=CSP, add_hidden=add_hidden, bottleneck=1.0,
                        device=device, dtype=dtype)

        if _SPPF:
            self.sppf = SPPF(self._ch(1024), out_channels=self._ch(1024), kernel_size=5,
                            device=device, dtype=dtype)
        # sppf --- stride 32 ---->

        if _FPN:
            self.fpn = FPN(residual_connection=False, 
                           CSP=CSP, add_hidden=add_hidden, 
                           variant=variant,
                           device=device, dtype=dtype)

        if _PANet:
            self.panet = PANet(residual_connection=False, 
                           CSP=CSP, add_hidden=add_hidden, 
                           variant=variant,
                           device=device, dtype=dtype)

        if three_heads:
            self.detect_8 = Detect(self._ch(256), decoupled=decoupled,
                                   num_classes=num_classes, num_boxes=num_boxes, 
                                   variant=variant, device=device, dtype=dtype)
    
            self.detect_16 = Detect(self._ch(512), decoupled=decoupled,
                                   num_classes=num_classes, num_boxes=num_boxes, 
                                   variant=variant, device=device, dtype=dtype)
    
            self.detect_32 = Detect(self._ch(1024), decoupled=decoupled,
                                   num_classes=num_classes, num_boxes=num_boxes, 
                                   variant=variant, device=device, dtype=dtype)
            
        if v8_loss:
            self.dfl = DFL()

    
    def _ch(self, ch):
        return int(min(ch, self.mc)*self.w)

    def _d(self, d):
        return int(d * self.d)

    def forward(self, x):
        with record_function('conv1'):
            out = self.conv1(x)
        
        with record_function('conv2'):
            out = self.conv2(out)
        with record_function('c2f1'):
            out = self.c2f1(out)

        with record_function('conv3'):
            out = self.conv3(out)
        with record_function('c2f2'):
            out = self.c2f2(out)
        
        # c2f2 --- stride 8 ---->
        out_8 = out
            
        with record_function('conv4'):
            out = self.conv4(out)
        with record_function('c2f3'):
            out = self.c2f3(out)

        # c2f3 --- stride 16 ---->
        out_16 = out
            
        with record_function('conv5'):
            out = self.conv5(out)
        with record_function('c2f4'):
            out = self.c2f4(out)

        if self._SPPF:
            with record_function('sppf'):
                out = self.sppf(out)
        else:
            pass

        # sppf --- stride 32 ---->
        if self._FPN:
            with record_function('fpn'):
                out_8, out_16, out = self.fpn(out_8, out_16, out)
        else:
            pass

        
        if self._PANet:
            with record_function('panet'):
                out_8, out_16, out = self.panet(out_8, out_16, out)
        else:
            pass

        if self.three_heads:
            with record_function('detect'):
                out_8 = self.detect_8(out_8)
                out_16 = self.detect_16(out_16)
                out = self.detect_32(out)
        else:
            pass

        if self.v8_loss:
            self.dfl()
            
            
        return out_8, out_16, out

https://discuss.pytorch.org/t/any-purpose-to-set-bias-false-in-densenet-torchvision/22067/2

In [145]:
for variant in ['n', 's', 'm', 'l', 'x']:
    model = Model(three_heads=True, decoupled=True,
                     _FPN=True, _PANet=True, _SPPF=True, v8_loss=False,
                     num_classes=80, num_boxes=16, variant=variant, 
                     device=None, dtype=None)
    my_count = count_parameters(model)
    model = YOLO(f"yolov8{variant}.pt").model.model
    their_count = count_parameters(model)
    print(their_count-my_count)

16
16
16
16
16


In [147]:
model = Model(three_heads=True, decoupled=True,
                     _FPN=True, _PANet=True, _SPPF=True, v8_loss=False,
                     num_classes=80, num_boxes=16, variant='n', 
                     device=None, dtype=None)
img = torch.rand(1,3,640,640)
out = model(img)

In [149]:
len(out)

3

In [152]:
len(out[0])

2