In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

In [None]:
class AtrousConv(nn.Module):
    def __init__(self, in_channels, out_channels, dilation_rate):
        super(AtrousConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=dilation_rate, dilation=dilation_rate)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))

In [None]:
class DeepLabV1(nn.Module):
    def __init__(self, num_classes):
        super(DeepLabV1, self).__init__()
        
        # Load a pre-trained VGG16 model (use it as a feature extractor)
        vgg16 = models.vgg16(pretrained=True)
        self.features = vgg16.features
        
        # Atrous convolution layers (with different dilation rates)
        self.atrous1 = AtrousConv(512, 512, dilation_rate=6)
        self.atrous2 = AtrousConv(512, 512, dilation_rate=12)
        self.atrous3 = AtrousConv(512, 512, dilation_rate=18)
        
        # Final 1x1 convolution for pixel-wise classification
        self.classifier = nn.Conv2d(512, num_classes, kernel_size=1)
    
    def forward(self, x):
        # Forward pass through VGG16 feature extractor
        x = self.features(x)
        
        # Apply atrous convolutions for multi-scale feature extraction
        x1 = self.atrous1(x)
        x2 = self.atrous2(x)
        x3 = self.atrous3(x)
        
        # Combine the multi-scale features by concatenating
        x = torch.cat([x1, x2, x3], dim=1)  # Concatenate along channel dimension
        
        # Apply the final classifier (1x1 convolution to predict classes)
        x = self.classifier(x)
        
        return x