In [1]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

device(type='cuda')

In [3]:
def count_params(model):
    return str(sum(p.numel() for p in model.parameters() if p.requires_grad) / 1000000) + "M params"

In [5]:
vits8 = torch.hub.load('facebookresearch/dino:main', 'dino_vits8')

Using cache found in /home/goswami.p/.cache/torch/hub/facebookresearch_dino_main


In [12]:
count_params(vits8)

'21.670272M params'

In [139]:
# vits8

In [13]:
pretrained_ckpt = torch.load("../../dino_vit_pretrained/dino_deitsmall8_pretrain.pth")

In [14]:
vits8.load_state_dict(pretrained_ckpt)

<All keys matched successfully>

In [59]:
img = torch.randn(2, 3, 256, 256)

In [100]:
features = vits8(img)
features.shape

torch.Size([2, 384])

In [110]:
len(vits8.blocks)

12

In [115]:
len(vits8.blocks) - 11

1

In [132]:
feature_maps = vits8.get_intermediate_layers(img, n=1)
feature_maps[0].shape

torch.Size([2, 3073, 384])

In [137]:
feature = feature_maps[0][:,1:,:]
feature.shape

torch.Size([2, 3072, 384])

In [138]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from ezflow.encoder import ENCODER_REGISTRY
from ezflow.config import configurable

In [158]:
_ENCODER = torch.hub.load('facebookresearch/dino:main', 'dino_vits8')

Using cache found in /home/goswami.p/.cache/torch/hub/facebookresearch_dino_main


In [159]:
# @ENCODER_REGISTRY.register()
class DinoVITS8(nn.Module):
    """
    This class is a wrapper for the DinoViT model without the classification head
    
    """

    @configurable
    def __init__(
        self,
        freeze=True,
        pretrained_ckpt_path=None
    ):
        
        super(DinoVITS8, self).__init__()
        
        self.freeze = freeze
        self.feature_extractor = _ENCODER
        
        if pretrained_ckpt_path is not None:
            self.feature_extractor.load_state_dict(
                torch.load(pretrained_ckpt_path)
            )
            print(f"Loaded Dino ViT S/8 pretrained checkpoint from {pretrained_ckpt_path}\n")

    @classmethod
    def from_config(self, cfg):
        return {
            "freeze": cfg.FREEZE,
            "pretrained_ckpt_path": cfg.PRETRAINED_CKPT_PATH
        }

    def forward(self, input):
        """
        Forward pass
        
        """
        _,c,h,w = input.shape

        if self.freeze:
            self.eval()
            self.feature_extractor.eval()
        
        output = self.feature_extractor.get_intermediate_layers(input, n=1)[0]
        
        # remove cls token
        output = output[:,1:,:]
        output = output.permute(0,2,1)

        h, w = int(h/8), int(w/8)
        b, c, _ = output.shape 

        output = output.reshape(b, c, h, w)

        return output

In [160]:
encoder = DinoVITS8(
    freeze=True, 
    pretrained_ckpt_path="../../dino_vit_pretrained/dino_deitsmall8_pretrain.pth"
)

Loaded Dino ViT S/8 pretrained checkpoint from ../../dino_vit_pretrained/dino_deitsmall8_pretrain.pth



In [152]:
img = torch.randn(2,3,368,496)

In [153]:
feat_map = encoder(img)
feat_map.shape

torch.Size([2, 384, 46, 62])

___
## HuggingFace DINO ViT S/8

In [4]:
from transformers import ViTModel, ViTConfig

In [5]:
from transformers.models.vit.modeling_vit import ViTEncoder, ViTSelfAttention

In [6]:
configuration = ViTConfig()
configuration

ViTConfig {
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "image_size": 224,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "patch_size": 16,
  "qkv_bias": true,
  "transformers_version": "4.22.2"
}

In [7]:
configuration.patch_size=8
configuration.hidden_size=128
configuration.num_hidden_layers=12
configuration.num_attention_heads=4
configuration.intermediate_size=768

In [8]:
configuration

ViTConfig {
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 128,
  "image_size": 224,
  "initializer_range": 0.02,
  "intermediate_size": 768,
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 4,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "patch_size": 8,
  "qkv_bias": true,
  "transformers_version": "4.22.2"
}

In [20]:
# encoder = ViTModel(configuration, add_pooling_layer=False)
encoder = ViTModel(configuration, add_pooling_layer=False)

In [21]:
count_params(encoder)

'3.294336M params'

In [22]:
img = torch.randn(2,3,384, 512)

In [23]:
feats = encoder(img, interpolate_pos_encoding=True).last_hidden_state

  "The default behavior for interpolate/upsample with float scale_factor changed "


In [24]:
feats.shape

torch.Size([2, 3073, 128])

In [4]:
import timm
timm.list_models("vit_tiny*")

['vit_tiny_patch16_224',
 'vit_tiny_patch16_224_in21k',
 'vit_tiny_patch16_384',
 'vit_tiny_r_s16_p8_224',
 'vit_tiny_r_s16_p8_224_in21k',
 'vit_tiny_r_s16_p8_384']

In [201]:
model = timm.create_model('vit_tiny_patch16_224')

In [202]:
count_params(model)

'5.717416M params'

In [203]:
model.embed_dim

192

In [9]:
from timm.models.vision_transformer import _create_vision_transformer as timm_create_vit

In [226]:
model = timm_create_vit('vit_tiny_patch16_224',patch_size=16, embed_dim=128, depth=12, num_heads=4)

In [227]:
count_params(model)

'2.632296M params'

In [95]:
# model

___
## Custom ViT Tiny

In [25]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from ezflow.encoder import ENCODER_REGISTRY,build_encoder
from ezflow.modules import BaseModule
from ezflow.config import configurable

In [26]:
from transformers import ViTModel, ViTConfig

In [59]:
# @ENCODER_REGISTRY.register()
class ViTEncoder(nn.Module):
    """
    This class implements the Swin Transformer without classification head.
    """

    @configurable
    def __init__(
        self,
        in_channels=3,
        embedding_channels=96,
        depths=(2, 2),
        input_resolution=(256, 256),
        number_of_heads=(3, 6, 12, 24),
        intermediate_size: int = 768,
        patch_size: int = 4,
        ff_feature_ratio: int = 4,
        dropout: float = 0.0,
        dropout_attention: float = 0.0,
        dropout_path: float = 0.2,
        use_checkpoint: bool = False,
        sequential_self_attention: bool = False,
    ) -> None:
        """
        """
        # Call super constructor
        super(ViTEncoder, self).__init__()

        configuration = ViTConfig()
        
        configuration.patch_size=patch_size
        configuration.hidden_size=embedding_channels
        configuration.num_hidden_layers=depths
        configuration.num_attention_heads=number_of_heads
        configuration.intermediate_size=intermediate_size
        

        self.vit_feature_extractor = ViTModel(configuration, add_pooling_layer=False)

    @classmethod
    def from_config(self, cfg):
        return {
            "in_channels": cfg.IN_CHANNELS,
            "embedding_channels": cfg.EMBEDDING_CHANNELS,
            "depths": cfg.DEPTHS,
            "input_resolution": cfg.INPUT_RESOLUTION,
            "number_of_heads": cfg.NUMBER_OF_HEADS,
            "intermediate_size": cfg.INTERMEDIATE_SIZE,
            "patch_size": cfg.PATCH_SIZE,
            "ff_feature_ratio": cfg.FF_FEATURE_RATIO,
            "dropout": cfg.DROPOUT,
            "dropout_attention": cfg.DROPOUT_ATTENTION,
            "dropout_path": cfg.DROPOUT_PATH,
            "use_checkpoint": cfg.USE_CHECKPOINT,
            "sequential_self_attention": cfg.SEQUENTIAL_SELF_ATTENTION,
        }

    def forward(self, input):
        """
        Forward pass
        
        """
        _,c,h,w = input.shape

        output = self.vit_feature_extractor(input, interpolate_pos_encoding=True).last_hidden_state
        
        # remove cls token
        output = output[:,1:,:]
        output = output.permute(0,2,1)

        h, w = int(h/8), int(w/8)
        b, c, _ = output.shape 

        output = output.reshape(b, c, h, w)

        return output

In [98]:
_encoder = ViTEncoder(
        patch_size=8,
        embedding_channels=128,
        depths=12,
        number_of_heads=4,
        intermediate_size=768
)

In [99]:
count_params(_encoder)

'3.294336M params'

In [95]:
img = torch.randn(2,3,384,512)

In [96]:
feat = _encoder(img)

In [97]:
feat.shape

torch.Size([2, 128, 48, 64])

In [133]:
patch_embed = nn.Conv2d(3, 384, kernel_size=(8, 8), stride=(8, 8))

In [85]:
embeddings = patch_embed(img)
embeddings.shape

torch.Size([2, 384, 48, 64])

In [81]:
cls_token = nn.Parameter(
            nn.init.trunc_normal_(torch.zeros(1, 1, 384), mean=0.0, std=0.02)
        )

In [83]:
cls_token.shape

torch.Size([1, 1, 384])

In [84]:
cls_tokens = cls_token.expand(2, -1, -1)
cls_tokens.shape

torch.Size([2, 1, 384])

In [88]:
embeddings = embeddings.permute(0,3,2,1)
embeddings.shape

torch.Size([2, 64, 48, 384])

In [90]:
embeddings = embeddings.reshape(2, 64*48, 384)
embeddings.shape

torch.Size([2, 3072, 384])

In [91]:
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
embeddings.shape

torch.Size([2, 3073, 384])

In [119]:
x = torch.tensor([[[1,1,1],[2,2,2]],[[3,3,3],[4,4,4]]])
x.shape

torch.Size([2, 2, 3])

In [120]:
x

tensor([[[1, 1, 1],
         [2, 2, 2]],

        [[3, 3, 3],
         [4, 4, 4]]])

In [121]:
c = torch.zeros(2,1,3)
c

tensor([[[0., 0., 0.]],

        [[0., 0., 0.]]])

In [122]:
cx = torch.cat((c,x), dim=1)
cx

tensor([[[0., 0., 0.],
         [1., 1., 1.],
         [2., 2., 2.]],

        [[0., 0., 0.],
         [3., 3., 3.],
         [4., 4., 4.]]])

In [123]:
cx.shape

torch.Size([2, 3, 3])

In [127]:
_cx = cx[:,1:,:]
_cx.shape

torch.Size([2, 2, 3])

In [128]:
_cx

tensor([[[1., 1., 1.],
         [2., 2., 2.]],

        [[3., 3., 3.],
         [4., 4., 4.]]])

In [130]:
_cx.shape == x.shape

True