<a href="https://colab.research.google.com/github/xiaofangZH/Segment-Anything-2-AGV-Test/blob/main/Detection_Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from torch import nn
from torchvision.models import resnet50

In [2]:
class DetectionTransformer(nn.Module):
  def __init__(self,num_classes,hidden_dim,nheads,num_encoder_layers,num_decoder_layers):
    super().__init__()
    self.backbone=nn.Sequential(*list(resnet50(pretrained=True).children())[:-2])
    self.conv=nn.Conv2d(2048,hidden_dim,1)
    self.transformer=nn.Transformer(hidden_dim,nheads,num_encoder_layers,num_decoder_layers)
    self.linear_classes=nn.Linear(hidden_dim,num_classes+1)
    self.linear_bbox=nn.Linear(hidden_dim,4)
    self.query_pos=nn.Parameter(torch.rand(100,hidden_dim))
    self.row_embed=nn.Parameter(torch.rand(50,hidden_dim//2))
    self.col_embed=nn.Parameter(torch.rand(50,hidden_dim//2))
  def forward(self,inputs):
    x=self.backbone(inputs)
    h=self.conv(x)
    H,W=h.shape[-2:]
    pos=torch.cat([self.col_embed[:W].unsqueeze(0).repeat(H,1,1),self.row_embed[:H].unsqueeze(1).repeat(1,W,1),],dim=-1).flatten(0,1).unsqueeze(1)
    h=self.transformer(pos+h.flatten(2).permute(2,0,1),self.query_pos.unsqueeze(1))
    return self.linear_classes(h),self.linear_bbox(h).sigmoid()

In [3]:
detr=DetectionTransformer(num_classes=91,hidden_dim=256,nheads=8,num_encoder_layers=6,num_decoder_layers=6)
detr.eval()
inputs=torch.randn(1,3,800,1200)
logits,bboxes=detr(inputs)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 188MB/s]


In [4]:
logits

tensor([[[ 1.0615,  0.0071,  0.8348,  ...,  0.8110, -0.2461, -0.3349]],

        [[ 1.0017,  0.0475,  0.8626,  ...,  0.7535, -0.2110, -0.3625]],

        [[ 1.0322, -0.0029,  0.9326,  ...,  0.7583, -0.2704, -0.3947]],

        ...,

        [[ 1.0033,  0.0952,  0.8987,  ...,  0.7963, -0.2344, -0.2576]],

        [[ 1.0073,  0.0876,  0.8287,  ...,  0.7513, -0.1801, -0.2971]],

        [[ 1.0857,  0.0343,  0.9029,  ...,  0.7263, -0.3136, -0.2596]]],
       grad_fn=<ViewBackward0>)

In [5]:
bboxes

tensor([[[0.6727, 0.6039, 0.3110, 0.4066]],

        [[0.6688, 0.6089, 0.3133, 0.4052]],

        [[0.6791, 0.5699, 0.3051, 0.3970]],

        [[0.6924, 0.5963, 0.3181, 0.4136]],

        [[0.7000, 0.5963, 0.3255, 0.4049]],

        [[0.6758, 0.6039, 0.3074, 0.4032]],

        [[0.6757, 0.5878, 0.3112, 0.4013]],

        [[0.6985, 0.5867, 0.3062, 0.3822]],

        [[0.6830, 0.5809, 0.3213, 0.4113]],

        [[0.6846, 0.5848, 0.3157, 0.3862]],

        [[0.6861, 0.5923, 0.2921, 0.3845]],

        [[0.6788, 0.6046, 0.3151, 0.4045]],

        [[0.6891, 0.6084, 0.3145, 0.3902]],

        [[0.6829, 0.6077, 0.3183, 0.4083]],

        [[0.6608, 0.5787, 0.3266, 0.3952]],

        [[0.6738, 0.6100, 0.3154, 0.4061]],

        [[0.6931, 0.5890, 0.3246, 0.4186]],

        [[0.6859, 0.5845, 0.3046, 0.4097]],

        [[0.6725, 0.5881, 0.3290, 0.3964]],

        [[0.6852, 0.5915, 0.3087, 0.3902]],

        [[0.6932, 0.5946, 0.3192, 0.3865]],

        [[0.6795, 0.5659, 0.3090, 0.3845]],

        [[