In [18]:
import torch
import torchvision.models as models
from torchvision import datasets, transforms as T
from torchvision.datasets import wrap_dataset_for_transforms_v2
from torchvision.models import ResNet50_Weights  # <-- import this
from torchvision.transforms import v2
from torchvision.utils import draw_bounding_boxes
import torch.nn as nn
import torch.nn.functional as F
import math
import utils



In [19]:
resnet50 = models.resnet50(weights = ResNet50_Weights.IMAGENET1K_V2)

In [20]:
# haven't checked if these are official. website down
itol = [
    "aeroplane", "bicycle", "bird", "boat", "bottle",
    "bus", "car", "cat", "chair", "cow",
    "diningtable", "dog", "horse", "motorbike", "person",
    "pottedplant", "sheep", "sofa", "train", "tvmonitor"
]
itol[13-1], itol[15-1]

('horse', 'person')

In [21]:
def plot(sample):
    # img is normalized, so have to unnormalize
    mean= torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)
    std=torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)
    img, target = sample
    img = img.data
    img = img*std + mean
    toimg = v2.ToPILImage()
    labels = [str(itol[i-1]) for i in target['labels']]
    toimg(draw_bounding_boxes(img, target['boxes'].data, width = 3, labels = labels)).show()

In [22]:
# imagenet stats here: https://docs.pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html#torchvision.models.ResNet50_Weights
transform = v2.Compose([
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])


data_filepath = '/Users/veb/ms/nanoDETR/data'
dataset = datasets.VOCDetection(root = data_filepath, 
                                year = '2012', 
                                image_set = 'train', 
                                download = False,
                                transform = transform) # len 5717, consisten with data
dataset = wrap_dataset_for_transforms_v2(dataset) 

In [23]:
# can now put through resnet
# original img: (3, H0, W0)
# backbone output: (C, H, W), where typically C = 2048, H, W = H0 / 32, W0 / 32

# The input images are batched together, applying 0-padding adequately to ensure
# they all have the same dimensions (H0,W0) as the largest image of the batch.

In [24]:
img, target = dataset[0]
print(img.shape)

utils.v2show(img)

torch.Size([3, 442, 500])


In [25]:
torch.manual_seed(5550)
in_channels = 2048
hidden_dim = 256 # = d_model in AttentionIsAllYouNeed
img, target = dataset[0]


backbone = nn.Sequential(*list(resnet50.children())[:-2])
downsample = nn.Conv2d(in_channels, hidden_dim, kernel_size = 1, stride=1, padding=0, dilation=1, groups=1, bias=False, padding_mode='zeros', device=None, dtype=None)

with torch.no_grad():
    out = backbone(img.unsqueeze(0))    
assert out.requires_grad == False

down = downsample(out) # (B, 2048, 14, 16) -> (B, hidden_dim, 14, 16) 
    
flattened = down.flatten(2) # (B, hidden_dim, H*W) 
flattened = flattened.permute(0,2,1) # (B, H*W, hidden_dim) 


# learnable positional embeddings, C = hidden_dim
B,HW,C = flattened.shape
pos_embed = nn.Parameter(torch.randn(size = (1, HW, hidden_dim), dtype = flattened.dtype, requires_grad = True))
flattened.shape, pos_embed.shape
x = flattened + pos_embed # (B, H*W, hidden_dim) 


In [26]:
# attention head
torch.manual_seed(5550)

query = nn.Linear(hidden_dim, hidden_dim, bias = False)
key = nn.Linear(hidden_dim, hidden_dim, bias = False)
value = nn.Linear(hidden_dim, hidden_dim, bias = False)

# project into seperate spaces
q = query(x) # (B, H*W, hidden_dim) 
k = key(x)
v = value(x)

# attention scores
scores = q @ k.transpose(-2,-1)
scores /= math.sqrt(hidden_dim) # print(scores.std()) will be ish 0.4 => breaks gaussian assumption
scores = F.softmax(scores, dim = -1)
out = scores @ v


q.shape, k.transpose(-2,-1).shape, scores.shape

(torch.Size([1, 224, 256]),
 torch.Size([1, 256, 224]),
 torch.Size([1, 224, 224]))

In [27]:
torch.manual_seed(5550)
nhead = 8
hidden_dim = 256 # = d_model in AttentionIsAllYouNeed


class EncoderHead(nn.Module):
    def __init__(self, positional_encoding, hidden_dim = 256, nhead = 8):
        super().__init__()
        self.positional_encoding = positional_encoding
        self.head_size = hidden_dim // nhead # 32
        self.query = nn.Linear(hidden_dim, self.head_size, bias = False) # 256 -> 32
        self.key = nn.Linear(hidden_dim, self.head_size, bias = False)
        self.value = nn.Linear(hidden_dim, self.head_size, bias = False)
        
    def forward(self, x):
        # project into separate spaces
        q = self.query(x + self.positional_encoding) # (B, H*W, head_size) 
        k = self.key(x + self.positional_encoding)
        v = self.value(x)
        
        # attention scores
        scores = q @ k.transpose(-2,-1)
        scores /= math.sqrt(self.head_size) # print(scores.std()) will be ish 0.4 => breaks gaussian assumption
        scores = F.softmax(scores, dim = -1)
        out = scores @ v # (B, H*W, head_size) 

        return out

class EncoderLayer(nn.Module):
    def __init__(self, pos_encode, hidden_dim = 256, nhead = 8):
        super().__init__()
        self.encoder_heads = nn.ModuleList([EncoderHead(pos_encode) for _ in range(nhead)])
        self.cat_proj = nn.Linear(hidden_dim, hidden_dim, bias = False)

    def forward(self, x):
        cat = torch.cat([head(x) for head in self.encoder_heads], dim = -1) # (B, HW, head_size) -> (B, HW, hidden_dim)
        out = self.cat_proj(cat)
        return out

class EncoderBlock(nn.Module):
    def __init__(self, pos_encode, hidden_dim = 256, nhead = 8):
        super().__init__()
        self.pos_encode = pos_encode
        self.encoder_layer = EncoderLayer(self.pos_encode)

        # layer norm
        self.gamma = nn.Parameter(torch.ones(1, 1, hidden_dim))
        self.beta = nn.Parameter(torch.zeros(1, 1, hidden_dim))
        self.eps = 1e-5

        # ffn
        self.layer_norm = nn.LayerNorm(hidden_dim)
        self.fc1 = nn.Linear(hidden_dim, hidden_dim*4)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim*4, hidden_dim)

    def forward(self, x):
        # attention layer
        x_res = x
        mean = x_res.mean(dim = -1, keepdim = True) # for layernorm
        var = x_res.var(dim = -1, correction = 0, keepdim = True)
        x_res = x_res - mean
        x_res = x_res / torch.sqrt(var + self.eps)
        x_res = x_res*self.gamma + self.beta
        x_res = self.encoder_layer(x_res) # (B, HW, hidden_dim)
        x = x + x_res # residual connection

        # compute layer
        x_res = x
        x_res = self.layer_norm(x_res)
        x_res = self.fc1(x_res)
        x_res = self.relu(x_res)
        x_res = self.fc2(x_res)
        x = x + x_res

        return x

class Encoder(nn.Module):
    def __init__(self, pos_encode, num_encoder_layers = 6, hidden_dim = 256, nhead = 8):
        super().__init__()
        self.layers = nn.Sequential(*[EncoderBlock(pos_encode) for _ in range(num_encoder_layers)])

    def forward(self, x):
        return self.layers(x)
        
pos_encode = nn.Parameter(torch.randn(size = (1, HW, hidden_dim), dtype = flattened.dtype))
elayer = EncoderLayer(pos_encode)
block = EncoderBlock(pos_encode)
encoder = Encoder(pos_encode)
encoder(x).shape, x.shape

    

(torch.Size([1, 224, 256]), torch.Size([1, 224, 256]))

In [28]:
t = torch.randn(1,2,4)
print(t)
t = t.split(2, dim = -1)
# t[1].shape, len(t)
print(t[0])
print(t[1])
print(torch.stack(t).shape)

tensor([[[ 0.7841,  0.2449,  1.3215,  0.3390],
         [ 0.1875,  1.0745,  0.8846, -0.1295]]])
tensor([[[0.7841, 0.2449],
         [0.1875, 1.0745]]])
tensor([[[ 1.3215,  0.3390],
         [ 0.8846, -0.1295]]])
torch.Size([2, 1, 2, 2])


In [29]:
class MHAttention(nn.Module):
    def __init__(self, embed_dim, nhead):
        super().__init__()
        self.nhead = nhead
        self.head_size = embed_dim // nhead
        self.query = nn.Linear(embed_dim, embed_dim, bias = False)
        self.key = nn.Linear(embed_dim, embed_dim, bias = False)
        self.value = nn.Linear(embed_dim, embed_dim, bias = False)
        self.projection = nn.Linear(embed_dim, embed_dim, bias = False)
        

    def forward(self, query, key, value):
        
        B, qT, C = query.shape
        _, kT, _ = key.shape
        
        q = self.query(query)
        k = self.key(key)
        v = self.value(value)
        
        # TODO: go over pen and paper for reshaping q, k, v
        q = q.view(B, qT, self.nhead, self.head_size).transpose(1,2) # (B, nhead, qT, head_size)
        k = k.view(B, kT, self.nhead, self.head_size).transpose(1,2) # (B, nhead, kT, head_size)
        v = v.view(B, kT, self.nhead, self.head_size).transpose(1,2)
            
        scores = q @ k.transpose(-2,-1) # (B, nhead, qT, head_size) @ (B, nhead, head_size, kT) ---> # (B, nhead, qT, kT)
        scores /= math.sqrt(self.head_size)
        scores = F.softmax(scores, dim = -1)
        out = scores @ v # (B, nhead, qT, kT) @ (B, nhead, kT, head_size) ---> (B, nhead, qT, head_size)
        out = out.transpose(1,2).flatten(2) # (B, nhead, qT, head_size) ---> (B, qT, nhead, head_size) ---> (B, qT, hidden_dim)
        out = self.projection(out)
        
        return out



# batch_size = 1
# encoder = Encoder(pos_encode)
# mem = encoder(x)
# q = torch.zeros((batch_size,100, hidden_dim))
# mhattn = MHAttention(256, 8)
# mhattn(q, mem, mem).shape

# q.shape, mem.shape

# pos_encode = nn.Parameter(torch.randn(size = (1, HW, hidden_dim), dtype = flattened.dtype))
# mhattn = MHAttention(256, 8)
# q = x + pos_encode
# k = x + pos_encode
# v = x
# mhattn(q,k,v).shape

    

In [30]:
# building the decoder
class DecoderBlock(nn.Module):
    def __init__(self, hidden_dim = 256, nhead = 8):
        super().__init__()

        self.layer_norm_1 = nn.LayerNorm(hidden_dim)
        self.multihead_selfattention = MHAttention(embed_dim = hidden_dim, nhead = nhead)
        self.layer_norm_2 = nn.LayerNorm(hidden_dim)
        self.multihead_crossattention = MHAttention(embed_dim = hidden_dim, nhead = nhead)
        self.layer_norm_3 = nn.LayerNorm(hidden_dim)
        self.ffn = nn.Sequential(nn.Linear(hidden_dim, hidden_dim*4), nn.ReLU(), nn.Linear(hidden_dim*4, hidden_dim)) 

        
    def forward(self, x, memory, memory_pos_encoding, query_pos_encoding):
        # self-attention
        x_res = self.layer_norm_1(x)
        q = x_res + query_pos_encoding
        k = x_res + query_pos_encoding
        v = x_res 
        x_res = self.multihead_selfattention(query = q, key = k, value = v)
        x = x + x_res

        # cross-attention
        x_res = self.layer_norm_2(x)
        q = x_res + query_pos_encoding
        k = memory + memory_pos_encoding
        v = memory
        x_res = self.multihead_crossattention(query = q, key = k, value = v)
        x = x + x_res

        #ffn
        x_res = self.layer_norm_3(x)
        x_res = self.ffn(x_res)
        x = x + x_res

        return x

# batch_size = 1
# encoder = Encoder(pos_encode)

# mem = encoder(x)
# q_pos = nn.Parameter(torch.zeros((batch_size,100, hidden_dim)))

# decoderblock = DecoderBlock(memory = mem, memory_pos_encoding = pos_encode, query_pos_encoding = q_pos)
# q = torch.zeros(100, hidden_dim)
# decoderblock(q).shape




In [42]:
class nanoDETR(nn.Module):
    def __init__(self, resnet50, ntokens = 224, nlayers = 6, nhead = 8, hidden_dim = 256, nqueries = 100):
        super().__init__()
        self.nlayers = nlayers
        self.nhead = nhead
        self.hidden_dim = hidden_dim
        
        # backbone
        self.resnet50 = resnet50
        self.backbone = nn.Sequential(*list(resnet50.children())[:-2])
        self.project = nn.Conv2d(in_channels = 2048, out_channels = hidden_dim, kernel_size = 1, bias=False)
        
        # build encoder
        self.encoder_pos_encode = nn.Parameter(torch.randn(ntokens, hidden_dim))
        self.encoder = Encoder(self.encoder_pos_encode, num_encoder_layers = nlayers, hidden_dim = hidden_dim, nhead = nhead)

        # build decoder
        self.nqueries = nqueries
        self.query_pos_encoding = nn.Parameter(torch.randn((nqueries, hidden_dim)))
        

    def forward(self, x):
        # expect x to be raw img from resnet50, e.g. x.shape = (B, C, H, W) = (1, 2048, 14, 16)
        # we assume ntokens when we overfit to one img to check for bugs. It will be made general => cannot use fixed-size feat.dim embedding
        assert x.shape == (3, 442, 500), f'image shape is {x.shape}, but should be (3, 442, 500)'

        # backbone
        with torch.no_grad():
            x = self.backbone(x.unsqueeze(0))
        x = self.project(x)
        x = x.flatten(2) # (B, hidden_dim, T)
        x = x.transpose(-2,-1) # (B, T, hidden_dim)
        # TODO: call contiguous? 

        # encoder
        memory = self.encoder(x)

        # decoder
        queries = torch.zeros((1, self.nqueries, self.hidden_dim)) # TODO: softcode this
        layers = nn.ModuleList([DecoderBlock(hidden_dim = self.hidden_dim, nhead = self.nhead) for _ in range(self.nlayers)])
        for layer in layers:
            queries = layer(x = queries, 
                            memory = memory, 
                            memory_pos_encoding = self.encoder_pos_encode, 
                            query_pos_encoding = self.query_pos_encoding)

        # prediction
        return queries
        

img,_ = dataset[0]
detr = nanoDETR(resnet50 = resnet50)
detr(img).shape
    
        
        


        
    

torch.Size([1, 100, 256])