In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

In [None]:
class Backbone(nn.Module):
    def __init__(self, pretrained=True):
        super(Backbone, self).__init__()
        
        # Load a pretrained ResNet50 backbone from torchvision
        resnet = models.resnet50(pretrained=pretrained)
        
        # Use layers before the fully connected layer (resnet.fc)
        self.backbone = nn.Sequential(*list(resnet.children())[:-2])
        
    def forward(self, x):
        # Forward pass through the layers of the backbone, storing outputs at different stages
        features = []
        
        # Pass through the layers of the backbone, storing outputs at different stages
        for name, module in self.backbone.named_children():
            x = module(x)
            if name in ['4', '5', '6']:  # These are stages where features are produced
                features.append(x)
        
        return features

In [None]:
class TopDownPathway(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(TopDownPathway, self).__init__()
        
        # List of convolutional layers to reduce the number of channels after each upsampling
        self.lateral_convs = nn.ModuleList([nn.Conv2d(in_ch, out_channels, kernel_size=1) 
                                            for in_ch in in_channels])
        
    def forward(self, features):
        """
        :param features: List of feature maps from the backbone network
                         Each feature map in the list is of shape (batch_size, channels, height, width)
        :return: List of feature maps after top-down aggregation.
        """
        # Make sure the input list is ordered from shallowest to deepest
        assert len(features) == len(self.lateral_convs)
        
        # Initialize the list to store the output of each stage
        top_down_features = []
        
        # Start with the last (deepest) feature map
        x = features[-1]
        
        # Iterate over the feature maps from deepest to shallowest
        for i in range(len(features) - 2, -1, -1):
            # Upsample the feature map
            x = F.interpolate(x, size=features[i].shape[2:], mode='bilinear', align_corners=False)
            
            # Apply the lateral convolution to match the number of channels
            lateral_x = self.lateral_convs[i](features[i])
            
            # Add the upsampled feature map to the current feature map
            x = x + lateral_x
            
            # Append the processed feature map to the output list
            top_down_features.append(x)
        
        # Reverse the list to maintain the order from shallowest to deepest
        top_down_features.reverse()
        
        return top_down_features

In [None]:
class BottomUpPathway(nn.Module):
    def __init__(self, in_channels, out_channels):
        """
        :param in_channels: List of channel sizes for each level of feature maps.
        :param out_channels: The desired number of output channels for the aggregated features.
        """
        super(BottomUpPathway, self).__init__()
        
        # Lateral convolutions to reduce the number of channels after each upsampling
        self.lateral_convs = nn.ModuleList([nn.Conv2d(in_ch, out_channels, kernel_size=1) 
                                            for in_ch in in_channels])

    def forward(self, features):
        """
        :param features: List of feature maps (from shallow to deep) passed through the network.
        :return: List of feature maps after bottom-up aggregation.
        """
        # Ensure the input list is ordered from shallowest to deepest
        assert len(features) == len(self.lateral_convs)
        
        # Initialize the list to store the output of each stage
        bottom_up_features = []
        
        # Start with the shallowest feature map
        x = features[0]
        
        # Iterate over the feature maps from shallow to deep
        for i in range(1, len(features)):
            # Upsample the feature map
            x = F.interpolate(x, size=features[i].shape[2:], mode='bilinear', align_corners=False)
            
            # Apply the lateral convolution to match the number of channels
            lateral_x = self.lateral_convs[i](features[i])
            
            # Add the upsampled feature map to the current feature map
            x = x + lateral_x
            
            # Append the processed feature map to the output list
            bottom_up_features.append(x)
        
        return bottom_up_features