In [4]:
import torch
import torch.nn as nn

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=3, patch_size=16, emb_size=768, img_size=224):
        super().__init__()
        self.patch_size = patch_size
        self.projection = nn.Conv2d(
            in_channels, out_channels=emb_size, kernel_size=patch_size, stride=patch_size
        )
        print("Weight matrix shape:", self.projection.weight.shape)
        # print("Weight matrix:", self.projection.weight)
        print("Bias:", self.projection.bias.shape)
        # print("Requires Grad:", conv.weight.requires_grad)
    def forward(self, x):
        x = self.projection(x)  # shape: [B, D, H/P, W/P]
        print("X:",x.shape)
        x = x.flatten(2)        # shape: [B, D, N]
        print("X:",x.shape)
        x = x.transpose(1, 2)   # shape: [B, N, D]
        print("X:",x.shape)
        return x


In [11]:
in_channels=3; patch_size=16; emb_size=768; img_size=224
patches = PatchEmbedding(in_channels, patch_size, emb_size, img_size)
# Example input tensor with shape [B, C, H, W]
input_tensor = torch.randn(1, 3, 224, 224)  # Batch size of 1
x = patches(input_tensor)
print("Output shape:", x.shape)  # Should be [1, 196, 768] for a 224x224 image with 14x14 patches


Weight matrix shape: torch.Size([768, 3, 16, 16])
Bias: torch.Size([768])
X: torch.Size([1, 768, 14, 14])
X: torch.Size([1, 768, 196])
X: torch.Size([1, 196, 768])
Output shape: torch.Size([1, 196, 768])


In [12]:
layer_norm = nn.LayerNorm(emb_size)
x_norm = layer_norm(x) # B x N x D

print("LayerNorm output shape:", x_norm.shape)  # Should be [1, 196, 768]


LayerNorm output shape: torch.Size([1, 196, 768])


In [None]:
attn = nn.MultiheadAttention(embed_dim=emb_size, num_heads=8, batch_first=True)
x_attn, _ = attn(x_norm, x_norm, x_norm)
x = x + x_attn  # Residual connection # [1, 196, 768]

print(x.shape)

torch.Size([1, 196, 768])


In [None]:
x = layer_norm(x) # norm before feedforward # [1, 196, 768]

mlp = nn.Sequential(
    nn.Linear(emb_size, emb_size * 4),
    nn.ReLU(),
    nn.Linear(emb_size * 4, emb_size)
)
x_mlp = mlp(x)  # Feed forward network
x = x + x_mlp  # Residual connection # [1, 196, 768]

print(x.shape)

torch.Size([1, 196, 768])


In [None]:
n_classes =  1000  # Number of classes for ImageNet
# Output head for ImageNet
output_layer = nn.Sequential(
    nn.Linear(emb_size, n_classes),      # from last hidden size to 1000 classes
    nn.Softmax(dim=1)          # softmax over class dimension
)

# print(output_layer.shape)