<a href="https://colab.research.google.com/github/torrhen/pytorch_miscellaneous/blob/main/ViT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
from torch import nn

In [27]:
class PatchEmbedding(nn.Module):
  '''
  Create patch embeddings using hybrid architecture as described in section 3.1
  '''
  def __init__(self, in_channels=3, patch_size=16, embedding_dim=768):
    super(PatchEmbedding, self).__init__()
    self.in_channels = in_channels
    self.patch_size = patch_size
    self.embedding_dim = embedding_dim
    # create input sequence of patches by flattening the spatial dimensions of the feature map and projecting to the embedding dimension used by the transformer.
    self.embedding = nn.Sequential(
        # [B, 3, 224, 244] -> [B, 768, 14, 14]
        nn.Conv2d(in_channels=self.in_channels, out_channels=self.embedding_dim, kernel_size=self.patch_size, stride=self.patch_size, padding=0),
        # [B, 768, 196]
        nn.Flatten(start_dim=2, end_dim=3)
    )
    self.class_token = nn.Parameter(torch.randn(1, 1, self.embedding_dim))
    self.position_embedding = nn.Parameter(torch.randn(1, 197, self.embedding_dim))

  def forward(self, x):
    # input spatial dimensions should be divided without remainder into 16x16 patches
    height, width = x.shape[-2:] #  x = [C, H, W]
    assert(height % self.patch_size == 0 and width % self.patch_size == 0)
    # calculate patch embedding
    x = self.embedding(x) # [B, (P . C^2), (HW / P^2)]
    x = x.permute(0, 2, 1) # [B, (HW / P^2), (P . C^2)]
    # prepend class token to patch embedding
    x = torch.cat((self.class_token, x), dim=1)
    # add position embedding to patch embedding
    x = x + self.position_embedding
    return x

