In [None]:
# @title Notes
# https://miro.com/app/board/uXjVGekFzXg=/

In [None]:
# @title Imports
import torch.nn as nn
import torch
import torchvision
import numpy as np
from torchvision.ops.boxes import box_area
from scipy.optimize import linear_sum_assignment

In [None]:
# @title Global variables
batch_size=-1
embed_dim=-1
device='cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# @title CNN Backbone(resnet 50)
class CNNBackbone(nn.Module):
  def __init__(self):
    super().__init__()
    self.model=torchvision.models.resnet50(pretrained=True)
    for param in self.model.parameters():
      param.requires_grad=False
    self.model=nn.Sequential(*list(self.model.children())[:-2])

  def forward(self,x):
    return self.model(x)

In [None]:
demo_img=torch.rand((1,3,224,224))
backbone=CNNBackbone()
backbone(demo_img).shape



Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


100%|██████████| 97.8M/97.8M [00:00<00:00, 184MB/s]


torch.Size([1, 2048, 7, 7])

In [None]:
# @title Positional encoding
def positional_encoding(dim,no_of_patches):
  # out_shape=no_of_patches*dim
  lst=torch.arange(dim)
  lst=[10000**(2*i/dim) for i in lst]
  out=[]
  for i in range(no_of_patches):
    ang=[i/(j+1e-8) for j in lst]
    ang=[np.sin(a) if i%2==0 else np.cos(a) for i,a in enumerate(ang)]
    out.append(ang)
  return torch.tensor(out)

In [None]:
pe=positional_encoding(2048,49)
pe.shape

  ang=[np.sin(a) if i%2==0 else np.cos(a) for i,a in enumerate(ang)]


torch.Size([49, 2048])

In [None]:
# @title EncoderLayer
class EncoderLayer(nn.Module):
  # here n is H*W
  def __init__(self,embed_dim,n,device):
    super().__init__()
    self.pos_encoding=positional_encoding(embed_dim,n).to(device)
    self.layer_norm1=nn.LayerNorm(embed_dim)
    self.layer_norm2=nn.LayerNorm(embed_dim)
    self.multihead_attention=nn.MultiheadAttention(embed_dim,num_heads=8,batch_first=True)
    self.ffn=nn.Sequential(nn.Linear(embed_dim,embed_dim*4),nn.ReLU(),nn.Linear(embed_dim*4,embed_dim))


  def forward(self,x):
    q=x+self.pos_encoding.unsqueeze(0)
    k=x+self.pos_encoding.unsqueeze(0)
    v=x
    out=self.multihead_attention(q,k,v)[0]
    out=self.layer_norm1(out+x)
    out1=self.ffn(out)
    out1=self.layer_norm2(out+out1)
    return out1

In [None]:
demo_enc_input=torch.rand((1,49,2048))
enc=EncoderLayer(2048,49,device)
enc(demo_enc_input).shape

  ang=[np.sin(a) if i%2==0 else np.cos(a) for i,a in enumerate(ang)]


torch.Size([1, 49, 2048])

In [None]:
class Encoder(nn.Module):
  # n is H*W
  def __init__(self,input_channel,embed_dim,enc_layers,n):
    super().__init__()
    self.base_conv=nn.Conv2d(input_channel,embed_dim,kernel_size=1)
    self.encoder_layers=nn.ModuleList([EncoderLayer(embed_dim,n,device) for _ in range(enc_layers)])

  def forward(self,x):
    # x will be of shape B,C,H,W
    z=self.base_conv(x)
    B,C,H,W=z.shape
    z=z.reshape(B,C,H*W)
    z=z.permute(0,2,1)
    for layer in self.encoder_layers:
      z=layer(z)
    return z

In [None]:
enc_inp=backbone(demo_img)
enc=Encoder(2048,768,12,49)
enc_out=enc(enc_inp)
enc_out.shape

  ang=[np.sin(a) if i%2==0 else np.cos(a) for i,a in enumerate(ang)]


torch.Size([1, 49, 768])

In [None]:
# @title DecoderLayer -> not completed will be done after the vizura video because it is so confusing
class DecoderLayer(nn.Module):
  def __init__(self,obj_query_dim,no_of_obj_queries,n_heads,n_dim,no_of_patches) :
    super().__init__()
    self.pos_embedding=positional_encoding(obj_query_dim,no_of_obj_queries).to(device)
    self.pos_embedding_learnable=nn.Embedding(n_dim,n_dim)
    self.pos_embedding_learnable2=nn.Embedding(n_dim,n_dim)
    self.mha1=nn.MultiheadAttention(obj_query_dim,num_heads=n_heads,batch_first=True)
    self.mha2=nn.MultiheadAttention(obj_query_dim,num_heads=n_heads,batch_first=True)
    self.layernorm1=nn.LayerNorm(obj_query_dim)
    self.layernorm2=nn.LayerNorm(obj_query_dim)
    self.layernorm3=nn.LayerNorm(obj_query_dim)
    self.embed_input1=torch.arange(no_of_obj_queries)
    self.embed_input2=torch.arange(no_of_patches)
    self.mlp=nn.Sequential(nn.Linear(obj_query_dim,obj_query_dim*4),nn.ReLU(),nn.Linear(obj_query_dim*4,obj_query_dim))

  def forward(self,obj_queries,enc_output):
    pos_embed=self.pos_embedding_learnable(self.embed_input1)
    key=obj_queries+pos_embed.unsqueeze(0)
    query=obj_queries+pos_embed.unsqueeze(0)

    out=self.mha1(query,key,obj_queries)[0]
    out1_normed=self.layernorm1(out+obj_queries)

    # cross attn
    query=obj_queries+self.pos_embedding_learnable2(self.embed_input2).unsqueeze(0)
    key=enc_output+self.pos_embedding.unsqueeze(0)
    out2=self.mha2(query,key,enc_output)[0]
    out2=out2+out1_normed

    out2_normed=self.layernorm2(out2)
    out2_mlp=self.mlp(out2_normed)
    out2_mlp=out2_normed+out2_mlp

    out2_mlp_normed=self.layernorm3(out2_mlp)
    return out2_mlp_normed

In [None]:
# demo_obj_query=torch.rand((1,10,768))
# dec_layer=DecoderLayer(768,10,8,768,49)
# dec_layer(demo_obj_query,enc_out).shape

In [None]:
# @title Decoder

In [None]:
# @title utility function
def box_cxcywh_to_xyxy(x):
  x_c,y_c,w,h=x.unbind(-1)
  b=[(x_c-0.5*w),(y_c-0.5*h),(x_c+0.5*w),(y_c+0.5*h)]
  return torch.stack(b,dim=-1)

def box_xyxy_to_cxcywh(x):
  x0, y0, x1, y1 = x.unbind(-1)
  b = [(x0 + x1) / 2, (y0 + y1) / 2,
       (x1 - x0), (y1 - y0)]
  return torch.stack(b, dim=-1)

def box_iou(boxes1, boxes2):
    area1 = box_area(boxes1)
    area2 = box_area(boxes2)

    lt = torch.max(boxes1[:, None, :2], boxes2[:, :2])  # [N,M,2]
    rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])  # [N,M,2]

    wh = (rb - lt).clamp(min=0)  # [N,M,2]
    inter = wh[:, :, 0] * wh[:, :, 1]  # [N,M]

    union = area1[:, None] + area2 - inter

    iou = inter / union
    return iou, union


def generalized_box_iou(boxes1, boxes2):
    """
    Generalized IoU from https://giou.stanford.edu/

    The boxes should be in [x0, y0, x1, y1] format

    Returns a [N, M] pairwise matrix, where N = len(boxes1)
    and M = len(boxes2)
    """
    # degenerate boxes gives inf / nan results
    # so do an early check
    assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
    assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
    iou, union = box_iou(boxes1, boxes2)

    lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
    rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])

    wh = (rb - lt).clamp(min=0)  # [N,M,2]
    area = wh[:, :, 0] * wh[:, :, 1]

    return iou - (area - union) / area

In [None]:
# @title hugarian
class hungarianMatcher(nn.module):
  def __init__(self):
    super().__init__()

  @torch.no_grad()
  def forward(self,outputs,targets):
    """ Performs the matching

        Params:
            outputs: This is a dict that contains at least these entries:
                 "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
                 "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates

            targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
                 "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
                           objects in the target) containing the class labels
                 "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates

        Returns:
            A list of size batch_size, containing tuples of (index_i, index_j) where:
                - index_i is the indices of the selected predictions (in order)
                - index_j is the indices of the corresponding selected targets (in order)
            For each batch element, it holds:
                len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
        """
    bs,num_queries=outputs['pred_logits'].shape[:2]
    out_prob=outputs['pred_logits'].flatten(0,1).softmax(-1)  # B*Q,C (c is no of classes)
    out_bbox=outputs['pred_boxes'].flatten(0,1)  # B*Q,4

    tgt_ids=torch.cat([tgt['labels'] for tgt in targets])  # sum(gt in each batch)
    tgt_bbox=torch.cat([tgt['bbox'] for tgt in targets]) # sum(gt in each batch),4

    cost_class=-out_prob[:,tgt_ids]  # B*Q,sum(gt in each batch)
    # l1 dist
    cost_bbox=torch.cdist(out_bbox,tgt_bbox,p=1)  # B*Q,sum(gt in each batch)

    cost_giou=-generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))

    C=cost_bbox+cost_class+cost_giou
    C=C.view(bs,num_queries,-1).cpu()

    sizes=[len(v['boxes']) for  v in targets]
    indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
    return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
