In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class RepConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, activation=True):
        super(RepConv, self).__init__()
        self.activation = activation
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        if self.activation:
            x = self.relu(x)
        return x

In [None]:
class DWConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, activation=True):
        super(DWConv, self).__init__()
        self.activation = activation
        self.dwconv = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, groups=in_channels, bias=False)
        self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.dwconv(x)
        x = self.pointwise(x)
        x = self.bn(x)
        if self.activation:
            x = self.relu(x)
        return x

In [None]:
class EfficientRepBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super(EfficientRepBlock, self).__init__()
        # Depthwise Separable Convolution
        self.dwconv = DWConv(in_channels, out_channels, kernel_size, stride, padding)
        # Reparameterized Conv for efficient computation
        self.repconv = RepConv(out_channels, out_channels, kernel_size, stride, padding)

    def forward(self, x):
        x = self.dwconv(x)
        x = self.repconv(x)
        return x

In [None]:
class EfficientRepBackbone(nn.Module):
    def __init__(self, num_blocks=4):
        super(EfficientRepBackbone, self).__init__()
        self.initial_conv = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)  # Initial conv
        self.initial_bn = nn.BatchNorm2d(64)
        self.initial_relu = nn.ReLU(inplace=True)

        # Stack of EfficientRep Blocks
        self.blocks = nn.ModuleList([
            EfficientRepBlock(64, 128, kernel_size=3, stride=2, padding=1),
            EfficientRepBlock(128, 256, kernel_size=3, stride=2, padding=1),
            EfficientRepBlock(256, 512, kernel_size=3, stride=2, padding=1),
            EfficientRepBlock(512, 1024, kernel_size=3, stride=2, padding=1)
        ])

        # Additional blocks to produce feature maps at different stages
        self.stage_C3 = EfficientRepBlock(128, 256, kernel_size=3, stride=2, padding=1)  # C3 output
        self.stage_C4 = EfficientRepBlock(256, 512, kernel_size=3, stride=2, padding=1)  # C4 output
        self.stage_C5 = EfficientRepBlock(512, 1024, kernel_size=3, stride=2, padding=1)  # C5 output

    def forward(self, x):
        x = self.initial_relu(self.initial_bn(self.initial_conv(x)))

        # Pass through the blocks
        x = self.blocks[0](x)  # Block 1
        C3 = self.stage_C3(x)  # C3 output (after first stage)

        x = self.blocks[1](x)  # Block 2
        C4 = self.stage_C4(x)  # C4 output (after second stage)

        x = self.blocks[2](x)  # Block 3
        C5 = self.stage_C5(x)  # C5 output (after third stage)

        # Finally, pass through the last block for full resolution
        x = self.blocks[3](x)

        # Return C3, C4, C5
        return C3, C4, C5