In [39]:
import torch
import torch.nn as nn
from torchvision import transforms

## Patch embedding

In [135]:
class PatchEmbed(nn.Module):
    """ 2D Image to Patch Embedding
    """
    def __init__(self, H, W, patch_size=16, in_chans=3, embed_dim=100):
        super().__init__()

        self.num_patches = (H * W) // (patch_size ** 2)

        # since we haveset kernel_size=stride=patch_size, the conv kernel acts on each indivial patch. The conv operation acts as a lienar embedding
        self.conv = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.conv(x)  #(B, C, H, W) -> (B, embed_dim, H, W)
        x = x.flatten(2) #(B, embed_dim, H, W) -> (B, embed_dim, H*W)
        x = x.transpose(1, 2) # (B, embed_dim, H*W) -> (B, H*W, embed_dim)

        return x


In [141]:
from PIL import Image
import torchvision.transforms.functional as TF

image_tensor = torch.rand(1, 3, 224, 224) # (B,C,H,W)
patch_embed = PatchEmbed( 224, 224)

embedded_patches = patch_embed(image_tensor)

print("Shape of the output from PatchEmbed:", embedded_patches.shape)

Shape of the output from PatchEmbed: torch.Size([1, 196, 100])


In [144]:
H = 224
W = 224
c = 3
p = 16
d = 100

input_size = H * W * c
output_size = (H * W // p ** 2) * d

print(input_size, output_size)

150528 19600
