In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class SELayer(nn.Module):
    def __init__(self, channels, reduction=4):
        super(SELayer, self).__init__()
        self.fc1 = nn.Linear(channels, channels // reduction, bias=False)
        self.fc2 = nn.Linear(channels // reduction, channels, bias=False)

    def forward(self, x):
        batch_size, channels, height, width = x.size()
        y = F.adaptive_avg_pool2d(x, (1, 1)).view(batch_size, channels)
        y = F.relu(self.fc1(y))
        y = torch.sigmoid(self.fc2(y)).view(batch_size, channels, 1, 1)
        return x * y

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, expand_ratio, use_se=False):
        super(ConvBlock, self).__init__()
        self.expand_ratio = expand_ratio
        self.use_se = use_se

        self.expand = nn.Conv2d(in_channels, in_channels * expand_ratio, kernel_size=1)
        self.bn1 = nn.BatchNorm2d(in_channels * expand_ratio)
        self.relu = nn.SiLU()  # Swish activation

        self.depthwise = nn.Conv2d(in_channels * expand_ratio, in_channels * expand_ratio, 
                                    kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, groups=in_channels * expand_ratio)
        self.bn2 = nn.BatchNorm2d(in_channels * expand_ratio)

        self.project = nn.Conv2d(in_channels * expand_ratio, out_channels, kernel_size=1)
        self.bn3 = nn.BatchNorm2d(out_channels)

        if self.use_se:
            self.se = SELayer(in_channels * expand_ratio)

    def forward(self, x):
        identity = x

        x = self.expand(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.depthwise(x)
        x = self.bn2(x)

        if self.use_se:
            x = self.se(x)

        x = self.project(x)
        x = self.bn3(x)

        return x + identity 

In [None]:
class EfficientNetBackbone(nn.Module):
    def __init__(self):
        super(EfficientNetBackbone, self).__init__()
        
        self.stem = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.SiLU()
        )

        self.blocks = nn.ModuleList([
            # MBConv Block: (in_channels, out_channels, kernel_size, stride, expand_ratio, use_se)
            ConvBlock(32, 16, kernel_size=3, stride=1, expand_ratio=1, use_se=False),
            ConvBlock(16, 24, kernel_size=3, stride=2, expand_ratio=6, use_se=False),
            ConvBlock(24, 24, kernel_size=3, stride=1, expand_ratio=6, use_se=False),
            ConvBlock(24, 40, kernel_size=5, stride=2, expand_ratio=6, use_se=True),
            ConvBlock(40, 40, kernel_size=5, stride=1, expand_ratio=6, use_se=True),
            ConvBlock(40, 80, kernel_size=3, stride=2, expand_ratio=6, use_se=False),
            ConvBlock(80, 80, kernel_size=3, stride=1, expand_ratio=6, use_se=False),
            ConvBlock(80, 80, kernel_size=3, stride=1, expand_ratio=6, use_se=False),
            ConvBlock(80, 112, kernel_size=5, stride=1, expand_ratio=6, use_se=True),
            ConvBlock(112, 112, kernel_size=5, stride=1, expand_ratio=6, use_se=True),
            ConvBlock(112, 192, kernel_size=3, stride=2, expand_ratio=6, use_se=False),
            ConvBlock(192, 192, kernel_size=3, stride=1, expand_ratio=6, use_se=False),
            ConvBlock(192, 192, kernel_size=3, stride=1, expand_ratio=6, use_se=False),
            ConvBlock(192, 192, kernel_size=3, stride=1, expand_ratio=6, use_se=False),
            ConvBlock(192, 192, kernel_size=3, stride=1, expand_ratio=6, use_se=False),
            ConvBlock(192, 320, kernel_size=1, stride=1, expand_ratio=6, use_se=True),
        ])

        self.final_conv = nn.Conv2d(320, 1280, kernel_size=1)

    def forward(self, x):
        x = self.stem(x)
        
        feature_maps = []
        for block in self.blocks:
            x = block(x)
            
            if len(feature_maps) in [1, 3, 5, 10]: #Feature map stages
                feature_maps.append(x)
        
        x = self.final_conv(x)
        feature_maps.append(x)
        
        return feature_maps

In [None]:
class TopDownPathway(nn.Module):
    def __init__(self, out_channels):
        super(TopDownPathway, self).__init__()
        self.fusion_conv = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)

    def forward(self, P3, P4, P5, P6):
        P6_up = F.interpolate(P6, size=P5.shape[2:], mode='nearest')
        P5 = P5 + P6_up
        P5 = self.fusion_conv(P5)

        P5_up = F.interpolate(P5, size=P4.shape[2:], mode='nearest')
        P4 = P4 + P5_up
        P4 = self.fusion_conv(P4)

        P4_up = F.interpolate(P4, size=P3.shape[2:], mode='nearest')
        P3 = P3 + P4_up
        P3 = self.fusion_conv(P3)

        return P3, P4, P5

In [None]:
class BottomUpPathway(nn.Module):
    def __init__(self, out_channels):
        super(BottomUpPathway, self).__init__()
        self.fusion_conv = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)

    def forward(self, P3, P4, P5):
        P3_down = F.max_pool2d(P3, kernel_size=3, stride=2, padding=1)
        P4 = P4 + P3_down
        P4 = self.fusion_conv(P4)

        P4_down = F.max_pool2d(P4, kernel_size=3, stride=2, padding=1)
        P5 = P5 + P4_down
        P5 = self.fusion_conv(P5)

        return P4, P5

In [None]:
class BiFPNLayer(nn.Module):
    def __init__(self, out_channels):
        super(BiFPNLayer, self).__init__()
        self.top_down_pathway = TopDownPathway(out_channels)
        self.bottom_up_pathway = BottomUpPathway(out_channels)

    def forward(self, inputs):
        P3, P4, P5, P6 = inputs

        # Top-Down Pathway
        P3, P4, P5 = self.top_down_pathway(P3, P4, P5, P6)
        
        # Bottom-Up Pathway
        P4, P5 = self.bottom_up_pathway(P3, P4, P5)

        return P3, P4, P5

In [None]:
class DetectionHead(nn.Module):
    def __init__(self, in_channels, num_classes, num_anchors):
        super(DetectionHead, self).__init__()
        
        # Shared convolutions
        self.shared_conv = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(in_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(in_channels),
            nn.ReLU(),
        )

        # Box regression head
        self.box_head = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=1)  # 4 coordinates per box

        # Class prediction head
        self.class_head = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=1)  # num_classes per box

        # Objectness score head
        self.obj_head = nn.Conv2d(in_channels, num_anchors * 1, kernel_size=1)  # 1 score per box

    def forward(self, x):
        # Pass through shared convolutions
        x = self.shared_conv(x)

        # Predict boxes, classes, and objectness scores
        box_preds = self.box_head(x)
        class_preds = self.class_head(x)
        obj_preds = self.obj_head(x)

        return box_preds, class_preds, obj_preds

In [None]:
class EfficientDet(nn.Module):
    def __init__(self, num_classes, num_anchors):
        super(EfficientDet, self).__init__()

        self.backbone = EfficientNetBackbone()

        # BiFPN
        self.bifpn = BiFPNLayer(out_channels=256)

        # Detection Heads
        self.detection_heads = nn.ModuleList([
            DetectionHead(in_channels=256, num_classes=num_classes, num_anchors=num_anchors),
            DetectionHead(in_channels=256, num_classes=num_classes, num_anchors=num_anchors),
            DetectionHead(in_channels=256, num_classes=num_classes, num_anchors=num_anchors)
        ])  # One head for each level of BiFPN output

    def forward(self, x):
        # Extract features from the backbone
        feature_maps = self.backbone(x)

        # Pass feature maps through BiFPN
        bifpn_outputs = self.bifpn(feature_maps)

        # Collect outputs from detection heads
        outputs = []
        for head, feature_map in zip(self.detection_heads, bifpn_outputs):
            outputs.append(head(feature_map))

        return outputs