## Import lib

In [78]:
from transformers import AutoProcessor, Swinv2Model,Swinv2Config
import torch
from datasets import load_dataset
from torch import nn
import math

## Build Model

In [79]:
CONFIG = Swinv2Config()
PRE_TRAINED_MODEL = "microsoft/swinv2-tiny-patch4-window8-256"

In [114]:
class SwinDetr(nn.Module):
    """
    nums_pos_feats = num_patches_embedding : output_swin = [1,64,768]
    """
    def __init__(self, num_classes,n_batches=1,num_pos_feats=64,hidden_dim=768, nheads=8,
                 num_encoder_layers=6, num_decoder_layers=6):
        super().__init__()
        self.pos_embedding = nn.Parameter(torch.randn(n_batches,num_pos_feats, hidden_dim))
        self.scale = 2 * math.pi 
        self.backbone = Swinv2Model(CONFIG)

        #Transformer block
        self.transformer = nn.Transformer(
            hidden_dim, nheads, num_encoder_layers, num_decoder_layers
            )

        #Classification 
        self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
        self.linear_bbox = nn.Linear(hidden_dim, 4)

        #Positional Embedding
        self.query_pos = nn.Parameter(torch.rand(n_batches,num_pos_feats, hidden_dim))

    def forward(self,inputs):
        self.backbone.pooler = nn.Identity()
        x = self.backbone(inputs)
        x = x.last_hidden_state
        x = self.transformer(self.pos_embedding + self.scale*x,
                            self.query_pos)


        return {'pred_logits': self.linear_class(x), 
                'pred_boxes': self.linear_bbox(x).sigmoid()}



In [None]:
x=torch.randn([10, 3, 256, 256])
n_batches = x.shape[0]
model = SwinDetr(num_classes=2,n_batches=n_batches)
with torch.no_grad():
    outputs = model(x)
print(outputs)
# 10,64,768 -> 10 64 1 768 

In [128]:
probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
keep = probas.max(-1).values > 0.5

In [131]:
print(outputs['pred_logits'].shape)

torch.Size([10, 64, 3])


## Run model

In [132]:
# custome loss and train the model
x = torch.randn(3,4)
y = torch.randn(1,4)
z = x-y

In [134]:
nn.functional.log_softmax(z)

  nn.functional.log_softmax(z)


tensor([[-0.7947, -3.2326, -3.6134, -0.7300],
        [-0.3149, -2.6439, -4.5132, -1.6706],
        [-0.2384, -4.1580, -3.8847, -1.7379]])