In [1]:
import torch
import torch.nn as nn   

In [2]:
class patchEmbed(nn.Module):
    '''split image into patches and then embed them
    Paramenters:
        img_size: int, size of image (square image)
        patch_size: int, size of patch
        in_chans: int, number of input channels
        embed_dim: int, dimension of embedding
    '''
    def __init__(self,img_size:int,patch_size:int,in_chans:int = 3,embed_dim:int = 768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        self.prog = nn.Conv2d(
            in_channels = in_chans,
            out_channels = embed_dim,
            kernel_size = patch_size,
            stride = patch_size
        )
    
    def forward(self,x:torch.Tensor)->torch.Tensor:
        x = self.prog(x)    
        x = x.flatten(2)
        x = x.transpose(1,2)
        return x

In [3]:
class Attention(nn.Module):
    '''attention mechanism
    parameters:
        dim: int, dimension of input
        num_heads: int, number of heads
        qkv_bias: bool, whether to include bias in qkv projection
        attn_p: float, dropout probability for attention
        proj_p: float, dropout probability for projection
    '''
    def __init__(self,dim:int,num_heads:int = 12,qkv_bias:bool = True,attn_p:float = 0.,proj_p:float = 0.):
        super().__init__()
        self.n_heads = num_heads
        self.dim = dim 
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.qkv = nn.Linear(dim,dim*3,bias = qkv_bias)
        self.attn_drop = nn.Dropout(attn_p)
        self.proj = nn.Linear(dim,dim)
        self.proj_drop = nn.Dropout(proj_p)
    
    def forward(self,x:torch.Tensor)->torch.Tensor:
        n_samples,n_tokens,dim = x.shape
        
        if dim != self.dim:
            raise ValueError(f'Input dim {dim} should be {self.dim}')
        
        qkv = self.qkv(x)
        qkv = qkv.reshape(n_samples,n_tokens,3,self.n_heads,self.head_dim)
        qkv = qkv.permute(2,0,3,1,4)
        q,k,v = qkv[0],qkv[1],qkv[2]
        k_t = k.transpose(-2,-1)
        dp = (q @ k_t) * self.scale
        attn = dp.softmax(dim = -1)
        attn = self.attn_drop(attn)
        weighted_avg = attn @ v
        weighted_avg = weighted_avg.transpose(1,2)
        weighted_avg = weighted_avg.flatten(2)
        x = self.proj(weighted_avg)
        x = self.proj_drop(x)
        return x 

In [4]:
class MLP(nn.Module):
    '''multi-layer perceptron
    parameters:
        in_features: int, number of input features
        hidden_features: int, number of hidden features
        out_features: int, number of output features
        p: float, dropout probability
    '''
    def __init__(self,in_features:int,hidden_features:int,out_features:int,p:float = 0.):
        super().__init__()
        self.fc1 = nn.Linear(in_features,hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features,out_features)
        self.drop = nn.Dropout(p)
    
    def forward(self,x:torch.Tensor)->torch.Tensor:
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x    

In [5]:
class Block(nn.Module):
    '''tranformer block
    parameters:
        dim: int, dimension of input
        num_heads: int, number of heads
        mlp_ratio: int, ratio of hidden to input dimension
        qkv_bias: bool, whether to include bias in qkv projection
        p: float, dropout probability
    '''
    def __init__(self,dim:int,num_heads:int,mlp_ratio:int = 4,qkv_bias:bool = True,p:float = 0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = Attention(dim,num_heads = num_heads,qkv_bias = qkv_bias,attn_p = p,proj_p = p)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(in_features = dim,hidden_features = dim*mlp_ratio,out_features = dim,p = p)
    
    def forward(self,x:torch.Tensor)->torch.Tensor:
        x = self.norm1(x)
        x = x + self.attn(x)
        x = self.norm2(x)
        x = x + self.mlp(x)
        return x    

In [6]:
class VisionTransformer(nn.Module):
    '''vision transformer
    parameters:
        img_size: int, size of image (square image)
        patch_size: int, size of patch
        in_chans: int, number of input channels
        num_classes: int, number of classes
        embed_dim: int, dimension of embedding
        depth: int, number of transformer blocks
        num_heads: int, number of heads
        mlp_ratio: int, ratio of hidden to input dimension
        qkv_bias: bool, whether to include bias in qkv projection
        p: float, dropout probability
    '''
    def __init__(self,
                 img_size:int,
                 patch_size:int,
                 in_chans:int = 3,
                 num_classes:int = 1,
                 embed_dim:int = 768,
                 depth:int = 12,
                 num_heads:int = 12,
                 mlp_ratio:int = 4,
                 qkv_bias:bool = True,
                 p:float = 0.):
        super().__init__()
        self.patch_embed = patchEmbed(img_size = img_size,patch_size = patch_size,in_chans = in_chans,embed_dim = embed_dim)
        self.cls_token = nn.Parameter(torch.randn(1,1,embed_dim))
        self.pos_embed = nn.Parameter(torch.randn(1,self.patch_embed.n_patches+1,embed_dim))
        self.pos_drop = nn.Dropout(p)
        self.blocks = nn.ModuleList([
            Block(dim = embed_dim,num_heads = num_heads,
                  mlp_ratio = mlp_ratio,
                  qkv_bias = qkv_bias,
                  p = p) 
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim,num_classes)
        
    def forward(self,x:torch.Tensor)->torch.Tensor:
        n_samples = x.shape[0]
        x = self.patch_embed(x)
        cls_token = self.cls_token.expand(n_samples,-1,-1)
        x = torch.cat((cls_token,x),dim = 1)
        x = x + self.pos_embed
        x = self.pos_drop(x)
        for block in self.blocks:
            x = block(x)
        x = self.norm(x)
        x = x[:,0]
        x = self.head(x)
        return x    