In [1]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

device(type='cuda')

In [3]:
def count_params(model):
    return str(sum(p.numel() for p in model.parameters() if p.requires_grad) / 1000000) + "M params"

___

### DINO ViT s/8 original

In [157]:
vits8 = torch.hub.load('facebookresearch/dino:main', 'dino_vits8')

Using cache found in /home/goswami.p/.cache/torch/hub/facebookresearch_dino_main


In [12]:
count_params(vits8)

'21.670272M params'

In [139]:
# vits8

In [13]:
pretrained_ckpt = torch.load("../../dino_vit_pretrained/dino_deitsmall8_pretrain.pth")

In [14]:
vits8.load_state_dict(pretrained_ckpt)

<All keys matched successfully>

In [59]:
img = torch.randn(2, 3, 256, 256)

In [100]:
features = vits8(img)
features.shape

torch.Size([2, 384])

In [110]:
len(vits8.blocks)

12

In [115]:
len(vits8.blocks) - 11

1

In [132]:
feature_maps = vits8.get_intermediate_layers(img, n=1)
feature_maps[0].shape

torch.Size([2, 3073, 384])

In [137]:
feature = feature_maps[0][:,1:,:]
feature.shape

torch.Size([2, 3072, 384])

___

## EzFlow Dino ViT Feature Extractor

In [138]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from ezflow.encoder import ENCODER_REGISTRY
from ezflow.config import configurable

In [144]:
# @ENCODER_REGISTRY.register()
class DinoVITS8(nn.Module):
    """
    This class is a wrapper for the DinoViT model without the classification head
    
    """

    @configurable
    def __init__(
        self,
        freeze=True,
        pretrained_ckpt_path=None
    ):
        
        super(DinoVITS8, self).__init__()
        
        self.freeze = freeze
        self.feature_extractor = torch.hub.load('facebookresearch/dino:main', 'dino_vits8')
        
        if pretrained_ckpt_path is not None:
            self.feature_extractor.load_state_dict(
                torch.load(pretrained_ckpt_path)
            )
            print(f"Loaded Dino ViT S/8 pretrained checkpoint from {pretrained_ckpt_path}\n")

    @classmethod
    def from_config(self, cfg):
        return {
            "freeze": cfg.FREEZE,
            "pretrained_ckpt_path": cfg.PRETRAINED_CKPT_PATH
        }

    def forward(self, input):
        """
        Forward pass
        
        """
        _,c,h,w = input.shape

        if self.freeze:
            self.eval()
            self.feature_extractor.eval()
        
        output = self.feature_extractor.get_intermediate_layers(input, n=1)[0]
        
        # remove cls token
        output = output[:,1:,:]
        output = output.permute(0,2,1)

        h, w = int(h/8), int(w/8)
        b, c, _ = output.shape 

        output = output.reshape(b, c, h, w)

        return output

In [145]:
encoder = DinoVITS8(
    freeze=True, 
    pretrained_ckpt_path="../../dino_vit_pretrained/dino_deitsmall8_pretrain.pth"
)

Using cache found in /home/goswami.p/.cache/torch/hub/facebookresearch_dino_main


Loaded Dino ViT S/8 pretrained checkpoint from ../../dino_vit_pretrained/dino_deitsmall8_pretrain.pth



In [152]:
img = torch.randn(2,3,368,496)

In [153]:
feat_map = encoder(img)
feat_map.shape

torch.Size([2, 384, 46, 62])

___
## HuggingFace DINO ViT S/8

In [70]:
from transformers import ViTModel

In [107]:
feature_extractor = ViTModel.from_pretrained("facebook/dino-vits8", add_pooling_layer=False)

In [103]:
count_params(feature_extractor)

'21.818112M params'

In [104]:
img = torch.randn(2,3,384, 512)

In [105]:
feats = feature_extractor(img, interpolate_pos_encoding=True).last_hidden_state

In [106]:
feats.shape

torch.Size([2, 3073, 384])

___

## Remove CLS Token 

In [133]:
patch_embed = nn.Conv2d(3, 384, kernel_size=(8, 8), stride=(8, 8))

In [85]:
embeddings = patch_embed(img)
embeddings.shape

torch.Size([2, 384, 48, 64])

In [81]:
cls_token = nn.Parameter(
            nn.init.trunc_normal_(torch.zeros(1, 1, 384), mean=0.0, std=0.02)
        )

In [83]:
cls_token.shape

torch.Size([1, 1, 384])

In [84]:
cls_tokens = cls_token.expand(2, -1, -1)
cls_tokens.shape

torch.Size([2, 1, 384])

In [88]:
embeddings = embeddings.permute(0,3,2,1)
embeddings.shape

torch.Size([2, 64, 48, 384])

In [90]:
embeddings = embeddings.reshape(2, 64*48, 384)
embeddings.shape

torch.Size([2, 3072, 384])

In [91]:
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
embeddings.shape

torch.Size([2, 3073, 384])

In [119]:
x = torch.tensor([[[1,1,1],[2,2,2]],[[3,3,3],[4,4,4]]])
x.shape

torch.Size([2, 2, 3])

In [120]:
x

tensor([[[1, 1, 1],
         [2, 2, 2]],

        [[3, 3, 3],
         [4, 4, 4]]])

In [121]:
c = torch.zeros(2,1,3)
c

tensor([[[0., 0., 0.]],

        [[0., 0., 0.]]])

In [122]:
cx = torch.cat((c,x), dim=1)
cx

tensor([[[0., 0., 0.],
         [1., 1., 1.],
         [2., 2., 2.]],

        [[0., 0., 0.],
         [3., 3., 3.],
         [4., 4., 4.]]])

In [123]:
cx.shape

torch.Size([2, 3, 3])

In [127]:
_cx = cx[:,1:,:]
_cx.shape

torch.Size([2, 2, 3])

In [128]:
_cx

tensor([[[1., 1., 1.],
         [2., 2., 2.]],

        [[3., 3., 3.],
         [4., 4., 4.]]])

In [130]:
_cx.shape == x.shape

True