In [None]:
%pip install torch

In [3]:
import torch
import torch.nn as nn

In [None]:
class PatchEmbed(nn.Module):
    """
    Patch embedding module.
    Splits the image into non-overlapping patches and then embeds them.

    Parameters:
    - in_channels: Number of input channels
    - patch_size: Size of the patch
    - emb_size: Size of the embedding (after the patch embedding)
    - img_size: Size of the image (image has to be square)

    Attributes:
    - n_patches: Number of patches inside a single image
    - projection: Convolutional layer that does both the splitting into patches and the embedding
    
    """

    def __init__(self, img_size, in_channels=3, patch_size = 16 , emb_size=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        self.projection = nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size)
    
    def forward(self, x):
        """
        Perform the forward pass of the PatchEmbed module.
        
        Parameters:
        x : torch.Tensor.Shape (n_samples, in_channels, img_size, img_size)

        Returns:
        torch.Tensor.Shape (n_samples, n_patches, emb_size)

        """
        x = self.projection(x) # (n_samples, emb_size, n_patches ** 0.5, n_patches ** 0.5)
        x = x.flatten(2) # (n_samples, emb_size, n_patches)
        x = x.transpose(1, 2) # (n_samples, n_patches, emb_size)
        return x
    
class Attention(nn.module):
    """
    Attention module.
    
    Parameters:
    - emb_size: Size of the embedding (input and out dimension of per token features)
    - n_heads: Number of attention heads
    -drop_prob_qkv: Dropout rate to apply to the query, key and value tensors
    - drop_prob_o: Dropout rate after the softmax layer
    -qkv_bias: Whether to include bias in the qkv projection layers
    - causal: Whether to apply causal masking or not
    
    Attributes:
    - scales: Precomputed square root of the head_dim
    - qkv: Linear layer for the query, key and value
    - proj_qkv: linear mapping that takes in concatenated output of all heads and projects it back to new space
    - dropout_qkv, dropout_o: Dropout layer

    
    """
    
    def __init__(self, emb_size, n_heads = 12, drop_prob_o = 0.,drop_prob_qkv = 0., qkv_bias = True):
        super().__init__()
        self.emb_size = emb_size
        self.n_heads = n_heads
        self.head_dim = emb_size // n_heads
        assert self.head_dim * n_heads == emb_size, "Embedding size should be divisible by the number of heads"
        self.scales = self.head_dim ** -0.5
        self.qkv = nn.Linear(emb_size, emb_size * 3, bias = qkv_bias) # q, k, v
        self.dropout_qkv = nn.Dropout(drop_prob_qkv)
        self.proj_qkv = nn.Linear(emb_size, emb_size)
        self.dropout_o = nn.Dropout(drop_prob_o)
        
    def forward(self, x):
        """
        Perform the forward pass of the Attention module.
        
        Parameters:
        x : torch.Tensor.Shape (n_samples, n_patches + 1, emb_size)

        Returns:
        torch.Tensor.Shape (n_samples, n_patches + 1, emb_size)

        """
        
        n_samples, n_tokens, emb_size = x.shape
        if emb_size != self.emb_size:
            raise ValueError(f"Attention expects {self.emb_size} input features, but got {emb_size}")
        qkv = self.qkv(x).reshape(n_samples, n_tokens, 3, self.n_heads, self.head_dim) # (n_samples, n_patches + 1, 3 * dim), (n_samples, n_patches + 1, 3, n_heads, head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4) # (3, n_samples, n_heads, n_patches + 1, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]
        k_t = k.transpose(-2, -1) # (n_samples, n_heads, head_dim, n_patches + 1)
        dots = (q @ k_t) * self.scales # (n_samples, n_heads, n_patches + 1, n_patches + 1)
        attn = dots.softmax(dim=-1) # (n_samples, n_heads, n_patches + 1, n_patches + 1)
        attn = self.dropout_kqv(attn)
        weighted_avg = attn @ v # (n_samples, n_heads, n_patches + 1, head_dim)
        weighted_avg = weighted_avg.transpose(1, 2) # (n_samples, n_patches + 1, n_heads, head_dim)
        weighted_avg = weighted_avg.flatten(2) # (n_samples, n_patches + 1, emb_size)
        x = self.proj_qkv(weighted_avg) # (n_samples, n_patches + 1, emb_size)
        x = self.dropout_o(x)
        return x
        
class MLP(nn.Module):
    """
    MLP module.

    Parameters:
    - in_features: Number of input features
    - hidden_features: Number of hidden layer features
    - out_features: Number of output features
    - drop_prob: Dropout rate to apply

    Attributes:
    - fc1, fc2: Linear layers -> nn.Linear(in_features, hidden_features), nn.Linear(hidden_features, out_features)
    - act_fn: GELU activation function -> nn.GELU()
    - dropout: Dropout layer -> nn.Dropout(drop_prob)

    """
    def __init__(self, in_features, hidden_features, out_features, drop_prob = 0.):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act_fn = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.dropout = nn.Dropout(drop_prob)

    def forward(self, x):
        """
        Perform the forward pass of the MLP module.
        
        Parameters:
        x : torch.Tensor.Shape (n_samples, n_patches + 1, emb_size or input_features)

        Returns:
        torch.Tensor.Shape (n_samples, n_patches + 1, emb_size or out_features)

        """
        x = self.fc1(x) # (n_samples, n_patches + 1, hidden_features)
        x = self.act_fn(x) # (n_samples, n_patches + 1, hidden_features)
        x = self.dropout(x) # (n_samples, n_patches + 1, hidden_features)
        x = self.fc2(x) # (n_samples, n_patches + 1, out_features)
        x = self.dropout(x) # (n_samples, n_patches + 1, out_features)  # here out_features are same as hidden_features

        return x
    
class Block(nn.Module):
    """
    Transformer block.

    Parameters:
    - emb_size: Size of the input embedding
    - n_heads: Number of attention heads
    - mlp_ratio: Multiplier for the hidden dim of the MLP wrt the input embedding
    - qkv_bias: Whether to include bias in the qkv projection layers
    - drop_prob: Dropout rate to apply
    - attn_drop_prob: Dropout rate to apply to the attention module

    Attributes:
    - norm1, norm2: LayerNorms  # Layer norm makes sure that the mean of each sample is 0 and the standard deviation is 1
    - attn: Attention module
    - mlp: MLP module

    """
    def __init__(self, emb_size, n_heads, mlp_ratio = 4., qkv_bias = True, drop_prob = 0., attn_drop_prob = 0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(emb_size, eps=1e-6)
        self.attn = Attention(emb_size, n_heads, drop_prob_o = attn_drop_prob, drop_prob_qkv = attn_drop_prob, qkv_bias = qkv_bias)
        self.norm2 = nn.LayerNorm(emb_size, eps=1e-6)
        hidden_features = int(emb_size * mlp_ratio)
        self.mlp = MLP(emb_size, hidden_features, emb_size, drop_prob)

    def forward(self, x):
        """
        Perform the forward pass of the Block module.
        
        Parameters:
        x : torch.Tensor.Shape (n_samples, n_patches + 1, emb_size)

        Returns:
        torch.Tensor.Shape (n_samples, n_patches + 1, emb_size)

        """
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x
    
class VisionTransformer(nn.Module):
    """
    Vision Transformer.

    Parameters:
    - img_size: Size of the image (image has to be square)
    - patch_size: Size of the patch
    - in_channels: Number of input channels
    - n_classes: Number of output classes
    - emb_size: Size of the token/patch embedding
    - depth: Number of transformer blocks
    - n_heads: Number of attention heads
    - mlp_ratio: Multiplier for the hidden dim of the MLP wrt the input embedding
    - qkv_bias: Whether to include bias in the qkv projection layers
    - drop_prob: Dropout rate to apply
    - attn_drop_prob: Dropout rate to apply to the attention module

    Attributes:
    - patch_embed: PatchEmbed module
    - cls_token: Learnable parameter that will represent the whole image or 1st token in sequence.
    - pos_embed: Positional embedding of cls token + all the patches and it has (n_patches + 1)* emb_size elements.
    - pos_drop: Dropout layer
    - blocks: Sequence of the transformer blocks
    - norm: Layer norm

    """
    def __init__(self, img_size = 384, patch_size = 16, in_channels = 3, n_classes = 1000, emb_size = 768, depth = 12, n_heads = 12, mlp_ratio = 4., qkv_bias = True, drop_prob = 0., attn_drop_prob = 0.):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size, in_channels, patch_size, emb_size)
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
        self.pos_embed = nn.Parameter(torch.randn(1, 1 + self.patch_embed.n_patches, emb_size))
        self.pos_drop = nn.Dropout(drop_prob)
        self.blocks = nn.ModuleList([Block(emb_size, n_heads, mlp_ratio, qkv_bias, drop_prob, attn_drop_prob) for _ in range(depth)])
        self.norm = nn.LayerNorm(emb_size, eps=1e-6)
        self.head = nn.Linear(emb_size, n_classes)

    def forward(self, x):
        """
        Perform the forward pass of the VisionTransformer module.
        
        Parameters:
        x : torch.Tensor.Shape (n_samples, in_channels, img_size, img_size)

        Returns:
        torch.Tensor.Shape (n_samples, n_classes) logits over all the classes

        """
        n_samples = x.shape[0]
        x = self.patch_embed(x)
        cls_token = self.cls_token.expand(n_samples, -1, -1) # (n_samples, 1, emb_size)
        x = torch.cat((cls_token, x), dim=1) # (n_samples, 1 + n_patches, emb_size)
        x = x + self.pos_embed # (n_samples, 1 + n_patches, emb_size)
        x = self.pos_drop(x)
        for block in self.blocks:
            x = block(x)
        x = self.norm(x)
        cls_token_final = x[:, 0]
        x = self.head(cls_token_final)
        return x

# Verification

In [None]:
%pip install timm

In [66]:
%pip install .

Note: you may need to restart the kernel to use updated packages.


ERROR: Directory '.' is not installable. Neither 'setup.py' nor 'pyproject.toml' found.

[notice] A new release of pip is available: 23.1.2 -> 24.0
[notice] To update, run: python.exe -m pip install --upgrade pip


In [None]:
import numpy as np
import torch
import timm 
from custom import VisionTransformer

In [None]:
# helper function to count number of parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def assert_tensors_equal(t1, t2):
    assert t1.shape == t2.shape, f"Shapes do not match: {t1.shape} and {t2.shape}"
    assert torch.allclose(t1, t2, atol=1e-4), "Tensors do not match"

model_name = "vit_base_patch16_384"
model = timm.create_model(model_name, pretrained=True)
model.eval()

custom_config = {"img_size": 384, "patch_size": 16, "in_channels": 3, "n_classes": 1000, "emb_size": 768, "depth": 12, "n_heads": 12, "mlp_ratio": 4.0, "qkv_bias": True, "drop_prob": 0.0, "attn_drop_prob": 0.0}

custom_model = VisionTransformer(**custom_config)
custom_model.eval()

In [61]:
torch.randn(1, 1, 5).expand(2, -1, -1)

tensor([[[-0.8924, -1.2784, -0.6579, -1.6928, -0.0461]],

        [[-0.8924, -1.2784, -0.6579, -1.6928, -0.0461]]])

In [13]:
module = torch.nn.Dropout(0.4)

In [53]:
sum(p.numel() for p in module.parameters() if p.requires_grad)

6

In [None]:
# elementwise affine = False means that the layer will not learn the affine parameters (bias and scale)
# only last dimension is normalized

In [48]:
inp = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32)
n_samples, n_features = inp.shape
module = torch.nn.LayerNorm(n_features, elementwise_affine=True)

In [58]:
module(inp).mean(-1)

tensor([2., 2.], grad_fn=<MeanBackward1>)

In [57]:
module.bias.data += 1
module.weight.data += 10

In [40]:
a = torch.ones(3,2,4,5)

In [None]:
a

In [None]:
a.flatten(2)

In [20]:
module(a)

tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])

In [18]:
module.eval()

Dropout(p=0.4, inplace=False)

In [19]:
module.training

False