In [None]:
%load_ext tensorboard
%tensorboard --logdir saved_models/experiments

In [79]:
from datasets import ViTDataLoader
from utils import img_to_patch
import torch.nn as nn
import torch
from models.vit import AttentionBlock

In [81]:
patch_size = 4
embed_dim = 256
num_channels = 3
num_patches = 64
hidden_dim = 512
num_heads = 8
dropout = 0.2
num_layers = 6

In [30]:
vit_loader = ViTDataLoader(dataset_path='data', batch_size=4)
train = vit_loader.get_train_loader()

Files already downloaded and verified


Global seed set to 42


### Pick a batch from training data

In [58]:
x = next(iter(train))[0]

In [59]:
x.shape

torch.Size([4, 3, 32, 32])

### Convert img to patch

In [60]:
x = img_to_patch(x, patch_size=4)

In [61]:
x.shape

torch.Size([4, 64, 48])

In [62]:
B, T, _ = x.shape
print(B, T)

4 64


In [63]:
in_features = num_channels*(patch_size**2)
in_features

48

### Input Linear Projection layer

In [64]:
input_layer = nn.Linear(in_features=in_features, out_features=embed_dim)

In [65]:
x = input_layer(x)
x.shape

torch.Size([4, 64, 256])

### CLS token

In [66]:
cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
cls_token.shape

torch.Size([1, 1, 256])

In [67]:
cls_token = cls_token.repeat(B, 1, 1)
cls_token.shape

torch.Size([4, 1, 256])

### Appending CLS token layer to Input projection

In [None]:
x = torch.cat([cls_token, x], dim=1)

In [69]:
x.shape

torch.Size([4, 65, 256])

### Positional embedding layer

In [72]:
pos_embedding = nn.Parameter(torch.randn(1, 1+num_patches, embed_dim))
pos_embedding.shape

torch.Size([1, 65, 256])

In [74]:
pos_embedding[:, :T+1].shape

torch.Size([1, 65, 256])

### Adding pos embedding to CLS+Input

In [75]:
x = x + pos_embedding[:, :T+1]
x.shape

torch.Size([4, 65, 256])

In [76]:
dropout = nn.Dropout(p=0.2)

In [77]:
x = dropout(x)
x.shape

torch.Size([4, 65, 256])

In [78]:
x = x.transpose(0, 1)
x.shape

torch.Size([65, 4, 256])

### Multi-Headed Attention

In [83]:
transformer = nn.Sequential(*[AttentionBlock(embed_dim, hidden_dim, num_heads, dropout=dropout)
                                           for _ in range(num_layers)])

In [84]:
x = transformer(x)
x.shape

torch.Size([65, 4, 256])

In [None]:
### 

In [108]:
cls = x[0]
cls.shape

torch.Size([4, 256])

In [109]:
mlp_head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, 10)
        )

In [110]:
out = mlp_head(cls)
out.shape

torch.Size([4, 10])