In [8]:
import torch
import d2l

class PatchEmbedding(torch.nn.Module):
    def __init__(self, image_size, patch_size, hidden_size, image_channels=3):
        super().__init__()
        def make_tuple(x):
            if not isinstance(x, (list, tuple)):
                return (x,x)
            else:
                return x
        image_size = make_tuple(image_size)
        patch_size = make_tuple(patch_size)

        self.num_patches = (image_size[0] // patch_size[0])*(image_size[1] // patch_size[1])
        self.conv = torch.nn.Conv2d(image_channels, hidden_size, patch_size,  patch_size)
    
    def forward(self, x):
        return self.conv(x).flatten(2).transpose(1,2)
        

In [6]:
image_size, patch_size, hidden_size, batch_size = 96, 16, 512, 4
patch_emb = PatchEmbedding(image_size, patch_size, hidden_size)
x = torch.zeros(batch_size, 3, image_size, image_size)
y = patch_emb(x)
x.shape,y.shape

(torch.Size([4, 3, 96, 96]), torch.Size([4, 36, 512]))

In [7]:
class ViTMLP(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size, dropout=0.5):
        super().__init__()
        self.dense1 = torch.nn.Linear(input_size, hidden_size)
        self.gelu = torch.nn.GELU()
        self.dropout1 = torch.nn.Dropout(dropout)
        self.dense2 = torch.nn.Linear(hidden_size, output_size)
        self.dropout2 = torch.nn.Dropout(dropout)

    def forward(self, x):
        return self.dropout2(self.dense2(self.dropout1(self.gelu(self.dense1(x)))))

In [9]:
class ViTBlock(torch.nn.Module):
    def __init__(self, hidden_size, num_heads, norm_shape, mlp_hidden_size, dropout, use_bias=False):
        super().__init__()
        self.ln1 = torch.nn.LayerNorm(norm_shape)
        self.attention = d2l.MultiHeadAttention(hidden_size,num_heads,dropout,use_bias)
        self.ln2 = torch.nn.LayerNorm(norm_shape)
        self.mlp = ViTMLP(hidden_size, mlp_hidden_size, hidden_size, dropout)

    def forward(self, x, valid_lens=None):
        x = x + self.attention(*([self.ln1(x)]*3), valid_lens)
        return x + self.mlp(self.ln2(x))
    

In [10]:
x = torch.ones((2, 100, 24))
encoder_blk = ViTBlock(24, 8, 24, 48, 0.5)
encoder_blk.eval()
y=encoder_blk(x)
x.shape, y.shape

(torch.Size([2, 100, 24]), torch.Size([2, 100, 24]))

In [31]:
class ViT(torch.nn.Module):
    def __init__(self, image_size, patch_size, hidden_size, num_heads, mlp_hidden_size, num_blocks, emb_dropout, block_dropout, use_bias=False, num_classes=10):
        super().__init__()
        self.embedding = PatchEmbedding(image_size, patch_size, hidden_size)
        self.cls_token = torch.nn.Parameter(torch.zeros(1, 1 , hidden_size))
        num_steps = self.embedding.num_patches + 1
        self.pos_embedding = torch.nn.Parameter(torch.randn(1, num_steps, hidden_size))
        self.dropout = torch.nn.Dropout(emb_dropout) 
        self.blocks = torch.nn.Sequential()
        for i in range(num_blocks):
            self.blocks.add_module("block"+str(i), ViTBlock(hidden_size, num_heads, hidden_size, mlp_hidden_size, block_dropout, use_bias))
        self.head = torch.nn.Sequential(torch.nn.LayerNorm(hidden_size), torch.nn.Linear(hidden_size, num_classes))

    def forward(self, x):
        x = self.embedding(x)
        x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), 1)
        x = self.dropout(x + self.pos_embedding)
        for block in self.blocks:
            x = block(x)
        print(x.shape,x[:,0].shape) 
        return self.head(x[:,0])

In [32]:
image_size, patch_size = 96, 16
hidden_size, num_heads, mlp_hidden_size,  num_blocks = 512, 8, 2048, 2
emb_dropout, blk_dropout  = 0.1, 0.1
vit = ViT(image_size, patch_size, hidden_size, num_heads, mlp_hidden_size,  num_blocks, emb_dropout, blk_dropout)
x = torch.ones(1,3,image_size,image_size)
y = vit(x)
x.shape, y.shape

torch.Size([1, 37, 512]) torch.Size([1, 512])


(torch.Size([1, 3, 96, 96]), torch.Size([1, 10]))

In [38]:
a = torch.ones(1,1,8)
b = torch.ones(2,10,8)

a+b
#torch.cat((a,b),1)

tensor([[[2., 2., 2., 2., 2., 2., 2., 2.],
         [2., 2., 2., 2., 2., 2., 2., 2.],
         [2., 2., 2., 2., 2., 2., 2., 2.],
         [2., 2., 2., 2., 2., 2., 2., 2.],
         [2., 2., 2., 2., 2., 2., 2., 2.],
         [2., 2., 2., 2., 2., 2., 2., 2.],
         [2., 2., 2., 2., 2., 2., 2., 2.],
         [2., 2., 2., 2., 2., 2., 2., 2.],
         [2., 2., 2., 2., 2., 2., 2., 2.],
         [2., 2., 2., 2., 2., 2., 2., 2.]],

        [[2., 2., 2., 2., 2., 2., 2., 2.],
         [2., 2., 2., 2., 2., 2., 2., 2.],
         [2., 2., 2., 2., 2., 2., 2., 2.],
         [2., 2., 2., 2., 2., 2., 2., 2.],
         [2., 2., 2., 2., 2., 2., 2., 2.],
         [2., 2., 2., 2., 2., 2., 2., 2.],
         [2., 2., 2., 2., 2., 2., 2., 2.],
         [2., 2., 2., 2., 2., 2., 2., 2.],
         [2., 2., 2., 2., 2., 2., 2., 2.],
         [2., 2., 2., 2., 2., 2., 2., 2.]]])