# Transformer Encoder

> Methods build a transformer encoder block

In [None]:
#| default_exp encoder

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import torch
from torch import nn
import torch.functional as F
from torchvision import datasets
import numpy as np

import yaml
from fastcore.basics import Path

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
CONFIG_PATH = '../config.yml'
DATA_PATH = Path('../input') 

Load parameters from the config file. 

In [None]:
config = yaml.safe_load(open(CONFIG_PATH))

In [None]:
dset = datasets.CIFAR10(DATA_PATH, download=True, train=True)

Files already downloaded and verified


In [None]:
images, targets = dset.data, dset.targets
len(images), len(targets)

(50000, 50000)

Prepare a small batch of images to test the image processing.

In [None]:
images.shape

(50000, 32, 32, 3)

Sample a bunch of points and select those as indices of the image for training.

In [None]:
image_idx = np.random.randint(low=0, high=len(images), size=3)

In [None]:
# corresponding labels
targets = [targets[t] for t in image_idx]
targets

[7, 5, 1]

In [None]:
in_ch = config["patch"]["in_ch"]
out_ch = config["patch"]["out_ch"]

In [None]:
# size of each small patch
patch_size = config['patch']['size']
patch_size

16

In [None]:
images.shape[1:]

(32, 32, 3)

In [None]:
images = torch.Tensor(images[image_idx])
images = images/255.
images.shape

torch.Size([3, 32, 32, 3])

Increase image size to match with ViT paper $224\times 224$

In [None]:
#| export
import torchvision.transforms as T

In [None]:
hw = config['data']['hw']
augs = T.Resize(hw)
augs

Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=None)

In [None]:
images = augs(images.permute(0, 3, 1, 2))
images.shape

torch.Size([3, 3, 224, 224])

# Make Embedded Patches 

In [None]:
from vit_pytorch.patch import PatchEmbedding

In [None]:
patch_embed = PatchEmbedding(config)(images)
patch_embed.shape

torch.Size([3, 197, 768])

# Prepare Transformer Layer

Apply LayerNorm over the embedding dimension, which in our cases is $768$.

In [None]:
seq_len = config['patch']['n']
embed_dim = config['patch']['out_ch']
seq_len, embed_dim

(196, 768)

In [None]:
x_ln = nn.LayerNorm(normalized_shape=embed_dim)(patch_embed)
x_ln.shape

torch.Size([3, 197, 768])

In [None]:
num_heads = config['encoder']['msa_heads']

In [None]:
attn_output, attn_output_weights = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads)(x_ln, x_ln, x_ln)

In [None]:
attn_output.shape

torch.Size([3, 197, 768])

## Prepare MSA block

In [None]:
class MultiheadSelfAttn(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.seq_len = config['patch']['n']
        self.embed_dim = config['patch']['out_ch']
        self.num_heads = config['encoder']['msa_heads']
        self.ln = nn.LayerNorm(normalized_shape=embed_dim)
        self.msa = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads)
    
    def forward(self, x):
        x_ln = self.ln(x)
        x_msa, _ = self.msa(x_ln, x_ln, x_ln)
        return x_msa + x

In [None]:
x = MultiheadSelfAttn(config)(patch_embed)
x.shape

torch.Size([3, 197, 768])

## Prepare MLP block

In [None]:
class MLPBlock(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        embed_dim = config["patch"]["out_ch"]
        mlp_size = config["encoder"]["mlp_size"]
        dropout = config["encoder"]["mlp_dropout"]
        self.ln = nn.LayerNorm(normalized_shape=embed_dim)
        self.linear = nn.Sequential(
            nn.Linear(in_features=embed_dim, out_features=mlp_size),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(in_features=mlp_size, out_features=embed_dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        x_linear = self.linear(self.ln(x))
        return x_linear + x


In [None]:
x = MLPBlock(config)(x)
x.shape

torch.Size([3, 197, 768])

## Transformer Encoder

In [None]:
class TransformerEncoder(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.msa_block = MLPBlock(config)
        self.mlp_block = MLPBlock(config)
    
    def forward(self, x):
        x = self.msa_block(x)
        return self.mlp_block(x)

In [None]:
out = TransformerEncoder(config)(patch_embed)
out.shape

torch.Size([3, 197, 768])

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()