In [1]:
import torch
import torch.nn as nn

# We will start with the most important part first

class PatchEmbedding(nn.Module):
    ''' Split Image into patches and then embed them
    
    Parameters
    ----------
    img_size : int
        size of the image (square)

    patch_size : int
        size of the patch (square)

    in_chans : int
        Number of input channels

    embed_dim : int
        The embedding dimension


    Attributes
    ----------
    n_patches : int
        Number of patches inside our image

    proj : nn.Conv2d
        Conv layer that does both the splitting into patches and their embedding
    '''
    def __init__(self, img_size, patch_size, in_chans=3, embed_dim=768):

        super().__init__()
        self.img_size = img_size 
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2

        self.proj = nn.Conv2d(
                in_chans,
                embed_dim,
                kernel_size=patch_size,
                stride=patch_size
        )

    def forward(self, x):
        ''' Run forward pass
        
        Parameters
        ----------
        x : torch.tensor
            Shape '(n_samples, in_chans, img_size, img_size)'

        Returns
        -------
        torch.tensor
            Shape '(n_samples, n_patches, embed_dim)'  Returns a shape 3 tensor
        '''

        x = self.proj(
                x
        )  # (n_samples, embed_dim, n_patches ** 0.5, n_patches ** 0.5)

        x = x.flatten(2) # (n_samples, embed_dim, n_patches)

        x = x.transpose(1, 2) # (n_samples, n_patches, embed_dim)

        return x

In [None]:
class Attention(nn.Module):
    ''' Attention mechanism
    
    Parameters
    ----------
    dim : int
        The input and output dimension of per token features

    n_heads : int
        Number of attention heads

    qkv_bias : bool
        If True, then we include bias to the query, key and value tensors

    attn_p : float
        Dropout probability applied to the query, key and value tensors

    proj_p : float
        Dropout probability applied to the output tensor

    Attributes
    ----------
    scale : float
        Normalising constant for the dot product
    
    qkv : nn.Linear
        Linear projection for the query, key and value

    proj : nn.Linear
        Linear mapping that takes in the concatenated output of all attention heads and maps it into a new space
    
    attn_drop, proj_drop : nn.Dropout
        Dropout layers
    '''

    def __init__(self, dim, n_heads = 12, qkv_bias=True, attn_p = 0, proj_p = 0):
        super().__init__()

        self.n_heads = n_heads
        self.dim = dim
        self.head_dim = dim // n_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim*3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_p)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_p)

    def forward(self, x):
        ''' Run forward pass
        
        Parameters
        ----------
        x : torch.tensor
            Shape '(n_samples, n_patches + 1, dim)'     # dim = embedding dimensions that we receive from the above PatchEmbedding Class

        Returns
        -------
        torch.tensor
            Shape '(n_samples, n_patches + 1, dim)

        NOTE - n_patches + 1, the extra +1 here comes from the class token that is being added every time
        '''                
        
        n_samples, n_tokens, dim = x.shape    # n_samples = n_samples, n_tokens = n_patches + 1, dim = embedding_dimension (768)

        if dim != self.dim:
            raise ValueError 

        qkv = self.qkv(x)   # (n_samples, n_patches + 1, 3*dim)     
        # See Example how a linear layer works for 3D input

        qkv = qkv.reshape(
            n_samples, n_tokens, 3, self.n_heads, self.head_dim
        )  # (n_samples, n_patches + 1, 3, n_heads, head_dim)

        '''
         EXPLANATION 
        ----------
        qkv = torch.randn([4, 257, 3*768])
        print(qkv.shape)

        x = qkv.reshape(4, 257, 3, 12, 64)
        print(x.shape)
        '''   
 
        qkv = qkv.permute(
            2, 0, 3, 1, 4
        ) # (3, n_samples, n_heads, n_patches + 1, head_dim)

        q, k, v = qkv[0], qkv[1], qkv[2]

        '''
        EXPLANATION
        -----------
        qkv = qkv.permute(
            2, 0, 3, 1, 4
        ) 
        print(qkv.shape) --- > torch.Size([3, 4, 12, 257, 64])

        q, k, v = qkv[0], qkv[1], qkv[2] 
        q.shape, k.shape, v.shape --- > torch.Size([4, 12, 257, 64]) - ALL 3
        '''

        key_transpose = k.transpose(-2, -1) # (n_samples, n_heads, head_dim, n_patches + 1) ---> torch.Size([4, 12, 64, 257])

        dot_product = (
            q @ key_transpose
        ) * self.scale   # (n_samples, n_heads, n_patches + 1, n_patches + 1)

        '''
        q = torch.Size([4, 12, 257, 64]) 
        k.transpose = torch.Size([4, 12, 64, 257])
        scale = (768/12) ** -0.5

        dot_product ---> torch.Size([4, 12, 257, 257])
        
        '''
        # attention = Softmax(Q.K_transpose)/ sqrt(dim of key vector)

        attention = dot_product.softmax(dim = -1) 
        attention = self.attn_drop(attention)

        weighted_avg = attention @ v  # (n_samples, n_headsm n_patches+1, head_dim)

        weighted_avg = weighted_avg.transpose(1, 2) # (n_samples, n_patches +1, n_heads, head_dim)      

        x = self.proj(weighted_avg)
        x = self.proj_drop(x)   

        return x 

In [None]:
class MLP(nn.Module):
    """
    Multilayer Perceptron

    Parameters
    ----------
    in_features : int
    hidden_features : int
    out_features : int
    p : float

    Attribute
    ---------
    fc : nn.Linear
    act : nn.GELU
    fc2 : nn.Linear
    drop : nn.Dropout
    """

    def __init__(self, in_features, hidden_features, out_features, p = 0):
        super().__init__()

        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(p)

    def forward(self, x):
        """
        Run forward pass

        Parameters
        ----------
        x : torch.tensor
            Shape '(n_samples, n_patches + 1, in_features)

        Returns
        -------
        torch.tensor
            Shape '(n_samples, n_patches + 1, out_features)'  Returns a shape 3 tensor
        """

        x = self.fc1(x) # (n_samples, n_patches + 1, hidden_features)
        x = self.act(x) 
        x = self.drop(x)

        x = self.fc2(x) # (n_samples, n_patches + 1, out_features)
        x = self.drop(x)
        
        return x

In [15]:
class Block(nn.Module):
    '''
    Transformer Block

    Parameters
    ----------
    dim : int
        Embedding dimension

    n_heads : int
        Number of attention heads

    mlp_ratio : float
        Determines the hidden dimension size of the 'MLP' module with respect to 'dim'

    qkv_bias : bool
        If true then we include bias to the q,k,v projections

    p, attn_p : float

    Attributes
    ----------
    norm1, norm2 : LayerNorm

    attn : Attention Module

    mlp : MLP module
    '''

    def __init__(self, dim, n_heads, mlp_ratio=4.0, qkv_bias=True, p=0., attn_p=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim, eps = 1e-6)

        self.attn = Attention(
            dim,
            n_heads=n_heads,
            qkv_bias=qkv_bias,
            attn_p=attn_p,
            proj_p=p
        )

        self.norm2 = nn.LayerNorm(dim, eps=1e-6)
        hidden_features = int(dim * mlp_ratio)

        self.mlp = MLP(
            in_features=dim,
            hidden_features=hidden_features,
            out_features=dim
        )

    def forward(self, x):
        """
        Run forward pass

        Parameters
        ----------
        x : torch.tensor
            Shape '(n_samples, n_patches + 1, dim)

        Returns
        -------
        torch.tensor
            Shape '(n_samples, n_patches + 1, dim)'  Returns a shape 3 tensor
        """

        x = x + self.attn(self.norm1(x))

        x = x + self.mlp(self.norm2(x))

        return x


In [16]:
class VisionTransformer(nn.Module):
    """
    Simplified implementation of Vision Transformer

    Parameters
    ----------
    img_size : int
    patch_size : int
    in_chans : int
    n_classes : int

    embed_dim : int
        Dimensionality of the token/patch embeddings

    depth : int
        Number of Blocks

    n_heads : int
        Number of attention heads

    mlp_ratio : float
        Determines the hidden dimension of the 'MLP' module

    qkv_bias : bool

    p, attn_p : float
        Dropout Probability

    Attributes
    ----------
    patch_embed : PatchEmbedding
        Instance of 'PatchEmbedding' layer

    cls_token : nn.Parameter
        Learnable parameter that will represent the first token in the sequence
        It has embed_dim elements

    pos_embed : nn.Parameter
        positional embedding of the cls token + all the patches
        It has '(n_patches + 1) * embed_dim' elements

    pos_drop = nn.Dropout
        Dropout Layer

    blocks : nn.ModuleList
        List of 'Block' modules

    norm : nn.LayerNorm
        Layer Normalisation

    """

    def __init__(self,
                img_size=384,
                patch_size=16,
                in_chans=3,
                n_classes=1000,
                embed_dim=768,
                depth=12,
                n_heads=12,
                mlp_ratio=4,
                qkv_bias=True,
                p=0.,
                attn_p=0.,
    ):

        super().__init__()

        self.patch_embed = PatchEmbedding(
                    img_size=img_size,
                    patch_size=patch_size,
                    in_chans=in_chans,
                    embed_dim=embed_dim
        )

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        
        self.pos_embed = nn.Parameter(
                torch.zeros(1, 1 + self.patch_embed.n_patches, embed_dim)
        )

        self.pos_drop = nn.Dropout(p=p)

        self.blocks = nn.ModuleList(
            [
                Block(
                    dim=embed_dim,
                    n_heads=n_heads,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    p=p,
                    attn_p=attn_p
                    )

                for _ in range(depth)
            ]
        )

        self.norm = nn.LayerNorm(embed_dim, eps=1e-6)

        self.head = nn.Linear(embed_dim, n_classes)

    def forward(self, x):
        ''' Run the forward pass
        
        Parameters
        ----------
        x : torch.tensor
            Shape '(n_samples, in_chans, img_size, img_size)'

        Returns
        -------
        logits : torch.tensor
        '''

        n_samples = x.shape[0]
        x = self.patch_embed(x)

        cls_token = self.cls_token.expand(n_samples, -1, -1) # (n_samples, 1, embed_dim)

        x = torch.cat((cls_token, x), dim=1) # (n_samples, 1 + n_patches, embed_dim)

        x = x + self.pos_embed # (n_samples, 1 + n_patches, embed_dim)
        x = self.pos_drop(x)

        for block in self.blocks:
            x = block(x)

        x = self.norm(x)

        cls_token_final = x[:, 0] # Just the CLS token

        x = self.head(cls_token_final)

        return x

In [17]:
def get_n_params(module):
    return sum(p.numel() for p in module.parameters() if p.requires_grad)

In [8]:
import timm

model_name = "vit_base_patch16_384"
model_official = timm.create_model(model_name, pretrained=True)
model_official.eval()
print(type(model_official))

<class 'timm.models.vision_transformer.VisionTransformer'>


In [None]:
custom_config = {
    "img_size":384,
    "in_chans" : 3,
    "patch_size" : 16,
    "embed_dim" : 768, 
    "depth" : 12,
    "n_heads" : 12,
    "qkv_bias" : True,
    "mlp_ratio" : 4
}

model_custom = VisionTransformer(**custom_config)
model_custom.eval()

In [31]:
import numpy as np

def assert_tensors_equal(t1, t2):
    a1, a2 = t1.detach().cpu().numpy(), t2.detach().cpu().numpy()

    np.testing.assert_allclose(a1, a2)

for (n_o, p_o) , (n_c, p_c) in zip(
    model_official.named_parameters(), model_custom.named_parameters()
):
    assert p_o.numel() == p_c.numel()
    print(f"{n_o} | {n_c}")

    if (n_o != n_c):
        print("Variable names are diff for - ", n_o , n_c)

    p_c.data[:] = p_o.data

    assert_tensors_equal(p_c.data, p_o.data)

inp = torch.rand(1, 3, 384, 384)
res_c = model_custom(inp.cuda())
res_o = model_official(inp.cuda())

cls_token | cls_token
pos_embed | pos_embed
patch_embed.proj.weight | patch_embed.proj.weight
patch_embed.proj.bias | patch_embed.proj.bias
blocks.0.norm1.weight | blocks.0.norm1.weight
blocks.0.norm1.bias | blocks.0.norm1.bias
blocks.0.attn.qkv.weight | blocks.0.attn.qkv.weight
blocks.0.attn.qkv.bias | blocks.0.attn.qkv.bias
blocks.0.attn.proj.weight | blocks.0.attn.proj.weight
blocks.0.attn.proj.bias | blocks.0.attn.proj.bias
blocks.0.norm2.weight | blocks.0.norm2.weight
blocks.0.norm2.bias | blocks.0.norm2.bias
blocks.0.mlp.fc1.weight | blocks.0.mlp.fc1.weight
blocks.0.mlp.fc1.bias | blocks.0.mlp.fc1.bias
blocks.0.mlp.fc2.weight | blocks.0.mlp.fc2.weight
blocks.0.mlp.fc2.bias | blocks.0.mlp.fc2.bias
blocks.1.norm1.weight | blocks.1.norm1.weight
blocks.1.norm1.bias | blocks.1.norm1.bias
blocks.1.attn.qkv.weight | blocks.1.attn.qkv.weight
blocks.1.attn.qkv.bias | blocks.1.attn.qkv.bias
blocks.1.attn.proj.weight | blocks.1.attn.proj.weight
blocks.1.attn.proj.bias | blocks.1.attn.proj.b

RuntimeError: CUDA error: CUBLAS_STATUS_INVALID_VALUE when calling `cublasSgemm( handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)`

In [None]:
from torchsummary import summary

summary(model_official.cuda(), (3, 384, 384))

In [3]:
import torch
qkv = torch.randn([4, 257, 3*768])
print(qkv.shape)

qkv = qkv.reshape(4, 257, 3, 12, 64)

qkv = qkv.permute(
            2, 0, 3, 1, 4
        ) 
print(qkv.shape)
q, k, v = qkv[0], qkv[1], qkv[2] 
q.shape, k.shape, v.shape

torch.Size([4, 257, 2304])
torch.Size([3, 4, 12, 257, 64])


(torch.Size([4, 12, 257, 64]),
 torch.Size([4, 12, 257, 64]),
 torch.Size([4, 12, 257, 64]))

In [4]:
key_transpose = k.transpose(-2, -1)
key_transpose.shape

torch.Size([4, 12, 64, 257])

In [10]:
scale = 64 ** -0.5
dot_product = (
            q @ key_transpose
        ) * scale 

print(q.shape, key_transpose.shape)
dot_product.shape

torch.Size([4, 12, 257, 64]) torch.Size([4, 12, 64, 257])


torch.Size([4, 12, 257, 257])

In [18]:
x = torch.ones([3, 3])
y = torch.ones([5, 3])

z = x @ y.transpose(-1, 0)
print(z.shape)
z

torch.Size([3, 5])


tensor([[3., 3., 3., 3., 3.],
        [3., 3., 3., 3., 3.],
        [3., 3., 3., 3., 3.]])

In [15]:
y = y.transpose(-1, 0)
y.shape

torch.Size([3, 5])