In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

In [None]:
class ASPPModule(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ASPPModule, self).__init__()

        # 1x1 convolution
        self.aspp1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(out_channels)
        
        # 3x3 dilated convolution with rate=6
        self.aspp2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=6, dilation=6)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        # 3x3 dilated convolution with rate=12
        self.aspp3 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=12, dilation=12)
        self.bn3 = nn.BatchNorm2d(out_channels)
        
        # 3x3 dilated convolution with rate=18
        self.aspp4 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=18, dilation=18)
        self.bn4 = nn.BatchNorm2d(out_channels)

        # Global average pooling
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.bn5 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        # Apply each dilated convolution with batch normalization and relu activation
        x1 = F.relu(self.bn1(self.aspp1(x)))
        x2 = F.relu(self.bn2(self.aspp2(x)))
        x3 = F.relu(self.bn3(self.aspp3(x)))
        x4 = F.relu(self.bn4(self.aspp4(x)))

        # Global average pooling and a 1x1 convolution to reduce channels
        x5 = F.relu(self.bn5(self.conv1x1(self.global_avg_pool(x))))

        # Upsample global average pooling result to match the spatial size of the other branches
        x5 = F.interpolate(x5, size=x4.shape[2:], mode='bilinear', align_corners=False)

        # Concatenate all the outputs from each branch
        x = torch.cat([x1, x2, x3, x4, x5], dim=1)

        return x