In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class CBS(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(CBS, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
        self.bn = nn.BatchNorm2d(out_channels)
        self.silu = nn.SiLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.silu(x)
        return x

In [None]:
class SPPBlock(nn.Module):
    def __init__(self, in_channels, out_channels, pool_sizes=[5, 9, 13]):
        super(SPPBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.pools = nn.ModuleList([nn.MaxPool2d(size, stride=1, padding=size//2) for size in pool_sizes])

    def forward(self, x):
        # Apply pooling at different scales and concatenate results
        pooled_outputs = [x]  # Start with the original input
        for pool in self.pools:
            pooled_outputs.append(pool(x))
        
        # Concatenate pooled results
        x = torch.cat(pooled_outputs, dim=1)
        x = self.conv1(x)  # Final convolution to match output channels
        return x

In [None]:
class C3(nn.Module):
    def __init__(self, in_channels):
        super(C3, self).__init__()
        # Split the input into two paths
        self.conv1 = nn.Conv2d(in_channels, in_channels // 2, kernel_size=1)
        self.conv2 = nn.Conv2d(in_channels, in_channels // 2, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=1)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(x)
        x = torch.cat([x1, x2], dim=1)
        x = self.conv3(x)
        return x

In [None]:
class YOLOv5Backbone(nn.Module):
    def __init__(self, in_channels=3, out_channels=1024):
        super(YOLOv5Backbone, self).__init__()

        # Initial CBS block
        self.cbs1 = CBS(in_channels, 64, kernel_size=6, stride=2, padding=2)

        # Second CBS block
        self.cbs2 = CBS(64, 128, kernel_size=3, stride=2)

        # First C3 block
        self.c3_1 = C3(128)

        # Third CBS block
        self.cbs3 = CBS(128, 256, kernel_size=3, stride=2)

        # Second C3 block
        self.c3_2 = C3(256)

        # Fourth CBS block
        self.cbs4 = CBS(256, 512, kernel_size=3, stride=2)

        # Third C3 block
        self.c3_3 = C3(512)

        # Fifth CBS block
        self.cbs5 = CBS(512, 1024, kernel_size=3, stride=2)

        # Fourth C3 block
        self.c3_4 = C3(1024)

        # Spatial Pyramid Pooling (SPP) block
        self.spp = SPPBlock(1024, 1024, pool_sizes=[5, 9, 13])

    def forward(self, x):
        # Initial layers
        x = self.cbs1(x)
        x = self.cbs2(x)
        x = self.c3_1(x)
        
        # Output from c3_2
        x_c3_2 = self.cbs3(x)
        x_c3_2 = self.c3_2(x_c3_2)
        
        # Output from c3_3
        x_c3_3 = self.cbs4(x_c3_2)
        x_c3_3 = self.c3_3(x_c3_3)

        # Output from spp
        x_cbs5 = self.cbs5(x_c3_3)
        x_c3_4 = self.c3_4(x_cbs5)
        x_spp = self.spp(x_c3_4)

        # Return outputs from c3_2, c3_3, and spp
        return x_c3_2, x_c3_3, x_spp

In [None]:
class PANetBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(PANetBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(out_channels, out_channels, kernel_size=1)
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')  # Upsample feature maps

    def forward(self, x, skip=None):
        # If a skip connection is provided, upsample and concatenate
        if skip is not None:
            x = self.upsample(x)  # Upsample the current feature map
            x = torch.cat([x, skip], dim=1)  # Concatenate with skip connection

        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        return x

In [None]:
class YOLOv5Neck(nn.Module):
    def __init__(self, in_channels_3, in_channels_4, in_channels_5, out_channels=1024):
        super(YOLOv5Neck, self).__init__()

        # PANet blocks for Top-Down and Bottom-Up aggregation
        self.panet_top_down_1 = PANetBlock(in_channels_5, out_channels)  # Top-Down from C5
        self.panet_top_down_2 = PANetBlock(in_channels_4, out_channels)  # Top-Down from C4
        self.panet_top_down_3 = PANetBlock(in_channels_3, out_channels)  # Top-Down from C3

        self.panet_bottom_up_1 = PANetBlock(out_channels, out_channels)  # Bottom-Up from C3
        self.panet_bottom_up_2 = PANetBlock(out_channels, out_channels)  # Bottom-Up from C4
        self.panet_bottom_up_3 = PANetBlock(out_channels, out_channels)  # Bottom-Up from C5

        # Final convolutions for P3, P4, P5 outputs
        self.final_conv_p3 = nn.Conv2d(out_channels, out_channels, kernel_size=1)
        self.final_conv_p4 = nn.Conv2d(out_channels, out_channels, kernel_size=1)
        self.final_conv_p5 = nn.Conv2d(out_channels, out_channels, kernel_size=1)

    def forward(self, c3, c4, c5):
        # Top-Down Path: Start from C5 and propagate upwards
        top_down_1 = self.panet_top_down_1(c5)
        top_down_2 = self.panet_top_down_2(c4, top_down_1)  # Concatenate with C5
        top_down_3 = self.panet_top_down_3(c3, top_down_2)  # Concatenate with C4

        # Bottom-Up Path: Start from C3 and propagate downwards
        bottom_up_1 = self.panet_bottom_up_1(top_down_3)  # Bottom-up from top-down path
        bottom_up_2 = self.panet_bottom_up_2(bottom_up_1, top_down_2)  # Concatenate with top-down C4
        bottom_up_3 = self.panet_bottom_up_3(bottom_up_2, top_down_1)  # Concatenate with top-down C5

        # Final convolutions to get the P3, P4, P5 feature maps
        p3 = self.final_conv_p3(bottom_up_3)  # P3 (highest resolution)
        p4 = self.final_conv_p4(bottom_up_2)  # P4 (medium resolution)
        p5 = self.final_conv_p5(bottom_up_1)  # P5 (lowest resolution)

        return p3, p4, p5

In [None]:
class YOLOv5DetectionHead(nn.Module):
    def __init__(self, num_classes=80, num_anchors=3, in_channels=1024):
        super(YOLOv5DetectionHead, self).__init__()
        
        # Number of values predicted for each anchor: 4 bbox values + 1 objectness + num_classes class probabilities
        self.num_preds = 4 + 1 + num_classes  # 4 bbox (x, y, w, h), 1 objectness, num_classes
        
        # Prediction conv layer for P3, P4, P5
        self.conv_p3 = nn.Conv2d(in_channels, num_anchors * self.num_preds, kernel_size=1)
        self.conv_p4 = nn.Conv2d(in_channels, num_anchors * self.num_preds, kernel_size=1)
        self.conv_p5 = nn.Conv2d(in_channels, num_anchors * self.num_preds, kernel_size=1)

    def forward(self, p3, p4, p5):
        # Get predictions for each scale (P3, P4, P5)
        pred_p3 = self.conv_p3(p3)  # Predictions for P3 (highest resolution)
        pred_p4 = self.conv_p4(p4)  # Predictions for P4 (medium resolution)
        pred_p5 = self.conv_p5(p5)  # Predictions for P5 (lowest resolution)
        
        # Reshape predictions to combine them into one tensor
        # The final prediction tensor should combine all scales (P3, P4, P5)
        predictions = torch.cat([pred_p3, pred_p4, pred_p5], dim=1)
        
        # Return unified predictions tensor
        return predictions

In [None]:
class YOLOv5Model(nn.Module):
    def __init__(self, num_classes=80):
        super(YOLOv5Model, self).__init__()

        # Backbone (already implemented as previously)
        self.backbone = YOLOv5Backbone()

        # Neck (PANet)
        self.neck = YOLOv5Neck(in_channels_3=256, in_channels_4=512, in_channels_5=1024)

        # Detection Head
        self.head = YOLOv5DetectionHead(num_classes=num_classes)

    def forward(self, x):
        # Forward pass through the Backbone
        c3, c4, c5 = self.backbone(x)

        # Forward pass through the Neck (Top-Down and Bottom-Up)
        p3, p4, p5 = self.neck(c3, c4, c5)

        # Forward pass through the Head to get unified detection predictions
        predictions = self.head(p3, p4, p5)

        return predictions