In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class ConvBlock(nn.Module):
    """
    Convolution + BatchNorm + LeakyReLU block.
    This is a basic building block for CSPDarknet53.
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, leaky_slope=0.1):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.leaky_relu = nn.LeakyReLU(negative_slope=leaky_slope, inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.leaky_relu(x)
        return x

In [None]:
class ResidualBlock(nn.Module):
    """
    Residual block for Darknet53.
    Consists of two 3x3 convolutions followed by a skip connection.
    """
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = ConvBlock(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.conv2 = ConvBlock(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        
        self.shortcut = nn.Identity() if in_channels == out_channels else nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
        
    def forward(self, x):
        residual = self.shortcut(x)
        x = self.conv1(x)
        x = self.conv2(x)
        return x + residual

In [None]:
class CSPResidualBlock(nn.Module):
    """
    Cross-Stage Partial (CSP) residual block.
    This block divides the feature map into two parts and processes them separately.
    """
    def __init__(self, in_channels, out_channels, num_blocks=1, stride=1):
        super(CSPResidualBlock, self).__init__()
        self.split_channels = in_channels // 2
        self.block1 = nn.Sequential(*[ResidualBlock(self.split_channels, self.split_channels) for _ in range(num_blocks)])
        self.block2 = nn.Sequential(*[ResidualBlock(self.split_channels, self.split_channels) for _ in range(num_blocks)])

        self.conv1 = ConvBlock(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.conv2 = ConvBlock(out_channels, out_channels, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        split1 = x[:, :self.split_channels, :, :]
        split2 = x[:, self.split_channels:, :, :]

        split1 = self.block1(split1)
        split2 = self.block2(split2)

        out = torch.cat([split1, split2], dim=1)
        out = self.conv1(out)
        out = self.conv2(out)
        return out

In [None]:
class CSPDarknet53(nn.Module):
    """
    Full CSPDarknet53 backbone for YOLOv4.
    It consists of several CSP blocks and residual blocks.
    """
    def __init__(self):
        super(CSPDarknet53, self).__init__()
        
        self.conv1 = ConvBlock(3, 32, kernel_size=3, stride=1, padding=1)  # Initial conv for input

        # CSP blocks at different stages, each outputting feature maps at various resolutions
        self.csp1 = CSPResidualBlock(32, 64, num_blocks=1, stride=2)
        self.csp2 = CSPResidualBlock(64, 128, num_blocks=2, stride=2)
        self.csp3 = CSPResidualBlock(128, 256, num_blocks=8, stride=2)
        self.csp4 = CSPResidualBlock(256, 512, num_blocks=8, stride=2)
        self.csp5 = CSPResidualBlock(512, 1024, num_blocks=4, stride=2)
        
        # The final convolutional block
        self.final_conv = ConvBlock(1024, 1024, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        # Stage 1 (Input)
        x1 = self.conv1(x)
        
        # Stage 2
        x2 = self.csp1(x1)
        
        # Stage 3
        x3 = self.csp2(x2)
        
        # Stage 4
        x4 = self.csp3(x3)
        
        # Stage 5
        x5 = self.csp4(x4)
        
        # Stage 6
        x6 = self.csp5(x5)
        
        # Final convolution (This will return the feature map at the final resolution)
        x_final = self.final_conv(x6)
        
        # Return feature maps at different stages
        return [x1, x2, x3, x4, x5, x_final]

In [None]:
class PANetBlock(nn.Module):
    """
    A single PANet block that performs path aggregation (top-down and bottom-up).
    Combines bottom-up and top-down features for improved information flow.
    """
    def __init__(self, in_channels, out_channels):
        super(PANetBlock, self).__init__()
        # 1x1 conv to reduce channels before aggregation
        self.conv1 = ConvBlock(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        # 3x3 conv for refinement
        self.conv2 = ConvBlock(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        
        # Upsample and downsample layers
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
        self.downsample = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1)

    def forward(self, x, skip=None):
        """
        Forward pass for the PANet block:
        - If skip is provided, performs bottom-up aggregation.
        - Otherwise, performs top-down aggregation with upsampling.
        """
        x = self.conv1(x)
        if skip is not None:
            x = x + skip  # Bottom-up aggregation (skip connection from earlier layers)

        x = self.conv2(x)
        up = self.upsample(x)  # Top-down aggregation (upsampling)
        down = self.downsample(x)  # Bottom-up aggregation (downsampling)

        return up, down

In [None]:
class PANet(nn.Module):
    """
    PANet (Path Aggregation Network) to aggregate multi-scale features.
    Used in YOLOv4's Neck to refine features extracted from the backbone.
    """
    def __init__(self, in_channels_list, out_channels):
        super(PANet, self).__init__()
        self.panet_blocks = nn.ModuleList([
            PANetBlock(in_channels, out_channels) for in_channels in in_channels_list
        ])

    def forward(self, x_list):
        """
        Forward pass through PANet:
        - x_list: A list of feature maps from different stages of the backbone.
        - Each feature map will be processed with a corresponding PANet block.
        """
        output_list = []
        skip_list = []

        for i in range(len(x_list)-1, -1, -1):  # Start from the deepest layer to the shallowest
            x = x_list[i]
            up, down = self.panet_blocks[i](x, skip_list[-1] if skip_list else None) if skip_list else (x, None)
            output_list.append(up)
            if down is not None:
                skip_list.append(down)

        # Upsample all outputs to the same size (we will combine them later)
        upsampled_outputs = [F.interpolate(o, size=output_list[0].shape[2:], mode='nearest') for o in output_list]

        return upsampled_outputs