In [8]:
import sys
sys.path.insert(1, '../')

In [9]:
import torch
import torch.nn as nn
from Models.block import *
from torch.autograd.profiler import record_function
from torch.nn.modules.upsampling import Upsample

In [10]:
class SPPF(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=5, 
                 device=None, dtype=None):
        super().__init__()
        self.hidden_channels = out_channels//2
        self.conv1 = Conv(in_channels, out_channels=self.hidden_channels, kernel_size=(1,1), stride=(1,1),
                          padding=(0,0), bias=False,
                          device=device, dtype=dtype)
        self.conv2 = Conv(self.hidden_channels * 4, out_channels=out_channels, kernel_size=(1,1), stride=(1,1), 
                          padding=(0,0), bias=False,
                          device=device, dtype=dtype)
        self.max_pool = nn.MaxPool2d(kernel_size=kernel_size, stride=(1,1), padding=kernel_size // 2)

    def forward(self, x):
        out = self.conv1(x)
        
        _l = [out]
        for i in range(3):
            out = self.max_pool(out)
            _l.append(out)
        out = torch.cat(_l, 1)
        
        out = self.conv2(out)
        return out

In [11]:
inp = torch.rand(1,32,20,20)
sppf = SPPF(32,32)
out = sppf(inp)
out.shape

torch.Size([1, 32, 20, 20])

# Can I decrease the number of tensors created?

In [29]:
class FPN(nn.Module):
    def __init__(self, residual_connection=True, 
                 CSP=True, add_hidden=True, variant='n',
                 device=None, dtype=None):
        super().__init__()

        if variant not in Model.variants.keys():
            raise Exception("Invalid variant.")
            
        self.variant = variant
        self.mc = Model.variants[self.variant]['mc']
        self.w = Model.variants[self.variant]['w']
        self.d = Model.variants[self.variant]['d']
        
        self.upsample = Upsample(scale_factor=2.0, mode='nearest')

        self.c2f_16 = C2f(self._ch(512)+self._ch(1024), out_channels=self._ch(512), n=self._d(3), residual_connection=residual_connection, 
                        CSP=CSP, add_hidden=add_hidden, bottleneck=1.0,
                        device=device, dtype=dtype)

        self.c2f_8 = C2f(self._ch(256)+self._ch(512), out_channels=self._ch(256), n=self._d(3), residual_connection=residual_connection, 
                        CSP=CSP, add_hidden=add_hidden, bottleneck=1.0,
                        device=device, dtype=dtype)
        
        
    def forward(self, out_8, out_16, out_32):
        out = self.upsample(out_32)

        out = torch.cat([out, out_16], 1)
        out_16 = self.c2f_16(out)

        out = self.upsample(out_16)
        out = torch.cat([out, out_8], 1)
        out_8 = self.c2f_8(out)

        return out_8, out_16, out_32

    def _ch(self, ch):
        return int(min(ch, self.mc)*self.w)

    def _d(self, d):
        return int(d * self.d)

In [30]:
64+128

192

In [31]:
fpn = FPN()
out_32 = torch.rand(1, 256, 20, 20)
out_16 = torch.rand(1, 128, 40, 40)
out_8 = torch.rand(1, 64, 80, 80)
_t = fpn(out_8, out_16, out_32)

In [32]:
for t in _t:
    print(t.shape)

torch.Size([1, 64, 80, 80])
torch.Size([1, 128, 40, 40])
torch.Size([1, 256, 20, 20])


In [35]:
class Model(nn.Module):
    variants = {'n': {'d': 0.34, 'w': 0.25, 'mc': 1024},
                's': {'d': 0.34, 'w': 0.50, 'mc': 1024},
                'm': {'d': 0.67, 'w': 0.75, 'mc': 768},
                'l': {'d': 1.00, 'w': 1.00, 'mc': 512},
                'xl': {'d': 1.00, 'w': 1.25, 'mc': 512}}

    def __init__(self, device=None, dtype=None, 
                 num_classes=1000, variant='n'):
        super().__init__()

        # Backbone model parameters
        residual_connection = True
        CSP = True
        add_hidden = True
        bottleneck = 1.0
        
        if variant not in Model.variants.keys():
            raise Exception("Invalid variant.")
        self.variant = variant
        self.mc = Model.variants[self.variant]['mc']
        self.w = Model.variants[self.variant]['w']
        self.d = Model.variants[self.variant]['d']
        

        self.conv1 = Conv(3, out_channels=self._ch(64), kernel_size=(3, 3), stride=(2, 2), 
                         padding=(1, 1), bias=False, 
                         device=device, dtype=dtype)
        self.conv2 = Conv(self._ch(64), out_channels=self._ch(128), kernel_size=(3, 3), stride=(2, 2), 
                          padding=(1, 1), bias=False,
                          device=device, dtype=dtype)
        self.c2f1 = C2f(self._ch(128), out_channels=self._ch(128), n=self._d(3), residual_connection=residual_connection, 
                        CSP=CSP, add_hidden=add_hidden, bottleneck=1.0,
                        device=device, dtype=dtype)
        self.conv3 = Conv(self._ch(128), out_channels=self._ch(256), kernel_size=(3, 3), stride=(2, 2), 
                         padding=(1, 1), bias=False, 
                         device=device, dtype=dtype)
        self.c2f2 = C2f(self._ch(256), out_channels=self._ch(256), n=self._d(6), residual_connection=residual_connection, 
                        CSP=CSP, add_hidden=add_hidden, bottleneck=1.0,
                        device=device, dtype=dtype)
        # c2f2 --- stride 8 ---->
        
        self.conv4 = Conv(self._ch(256), out_channels=self._ch(512), kernel_size=(3, 3), stride=(2, 2), 
                         padding=(1, 1), bias=False, 
                         device=device, dtype=dtype)
        self.c2f3 = C2f(self._ch(512), out_channels=self._ch(512), n=self._d(6), residual_connection=residual_connection, 
                        CSP=CSP, add_hidden=add_hidden, bottleneck=1.0,
                        device=device, dtype=dtype)
        # c2f3 --- stride 16 ---->
        
        self.conv5 = Conv(self._ch(512), out_channels=self._ch(1024), kernel_size=(3, 3), stride=(2, 2), 
                         padding=(1, 1), bias=False, 
                         device=device, dtype=dtype)
        self.c2f4 = C2f(self._ch(1024), out_channels=self._ch(1024), n=self._d(3), residual_connection=residual_connection, 
                        CSP=CSP, add_hidden=add_hidden, bottleneck=1.0,
                        device=device, dtype=dtype)
        
        self.sppf = SPPF(self._ch(1024), out_channels=self._ch(1024), kernel_size=5,
                        device=device, dtype=dtype)
        # sppf --- stride 32 ---->

        self.fpn = FPN(residual_connection=residual_connection, 
                       CSP=CSP, add_hidden=add_hidden, 
                       variant=variant,
                       device=device, dtype=dtype)
        

    def _ch(self, ch):
        return int(min(ch, self.mc)*self.w)

    def _d(self, d):
        return int(d * self.d)

    def forward(self, x):
        with record_function('conv1'):
            out = self.conv1(x)
        
        with record_function('conv2'):
            out = self.conv2(out)
        with record_function('c2f1'):
            out = self.c2f1(out)

        with record_function('conv3'):
            out = self.conv3(out)
        with record_function('c2f2'):
            out = self.c2f2(out)
        
        # c2f2 --- stride 8 ---->
        out_8 = out
            
        with record_function('conv4'):
            out = self.conv4(out)
        with record_function('c2f3'):
            out = self.c2f3(out)

        # c2f3 --- stride 16 ---->
        out_16 = out
            
        with record_function('conv5'):
            out = self.conv5(out)
        with record_function('c2f4'):
            out = self.c2f4(out)

        with record_function('sppf'):
            out = self.sppf(out)

        # sppf --- stride 32 ---->

        with record_function('fpn'):
            out_8, out_16, out = self.fpn(out_8, out_16, out)
            
            
        return out_8, out_16, out

In [36]:
model = Model()

In [37]:
img = torch.rand(1,3,640,640)
out = model(img)

In [40]:
out[0].shape, out[1].shape, out[2].shape

(torch.Size([1, 64, 80, 80]),
 torch.Size([1, 128, 40, 40]),
 torch.Size([1, 256, 20, 20]))

In [48]:
img = img+1

In [9]:
reference = torch.load('yolov8m.pt')['model'].model

  reference = torch.load('yolov8m.pt')['model'].model


In [20]:
up = None
for i, (name, m) in enumerate(reference.named_modules()):
    if name == '10':
        up = m

In [22]:
type(up)

torch.nn.modules.upsampling.Upsample

In [23]:
up

Upsample(scale_factor=2.0, mode='nearest')

In [28]:
tensor = torch.tensor([[[1.,2.],[3.,4.]]])
tensor.shape

torch.Size([1, 2, 2])

In [29]:
out = up(tensor)
out

tensor([[[1., 1., 2., 2.],
         [3., 3., 4., 4.]]])

In [30]:
out.shape

torch.Size([1, 2, 4])

In [19]:
reference

Sequential(
  (0): Conv(
    (conv): Conv2d(3, 48, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(48, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
    (act): SiLU(inplace=True)
  )
  (1): Conv(
    (conv): Conv2d(48, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(96, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
    (act): SiLU(inplace=True)
  )
  (2): C2f(
    (cv1): Conv(
      (conv): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(96, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): SiLU(inplace=True)
    )
    (cv2): Conv(
      (conv): Conv2d(192, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(96, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): SiLU(inplace=True)
    )
    (m): ModuleList(
      (0-1): 2 x Bottleneck(
        (cv1): Conv(
 