In [43]:
import torch
import torch.nn as nn

In [44]:
reference = torch.load('yolov8n-cls_10.pt')

In [45]:
reference

Sequential(
  (0): Conv(
    (conv): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): SiLU()
  )
  (1): Conv(
    (conv): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act): SiLU()
  )
  (2): C2f(
    (cv1): Conv(
      (conv): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): SiLU()
    )
    (cv2): Conv(
      (conv): Conv2d(48, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): SiLU()
    )
    (m): ModuleList(
      (0): Bottleneck(
        (cv1): Conv(
          (conv): Conv2d(16, 16, kernel_size=(3, 3), stride=

In [46]:
class Conv(nn.Module): # for my experiments, it is identical to the ultralytics conv module
    def __init__(self, in_channels, out_channels, 
                 kernel_size, stride=1, padding=1, 
                 dilation=1, groups=1, bias=True, 
                 padding_mode='zeros', device=None, dtype=None):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, 
                             kernel_size, stride, padding, 
                             dilation, groups, bias, 
                             padding_mode, device, dtype)
        self.bn = nn.BatchNorm2d(out_channels)
        self.act = nn.SiLU()

    def forward(self, x):
        out = self.act(self.bn(self.conv(x)))
        return out

In [47]:
class Bottleneck(nn.Module):
    def __init__(self, in_channels, out_channels, 
                 residual_connection=True, bottleneck=0.5):
        super().__init__()
        self.hidden_channels = int(out_channels*bottleneck)
        self.conv1 = Conv(in_channels, out_channels=self.hidden_channels, kernel_size=(3,3), stride=1, padding=1)
        self.conv2 = Conv(self.hidden_channels, out_channels=out_channels, kernel_size=(3,3), stride=1, padding=1)
        self.add = residual_connection and in_channels == out_channels

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)

        if self.add:
            return x + out
        else:
            return out

In [48]:
class C2f(nn.Module):
    def __init__(self, in_channels, out_channels, n=1, residual_connection=False, CSP=False, add_hidden=False, bottleneck=1.0):
        super().__init__()
        self.conv1 = Conv(in_channels, out_channels=out_channels, kernel_size=(1,1), stride=1, padding=0)
        self.hidden_channels = out_channels // 2 if CSP else out_channels
        self.n_blocks = [Bottleneck(self.hidden_channels, self.hidden_channels, residual_connection=residual_connection, bottleneck=bottleneck) for _ in range(n)]
        self.CSP = CSP
        self.add_hidden = CSP and add_hidden
        if self.add_hidden:
            self.conv2 = Conv((2 + n) * self.hidden_channels, out_channels=out_channels, kernel_size=(1,1), stride=1, padding=0)
        else:
            self.conv2 = Conv(self.hidden_channels, out_channels=out_channels, kernel_size=(1,1), stride=1, padding=0)

    def forward(self, x):
        out = self.conv1(x)
        if self.CSP:
            _out = list(out.chunk(2, dim=1))
            out = _out[0]
            for block in self.n_blocks:
                out = block(out)
                if self.add_hidden:
                    _out.append(out)
            out = torch.cat(_out, 1)
        else:
            for block in self.n_blocks:
                out = block(out)
        out = self.conv2(out)
        return out

In [61]:
class Model1(nn.Module):
    def __init__(self, device=None, dtype=None, num_classes=1000):
        super().__init__()
        self.conv1 = Conv(3, out_channels=16, kernel_size=(3, 3), stride=(2, 2), 
                         padding=(1, 1), bias=False, 
                         device=device, dtype=dtype)
        self.conv2 = Conv(16, out_channels=32, kernel_size=(3, 3), stride=(2, 2), 
                          padding=(1, 1), bias=False,
                          device=device, dtype=dtype)
        self.c2f1 = C2f(32, out_channels=32, n=1, residual_connection=False, CSP=False, add_hidden=False, bottleneck=1.0)
        self.conv3 = Conv(32, out_channels=64, kernel_size=(3, 3), stride=(2, 2), 
                         padding=(1, 1), bias=False, 
                         device=device, dtype=dtype)
        self.c2f2 = C2f(64, out_channels=64, n=2, residual_connection=False, CSP=False, add_hidden=False, bottleneck=1.0)
        self.conv4 = Conv(64, out_channels=128, kernel_size=(3, 3), stride=(2, 2), 
                         padding=(1, 1), bias=False, 
                         device=device, dtype=dtype)
        self.c2f3 = C2f(128, out_channels=128, n=2, residual_connection=False, CSP=False, add_hidden=False, bottleneck=1.0)
        self.conv5 = Conv(128, out_channels=256, kernel_size=(3, 3), stride=(2, 2), 
                         padding=(1, 1), bias=False, 
                         device=device, dtype=dtype)
        self.c2f4 = C2f(256, out_channels=256, n=1, residual_connection=False, CSP=False, add_hidden=False, bottleneck=1.0)
        

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)

        out = self.c2f1(out)
        out = self.conv3(out)

        out = self.c2f2(out)
        out = self.conv4(out)

        out = self.c2f3(out)
        out = self.conv5(out)

        out = self.c2f4(out)

        return out

In [62]:
inp = torch.rand(1, 3,640,640)

In [63]:
model = Model1()

In [64]:
out = model(inp)

In [65]:
out.shape

torch.Size([1, 256, 20, 20])