In [2]:
import torch
import random
import timeit
import math
import numpy
import numpy as np
from torch import optim
from torch import nn
!pip install torchinfo
from torchinfo import summary
import torch.nn.functional as F
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Dataset
from dataclasses import dataclass
from torch.optim.lr_scheduler import CosineAnnealingLR
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt

%matplotlib inline

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [4]:
@dataclass
class vit_config:
    num_channels: int = 3
    batch_size:int = 16
    image_size: int = 224
    patch_size: int = 16
    num_heads:int = 8
    dropout: float = 0.0
    layer_norm_eps: float = 1e-6
    num_encoder_layers: int = 12
    random_seed: int = 42
    epochs: int = 30
    num_classes: int = 10
    learning_rate: float = 1e-5
    adam_weight_decay: int = 0
    adam_betas: tuple = (0.9, 0.999)
    embd_dim: int = (patch_size ** 2) * num_channels           # 768
    num_patches: int = (image_size // patch_size) ** 2         # 196
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    print(device)

cpu


In [6]:
config = vit_config

random.seed(config.random_seed)
numpy.random.seed(config.random_seed)
torch.manual_seed(config.random_seed)
torch.cuda.manual_seed(config.random_seed)
torch.cuda.manual_seed_all(config.random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

In [8]:
#a custom embedding layer that concatenates and returns the patch embeddings along with position embeddings for each patch of the input image.

In [9]:
class VisionEmbedding(nn.Module):
    def __init__(self, config: vit_config):
        super().__init__()

        self.config  = config
        self.patch_embedding = nn.Sequential(
            nn.Conv2d(
                in_channels=config.num_channels,
                out_channels=config.embd_dim,
                kernel_size=config.patch_size,
                stride=config.patch_size,
                padding="valid"
            ),
            nn.Flatten(start_dim=2)
        )

        self.cls_token = nn.Parameter(torch.randn(size=(1, 1, config.embd_dim)), requires_grad=True)
        self.pos_embeddings = nn.Parameter(torch.randn(size=(1, config.num_patches + 1, config.embd_dim)), requires_grad=True)
        self.dropout = nn.Dropout(p=config.dropout)

    def forward(self, x : torch.Tensor) -> torch.Tensor:
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)

        patch_embd = self.patch_embedding(x).transpose(2,1)
        patch_embd = torch.cat([cls_token, patch_embd], dim=1)
        embd = self.pos_embeddings + patch_embd
        embd = self.dropout(embd)
        return embd

In [11]:
model = VisionEmbedding(config)
dummy_input = torch.randn(1, config.num_channels, config.image_size, config.image_size).to(config.device)
output = model(dummy_input)
print(f"Output shape: {output.shape}")
print(f"Output (first 5 elements): {output[0, :5, :5]}")

Output shape: torch.Size([1, 197, 768])
Output (first 5 elements): tensor([[-0.4828, -2.9331, -1.0430,  0.5191,  1.1593],
        [-0.5916,  0.7174, -0.8135, -0.0548,  0.7490],
        [-0.5815,  0.2315, -0.9726, -1.6018,  0.6856],
        [-0.2447,  0.3341,  0.1944, -0.9270,  0.7581],
        [-0.9925,  0.5160,  0.9445,  0.3006,  0.7511]],
       grad_fn=<SliceBackward0>)
