In [None]:
import torch
import torch.nn as nn
import torchvision.models as models

In [None]:
class DeepLabV3Plus(nn.Module):
    def __init__(self, n_classes=1, backbone='resnet50', output_stride=16):
        super(DeepLabV3Plus, self).__init__()
        self.backbone = getattr(models, backbone)(pretrained=True)

        if output_stride == 16:
            replace_stride_with_dilation = [False, True, True]
        elif output_stride == 8:
            replace_stride_with_dilation = [False, True, False]
        else:
            raise ValueError('Output stride should be 16 or 8')

        for m in self.backbone.layer4.children():
            if isinstance(m, nn.Conv2d):
                m.stride = (1, 1)
                if m.kernel_size == (3, 3):
                    m.dilation = (2, 2)
                    m.padding = (2, 2)
            elif len(replace_stride_with_dilation) > 1:
                replace_stride_with_dilation.pop(0)

        self.aspp = ASPP(2048, [6, 12, 18])
        self.decoder = Decoder(256, n_classes)
        self.aspp = ASPP(2048, [6, 12, 18])
        self.decoder = Decoder(256, n_classes)

    def forward(self, x):
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)

        x = self.backbone.layer1(x)
        low_level_features = x
        x = self.backbone.layer2(x)
        x = self.backbone.layer3(x)
        x = self.backbone.layer4(x)

        x = self.aspp(x)
        x = self.decoder(x, low_level_features)
        x = nn.functional.interpolate(x, scale_factor=4, mode='bilinear', align_corners=True)

        return x

class ASPP(nn.Module):
    def __init__(self, in_channels, atrous_rates):
        super(ASPP, self).__init__()
        out_channels = 256
        modules = []
        modules.append(nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False),
                                      nn.BatchNorm2d(out_channels),
                                      nn.ReLU()))

        rate1, rate2, rate3 = tuple(atrous_rates)
        modules.append(ASPPConv(in_channels, out_channels, rate1))
        modules.append(ASPPConv(in_channels, out_channels, rate2))
        modules.append(ASPPConv(in_channels, out_channels, rate3))
        modules.append(ASPPPooling(in_channels, out_channels))

        self.convs = nn.ModuleList(modules)

        self.project = nn.Sequential(
            nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Dropout(0.5))

    def forward(self, x):
        res = []
        for conv in self.convs:
            res.append(conv(x))
        res = torch.cat(res, dim=1)
        return self.project(res)

class ASPPConv(nn.Sequential):
    def __init__(self, in_channels, out_channels, dilation):
        modules = [
            nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        ]
        super(ASPPConv, self).__init__(*modules)

class ASPPPooling(nn.Sequential):
    def __init__(self, in_channels, out_channels):
        super(ASPPPooling, self).__init__(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU())

    def forward(self, x):
        size = x.shape[-2:]
        for mod in self:
            x = mod(x)
        return nn.functional.interpolate(x, size=size, mode='bilinear', align_corners=False)

class Decoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Decoder, self).__init__()
        self.conv1 = nn.Conv2d(in_channels + 256, 256, 3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(256)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        self.conv2 = nn.Conv2d(256, 256, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(256)
        self.conv3 = nn.Conv2d(256, out_channels, 1)

    def forward(self, x, low_level_features):
        low_level_features = self.dropout(low_level_features)
        x = nn.functional.interpolate(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True)
        x = torch.cat((x, low_level_features), dim=1)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.conv3(x)
        return x