In [39]:
import torch
import ST
from torchvision.models import resnet50
import pandas as pd

# DETR

In [29]:
class DETR (torch.nn.Module):
    def __init__ (
        self,
        num_classes,
        hidden_dim,
        nheads,
        num_encoder_layers,
        num_decoder_layers,
    ):
        super().__init__()

        self.backbone = torch.nn.Sequential(*list(resnet50(pretrained=True).children())[:-2])
        self.conv = torch.nn.Conv2d(in_channels=2048,
                                    out_channels = hidden_dim,
                                    kernel_size = 1)
        self.transformer = torch.nn.Transformer(hidden_dim, nheads, num_encoder_layers, num_decoder_layers)
        self.linear_class = torch.nn.Linear(hidden_dim, num_classes + 1)
        self.linear_bbox = torch.nn.Linear(hidden_dim, 4)
        self.query_pos = torch.nn.Parameter(torch.rand(100,  hidden_dim))
        self.row_embed = torch.nn.Parameter(torch.rand(50, hidden_dim // 2))
        self.col_embed = torch.nn.Parameter(torch.rand(50, hidden_dim // 2))

    def forward(self, inputs):
        x = self.backbone(inputs)
        h = self.conv(x)
        H, W = h.shape[-2:]
        pos = torch.cat((
            self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
            self.row_embed[:H].unsqueeze(1).repeat(1, W, 1)
        ), dim =-1).flatten(0, 1).unsqueeze(1)
        h = self.transformer(pos + h.flatten(2).permute(2, 0, 1),
                            self.query_pos.unsqueeze(1))
        return self.linear_class(h).squeeze(), self.linear_bbox(h).sigmoid().squeeze()

In [33]:
detr = DETR(num_classes=1,
            hidden_dim=256,
            nheads=8,
            num_encoder_layers=6,
            num_decoder_layers=6)

In [34]:
logits, bboxes = detr(torch.randn(1, 3, 800, 1200))

In [36]:
logits.shape

torch.Size([100, 2])

### Getting detection dataset

In [141]:
train_images = torch.load('Saved Tensors/Augmented images.pth')
train_bboxes = torch.load('Saved Tensors/bboxes.pth')
train_classes = torch.load('Saved Tensors/classes.pth')

# Training