In [62]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torchvision.models import ResNet50_Weights  
import utils
import math

In [20]:
backbone = models.resnet50(weights = ResNet50_Weights.IMAGENET1K_V2)
backbone = nn.Sequential(*list(backbone.children())[:-2])

In [209]:
C, H0, W0 = 3, 442, 500
img = torch.randn((C, H0, W0))
img.shape
img = img.unsqueeze(0)
img.shape
img = backbone(img)
img.shape
dmodel = 256
proj = nn.Conv2d(2048, 256, 1)
img = proj(img)
_, _, H, W = img.shape
img = img.flatten(-2,-1) # (B, dmodel, ntoken)
img.shape
img = img.transpose(-2,-1)
img.shape # (B, ntokens, dmodel)

pos_encode = utils.sinusoidal_pos_encode_2d(max_dim = 100, d_model = 256)
pos_encode = pos_encode[:H, :W]
pos_encode.shape
pos_encode = pos_encode.flatten(0,1)
pos_encode.shape
pos_encode = pos_encode.unsqueeze(0)
pos_encode.shape, img.shape # (B, ntokens, dmodel)

query = nn.Linear(dmodel, dmodel, bias = False)
key = nn.Linear(dmodel, dmodel, bias = False)
value = nn.Linear(dmodel, dmodel, bias = False)

q = query(img + pos_encode)
k = key(img + pos_encode)
v = value(img)
q.shape, k.transpose(-2,-1).shape
scores = q @ k.transpose(-2,-1)
scores.shape
scores /= math.sqrt(dmodel)
scores = scores.softmax(dim = -1)
out = scores @ v
out.shape

torch.Size([1, 224, 256])

In [212]:
class EncoderHead(nn.Module):
    def __init__(self, hidden_dim = 256, nhead = 8):
        super().__init__()
        self.head_size = hidden_dim // nhead # 32
        self.query = nn.Linear(hidden_dim, self.head_size, bias = False) # 256 -> 32
        self.key = nn.Linear(hidden_dim, self.head_size, bias = False)
        self.value = nn.Linear(hidden_dim, self.head_size, bias = False)
        
    def forward(self, x, positional_encoding):
        # project into separate spaces
        q = self.query(x + positional_encoding) # (B, H*W, head_size) 
        k = self.key(x + positional_encoding)
        v = self.value(x)
        
        # attention scores
        scores = q @ k.transpose(-2,-1)
        scores /= math.sqrt(self.head_size) 
        scores = F.softmax(scores, dim = -1)
        out = scores @ v # (B, H*W, head_size) 

        return out

In [224]:
nhead = 8
eheads = [EncoderHead() for _ in range(nhead)]
vals = [eh(img, pos_encode) for head in eheads]
vals = torch.cat(vals, dim = -1)
final_proj = nn.Linear(dmodel, dmodel)
vals = final_proj(vals)
vals.shape



torch.Size([1, 224, 256])

In [226]:
class EncoderLayer(nn.Module):
    def __init__(self, hidden_dim = 256, nhead = 8):
        super().__init__()
        self.encoder_heads = nn.ModuleList([EncoderHead() for _ in range(nhead)])
        self.cat_proj = nn.Linear(hidden_dim, hidden_dim, bias = False)

    def forward(self, x, positional_encoding):
        cat = torch.cat([head(x, positional_encoding) for head in self.encoder_heads], dim = -1) # (B, HW, head_size) -> (B, HW, hidden_dim)
        out = self.cat_proj(cat)
        return out

In [232]:
ffn1 = nn.Linear(dmodel, dmodel*4)
relu = nn.ReLU()
ffn2 = nn.Linear(dmodel*4, dmodel)
ln1 = nn.LayerNorm(dmodel)
ln2 = nn.LayerNorm(dmodel)
el = EncoderLayer()

x = img
x_res = x
x_res = ln1(x_res)
x_res = el(x_res, pos_encode)
x = x + x_res

x_res = x
x_res = ln2(x_res)
x_res = ffn1(x_res)
x_res = relu(x_res)
x_res = ffn2(x_res)
x = x_res + x
img.shape, x.shape



(torch.Size([1, 224, 256]), torch.Size([1, 224, 256]))

In [233]:
class EncoderBlock(nn.Module):
    def __init__(self, hidden_dim = 256, nhead = 8):
        super().__init__()
        self.encoder_layer = EncoderLayer(hidden_dim= hidden_dim, nhead = nhead)

        # layer norm
        self.gamma = nn.Parameter(torch.ones(1, 1, hidden_dim))
        self.beta = nn.Parameter(torch.zeros(1, 1, hidden_dim))
        self.eps = 1e-5

        # ffn
        self.layer_norm = nn.LayerNorm(hidden_dim)
        self.fc1 = nn.Linear(hidden_dim, hidden_dim*4)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim*4, hidden_dim)

    def forward(self, x, positional_encoding):
        # attention layer
        x_res = x
        mean = x_res.mean(dim = -1, keepdim = True) # for layernorm
        var = x_res.var(dim = -1, correction = 0, keepdim = True)
        x_res = x_res - mean
        x_res = x_res / torch.sqrt(var + self.eps)
        x_res = x_res*self.gamma + self.beta
        x_res = self.encoder_layer(x_res, positional_encoding) # (B, HW, hidden_dim)
        x = x + x_res # residual connection

        # compute layer
        x_res = x
        x_res = self.layer_norm(x_res)
        x_res = self.fc1(x_res)
        x_res = self.relu(x_res)
        x_res = self.fc2(x_res)
        x = x + x_res

        return x

In [235]:
blocks = [EncoderBlock() for _ in range(6)]
x = img
for block in blocks:
    img = block(img, pos_encode)

x.shape
    

torch.Size([1, 224, 256])

In [None]:
class Encoder(nn.Module):
    def __init__(self, num_encoder_layers = 6, hidden_dim = 256, nhead = 8):
        super().__init__()
        self.layers = nn.ModuleList([EncoderBlock(hidden_dim=hidden_dim, nhead=nhead) for _ in range(num_encoder_layers)])

    def forward(self, x, positional_encoding):
        for layer in self.layers:
            x = layer(x, positional_encoding)
        return x
        

In [278]:
nqueries = 100
q = torch.randn((1, nqueries, dmodel))
k = v = x
q.shape, k.shape, v.shape

query = nn.Linear(dmodel, dmodel, bias = False)
key = nn.Linear(dmodel, dmodel, bias = False)
value = nn.Linear(dmodel, dmodel, bias = False)

q = query(q)
k = key(k)
v = value(v)
q.shape # (B, nqueries, dmodel)
k.shape, v.shape # (B, ntokens, dmodel)

nhead = 8
hs = dmodel // nhead # 32

# q = q.view(1,nqueries,nhead,hs)
# q.shape # (B, nqueries, nhead, hs)
# q = q.transpose(1,2)
# q.shape # (B, nhead, nqueries, hs)

ntokens = k.shape[1]
q = q.view(1,nqueries,nhead,hs).transpose(1,2) # (B, nhead, nqueries, hs)
k = k.view(1,ntokens, nhead, hs).transpose(1,2) # (B, nhead, ntokens, hs)
v = v.view(1,ntokens, nhead, hs).transpose(1,2) # (B, nhead, ntokens, hs)

scores = q @ k.transpose(-2,-1) 
# (B, nhead, nqueries, hs) @ (B, nhead, hs, ntokens)
# ---> (B, nhead, nqueries, ntokens)
scores /= math.sqrt(hs)
scores = scores.softmax(dim = -1) # (B, nhead, nqueries, ntokens)
out = scores @ v
 # (B, nhead, nqueries, ntokens) @ (B, nhead, ntokens, hs)
# (B, nhead, nqueries, hs)

# would like (B, nqueries, dmodel)
out = out.transpose(1,2)
out.shape # (B, nqueries, nhead, hs)
out = out.flatten(-2,-1)
out.shape
final_transf = nn.Linear(dmodel, dmodel, bias = False)
out = final_transf(out)
out.shape





torch.Size([1, 100, 256])

In [268]:
a = torch.arange(4)
a, a.view(2,2)

(tensor([0, 1, 2, 3]),
 tensor([[0, 1],
         [2, 3]]))