In [1]:
import torch
import torch.nn as nn

from utils import img_to_patch
from datasets import ViTDataLoader
from models.vit import AttentionBlock

In [2]:
patch_size = 4
embed_dim = 256
num_channels = 3
num_patches = 64
hidden_dim = 512
num_heads = 8
dropout = 0.2
num_layers = 6

In [3]:
vit_loader = ViTDataLoader(dataset_path='data', batch_size=1)
train = vit_loader.get_train_loader()

Files already downloaded and verified


Global seed set to 42


### Pick a batch with single image from training data (1 x 3 x 32 x 32)

In [13]:
x = next(iter(train))[0]
x.shape

torch.Size([1, 3, 32, 32])

### Convert img to patch

![img_to_patch](img_to_patch.png)

In [14]:
x = img_to_patch(x, patch_size=4)
x.shape

torch.Size([1, 64, 48])

In [15]:
B, T, _ = x.shape
print(B, T)

1 64


In [19]:
in_features = num_channels*(patch_size**2)
print(f"feature vector from each image patch {in_features} ")

feature vector from each image patch 48 


### Input Linear Projection layer

In [17]:
input_layer = nn.Linear(in_features=in_features, out_features=embed_dim)

Project feature vector of size 48 to embed_dim of 256

In [20]:
x = input_layer(x)
x.shape

torch.Size([1, 64, 256])

### CLS token

In [21]:
cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
cls_token.shape

torch.Size([1, 1, 256])

In [22]:
cls_token = cls_token.repeat(B, 1, 1)
cls_token.shape

torch.Size([1, 1, 256])

### Appending CLS token layer to Input projection

In [None]:
x = torch.cat([cls_token, x], dim=1)

In [69]:
x.shape

torch.Size([4, 65, 256])

### Positional embedding layer

In [72]:
pos_embedding = nn.Parameter(torch.randn(1, 1+num_patches, embed_dim))
pos_embedding.shape

torch.Size([1, 65, 256])

In [74]:
pos_embedding[:, :T+1].shape

torch.Size([1, 65, 256])

### Adding pos embedding to CLS+Input

In [75]:
x = x + pos_embedding[:, :T+1]
x.shape

torch.Size([4, 65, 256])

In [76]:
dropout = nn.Dropout(p=0.2)

In [77]:
x = dropout(x)
x.shape

torch.Size([4, 65, 256])

### Taking transpose for self attention of embedded vectors

In [78]:
x = x.transpose(0, 1)
x.shape

torch.Size([65, 4, 256])

### Multi-Headed Attention

In [83]:
transformer = nn.Sequential(*[AttentionBlock(embed_dim, hidden_dim, num_heads, dropout=dropout)
                                           for _ in range(num_layers)])

In [84]:
x = transformer(x)
x.shape

torch.Size([65, 4, 256])

### Output feature vector of CLS token

In [108]:
cls = x[0]
cls.shape

torch.Size([4, 256])

### MLP head for output classification

In [109]:
mlp_head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, 10)
        )

In [112]:
out = mlp_head(cls)
out.shape # output probabilities

torch.Size([4, 10])

In [113]:
out

tensor([[-0.1672, -0.8544,  0.0415, -1.3316, -0.2346,  0.0321,  0.1616,  1.1617,
         -0.3525, -0.3639],
        [ 0.2456, -0.3916, -0.4581, -1.0139,  0.2070,  0.5930, -0.6954,  1.1957,
         -0.5270, -0.5197],
        [-0.0993, -0.9147,  0.2171, -0.6132,  0.2205,  0.2926,  0.0089,  0.9484,
         -0.6022, -0.3287],
        [-0.2900, -0.6046, -0.1202, -1.0657,  0.7453,  0.7413, -0.2312,  0.6247,
         -0.2546, -0.1642]], grad_fn=<AddmmBackward0>)

In [None]:
%load_ext tensorboard
%tensorboard --logdir saved_models/experiments