In [7]:
#from farabio.models.classification.conv import resnet
import torch
from torch import nn
from typing import Type, Any, Callable, Union, List, Optional, Dict
from collections import OrderedDict
from torch.nn import functional as F
from torchvision import transforms
from torch import Tensor
from farabio.models.classification.conv import resnet #######???
from farabio.models._utils import IntermediateLayerGetter

In [8]:
class _SimpleSegmentationModel(nn.Module):
    __constants__ = ['aux_classifier']

    def __init__(self, backbone, classifier, aux_classifier=None):
        super(_SimpleSegmentationModel, self).__init__()
        self.backbone = backbone
        self.classifier = classifier
        self.aux_classifier = aux_classifier

    def forward(self, x):
        input_shape = x.shape[-2:]
        # contract: features is a dict of tensors
        features = self.backbone(x)

        result = OrderedDict()
        x = features["out"]
        x = self.classifier(x)
        x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
        result["out"] = x

        if self.aux_classifier is not None:
            x = features["aux"]
            x = self.aux_classifier(x)
            x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
            result["aux"] = x

        return result

In [9]:
class DeepLabV3(_SimpleSegmentationModel):
    pass


class DeepLabHead(nn.Sequential):
    def __init__(self, in_channels, num_classes):
        super(DeepLabHead, self).__init__(
            ASPP(in_channels, [12, 24, 36]),
            nn.Conv2d(256, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, num_classes, 1)
        )


class ASPPConv(nn.Sequential):
    def __init__(self, in_channels, out_channels, dilation):
        modules = [
            nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        ]
        super(ASPPConv, self).__init__(*modules)


class ASPPPooling(nn.Sequential):
    def __init__(self, in_channels, out_channels):
        super(ASPPPooling, self).__init__(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU())

    def forward(self, x):
        size = x.shape[-2:]
        for mod in self:
            x = mod(x)
        return F.interpolate(x, size=size, mode='bilinear', align_corners=False)


class ASPP(nn.Module):
    def __init__(self, in_channels, atrous_rates, out_channels=256):
        super(ASPP, self).__init__()
        modules = []
        modules.append(nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()))

        rates = tuple(atrous_rates)
        for rate in rates:
            modules.append(ASPPConv(in_channels, out_channels, rate))

        modules.append(ASPPPooling(in_channels, out_channels))

        self.convs = nn.ModuleList(modules)

        self.project = nn.Sequential(
            nn.Conv2d(len(self.convs) * out_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Dropout(0.5)
        )

    def forward(self, x):
        res = []
        for conv in self.convs:
            res.append(conv(x))
        res = torch.cat(res, dim=1)
        return self.project(res)

In [10]:
def _load_model(arch_type, backbone, num_classes, aux_loss, **kwargs):
    model = _segm_model(arch_type, backbone, num_classes, aux_loss, **kwargs)

    return model

In [11]:
def deeplabv3_resnet50(num_classes=21, aux_loss=None, **kwargs):
    
    return _load_model('deeplabv3', 'resnet50', num_classes, aux_loss, **kwargs)

In [12]:
def deeplabv3_resnet101(num_classes=21, aux_loss=None, **kwargs):
    return _load_model('deeplabv3', 'resnet101', num_classes, aux_loss, **kwargs)

In [13]:
def _segm_model(name, backbone_name, num_classes, aux):
    if 'resnet' in backbone_name:
        backbone = resnet.__dict__[backbone_name](
            replace_stride_with_dilation=[False, True, True])
        out_layer = 'layer4'
        out_inplanes = 2048
        aux_layer = 'layer3'
        aux_inplanes = 1024
    
    return_layers = {out_layer: 'out'}
    if aux:
        return_layers[aux_layer] = 'aux'
    backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)

    aux_classifier = None
    if aux:
        aux_classifier = FCNHead(aux_inplanes, num_classes)

    model_map = {
        'deeplabv3': (DeepLabHead, DeepLabV3)
        #'fcn': (FCNHead, FCN),
    }
    classifier = model_map[name][0](out_inplanes, num_classes)
    base_model = model_map[name][1]

    model = base_model(backbone, classifier, aux_classifier)
    return model

In [14]:
model = deeplabv3_resnet101()
example = torch.Tensor(4,3,256,256)
preprocess = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
input_tensor = preprocess(example)
y = model(input_tensor)

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


In [16]:
model

DeepLabV3(
  (backbone): IntermediateLayerGetter(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Se

In [10]:
#from torchvision.models.segmentation import deeplabv3_resnet50

In [13]:
#deeplabv3_resnet50

model = deeplabv3_resnet50()
example = torch.Tensor(4,3,256,256)
preprocess = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
input_tensor = preprocess(example)
y = model(input_tensor)

In [14]:
y

OrderedDict([('out',
              tensor([[[[ 0.3482,  0.3482,  0.3482,  ..., -0.5609, -0.5609, -0.5609],
                        [ 0.3482,  0.3482,  0.3482,  ..., -0.5609, -0.5609, -0.5609],
                        [ 0.3482,  0.3482,  0.3482,  ..., -0.5609, -0.5609, -0.5609],
                        ...,
                        [ 0.1470,  0.1470,  0.1470,  ..., -0.3380, -0.3380, -0.3380],
                        [ 0.1470,  0.1470,  0.1470,  ..., -0.3380, -0.3380, -0.3380],
                        [ 0.1470,  0.1470,  0.1470,  ..., -0.3380, -0.3380, -0.3380]],
              
                       [[ 0.5416,  0.5416,  0.5416,  ...,  0.0526,  0.0526,  0.0526],
                        [ 0.5416,  0.5416,  0.5416,  ...,  0.0526,  0.0526,  0.0526],
                        [ 0.5416,  0.5416,  0.5416,  ...,  0.0526,  0.0526,  0.0526],
                        ...,
                        [-0.2868, -0.2868, -0.2868,  ...,  0.2606,  0.2606,  0.2606],
                        [-0.2868, -0.2868, -0