In [1]:
import timm

  from .autonotebook import tqdm as notebook_tqdm


In [16]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

from timm.models.layers import trunc_normal_

def init_weights(m):
    if isinstance(m, nn.Linear):
        trunc_normal_(m.weight, std=0.02)
        if isinstance(m, nn.Linear) and m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.LayerNorm):
        nn.init.constant_(m.bias, 0)
        nn.init.constant_(m.weight, 1.0)

class DecoderLinear(nn.Module):
    def __init__(self, n_cls, patch_size, d_encoder):
        super().__init__()

        self.d_encoder = d_encoder
        self.patch_size = patch_size
        self.n_cls = n_cls

        self.head = nn.Linear(self.d_encoder, n_cls)
        self.apply(init_weights)

    @torch.jit.ignore
    def no_weight_decay(self):
        return set()

    def forward(self, x, im_size):
        H, W = im_size
        GS = H // self.patch_size
        x = self.head(x)
        x = rearrange(x, "b (h w) c -> b c h w", h=GS)

        return x
    
import torch
import torch.nn as nn
import torch.nn.functional as F

class Segmenter(nn.Module):
    def __init__(
        self,
        encoder,
        decoder,
        n_cls,
    ):
        super().__init__()
        self.n_cls = n_cls
        self.patch_size = encoder.patch_size
        self.encoder = encoder
        self.decoder = decoder

    @torch.jit.ignore
    def no_weight_decay(self):
        def append_prefix_no_weight_decay(prefix, module):
            return set(map(lambda x: prefix + x, module.no_weight_decay()))

        nwd_params = append_prefix_no_weight_decay("encoder.", self.encoder).union(
            append_prefix_no_weight_decay("decoder.", self.decoder)
        )
        return nwd_params

    def forward(self, im):
        H_ori, W_ori = im.size(2), im.size(3)
        #im = padding(im, self.patch_size)
        H, W = im.size(2), im.size(3)

        #x = self.encoder(im, return_features=True)
        x = self.encoder.forward_features(im)
        

        # remove CLS/DIST tokens for decoding
        #num_extra_tokens = 1 + self.encoder.distilled
        #x = x[:, num_extra_tokens:]
        x = x[:,self.encoder.num_prefix_tokens:,...]

        masks = self.decoder(x, (H, W))

        masks = F.interpolate(masks, size=(H, W), mode="bilinear")
        #masks = unpadding(masks, (H_ori, W_ori))

        return masks

#    def get_attention_map_enc(self, im, layer_id):
#        return self.encoder.get_attention_map(im, layer_id)
#
#    def get_attention_map_dec(self, im, layer_id):
#        x = self.encoder(im, return_features=True)
#
#        # remove CLS/DIST tokens for decoding
#        num_extra_tokens = 1 + self.encoder.distilled
#        x = x[:, num_extra_tokens:]
#
#        return self.decoder.get_attention_map(x, layer_id)
    
encoder = timm.create_model('vit_base_patch8_224_dino', pretrained=True)
encoder.patch_size = 8
decoder = DecoderLinear(n_cls=10, d_encoder=encoder.embed_dim, patch_size=8)
model = Segmenter(encoder,decoder,n_cls=10)

In [17]:
x = torch.randn(1,3,224,224)
model(x).shape

torch.Size([1, 10, 224, 224])