In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np

import os
import tempfile

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
torchvision.models.vgg16(pretrained=True)

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "


VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [3]:
class VGG16_Backbone(nn.Module):
    def __init__(self, dropout=None):
        super(VGG16_Backbone, self).__init__()
        original = torchvision.models.vgg16(pretrained=True)

        layers = []
        block_counter = 1
        for idx, layer in enumerate(original.features.children()):
            if isinstance(layer, nn.Conv2d):
                if block_counter == 4:
                    layer = nn.Conv2d(layer.in_channels, layer.out_channels, layer.kernel_size,
                                      padding=2, dilation=2, bias=layer.bias is not None)
                elif block_counter == 5:
                    layer = nn.Conv2d(layer.in_channels, layer.out_channels, layer.kernel_size,
                                      padding=4, dilation=4, bias=layer.bias is not None)

            if not (isinstance(layer, nn.MaxPool2d) and (block_counter == 4 or block_counter == 5)):
                    layers.append(layer)

            if isinstance(layer, nn.MaxPool2d):
                block_counter += 1

        self.backbone = nn.Sequential(*layers)

        if dropout is not None:
            for idx, layer in enumerate(self.backbone.children()):
                if isinstance(layer, nn.Conv2d):
                    self.backbone[idx] = nn.Sequential(layer, nn.Dropout(dropout))

    def forward(self, x):
        return self.backbone(x)

In [4]:
class ASPPModule(nn.Module):
    def __init__(self, in_channels, out_channels, dilation):
        super(ASPPModule, self).__init__()
        self.padding = dilation
        self.kernel_size = 3
        self.dilation = dilation
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=self.kernel_size, padding=0, dilation=dilation, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
    
    def forward(self, x):
        padding = ((self.kernel_size - 1) * self.dilation) // 2
        x = F.pad(x, (padding, padding, padding, padding))
        x = self.conv(x)
        x = self.bn(x)
        x = F.relu(x)
        return x

In [5]:
class DeepLabV3(nn.Module):
    def __init__(self, num_classes=5, backbone='vgg16', activation=None):
        super(DeepLabV3, self).__init__()
        assert backbone in ['vgg16']

        if backbone == 'vgg16':
            self.backbone = VGG16_Backbone()

        self.aspp1 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )

        self.aspp_modules = nn.ModuleList([
            ASPPModule(512, 256, dilation=12),
            ASPPModule(512, 256, dilation=24),
            ASPPModule(512, 256, dilation=36)
        ])

        self.global_pooling = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Conv2d(512, 256, kernel_size=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )

        self.concat = nn.Sequential(
            nn.Conv2d(1280, 256, kernel_size=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Dropout(0.5)
        )

        self.project = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Dropout(0.5)
        )
        
        
        self.post_project = nn.Sequential(
            nn.ZeroPad2d(1),
            nn.Conv2d(256, 256, kernel_size=3, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, num_classes, kernel_size=1)
        )
        

        self.activation = activation

    def forward(self, x):
        img_size = (x.shape[2], x.shape[3])
        
        x = self.backbone(x)
        x1 = self.aspp1(x)

        aspp_outputs = [x1]
        for aspp_module in self.aspp_modules:
            aspp_outputs.append(aspp_module(x))

        x5 = self.global_pooling(x) #x5: torch.Size([batch, 256, 1, 1])
        x5 = nn.Upsample(size=x.size()[2:], mode='bilinear', align_corners=False)(x5) #x5: torch.Size([batch, 256, 64, 64])
        aspp_outputs.append(x5)
        
        x = torch.cat(aspp_outputs, dim=1)
        x = self.concat(x)
        x = self.project(x)
        x = self.post_project(x)
        
        x = nn.Upsample(size = img_size, mode='bilinear', align_corners=False)(x)
        
        if self.activation is not None:
            x = self.activation(x)

        return x

In [6]:
deeplabv3 = DeepLabV3()

In [7]:
x = torch.rand((2, 3, 512, 512)) # bacth cannot be 1 because of BN2d/3d

In [8]:
deeplabv3(x)

tensor([[[[-0.0890, -0.0890, -0.0890,  ...,  0.8531,  0.8531,  0.8531],
          [-0.0890, -0.0890, -0.0890,  ...,  0.8531,  0.8531,  0.8531],
          [-0.0890, -0.0890, -0.0890,  ...,  0.8531,  0.8531,  0.8531],
          ...,
          [ 0.9449,  0.9449,  0.9449,  ...,  0.8082,  0.8082,  0.8082],
          [ 0.9449,  0.9449,  0.9449,  ...,  0.8082,  0.8082,  0.8082],
          [ 0.9449,  0.9449,  0.9449,  ...,  0.8082,  0.8082,  0.8082]],

         [[-0.0853, -0.0853, -0.0853,  ..., -0.0237, -0.0237, -0.0237],
          [-0.0853, -0.0853, -0.0853,  ..., -0.0237, -0.0237, -0.0237],
          [-0.0853, -0.0853, -0.0853,  ..., -0.0237, -0.0237, -0.0237],
          ...,
          [-0.2633, -0.2633, -0.2633,  ...,  0.2416,  0.2416,  0.2416],
          [-0.2633, -0.2633, -0.2633,  ...,  0.2416,  0.2416,  0.2416],
          [-0.2633, -0.2633, -0.2633,  ...,  0.2416,  0.2416,  0.2416]],

         [[ 0.3049,  0.3049,  0.3049,  ...,  0.4436,  0.4436,  0.4436],
          [ 0.3049,  0.3049,  