In [5]:
#from farabio.models.classification.conv import resnet
import torch
from torch import nn
from typing import Type, Any, Callable, Union, List, Optional, Dict
from collections import OrderedDict
from torch.nn import functional as F
from torchvision import transforms
from torch import Tensor
from farabio.models.classification.conv import resnet

In [None]:
_path_chest = "/home/data/02_SSD4TB/suzy/datasets/public/chest-xray"
_path_dsb = "/home/data/02_SSD4TB/suzy/datasets/public/data-science-bowl-2018"
_path_histo = "/home/data/02_SSD4TB/suzy/datasets/public/histopathologic-cancer-detection"
_path_ranzcr = "/home/data/02_SSD4TB/suzy/datasets/public/ranzcr-clip-catheter-line-classification"
_path_retina = "/home/data/02_SSD4TB/suzy/datasets/public/aptos2019-blindness-detection"

In [6]:
print("hi!")

hi!


In [7]:
class _SimpleSegmentationModel(nn.Module):
    __constants__ = ['aux_classifier']

    def __init__(self, backbone, classifier, aux_classifier=None):
        super(_SimpleSegmentationModel, self).__init__()
        self.backbone = backbone
        self.classifier = classifier

    def forward(self, x):
        input_shape = x.shape[-2:]
        # contract: features is a dict of tensors
        features = self.backbone(x)

        result = OrderedDict()
        x = features["out"]
        x = self.classifier(x)
        x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
        result["out"] = x

        return result

In [8]:
class FCN(_SimpleSegmentationModel):
    pass


class FCNHead(nn.Sequential):
    def __init__(self, in_channels, channels):
        inter_channels = in_channels // 4
        layers = [
            nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(inter_channels),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Conv2d(inter_channels, channels, 1)
        ]

        super(FCNHead, self).__init__(*layers)

In [9]:
class DeepLabV3(_SimpleSegmentationModel):
    pass


class DeepLabHead(nn.Sequential):
    def __init__(self, in_channels, num_classes):
        super(DeepLabHead, self).__init__(
            ASPP(in_channels, [12, 24, 36]),
            nn.Conv2d(256, 256, 3, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, num_classes, 1)
        )


class ASPPConv(nn.Sequential):
    def __init__(self, in_channels, out_channels, dilation):
        modules = [
            nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        ]
        super(ASPPConv, self).__init__(*modules)


class ASPPPooling(nn.Sequential):
    def __init__(self, in_channels, out_channels):
        super(ASPPPooling, self).__init__(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU())

    def forward(self, x):
        size = x.shape[-2:]
        for mod in self:
            x = mod(x)
        return F.interpolate(x, size=size, mode='bilinear', align_corners=False)


class ASPP(nn.Module):
    def __init__(self, in_channels, atrous_rates, out_channels=256):
        super(ASPP, self).__init__()
        modules = []
        modules.append(nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()))

        rates = tuple(atrous_rates)
        for rate in rates:
            modules.append(ASPPConv(in_channels, out_channels, rate))

        modules.append(ASPPPooling(in_channels, out_channels))

        self.convs = nn.ModuleList(modules)

        self.project = nn.Sequential(
            nn.Conv2d(len(self.convs) * out_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Dropout(0.5))

    def forward(self, x):
        res = []
        for conv in self.convs:
            res.append(conv(x))
        res = torch.cat(res, dim=1)
        return self.project(res)

In [10]:
class IntermediateLayerGetter(nn.ModuleDict):
    _version = 2
    __annotations__ = {
        "return_layers": Dict[str, str],
    }

    def __init__(self, model: nn.Module, return_layers: Dict[str, str]) -> None:
        orig_return_layers = return_layers
        return_layers = {str(k): str(v) for k, v in return_layers.items()}
        layers = OrderedDict()
        for name, module in model.named_children():
            layers[name] = module
            if name in return_layers:
                del return_layers[name]
            if not return_layers:
                break

        super(IntermediateLayerGetter, self).__init__(layers)
        self.return_layers = orig_return_layers

    def forward(self, x):
        out = OrderedDict()
        for name, module in self.items():
            x = module(x)
            if name in self.return_layers:
                out_name = self.return_layers[name]
                out[out_name] = x
        return out

In [11]:
def _segm_model(name, backbone_name, num_classes, aux, pretrained_backbone=True):
    if 'resnet' in backbone_name:
        backbone = resnet.__dict__[backbone_name](
                    replace_stride_with_dilation=[False, True, True])
        out_layer = 'layer4'
        out_inplanes = 2048
        aux_layer = 'layer3'
        aux_inplanes = 1024

    return_layers = {out_layer: 'out'}
    if aux:
        return_layers[aux_layer] = 'aux'
    backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)

    aux_classifier = None

    model_map = {
        'deeplabv3': (DeepLabHead, DeepLabV3),
        'fcn': (FCNHead, FCN),
    }
    classifier = model_map[name][0](out_inplanes, num_classes)
    base_model = model_map[name][1]

    model = base_model(backbone, classifier, aux_classifier)
    return model

In [12]:
def _load_model(arch_type, backbone, num_classes, aux_loss, **kwargs):
    model = _segm_model(arch_type, backbone, num_classes, aux_loss, **kwargs)
    return model

In [20]:
def fcn_resnet50(num_classes=21, aux_loss=None, **kwargs):
    return _load_model('fcn', 'resnet50', num_classes, aux_loss, **kwargs)

In [13]:
def deeplabv3_resnet50(num_classes=21, aux_loss=None, **kwargs):
    return _load_model('deeplabv3', 'resnet50', num_classes, aux_loss, **kwargs)

In [14]:
model = deeplabv3_resnet50()

In [21]:
model = fcn_resnet50()

In [15]:
example = torch.Tensor(1,3,256,256)
preprocess = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

In [16]:
input_tensor = preprocess(example)

In [17]:
model(input_tensor)

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 256, 1, 1])