In [1]:
# basic imports
import random
import numpy as np

# DL library imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet50

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class pyramid_pooling_module(nn.Module):
    def __init__(self, in_channels, out_channels, bin_sizes):
        super(pyramid_pooling_module, self).__init__()
        
        # create pyramid pooling layers for each level
        self.pyramid_pool_layers = []
        for bin_sz in bin_sizes:
            self.pyramid_pool_layers.append(nn.Sequential(
                nn.AdaptiveAvgPool2d(bin_sz),
                nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            ))
        self.pyramid_pool_layers = nn.ModuleList(self.pyramid_pool_layers)

    def forward(self, x):
        x_size = x.size()
        out = [x]
        for layer in self.pyramid_pool_layers:
            out.append(F.interpolate(layer(x), x_size[2:], mode='bilinear', align_corners=True))
        return torch.cat(out, 1)

In [None]:
class PSPNet(nn.Module):
    def __init__(self, in_channels, num_classes, use_aux=False):
        super(PSPNet, self).__init__()
        self.in_channels = in_channels
        self.num_classes = num_classes
                
        # backbone layers
        backbone = resnet50(pretrained=True, replace_stride_with_dilation=[False, True, True])        
        self.initial = nn.Sequential(*list(backbone.children())[:4])
        self.layer1 = backbone.layer1
        self.layer2 = backbone.layer2
        self.layer3 = backbone.layer3
        self.layer4 = backbone.layer4
        
        # Pyramid pooling module components
        ppm_in_channels = int(backbone.fc.in_features)
        self.ppm = pyramid_pooling_module(in_channels=ppm_in_channels, 
                                     out_channels=512, bin_sizes=[1,2,3,6])
        
        # classifier head
        self.cls = nn.Sequential(
            nn.Conv2d(ppm_in_channels * 2, 512, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout2d(p=0.1),
            nn.Conv2d(512,  self.num_classes, kernel_size=1)            
        )
        
        # main branch is composed of PPM + Classifier
        self.main_branch = nn.Sequential(self.ppm, self.cls)
        
        # Define Auxilary branch if specified
        self.use_aux = False
        if(self.training and use_aux):
            self.use_aux = True
            self.aux_branch = nn.Sequential(
                nn.Conv2d( int(ppm_in_channels / 2) , 256, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(256),
                nn.ReLU(inplace=True),
                nn.Dropout2d(p=0.1),
                nn.Conv2d(256, self.num_classes, kernel_size=1)
            )
        
        
    def forward(self, x):
        input_size = x.shape[-2:]
        
        # Pass input through Backbone layers
        x = self.initial(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x_aux = self.layer3(x)
        x = self.layer4(x_aux)
        
        # Get Main branch output
        main_output = self.main_branch(x)
        main_output = F.interpolate(main_output, size=input_size, mode='bilinear')
        
        # If needed, get auxiliary branch output
        if(self.training and self.use_aux):
            aux_output = F.interpolate(self.aux_branch(x_aux), size=input_size, mode='bilinear')
            return main_output, aux_output
        return main_output        

In [None]:
model = PSPNet(in_channels=3, num_classes=2, use_aux=True)
test_input = torch.Tensor(2,3,180, 320)
main_branch_output, aux_input = model(test_input)
print(main_branch_output.shape, aux_input.shape)