In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

In [None]:
class Backbone(nn.Module):
    def __init__(self, pretrained=True):
        super(Backbone, self).__init__()
        
        # Load a pretrained ResNet50 backbone from torchvision
        resnet = models.resnet50(pretrained=pretrained)
        
        # Use layers before the fully connected layer (resnet.fc)
        self.backbone = nn.Sequential(*list(resnet.children())[:-2])
        
    def forward(self, x):
        # Forward pass through the layers of the backbone, storing outputs at different stages
        features = []
        
        # Pass through the layers of the backbone, storing outputs at different stages
        for name, module in self.backbone.named_children():
            x = module(x)
            if name in ['4', '5', '6']:  # These are stages where features are produced
                features.append(x)
        
        return features

In [None]:
class TopDownPathway(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(TopDownPathway, self).__init__()
        
        # List of convolutional layers to reduce the number of channels after each upsampling
        self.lateral_convs = nn.ModuleList([nn.Conv2d(in_ch, out_channels, kernel_size=1) 
                                            for in_ch in in_channels])
        
    def forward(self, features):
        """
        :param features: List of feature maps from the backbone network
                         Each feature map in the list is of shape (batch_size, channels, height, width)
        :return: List of feature maps after top-down aggregation.
        """
        # Make sure the input list is ordered from shallowest to deepest
        assert len(features) == len(self.lateral_convs)
        
        # Initialize the list to store the output of each stage
        top_down_features = []
        
        # Start with the last (deepest) feature map
        x = features[-1]
        
        # Iterate over the feature maps from deepest to shallowest
        for i in range(len(features) - 2, -1, -1):
            # Upsample the feature map
            x = F.interpolate(x, size=features[i].shape[2:], mode='bilinear', align_corners=False)
            
            # Apply the lateral convolution to match the number of channels
            lateral_x = self.lateral_convs[i](features[i])
            
            # Add the upsampled feature map to the current feature map
            x = x + lateral_x
            
            # Append the processed feature map to the output list
            top_down_features.append(x)
        
        # Reverse the list to maintain the order from shallowest to deepest
        top_down_features.reverse()
        
        return top_down_features

In [None]:
class BottomUpPathway(nn.Module):
    def __init__(self, in_channels, out_channels):
        """
        :param in_channels: List of channel sizes for each level of feature maps.
        :param out_channels: The desired number of output channels for the aggregated features.
        """
        super(BottomUpPathway, self).__init__()
        
        # Lateral convolutions to reduce the number of channels after each upsampling
        self.lateral_convs = nn.ModuleList([nn.Conv2d(in_ch, out_channels, kernel_size=1) 
                                            for in_ch in in_channels])

    def forward(self, features):
        """
        :param features: List of feature maps (from shallow to deep) passed through the network.
        :return: List of feature maps after bottom-up aggregation.
        """
        # Ensure the input list is ordered from shallowest to deepest
        assert len(features) == len(self.lateral_convs)
        
        # Initialize the list to store the output of each stage
        bottom_up_features = []
        
        # Start with the shallowest feature map
        x = features[0]
        
        # Iterate over the feature maps from shallow to deep
        for i in range(1, len(features)):
            # Upsample the feature map
            x = F.interpolate(x, size=features[i].shape[2:], mode='bilinear', align_corners=False)
            
            # Apply the lateral convolution to match the number of channels
            lateral_x = self.lateral_convs[i](features[i])
            
            # Add the upsampled feature map to the current feature map
            x = x + lateral_x
            
            # Append the processed feature map to the output list
            bottom_up_features.append(x)
        
        return bottom_up_features

In [None]:
class AdaptiveFeaturePooling(nn.Module):
    def __init__(self, output_size):
        """
        :param output_size: The target output size (e.g., (1, 1) for global pooling or (7, 7) for a detailed representation)
        """
        super(AdaptiveFeaturePooling, self).__init__()
        self.output_size = output_size

    def forward(self, x):
        """
        :param x: Input tensor (batch_size, channels, height, width)
        :return: Pooled output of the target size
        """
        # Apply adaptive average pooling
        x = F.adaptive_avg_pool2d(x, self.output_size)
        return x

In [None]:
class BoxBranch(nn.Module):
    def __init__(self, in_channels, num_anchors):
        """
        :param in_channels: Number of input channels from the feature map.
        :param num_anchors: Number of anchor boxes per location.
        """
        super(BoxBranch, self).__init__()

        # Convolutional layers to process the feature map before regressing bounding boxes
        self.conv1 = nn.Conv2d(in_channels, 256, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        
        # Final layer to predict bounding box coordinates (dx, dy, dw, dh)
        # For each anchor, we predict 4 values (dx, dy, dw, dh)
        self.box_pred = nn.Conv2d(256, num_anchors * 4, kernel_size=3, padding=1)

    def forward(self, x):
        """
        :param x: Feature map from the backbone or a subsequent layer (batch_size, channels, height, width)
        :return: Predicted bounding box offsets (dx, dy, dw, dh) for each anchor.
        """
        # Pass through convolutional layers
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))

        # Output bounding box predictions (dx, dy, dw, dh)
        box_pred = self.box_pred(x)

        # Reshape the predictions: (batch_size, num_anchors*4, height, width)
        # num_anchors is the number of anchor boxes per spatial location
        box_pred = box_pred.permute(0, 2, 3, 1).contiguous()
        
        # Reshape to (batch_size, height, width, num_anchors, 4)
        # This makes it easier to match with the anchors for bounding box regression
        box_pred = box_pred.view(x.size(0), x.size(2), x.size(3), -1, 4)
        
        return box_pred

In [None]:
class ClassBranch(nn.Module):
    def __init__(self, in_channels, num_anchors, num_classes):
        """
        Classifies the object in each anchor box location.
        
        :param in_channels: The number of input channels from the feature map.
                             Typically, this would be the number of channels from the output of 
                             the Top-Down or Bottom-Up Pathways.
        :param num_anchors: The number of anchor boxes per spatial location in the feature map.
                             For example, there might be 9 anchors for each grid location in the feature map.
        :param num_classes: The number of object classes to predict, including the background class.
                            This should be the number of classes in the dataset (e.g., 80 for COCO).
        """
        super(ClassBranch, self).__init__()

        # Convolutional layers to process the feature map before predicting class scores
        # First convolution layer (with ReLU activation)
        self.conv1 = nn.Conv2d(in_channels, 256, kernel_size=3, padding=1)
        
        # Second convolution layer (with ReLU activation)
        self.conv2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        
        # Final layer to predict class scores for each anchor box.
        # For each anchor box, we predict `num_anchors * num_classes` values (e.g., 9 anchors * 80 classes = 720 output values).
        self.class_pred = nn.Conv2d(256, num_anchors * num_classes, kernel_size=3, padding=1)

    def forward(self, x):
        """
        Forward pass through the classification branch.
        
        :param x: The input feature map, which has the shape (batch_size, channels, height, width)
                  and comes from either the Top-Down or Bottom-Up Pathway.
        
        :return: The predicted class scores for each anchor at every spatial location.
                 The output tensor has the shape (batch_size, height, width, num_anchors, num_classes).
                 This tensor contains the predicted probabilities for each anchor at each grid cell.
        """
        # Apply the first convolution followed by ReLU activation
        x = F.relu(self.conv1(x))
        
        # Apply the second convolution followed by ReLU activation
        x = F.relu(self.conv2(x))
        
        # Predict the class scores for each anchor using the final convolution layer
        class_pred = self.class_pred(x)
        
        # Permute the output to shape (batch_size, height, width, num_anchors * num_classes)
        # This is necessary because we want the class predictions in the format (batch_size, H, W, num_anchors, num_classes)
        class_pred = class_pred.permute(0, 2, 3, 1).contiguous()
        
        # Reshape the output to (batch_size, height, width, num_anchors, num_classes)
        # This makes it easier to match the predictions with the anchors for classification.
        class_pred = class_pred.view(x.size(0), x.size(2), x.size(3), -1, class_pred.size(1) // x.size(2) // x.size(3))
        
        return class_pred


In [None]:
class PANet(nn.Module):
    def __init__(self, num_classes, num_anchors=9):
        """
        PANet (Path Aggregation Network) for object detection.
        
        :param num_classes: The number of object classes (including the background class).
        :param num_anchors: The number of anchor boxes per spatial location.
        """
        super(PANet, self).__init__()
        
        # Backbone: Feature extraction (e.g., ResNet50)
        self.backbone = Backbone(pretrained=True)
        
        # Top-Down Pathway (TDP)
        self.top_down = TopDownPathway([1024, 512, 256], 256)  # Example input channels sizes
        
        # Bottom-Up Pathway (BUP)
        self.bottom_up = BottomUpPathway([256, 512, 1024], 256)  # Example input channels sizes
        
        # Adaptive Feature Pooling to ensure fixed output sizes
        self.adaptive_pool = AdaptiveFeaturePooling(output_size=(1, 1))
        
        # Box Branch: Predict bounding box offsets
        self.box_branch = BoxBranch(256, num_anchors)
        
        # Class Branch: Predict object class scores
        self.class_branch = ClassBranch(256, num_anchors, num_classes)

    def forward(self, x):
        """
        Forward pass through the PANet network.
        
        :param x: The input tensor (batch_size, 3, height, width) representing the input image.
        :return: A tuple containing:
                 - `box_preds`: The predicted bounding box coordinates (dx, dy, dw, dh) for each anchor.
                 - `class_preds`: The predicted class scores for each anchor.
        """
        # Step 1: Feature extraction from the backbone (e.g., ResNet50)
        features = self.backbone(x)
        
        # Step 2: Apply Top-Down Pathway aggregation
        top_down_features = self.top_down(features)
        
        # Step 3: Apply Bottom-Up Pathway aggregation
        bottom_up_features = self.bottom_up(features)
        
        # Step 4: Apply Adaptive Feature Pooling to all aggregated feature maps
        pooled_top_down = [self.adaptive_pool(f) for f in top_down_features]
        pooled_bottom_up = [self.adaptive_pool(f) for f in bottom_up_features]
        
        # Step 5: Predict bounding box offsets (Box Branch)
        box_preds = [self.box_branch(f) for f in pooled_top_down + pooled_bottom_up]
        
        # Step 6: Predict class scores (Class Branch)
        class_preds = [self.class_branch(f) for f in pooled_top_down + pooled_bottom_up]
        
        return box_preds, class_preds