In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class CBS(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(CBS, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
        self.bn = nn.BatchNorm2d(out_channels)
        self.silu = nn.SiLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.silu(x)
        return x

In [None]:
class SPPBlock(nn.Module):
    def __init__(self, in_channels, out_channels, pool_sizes=[5, 9, 13]):
        super(SPPBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.pools = nn.ModuleList([nn.MaxPool2d(size, stride=1, padding=size//2) for size in pool_sizes])

    def forward(self, x):
        # Apply pooling at different scales and concatenate results
        pooled_outputs = [x]  # Start with the original input
        for pool in self.pools:
            pooled_outputs.append(pool(x))
        
        # Concatenate pooled results
        x = torch.cat(pooled_outputs, dim=1)
        x = self.conv1(x)  # Final convolution to match output channels
        return x

In [None]:
class C3(nn.Module):
    def __init__(self, in_channels):
        super(C3, self).__init__()
        # Split the input into two paths
        self.conv1 = nn.Conv2d(in_channels, in_channels // 2, kernel_size=1)
        self.conv2 = nn.Conv2d(in_channels, in_channels // 2, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=1)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(x)
        x = torch.cat([x1, x2], dim=1)
        x = self.conv3(x)
        return x

In [None]:
class YOLOv5Backbone(nn.Module):
    def __init__(self, in_channels=3, out_channels=1024):
        super(YOLOv5Backbone, self).__init__()

        # Initial CBS block
        self.cbs1 = CBS(in_channels, 64, kernel_size=6, stride=2, padding=2)

        # Second CBS block
        self.cbs2 = CBS(64, 128, kernel_size=3, stride=2)

        # First C3 block
        self.c3_1 = C3(128)

        # Third CBS block
        self.cbs3 = CBS(128, 256, kernel_size=3, stride=2)

        # Second C3 block
        self.c3_2 = C3(256)

        # Fourth CBS block
        self.cbs4 = CBS(256, 512, kernel_size=3, stride=2)

        # Third C3 block
        self.c3_3 = C3(512)

        # Fifth CBS block
        self.cbs5 = CBS(512, 1024, kernel_size=3, stride=2)

        # Fourth C3 block
        self.c3_4 = C3(1024)

        # Spatial Pyramid Pooling (SPP) block
        self.spp = SPPBlock(1024, 1024, pool_sizes=[5, 9, 13])

    def forward(self, x):
        # Initial layers
        x = self.cbs1(x)
        x = self.cbs2(x)
        x = self.c3_1(x)
        
        # Output from c3_2
        x_c3_2 = self.cbs3(x)
        x_c3_2 = self.c3_2(x_c3_2)
        
        # Output from c3_3
        x_c3_3 = self.cbs4(x_c3_2)
        x_c3_3 = self.c3_3(x_c3_3)

        # Output from spp
        x_cbs5 = self.cbs5(x_c3_3)
        x_c3_4 = self.c3_4(x_cbs5)
        x_spp = self.spp(x_c3_4)

        # Return outputs from c3_2, c3_3, and spp
        return x_c3_2, x_c3_3, x_spp