In [28]:
import torch
import torch.nn as nn
from ultralytics import YOLO

In [29]:
%run block.ipynb

variants taken from https://github.com/ultralytics/ultralytics/blob/07a5ff9ddca487581035b61ff7678c0f7e0f40d9/ultralytics/cfg/models/v8/yolov8.yaml

In [30]:
class Model(nn.Module):
    variants = {'n': {'d': 0.34, 'w': 0.25, 'mc': 1024},
                's': {'d': 0.34, 'w': 0.50, 'mc': 1024},
                'm': {'d': 0.67, 'w': 0.75, 'mc': 768},
                'l': {'d': 1.00, 'w': 1.00, 'mc': 512},
                'xl': {'d': 1.00, 'w': 1.25, 'mc': 512}}
    def __init__(self, device=None, dtype=None, 
                 residual_connection=False, CSP=False, add_hidden=False, bottleneck=1.0,
                 num_classes=1000, variant='n', classifyV8=False):
        super().__init__()
        
        if variant not in Model.variants.keys():
            raise Exception("Invalid variant.")
        self.variant = variant
        self.mc = Model.variants[self.variant]['mc']
        self.w = Model.variants[self.variant]['w']
        self.d = Model.variants[self.variant]['d']
        self._ch = lambda ch: int(min(ch, self.mc)*self.w)
        self._d = lambda d: int(d * self.d)
        
        self.conv1 = Conv(3, out_channels=self._ch(64), kernel_size=(3, 3), stride=(2, 2), 
                         padding=(1, 1), bias=False, 
                         device=device, dtype=dtype)
        self.conv2 = Conv(self._ch(64), out_channels=self._ch(128), kernel_size=(3, 3), stride=(2, 2), 
                          padding=(1, 1), bias=False,
                          device=device, dtype=dtype)
        self.c2f1 = C2f(self._ch(128), out_channels=self._ch(128), n=self._d(3), residual_connection=residual_connection, CSP=CSP, add_hidden=add_hidden, bottleneck=1.0,
                        device=device, dtype=dtype)
        self.conv3 = Conv(self._ch(128), out_channels=self._ch(256), kernel_size=(3, 3), stride=(2, 2), 
                         padding=(1, 1), bias=False, 
                         device=device, dtype=dtype)
        self.c2f2 = C2f(self._ch(256), out_channels=self._ch(256), n=self._d(6), residual_connection=residual_connection, CSP=CSP, add_hidden=add_hidden, bottleneck=1.0,
                        device=device, dtype=dtype)
        self.conv4 = Conv(self._ch(256), out_channels=self._ch(512), kernel_size=(3, 3), stride=(2, 2), 
                         padding=(1, 1), bias=False, 
                         device=device, dtype=dtype)
        self.c2f3 = C2f(self._ch(512), out_channels=self._ch(512), n=self._d(6), residual_connection=residual_connection, CSP=CSP, add_hidden=add_hidden, bottleneck=1.0,
                        device=device, dtype=dtype)
        self.conv5 = Conv(self._ch(512), out_channels=self._ch(1024), kernel_size=(3, 3), stride=(2, 2), 
                         padding=(1, 1), bias=False, 
                         device=device, dtype=dtype)
        self.c2f4 = C2f(self._ch(1024), out_channels=self._ch(1024), n=self._d(3), residual_connection=residual_connection, CSP=CSP, add_hidden=add_hidden, bottleneck=1.0,
                       device=device, dtype=dtype)

        if classifyV8:
            self.classify = ClassifyV8(self._ch(1024), num_classes=num_classes,
                                      device=device, dtype=dtype)
        else:
            self.classify = ClassifyV2(self._ch(1024), num_classes=num_classes,
                                      device=device, dtype=dtype)

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)

        out = self.c2f1(out)
        out = self.conv3(out)

        out = self.c2f2(out)
        out = self.conv4(out)

        out = self.c2f3(out)
        out = self.conv5(out)

        out = self.c2f4(out)

        out = self.classify(out)
        return out

### Path:

In [68]:
inp = torch.rand(1,3,640,640)

In [19]:
out = model1(inp)

In [20]:
out.shape

torch.Size([1, 1000])

In [20]:
total_params = sum(p.numel() for p in model1.parameters())
total_params

2850202

In [21]:
model2 = Model(num_classes=10, residual_connection=True, CSP=True)

In [22]:
out = model2(inp)

In [23]:
out.shape

torch.Size([1, 10])

In [24]:
total_params = sum(p.numel() for p in model2.parameters())
total_params

1397338

In [25]:
model3 = Model(num_classes=10, residual_connection=True, CSP=True, add_hidden=True)

In [26]:
out = model3(inp)

In [27]:
total_params = sum(p.numel() for p in model3.parameters())
total_params

1451098

## Do the parameters match?

### YOLOv2n-cls

In [44]:
model1 = Model(residual_connection=True, CSP=True, add_hidden=True, bottleneck=1.0,
                 num_classes=1000, variant='n', classifyV8=True, dtype=torch.float32)

In [45]:
reference = torch.load('yolov8n-cls.pt')['model'].model

In [46]:
all_parameters_match = True
for i, (p1, p2) in enumerate(zip(model1.parameters(), reference.parameters())):
    all_parameters_math = all_parameters_match and p1.shape == p2.shape

In [47]:
all_parameters_match

True

In [48]:
total_params = sum(p.numel() for p in model1.parameters())
total_params

2719288

In [49]:
total_params_reference = sum(p.numel() for p in reference.parameters())
total_params_reference

2719288

In [50]:
total_params == total_params_reference

True

### YOLOv2s-cls

In [61]:
reference = torch.load('yolov8s-cls.pt')['model'].model

In [62]:
model1 = Model(residual_connection=True, CSP=True, add_hidden=True, bottleneck=1.0,
                 num_classes=1000, variant='s', classifyV8=True, dtype=torch.float32)

In [63]:
all_parameters_match = True
for i, (p1, p2) in enumerate(zip(model1.parameters(), reference.parameters())):
    all_parameters_math = all_parameters_match and p1.shape == p2.shape

In [64]:
all_parameters_math

True

In [65]:
total_params = sum(p.numel() for p in model1.parameters())
total_params

6361736

In [66]:
total_params_reference = sum(p.numel() for p in reference.parameters())
total_params_reference

6361736

In [69]:
total_params == total_params_reference

True