In [4]:
!pip install timm --upgrade

Collecting timm
  Downloading timm-0.6.7-py3-none-any.whl (509 kB)
[K     |████████████████████████████████| 509 kB 372 kB/s eta 0:00:01
Installing collected packages: timm
  Attempting uninstall: timm
    Found existing installation: timm 0.4.12
    Uninstalling timm-0.4.12:
      Successfully uninstalled timm-0.4.12
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
segmentation-models-pytorch 0.3.0 requires timm==0.4.12, but you have timm 0.6.7 which is incompatible.[0m
Successfully installed timm-0.6.7


In [1]:
import sys
sys.path.append('..')
from utils import load_config
class Arguments: pass
args=Arguments()
args.config = '../configs/config.yaml'
args.config_workspace = '../configs/config_aia.yaml'

cfg=load_config(args)

Config path 
 config: ../configs/config.yaml 
 config_workspace: ../configs/config_aia.yaml


In [2]:
import segmentation_models_pytorch as smp
import torch.nn as nn
import torch

In [6]:
encoder_name=cfg["smp"]["encoder_name"]
num_cls=cfg["num_cls"]
# https://smp.readthedocs.io/en/latest/_modules/segmentation_models_pytorch/decoders/unetplusplus/model.html#UnetPlusPlus
# model = smp.UnetPlusPlus(encoder_name=encoder_name,
#                          encoder_depth=5,
#                          encoder_weights='imagenet',
#                          decoder_use_batchnorm=True,
#                          decoder_channels=(256, 128, 64, 32, 16),
#                          decoder_attention_type=None,
#                          in_channels=3,
#                          classes=num_cls, activation=None, aux_params=None)

In [4]:
encoder_name

'timm-efficientnet-b2'

In [5]:
pred=model(torch.zeros(1,3,320,320))
pred.shape

torch.Size([1, 10, 320, 320])

In [3]:
import torch.nn.functional as F

In [7]:
F.interpolate(pred,scale_factor=pow(0.5,3),mode="bilinear").shape

  "See the documentation of nn.Upsample for details.".format(mode)
  "The default behavior for interpolate/upsample with float scale_factor changed "


torch.Size([1, 10, 40, 40])

In [24]:
[*zip(*model.named_children())][0]

('encoder', 'decoder', 'segmentation_head')

In [15]:
from typing import Optional, Union, List

from segmentation_models_pytorch.encoders import get_encoder
from segmentation_models_pytorch.base import (
    SegmentationModel,
    SegmentationHead,
    ClassificationHead,
)


class UnetPlusPlusCGMDecoder(smp.decoders.unetplusplus.decoder.UnetPlusPlusDecoder):
    # output depth      ^         ^            ^
    #                  |          |            |
    #         f0 -> out(0,0) -> out(0,1) -> *out(0,2) ->
    #         ^     ^^          ^^^         ^^^^
    #         f1 -> latent(1,1)-> out(1,2) 
    #         ^     ^^            ^^^
    #         f2 -> latent(2,2) 
    #         ^     ^^
    # input-> f3
    def forward(self, *features):

        features = features[1:]  # remove first skip with same spatial resolution
        features = features[::-1]  # reverse channels to start from head of encoder
        # start building dense connections
        dense_x = {}
        for layer_idx in range(len(self.in_channels) - 1):
            for depth_idx in range(self.depth - layer_idx):
                if layer_idx == 0:
                    output = self.blocks[f"x_{depth_idx}_{depth_idx}"](features[depth_idx], features[depth_idx + 1])
                    dense_x[f"x_{depth_idx}_{depth_idx}"] = output
                else:
                    dense_l_i = depth_idx + layer_idx                    
                    cat_features = [dense_x[f"x_{didx}_{dense_l_i}"] for didx in range(depth_idx + 1, dense_l_i + 1)]
                    cat_features = torch.cat(cat_features + [features[dense_l_i + 1]], dim=1)
                    dense_x[f"x_{depth_idx}_{dense_l_i}"] = self.blocks[f"x_{depth_idx}_{dense_l_i}"](
                        dense_x[f"x_{depth_idx}_{dense_l_i-1}"], cat_features
                    )
                    
        dense_x[f"x_{0}_{self.depth}"] = self.blocks[f"x_{0}_{self.depth}"](dense_x[f"x_{0}_{self.depth-1}"])
        return [dense_x[f"x_{0}_{d}"] for d in range(self.depth+1)]


class UnetPlusPlusCGM(SegmentationModel):
    """
    Reference:
        https://arxiv.org/abs/1807.10165
    """

    def __init__(
        self,
        encoder_name: str = "resnet34",
        encoder_depth: int = 5,
        encoder_weights: Optional[str] = "imagenet",
        decoder_use_batchnorm: bool = True,
        decoder_channels: List[int] = (256, 128, 64, 32, 16),
        decoder_attention_type: Optional[str] = None,
        in_channels: int = 3,
        classes: int = 1,
        activation: Optional[Union[str, callable]] = None,
        aux_params: Optional[dict] = None,
    ):
        super().__init__()

        self.encoder = get_encoder(
            encoder_name,
            in_channels=in_channels,
            depth=encoder_depth,
            weights=encoder_weights,
        )
        self.encoder_depth=encoder_depth
        
        self.decoder = UnetPlusPlusCGMDecoder(
            encoder_channels=self.encoder.out_channels,
            decoder_channels=decoder_channels,
            n_blocks=encoder_depth,
            use_batchnorm=decoder_use_batchnorm,
            center=True if encoder_name.startswith("vgg") else False,
            attention_type=decoder_attention_type,
        )
        
        self.pooling_heads=[torch.nn.AvgPool3d(
            kernel_size=[chs//decoder_channels[-1],1,1]
        ) for chs in decoder_channels]
        
        self.segmentation_head = SegmentationHead(
            in_channels=decoder_channels[-1],
            out_channels=classes,
            activation=activation,
            kernel_size=3)

        if aux_params is not None:
            self.classification_head = ClassificationHead(in_channels=self.encoder.out_channels[-1], **aux_params)
        else:
            self.classification_head = None

        self.name = "unetplusplus-{}".format(encoder_name)
        self.initialize()
    def forward(self, x):
        """Sequentially pass `x` trough model`s encoder, decoder and heads"""

        self.check_input_shape(x)

        features = self.encoder(x)
        decoder_output = self.decoder(*features)
        # print([o.shape for o in decoder_output])
        pooled_output = [h(raw_mask) for h,raw_mask in zip(self.pooling_heads,decoder_output)]
        # print([o.shape for o in pooled_output])
        masks=[*map(self.segmentation_head,pooled_output)]
        
        if self.classification_head is not None:
            labels = self.classification_head(features[-1])
            return masks, labels

        return masks

In [16]:
model = UnetPlusPlusCGM(encoder_name=encoder_name,
                         encoder_depth=5,
                         encoder_weights='imagenet',
                         decoder_use_batchnorm=True,
                         decoder_channels=(256, 128, 64, 32, 16),
                         decoder_attention_type=None,
                         in_channels=3,
                         classes=num_cls, activation=None, aux_params=None)

In [17]:
pred=model(torch.zeros(1,3,320,320))

In [18]:
[x.shape for x in pred]

[torch.Size([1, 10, 20, 20]),
 torch.Size([1, 10, 40, 40]),
 torch.Size([1, 10, 80, 80]),
 torch.Size([1, 10, 160, 160]),
 torch.Size([1, 10, 320, 320])]

In [11]:
self.upscore6 = nn.Upsample(scale_factor=32,mode='bilinear')###
self.upscore5 = nn.Upsample(scale_factor=16,mode='bilinear')
self.upscore4 = nn.Upsample(scale_factor=8,mode='bilinear')
self.upscore3 = nn.Upsample(scale_factor=4,mode='bilinear')
self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear')

In [55]:
class And1:
    def __init__(self):
        for i in range(1,5):
            self.__dict__[f"upscore{i}"]=nn.Upsample(scale_factor=2*i,mode='bilinear')
a1=And1()

In [57]:
a1.upscore1(pred[0])

tensor([[[[ 3.4870e-03, -2.7763e-02, -9.0263e-02,  ...,  1.6601e-01,
            1.6265e-01,  1.6098e-01],
          [-1.6006e-01, -1.1437e-01, -2.2983e-02,  ...,  1.6297e-01,
            1.7167e-01,  1.7603e-01],
          [-4.8715e-01, -2.8757e-01,  1.1158e-01,  ...,  1.5690e-01,
            1.8972e-01,  2.0612e-01],
          ...,
          [-3.2354e-01, -2.5185e-01, -1.0848e-01,  ...,  2.5266e-02,
            1.2078e-01,  1.6854e-01],
          [-5.1421e-01, -3.7356e-01, -9.2257e-02,  ...,  1.5943e-01,
            3.1755e-01,  3.9661e-01],
          [-6.0955e-01, -4.3442e-01, -8.4145e-02,  ...,  2.2651e-01,
            4.1593e-01,  5.1065e-01]],

         [[-3.1131e-02,  1.2800e-01,  4.4626e-01,  ...,  3.0857e-01,
            5.4100e-01,  6.5721e-01],
          [-1.2427e-01, -4.8210e-04,  2.4709e-01,  ...,  2.0827e-01,
            4.8761e-01,  6.2728e-01],
          [-3.1054e-01, -2.5745e-01, -1.5127e-01,  ...,  7.6639e-03,
            3.8083e-01,  5.6742e-01],
          ...,
     

In [12]:
upsample2x(pred[0]).shape

torch.Size([1, 10, 40, 40])

In [6]:
from models import smp

In [7]:
from models import smppp

In [1]:
import timm

In [13]:
timm.list_models("*swin*")

['eca_swinnext26ts_256',
 'swin_base_patch4_window7_224',
 'swin_base_patch4_window7_224_in22k',
 'swin_base_patch4_window12_384',
 'swin_base_patch4_window12_384_in22k',
 'swin_large_patch4_window7_224',
 'swin_large_patch4_window7_224_in22k',
 'swin_large_patch4_window12_384',
 'swin_large_patch4_window12_384_in22k',
 'swin_small_patch4_window7_224',
 'swin_tiny_patch4_window7_224',
 'swinnet26t_256',
 'swinnet50ts_256']

In [14]:
timm.list_models("*vit*")

['convit_base',
 'convit_small',
 'convit_tiny',
 'levit_128',
 'levit_128s',
 'levit_192',
 'levit_256',
 'levit_384',
 'vit_base_patch16_224',
 'vit_base_patch16_224_in21k',
 'vit_base_patch16_224_miil',
 'vit_base_patch16_224_miil_in21k',
 'vit_base_patch16_384',
 'vit_base_patch32_224',
 'vit_base_patch32_224_in21k',
 'vit_base_patch32_384',
 'vit_base_r26_s32_224',
 'vit_base_r50_s16_224',
 'vit_base_r50_s16_224_in21k',
 'vit_base_r50_s16_384',
 'vit_base_resnet26d_224',
 'vit_base_resnet50_224_in21k',
 'vit_base_resnet50_384',
 'vit_base_resnet50d_224',
 'vit_huge_patch14_224_in21k',
 'vit_large_patch16_224',
 'vit_large_patch16_224_in21k',
 'vit_large_patch16_384',
 'vit_large_patch32_224',
 'vit_large_patch32_224_in21k',
 'vit_large_patch32_384',
 'vit_large_r50_s32_224',
 'vit_large_r50_s32_224_in21k',
 'vit_large_r50_s32_384',
 'vit_small_patch16_224',
 'vit_small_patch16_224_in21k',
 'vit_small_patch16_384',
 'vit_small_patch32_224',
 'vit_small_patch32_224_in21k',
 'vit_sma

In [15]:
import segmentation_models_pytorch as smp

In [16]:
smp.__version__

'0.3.0'

In [None]:
swinnet26t_256

In [23]:
model = smp.UnetPlusPlus(encoder_name='tu-tf_efficientnet_b6_ns',
                                     encoder_depth=5,
                                     encoder_weights='imagenet',
                                     decoder_use_batchnorm=True,
                                     decoder_channels=(256, 128, 64, 32, 16),
                                     decoder_attention_type=None,
                                     in_channels=3,
                                     classes=10, activation=None, aux_params=None)

Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ns-51548356.pth" to /home/jovyan/.cache/torch/hub/checkpoints/tf_efficientnet_b6_ns-51548356.pth


In [3]:
import sys
sys.path.append('..')
from models import head

In [5]:
out_head=head.MetricLayer(32,10)

In [6]:
import torch

In [8]:
out_head(torch.zeros(1,32,128,128)).shape

torch.Size([1, 10, 128, 128])

In [1]:
import ranger21

In [None]:
ranger21.Ranger21()

In [3]:
import torch.optim

In [4]:
torch.optim.SGD()

torch.optim.sgd.SGD