In [None]:
import torch
import torch.nn as nn
import torchvision
from torchvision import models

In [None]:
class ImageEncoder(nn.Module):
    def __init__(self, backbone_name='resnet50', pretrained=True, hidden_dim=256):
        super(ImageEncoder, self).__init__()

        backbone = models.resnet50(pretrained=pretrained)
        self.backbone = nn.Sequential(*list(backbone.children())[:-2])
        
        self.hidden_dim = hidden_dim
        self.conv1x1 = nn.Conv2d(backbone.fc.in_features, hidden_dim, kernel_size=1)
        
        self.positional_encoding = self._get_positional_encoding()

    
    def forward(self, x):
        features = self.backbone(x)
        features = self.conv1x1(features)
        
        features = features + self.positional_encoding
        
        return features

    def _get_positional_encoding(self, height=32, width=32):
        pe = torch.zeros(self.hidden_dim, height, width)
        y, x = torch.meshgrid(torch.arange(height), torch.arange(width), indexing='ij')
        div_term = torch.exp(torch.arange(0., self.hidden_dim, 2) * -(torch.log(torch.tensor(10000.0)) / self.hidden_dim))
        
        pe[0::2, :, :] = torch.sin(x.unsqueeze(0) * div_term.unsqueeze(1).unsqueeze(2))
        pe[1::2, :, :] = torch.cos(x.unsqueeze(0) * div_term.unsqueeze(1).unsqueeze(2))
        
        pe = pe.unsqueeze(0)
        
        return pe

In [None]:
class TransformerDecoderLayer(nn.Module):
    def __init__(self, hidden_dim=256, nhead=8, dim_feedforward=2048):
        super(TransformerDecoderLayer, self).__init__()

        self.self_attn = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=nhead)
        self.multihead_attn = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=nhead)
        
        self.linear1 = nn.Linear(hidden_dim, dim_feedforward)
        self.linear2 = nn.Linear(dim_feedforward, hidden_dim)
        
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)
        self.norm3 = nn.LayerNorm(hidden_dim)

    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):
        
        tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask)[0]
        tgt = self.norm1(tgt + tgt2)
        
        tgt2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask)[0]
        tgt = self.norm2(tgt + tgt2)
        
        tgt2 = self.linear2(F.relu(self.linear1(tgt)))
        tgt = self.norm3(tgt + tgt2)
        
        return tgt

In [None]:
class TransformerModel(nn.Module):
    def __init__(self, hidden_dim=256, num_queries=100, num_classes=91, nhead=8, num_layers=6, dim_feedforward=2048):
        super(TransformerModel, self).__init__()

        self.num_queries = num_queries
        self.detr_output_dim = hidden_dim

        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=nhead, dim_feedforward=dim_feedforward),
            num_layers=num_layers
        )

        self.decoder_layer = TransformerDecoderLayer(hidden_dim, nhead, dim_feedforward)
        
        self.query_embed = nn.Embedding(num_queries, hidden_dim)

        self.class_embed = nn.Linear(hidden_dim, num_classes)
        self.bbox_embed = nn.Linear(hidden_dim, 4)

    def forward(self, x):
        batch_size, _, height, width = x.shape
        memory = self.encoder(x.flatten(2).transpose(1, 2))

        queries = self.query_embed.weight.unsqueeze(0).repeat(batch_size, 1, 1)

        output = self.decoder_layer(queries, memory)

        class_logits = self.class_embed(output)
        bbox_preds = self.bbox_embed(output)

        return class_logits, bbox_preds

In [None]:
class DETR(nn.Module):
    def __init__(self, num_queries=100, num_classes=91, hidden_dim=256, nhead=8, num_layers=6, dim_feedforward=2048):
        super(DETR, self).__init__()

        self.backbone = ImageEncoder(hidden_dim=hidden_dim)
        self.transformer = TransformerModel(
            hidden_dim=hidden_dim, 
            num_queries=num_queries, 
            num_classes=num_classes,
            nhead=nhead, 
            num_layers=num_layers, 
            dim_feedforward=dim_feedforward
        )

    def forward(self, x):
        features = self.backbone(x)
        class_logits, bbox_preds = self.transformer(features)
        return class_logits, bbox_preds