# ViT model

> Putting together patch embeddings and transformer encoder

In [None]:
#| default_exp model

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import torch
from torch import nn
import torch.functional as F
from torchvision import datasets
import numpy as np

import yaml
from fastcore.basics import Path

In [None]:
CONFIG_PATH = '../config.yml'
DATA_PATH = Path('../input') 

Load parameters from the config file. 

In [None]:
config = yaml.safe_load(open(CONFIG_PATH))

In [None]:
dset = datasets.CIFAR10(DATA_PATH, download=True)

Files already downloaded and verified


In [None]:
images, targets = dset.data, dset.targets
len(images), len(targets)

(50000, 50000)

Prepare a small batch of images to test the image processing.

In [None]:
images.shape

(50000, 32, 32, 3)

Sample a bunch of points and select those as indices of the image for training.

In [None]:
image_idx = np.random.randint(low=0, high=len(images), size=3)

In [None]:
# corresponding labels
targets = [targets[t] for t in image_idx]
targets

[3, 6, 2]

In [None]:
n_classes = config["model"]["n_classes"]
n_classes

10

# Putting together PatchEmbedding and TransformerEncoder

In [None]:
#| export
from vit_pytorch.patch import PatchEmbedding
from vit_pytorch.encoder import TransformerEncoder
import torchvision.transforms as T
from einops import repeat

In [None]:
images = torch.Tensor(images[image_idx])
images = images/255.
hw = config['data']['hw']
augs = T.Resize(hw)

images = augs(images.permute(0, 3, 1, 2))
images.shape



torch.Size([3, 3, 224, 224])

In [None]:
# | export


class VisionTransformer(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        n_classes = config["model"]["n_classes"]
        training = config["model"]["training"]
        emb_dim = config["patch"]["out_ch"]
        dropout = config["model"]["clf_dropout"]
        hidden_units = config["model"]["clf_hidden_units"]
        self.patch_embedding = PatchEmbedding(config)
        self.transformer_encoder = TransformerEncoder(config)
        # classification head
        self.ln = nn.LayerNorm(normalized_shape=emb_dim)
        mlp_layers = (
            [
                nn.Linear(emb_dim, hidden_units),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Linear(hidden_units, n_classes),
            ]
            if training
            else [nn.Linear(emb_dim, n_classes)]
        )
        self.mlp = nn.Sequential(self.ln, *mlp_layers)
        # learned representations 
        self.embeddings_ = None
        self.cls_tokens_ = None

    def forward(self, x):
        bs = x.shape[0]
        x = self.patch_embedding(x)
        x = self.transformer_encoder(x)
        self.embeddings_ = x[:, 1:, :] # learned embeddings
        # this is the first item that was concatenated in the patch embedding
        # same as attribute cls_token of PatchEmbedding 
        self.cls_tokens_ = x[:, 0, :] # shape of cls_token is bs, 1, embed_dim 
        x = self.mlp(self.cls_tokens_)
        return x


In [None]:
vit = VisionTransformer(config)

In [None]:
outs = vit(images)
outs.shape

torch.Size([3, 10])

In [None]:
vit.embeddings_.shape

torch.Size([3, 196, 768])

In [None]:
vit.cls_tokens_.shape

torch.Size([3, 768])

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()