In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class ConvBlock(nn.Module):
    """
    Convolution + BatchNorm + LeakyReLU block.
    This is a basic building block for CSPDarknet53, renamed to ConvBlock.
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, leaky_slope=0.1):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.leaky_relu = nn.LeakyReLU(negative_slope=leaky_slope, inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.leaky_relu(x)
        return x

In [None]:
class ResidualBlock(nn.Module):
    """
    Residual block for Darknet53.
    Consists of two 3x3 convolutions followed by a skip connection.
    """
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = ConvBlock(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.conv2 = ConvBlock(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        
        self.shortcut = nn.Identity() if in_channels == out_channels else nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
        
    def forward(self, x):
        residual = self.shortcut(x)
        x = self.conv1(x)
        x = self.conv2(x)
        return x + residual

In [None]:
class CSPResidualBlock(nn.Module):
    """
    Cross-Stage Partial (CSP) residual block.
    This block divides the feature map into two parts and processes them separately.
    """
    def __init__(self, in_channels, out_channels, num_blocks=1, stride=1):
        super(CSPResidualBlock, self).__init__()
        self.split_channels = in_channels // 2
        self.block1 = nn.Sequential(*[ResidualBlock(self.split_channels, self.split_channels) for _ in range(num_blocks)])
        self.block2 = nn.Sequential(*[ResidualBlock(self.split_channels, self.split_channels) for _ in range(num_blocks)])

        self.conv1 = ConvBlock(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.conv2 = ConvBlock(out_channels, out_channels, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        split1 = x[:, :self.split_channels, :, :]
        split2 = x[:, self.split_channels:, :, :]

        split1 = self.block1(split1)
        split2 = self.block2(split2)

        out = torch.cat([split1, split2], dim=1)
        out = self.conv1(out)
        out = self.conv2(out)
        return out