In [30]:
import timm
import torch
import torchvision
from timm.layers import resample_abs_pos_embed     

class Detector(nn.Module):
    def __init__(self, num_classes, det_token_num=100):
        super().__init__()
        self.backbone = timm.create_model('vit_small_patch14_dinov2', pretrained=True, dynamic_img_size=True)
        hidden_dim = 384 
        self.det_token_num = det_token_num
        self.add_det_tokens()
        self.class_embed = torchvision.ops.MLP(384, [384,384,num_classes+1])
        self.bbox_embed = torchvision.ops.MLP(384, [384,384,4])
        
    def add_det_tokens(self):
        
        det_token = nn.Parameter(torch.zeros(1, self.det_token_num, self.backbone.embed_dim))
        self.det_token = torch.nn.init.trunc_normal_(det_token, std=.02)
        
        det_pos_embed = torch.zeros(1, self.det_token_num, self.backbone.embed_dim)
        det_pos_embed = torch.nn.init.trunc_normal_(det_pos_embed, std=.02)
        cls_pos_embed = self.backbone.pos_embed[:, 0, :][:,None] # size 1x1xembed_dim
        patch_pos_embed = self.backbone.pos_embed[:, 1:, :] # 1xnum_patchxembed_dim
        self.pos_embed = torch.nn.Parameter(torch.cat((cls_pos_embed, det_pos_embed, patch_pos_embed), dim=1))
        
        self.backbone.num_prefix_tokens += self.det_token_num
        
    def _pos_embed_with_det(self, x: torch.Tensor) -> torch.Tensor:
        if self.backbone.dynamic_img_size:
            B, H, W, C = x.shape
            pos_embed = resample_abs_pos_embed(
                self.pos_embed,
                (H, W),
                num_prefix_tokens=0 if self.backbone.no_embed_class else self.backbone.num_prefix_tokens,
            )
            x = x.view(B, -1, C)
        else:
            pos_embed = self.pos_embed

        to_cat = []
        if self.backbone.cls_token is not None:
            to_cat.append(self.backbone.cls_token.expand(x.shape[0], -1, -1))
        if self.backbone.reg_token is not None:
            to_cat.append(self.backbone.reg_token.expand(x.shape[0], -1, -1))
        to_cat.append(self.det_token.expand(x.shape[0], -1, -1)) # HERE det tokens

        if self.backbone.no_embed_class:
            # deit-3, updated JAX (big vision)
            # position embedding does not overlap with class token, add then concat
            x = x + pos_embed
            if to_cat:
                x = torch.cat(to_cat + [x], dim=1)
        else:
            # original timm, JAX, and deit vit impl
            # pos_embed has entry for class token, concat then add
            if to_cat:
                x = torch.cat(to_cat + [x], dim=1)
            x = x + pos_embed

        return self.backbone.pos_drop(x)
        
    def backbone_forward_features(self, x: torch.Tensor) -> torch.Tensor:
        x = self.backbone.patch_embed(x)
        x = self._pos_embed_with_det(x)
        x = self.backbone.patch_drop(x)
        x = self.backbone.norm_pre(x)
        if self.backbone.grad_checkpointing and not torch.jit.is_scripting():
            x = checkpoint_seq(self.backbone.blocks, x)
        else:
            x = self.backbone.blocks(x)
        x = self.backbone.norm(x)
        return x
    
    def forward(self, x):      
        x = self.backbone_forward_features(x)
        x = x[:,1:1+self.det_token_num,...]
        outputs_class = self.class_embed(x)
        outputs_coord = self.bbox_embed(x).sigmoid()
        out = {'pred_logits': outputs_class, 'pred_boxes': outputs_coord}
        return out
    
detector = Detector(10, 100)
x = torch.rand(2, 3, 224, 224)
out = detector(x)
print(f"{out.keys() = }\n{[v.shape for v in out.values()] = }")

out.keys() = dict_keys(['pred_logits', 'pred_boxes'])
[v.shape for v in out.values()] = [torch.Size([2, 100, 11]), torch.Size([2, 100, 4])]
