# ViT model

> Putting together patch embeddings and transformer encoder

In [None]:
#| default_exp model

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import torch
from torch import nn
import torch.functional as F
from torchvision import datasets
import numpy as np

import yaml
from fastcore.basics import Path

In [None]:
CONFIG_PATH = '../config.yml'
DATA_PATH = Path('../input') 

Load parameters from the config file. 

In [None]:
config = yaml.safe_load(open(CONFIG_PATH))

In [None]:
dset = datasets.CIFAR10(DATA_PATH, download=True)

In [None]:
images, targets = dset.data, dset.targets
len(images), len(targets)

Prepare a small batch of images to test the image processing.

In [None]:
images.shape

Sample a bunch of points and select those as indices of the image for training.

In [None]:
image_idx = np.random.randint(low=0, high=len(images), size=3)

In [None]:
# corresponding labels
targets = [targets[t] for t in image_idx]
targets

In [None]:
n_classes = config["model"]["n_classes"]
n_classes

# Putting together PatchEmbedding and TransformerEncoder

In [None]:
#| export
from vit_pytorch.patch import PatchEmbedding
from vit_pytorch.encoder import TransformerEncoder
import torchvision.transforms as T

In [None]:
images = torch.Tensor(images[image_idx])
images = images/255.
hw = config['data']['hw']
augs = T.Resize(hw)

images = augs(images.permute(0, 3, 1, 2))
images.shape

In [None]:
# | export


class VisionTransformer(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        n_classes = config["model"]["n_classes"]
        training = config["model"]["training"]
        emb_dim = config["patch"]["out_ch"]
        dropout = config["model"]["clf_dropout"]
        hidden_units = config["model"]["clf_hidden_units"]
        self.patch_embedding = PatchEmbedding(config)
        self.transformer_encoder = TransformerEncoder(config)
        # classification head
        self.ln = nn.LayerNorm(emb_dim)
        mlp_layers = (
            [
                nn.Linear(emb_dim, hidden_units),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Linear(hidden_units, n_classes),
            ]
            if training
            else [nn.Linear(emb_dim, n_classes)]
        )
        self.mlp = nn.Sequential(self.ln, *mlp_layers)
        self.representation_ = None
        self.class_token_ = None

    def forward(self, x):
        x = self.patch_embedding(x)
        x = self.transformer_encoder(x)
        self.representation_ = x[:, 1:, :] # learned representation
        # ? In lucidrains implementation, why is class token same in vision transformer and repeated in bs
        # ? In my implementation, I initialized classtoken for each image, pass only the class token through the mlp head (bs, 1, 768)
        # this is the first item that was concatenated in the patch embedding
        self.class_token_ = self.patch_embedding.class_token
        print(self.class_token_.shape)
        x = self.mlp(self.class_token_)
        return x


In [None]:
vit = VisionTransformer(config)

In [None]:
vit(images) 

In [None]:
vit.representation_.shape

In [None]:
vit.class_token_

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()