In [1]:
import os
import sys
import types

sys.path.append('../..')

from collections import OrderedDict
import math

import torch
from torch import nn

from torchsummary import summary

from segmentation_models_pytorch.unet import Unet
from segmentation_models_pytorch.unet.decoder import DecoderBlock
from segmentation_models_pytorch import encoders

from lib.models.seresnet_mini_custom import se_resnet18_mini_backbone, SEResNetBasicBlock

In [2]:
#from segmentation_models_pytorch.encoders.resnet import ResNetEncoder
#import torchvision

In [3]:
#tmp_model = ResNetEncoder(out_channels=(3, 64, 64, 128, 256, 512), 
#                          depth=5, block=torchvision.models.resnet.BasicBlock, 
#                          layers=[2, 2, 2, 2])

In [4]:
#fs = tmp_model(X)

In [5]:
# [tuple(f.shape) for f in fs]
#[(16, 3, 256, 256),
# (16, 64, 128, 128),
# (16, 64, 64, 64),
# (16, 128, 32, 32),
# (16, 256, 16, 16),
# (16, 512, 8, 8)]

In [6]:
my_model = se_resnet18_mini_backbone()

In [7]:
# my_model = models.se_resnet50(num_classes=1000, pretrained=None)

In [8]:
# summary(my_model, (3, 256, 256), -1, 'cpu')

In [9]:
X = torch.rand((16, 3, 256, 256))

In [10]:
fs = my_model(X)

In [11]:
[tuple(f.shape) for f in fs]

[(16, 3, 256, 256),
 (16, 64, 128, 128),
 (16, 64, 64, 64),
 (16, 64, 32, 32),
 (16, 64, 16, 16),
 (16, 64, 8, 8)]

In [12]:
# my_model

In [13]:
classes=6
decoder=True
labels=True
segmentation=True
mask_activation=None
label_activation=None

In [14]:
encoders.encoders['resnet18']['params']

{'out_channels': (3, 64, 64, 128, 256, 512),
 'block': torchvision.models.resnet.BasicBlock,
 'layers': [2, 2, 2, 2]}

In [15]:
encoders.encoders['seresnet18_mini'] = {
    'encoder': se_resnet18_mini_backbone,
    'params': {

    }
}

In [16]:
# model.encoder.out_channels

In [17]:
model = Unet('seresnet18_mini', encoder_weights=None,
             decoder_channels=(64, 64, 64, 64, 64),
             activation=mask_activation,
             classes=classes,
             aux_params={
                 'classes': classes,
                 'activation': label_activation} if labels else None
            )

In [18]:
tmp = model(X)

torch.Size([16, 64, 16, 16]) torch.Size([16, 64, 16, 16])
torch.Size([16, 64, 32, 32]) torch.Size([16, 64, 32, 32])
torch.Size([16, 64, 64, 64]) torch.Size([16, 64, 64, 64])
torch.Size([16, 64, 128, 128]) torch.Size([16, 64, 128, 128])


In [None]:
tmp[1].shape

In [None]:
def get_model(classes, decoder=True, labels=True, segmentation=True,
              mask_activation=None, label_activation=None):
    class AutoDecoder(nn.Module):
        def __init__(
                self,
                channels,
                use_batchnorm=True,
                attention_type=None,
        ):
            super().__init__()

            in_channels = channels[:0:-1]
            out_channels = channels[-2::-1]

            kwargs = dict(use_batchnorm=use_batchnorm,
                          attention_type=attention_type)
            blocks = [
                DecoderBlock(in_ch, 0, out_ch, **kwargs)
                for in_ch, out_ch in zip(in_channels, out_channels)
            ]
            self.blocks = nn.ModuleList(blocks)

        def forward(self, features):
            x = features

            for i, decoder_block in enumerate(self.blocks):
                x = decoder_block(x)

            x = torch.sigmoid(x)
            # x = torch.tanh(x)

            return x

    def forward(self, x, return_features=False):
        """Sequentially pass `x` trough model`s encoder, decoder and heads"""
        features = self.encoder(x)

        out = ()

        if self.segmentation:
            decoder_output = self.decoder(*features)
            masks = self.segmentation_head(decoder_output)
            out = out + (masks,)

        if self.classification_head is not None:
            labels = self.classification_head(features[-1])
            out = out + (labels,)

        if self.autodecoder is not None:
            decoded = self.autodecoder(features[-1])
            out = out + (decoded,)

        return ((features[-1],) if return_features else ()) + out

    model = Unet('seresnet18_mini', encoder_weights=None,
                 activation=mask_activation,
                 classes=classes,
                 aux_params={
                     'classes': classes,
                     'activation': label_activation} if labels else None)

    model.forward = types.MethodType(forward, model)
    channels = model.encoder.out_channels
    if decoder:
        model.autodecoder = AutoDecoder(channels)
    else:
        model.autodecoder = None

    model.segmentation = segmentation

    return model