In [1]:
import torch
import torch.nn as nn

from utils import img_to_patch
from datasets import ViTDataLoader
from models.vit import AttentionBlock

In [23]:
patch_size = 4
embed_dim = 256
num_channels = 3
num_patches = 64
hidden_dim = 512
num_heads = 8
dropout = 0.2
num_layers = 6

In [3]:
vit_loader = ViTDataLoader(dataset_path='data', batch_size=1)
train = vit_loader.get_train_loader()

Files already downloaded and verified


Global seed set to 42


### Pick a batch with single image from training data

In [4]:
x = next(iter(train))[0]
x.shape

torch.Size([1, 3, 32, 32])

### Convert img to patch

![img_to_patch](img_to_patch.png)

In [5]:
x = img_to_patch(x, patch_size=4)
x.shape

torch.Size([1, 64, 48])

In [6]:
B, T, _ = x.shape
print(B, T)

1 64


In [7]:
in_features = num_channels*(patch_size**2)
print(f"feature vector from each image patch {in_features} ")

feature vector from each image patch 48 


### Input Linear Projection layer

In [8]:
input_layer = nn.Linear(in_features=in_features, out_features=embed_dim)

Project feature vector of size 48 to embed_dim of 256

In [9]:
x = input_layer(x)
x.shape

torch.Size([1, 64, 256])

### CLS token

In [10]:
cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
cls_token.shape

torch.Size([1, 1, 256])

In [11]:
# repeat the token to match the batch
cls_token = cls_token.repeat(B, 1, 1)
cls_token.shape

torch.Size([1, 1, 256])

### Appending CLS token at the beginning of Input projection

In [12]:
x = torch.cat([cls_token, x], dim=1)

In [13]:
x.shape

torch.Size([1, 65, 256])

### Positional embedding layer

In [14]:
pos_embedding = nn.Parameter(torch.randn(1, 1+num_patches, embed_dim))
pos_embedding.shape

torch.Size([1, 65, 256])

In [15]:
pos_embedding[:, :T+1].shape

torch.Size([1, 65, 256])

### Adding pos embedding to CLS+Input

In [16]:
x = x + pos_embedding[:, :T+1]
x.shape

torch.Size([1, 65, 256])

In [17]:
dropout = nn.Dropout(p=0.2)

In [18]:
x = dropout(x)
x.shape

torch.Size([1, 65, 256])

### Taking transpose for self attention of embedded vectors

In [19]:
x = x.transpose(0, 1)
x.shape

torch.Size([65, 1, 256])

### Transformer block

In [25]:
transformer = nn.Sequential(*[AttentionBlock(embed_dim, hidden_dim, num_heads, dropout=dropout)
                                           for _ in range(num_layers)])

In [26]:
x = transformer(x)
x.shape

torch.Size([65, 1, 256])

### Output feature vector of CLS token considered as image representation

In [27]:
cls = x[0]
cls.shape

torch.Size([1, 256])

### MLP head for output classification

In [28]:
mlp_head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, 10)
        )

In [29]:
out = mlp_head(cls)
out.shape # output probabilities

torch.Size([1, 10])

In [30]:
out

tensor([[ 0.6173,  0.6003, -0.2451, -0.1761,  0.9034,  0.4863,  0.8357, -0.8066,
          0.2395, -0.5247]], grad_fn=<AddmmBackward0>)

In [None]:
%load_ext tensorboard
%tensorboard --logdir saved_models/experiments