In [1]:
import torch
import torch.nn as nn

In [2]:
class PatchEmbed(nn.Module): # Normally use nn.Embedding instead
  """
  Parameters
  img_size : Size of the image
  patch_size : Size of the patch
  in_chans: Number of input channels
  embed_dim: Embedding dimension

  Attributes
  n_patchs: Number of patches inside of our image
  proj: Convolutional layer that splits into patches + embedding
  """
  def __init__(self, img_size, patch_size, in_chans=3, embed_dim=768):
    super().__init__()
    self.img_size = img_size
    self.patch_size = patch_size
    assert img_size % patch_size == 0, "Error on patch size"
    self.n_patches = (img_size // patch_size) ** 2

    self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) # Linear Projection of Flattened Patches
  
  def forward(self, x):
    """Run forward pass.
    Parameters
    x: (n_samples, in_chans, img_size, img_size)
    
    Returns
    torch.Tensor(n_samples, n_patches, embed_dim)
    """
    x = self.proj(x) # (n_samples, embed_dim, sqrt(n_patches), sqrt(n_patches))
    x = x.flatten(2) # (n_samples, embed_dim, n_patches)
    x = x.transpose(1, 2) # (n_samples, n_patches, embed_dim)

    return x

In [3]:
class Attention(nn.Module):
  """
  Parameters
  dim: Input and output dimension of per token features
  n_heads: Number of attention heads
  qkv_bias: Whether to include bias to qkv projections
  attn_p: Dropout probability applied to the qkv tensors
  proj_p: Dropout probability applied to the output tensor

  Attributes
  scale: Normalizing constant for the drop product
  qkv: Linear projection for the qkv
  proj: Linear mapping that inputs concatenated output of all attention heads
  attn_drop, proj_drop: Dropout layers
  """
  def __init__(self, dim, n_heads=12, qkv_bias=True, attn_p=0., proj_p=0.):
    super().__init__()
    self.n_heads = n_heads
    self.dim = dim # Feature of attention: same input and output dim
    self.head_dim = dim // n_heads
    self.scale = self.head_dim ** -0.5 # prevents extreme values using softmax - saturation problem

    self.qkv = nn.Linear(dim, dim*3, bias=qkv_bias)
    self.attn_drop = nn.Dropout(attn_p)
    self.proj = nn.Linear(dim, dim)
    self.proj_drop = nn.Dropout(proj_p)

  def forward(self, x):
    """Run forward pass
    Parameters
    x: (n_samples, n_patches+1, dim), +1 due to cls token

    Returns
    torch.tensor(n_samples, n_patches+1, dim)
    """
    n_samples, n_tokens, dim = x.shape
    assert dim == self.dim

    qkv = self.qkv(x) # (n_samples, n_patches+1, 3*dim)
    qkv = qkv.reshape(n_samples, n_tokens, 3, self.n_heads, self.head_dim) # (n_samples, n_patches+1, 3, n_heads, head_dim)
    qkv = qkv.permute(2, 0, 3, 1, 4) # (3, n_samples, n_heads, n_patches+1, head_dim)

    q, k, v  = qkv[0], qkv[1], qkv[2]
    k_t = k.transpose(-2, -1) # to compute dot product
    dp = (q @ k_t) * self.scale # (n_samples, n_heads, n_patches+1, n_patches+1)
    attn = dp.softmax(dim=-1) # (n_samples, n_heads, n_patches+1, n_patches+1)
    attn = self.attn_drop(attn)

    weighted_avg = attn @ v # (n_samples, n_heads, n_patches+1, head_dim)
    weighted_avg = weighted_avg.transpose(1, 2) # (n_samples, n_patches+1, n_heads, head_dim)
    weighted_avg = weighted_avg.flatten(2) # (n_samples, n_patches+1, dim)
    x = self.proj(weighted_avg) # (n_samples, n_patches+1, dim)
    x = self.proj_drop(x)

    return x

In [4]:
class MLP(nn.Module):
  """
  Parameters
  in_features: Number of input/output features
  hidden_features: Number of hidden features
  p: Dropout probability

  Attribute
  fc: First linear layer
  act: GELU activation function
  fc2: Second linear layer
  drop: Dropout layer
  """
  def __init__(self, features, hidden_features, p=0.):
    super().__init__()
    self.fc1 = nn.Linear(features, hidden_features)
    self.act = nn.GELU()
    self.fc2 = nn.Linear(hidden_features, features)
    self.drop = nn.Dropout(p)

  def forward(self, x):
    """Run forward pass.
    Parameters
    x: (n_samples, n_patches+1, features)

    Returns
    torch.tensor(n_samples, n_patches+1, features)
    """
    x = self.fc1(x) # (n_samples, n_patches+1, hidden_features)
    x = self.act(x)
    x = self.drop(x)
    x = self.fc2(x) # (n_samples, n_patches+1,features)
    x = self.drop(x)

    return x

In [5]:
class Block(nn.Module):
  """
  Parameters
  dim: Embedding dimension
  n_heads: Number of attention heads
  mlp_ratio: hidden dimension size of MLP module respect to dim
  qkv_bias: Whether to include bias to qkv projections
  p, attn_p: Dropout probability

  Attributes
  norm1, norm2; Layer normalization
  attn: Attention module
  mlp: MLP module
  """
  def __init__(self, dim, n_heads, mlp_ratio=4.0, qkv_bias=True, p=0., attn_p=0.):
    super().__init__()
    self.norm1 = nn.LayerNorm(dim, eps=1e-6) # Always last dimension is normalized
    self.attn = Attention(dim, n_heads=n_heads, qkv_bias=qkv_bias, attn_p=attn_p, proj_p=p)
    self.norm2 = nn.LayerNorm(dim, eps=1e-6)
    hidden_features = int(dim * mlp_ratio)
    self.mlp = MLP(features=dim, hidden_features=hidden_features)
  
  def forward(self, x):
    """Run forward pass
    Parameters
    x: (n_samples, n_patches+1, dim)

    Returns
    torch.tensor(n_samples, n_patches+1, dim)
    """ 
    x = x + self.attn(self.norm1(x))
    x = x + self.mlp(self.norm2(x))

    return x

In [6]:
class VisionTransformer(nn.Module):
  """
  Parameters
  img_size: Height/Width of image
  patch_size: Height/Width of patch
  in_chans: Number of input channels
  n_classes: Number of classes
  embed_dim: Dimensionality of the token/patch embeddings
  depth: Number of blocks
  n_heads: Number of attention heads
  mlp_ratio: hidden dimension size of MLP module respect to dim
  qkv_bias: Whether to include bias to qkv projections
  p, attn_p: Dropout probability

  Attributes
  patch_embed: Instance of PatchEmbed layer
  cls_token: Learnable parameter represent first token in the sequence
  pos_emb: Positional embedding of cls_token + all the pacehs (n_patches+1)*embed_dim
  pos_drop: Dropout layer
  blocks: List of Block modules
  norm: Layer normalization
  """
  def __init__(self, img_size=384, patch_size=16, in_chans=3, n_classes=1000, embed_dim=768,
               depth=12, n_heads=12, mlp_ratio=4., qkv_bias=True, p=0., attn_p=0.):
    super().__init__()
    self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
    self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) # first two dims for convinience
    self.pos_embed = nn.Parameter(torch.zeros(1, 1+self.patch_embed.n_patches, embed_dim))
    self.pos_drop = nn.Dropout(p=p)
    
    self.blocks = nn.ModuleList(
        [
        Block(dim=embed_dim, n_heads=n_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, p=p, attn_p=attn_p)
        for _ in range(depth)
        ]
    )

    self.norm = nn.LayerNorm(embed_dim, eps=1e-6)
    self.head = nn.Linear(embed_dim, n_classes)

  def forward(self, x):
    """Run forward pass.
    Parameters
    x: (n_samples, in_chans, img_size, img_size)
    
    Returns
    torch.tensor(n_samples, n_classes)
    """
    n_samples = x.shape[0]
    x = self.patch_embed(x) 

    cls_token = self.cls_token.expand(n_samples, -1, -1) # (n_samples, 1, embed_dim)
    x = torch.cat((cls_token, x), dim=1) # (n_samples, 1+n_patches, embed_dim)
    x = x + self.pos_embed # (n_samples, 1+n_patches, embed_dim)
    x = self.pos_drop(x)

    for block in self.blocks:
      x = block(x)
    
    x = self.norm(x)
    cls_token_final = x[:, 0] # (n_samples, 1, embed_dim)
    x = self.head(cls_token_final) 

    return x

In [10]:
# Test with random input
input = torch.rand(1, 3, 384, 384).cuda()
model = VisionTransformer().cuda()
res = model(input)
res.shape

torch.Size([1, 1000])