In [370]:
import torch
import torchvision.models as models
from torchvision import datasets, transforms as T
from torchvision.datasets import wrap_dataset_for_transforms_v2
from torchvision.models import ResNet50_Weights  # <-- import this
from torchvision.transforms import v2
from torchvision.utils import draw_bounding_boxes
import torch.nn as nn
import torch.nn.functional as F
import math



In [275]:
resnet50 = models.resnet50(weights = ResNet50_Weights.IMAGENET1K_V2)

In [276]:
# haven't checked if these are official. website down
itol = [
    "aeroplane", "bicycle", "bird", "boat", "bottle",
    "bus", "car", "cat", "chair", "cow",
    "diningtable", "dog", "horse", "motorbike", "person",
    "pottedplant", "sheep", "sofa", "train", "tvmonitor"
]
itol[13-1], itol[15-1]

('horse', 'person')

In [277]:
def plot(sample):
    # img is normalized, so have to unnormalize
    mean= torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)
    std=torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)
    img, target = sample
    img = img.data
    img = img*std + mean
    toimg = v2.ToPILImage()
    labels = [str(itol[i-1]) for i in target['labels']]
    toimg(draw_bounding_boxes(img, target['boxes'].data, width = 3, labels = labels)).show()

In [278]:
# imagenet stats here: https://docs.pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html#torchvision.models.ResNet50_Weights
transform = v2.Compose([
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])


data_filepath = '/Users/veb/ms/nanoDETR/data'
dataset = datasets.VOCDetection(root = data_filepath, 
                                year = '2012', 
                                image_set = 'train', 
                                download = False,
                                transform = transform) # len 5717, consisten with data
dataset = wrap_dataset_for_transforms_v2(dataset) 

In [279]:
# can now put through resnet
# original img: (3, H0, W0)
# backbone output: (C, H, W), where typically C = 2048, H, W = H0 / 32, W0 / 32

# The input images are batched together, applying 0-padding adequately to ensure
# they all have the same dimensions (H0,W0) as the largest image of the batch.

In [288]:
torch.manual_seed(5550)
in_channels = 2048
hidden_dim = 256 # = d_model in AttentionIsAllYouNeed
img, target = dataset[0]


backbone = nn.Sequential(*list(resnet50.children())[:-2])
downsample = nn.Conv2d(in_channels, hidden_dim, kernel_size = 1, stride=1, padding=0, dilation=1, groups=1, bias=False, padding_mode='zeros', device=None, dtype=None)

with torch.no_grad():
    out = backbone(img.unsqueeze(0))    
assert out.requires_grad == False

down = downsample(out) # (B, 2048, 14, 16) -> (B, hidden_dim, 14, 16) 
    
flattened = down.flatten(2) # (B, hidden_dim, H*W) 
flattened = flattened.permute(0,2,1) # (B, H*W, hidden_dim) 


# learnable positional embeddings, C = hidden_dim
B,HW,C = flattened.shape
pos_embed = nn.Parameter(torch.randn(size = (1, HW, hidden_dim), dtype = flattened.dtype, requires_grad = True))
flattened.shape, pos_embed.shape
x = flattened + pos_embed # (B, H*W, hidden_dim) 


In [406]:
# attention head
torch.manual_seed(5550)

query = nn.Linear(hidden_dim, hidden_dim, bias = False)
key = nn.Linear(hidden_dim, hidden_dim, bias = False)
value = nn.Linear(hidden_dim, hidden_dim, bias = False)

# project into seperate spaces
q = query(x) # (B, H*W, hidden_dim) 
k = key(x)
v = value(x)

# attention scores
scores = q @ k.transpose(-2,-1)
scores /= math.sqrt(hidden_dim) # print(scores.std()) will be ish 0.4 => breaks gaussian assumption
scores = F.softmax(scores, dim = -1)
out = scores @ v


q.shape, k.transpose(-2,-1).shape, scores.shape

tensor(0.4047, grad_fn=<StdBackward0>)


(torch.Size([1, 224, 256]),
 torch.Size([1, 256, 224]),
 torch.Size([1, 224, 224]),
 <function Tensor.std>)