In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
def patchify(imgs, side_length):
  """
  imgs: B, C, H, W
  sidelengt: length of each paths
  """
  B, C, H, W = imgs.shape
  assert H % side_length == 0
  assert W % side_length == 0

  no_patches_h = H // side_length
  np_patches_w = W // side_length

  patches = []
  for h_idx in range(no_patches_h):
    for w_idx in range(np_patches_w):
      patch = imgs[:, :, h_idx * side_length:(h_idx+1) * side_length, w_idx * side_length:(w_idx+1) * side_length]
      patches.append(patch) # B, C, side_length, side_length

  patches = torch.stack(patches, dim=1) # B, P, C, H, W
  return patches



In [None]:
class SingleHeadAttention(nn.Module):
  def __init__(self, n_embd, head_size):
    super().__init__()
    self.n_embd = n_embd
    self.head_size = head_size
    self.q = nn.Linear(n_embd, head_size, bias=False)
    self.k = nn.Linear(n_embd, head_size, bias=False)
    self.v = nn.Linear(n_embd, head_size, bias=False)

  def forward(self, x):
    """
    x: B, P, n_embd
    """
    B, P, n_embd = x.shape
    assert n_embd == self.n_embd

    q_val = self.q(x) # B, P, n_embd
    k_val = self.k(x) # B, P, n_embd
    v_val = self.v(x) # B, P, n_embd

    sim = (q_val @ k_val.permute(0, 2, 1)) * self.head_size**(-0.5) # B, P, P
    wei = F.softmax(sim, dim=-1)
    res = wei @ v_val # B, P, n_embd
    return res







In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self, n_embd, no_heads, head_size):
    super().__init__()
    self.heads = nn.ModuleList([SingleHeadAttention(n_embd, head_size) for _ in range(no_heads)])
    self.proj = nn.Linear(no_heads * head_size, n_embd)

  def forward(self, x):
    out = torch.concat([head(x) for head in self.heads], dim=-1)
    out = self.proj(out)

    return out

In [None]:
class AttentionBlock(nn.Module):
  def __init__(self, n_embd, no_heads):
    super().__init__()
    assert n_embd % no_heads == 0
    head_size = n_embd // no_heads
    mlp_ratio = 4

    self.msa = MultiHeadAttention(n_embd, no_heads, head_size)
    self.ff = nn.Sequential(
        nn.Linear(n_embd,mlp_ratio* n_embd),
        nn.GELU(),
        nn.Linear(mlp_ratio*n_embd, n_embd)
    )

    self.ln_1 = nn.LayerNorm(n_embd)
    self.ln_2 = nn.LayerNorm(n_embd)

  def forward(self, x):
    """
    x: B, P, n_embd
    """
    z_dash = self.msa(self.ln_1(x)) + x
    z = self.ff(self.ln_2(z_dash)) + z_dash

    return z


In [None]:
class ViT(nn.Module):
  def __init__(self, patch_size, no_patches, n_embd, no_blocks, no_heads, no_classes):
    super().__init__()
    self.pos_embd_tabl = nn.Embedding(no_patches, n_embd)
    self.patch_proj = nn.Linear(patch_size, n_embd, bias=False)
    self.zero_embd = nn.Parameter(torch.rand(n_embd))

    self.blocks = nn.Sequential(*[AttentionBlock(n_embd, no_heads) for _ in range(no_blocks)])
    self.classifier = nn.Linear(n_embd, no_classes)

  def forward(self, x):
    """
    x: B, P, patch_size: Beware flattening the imgs must be done before
    """
    B, P, patch_size = x.shape

    pos_embd = self.pos_embd_tabl(torch.arange(P, device=x.device)) # P
    input_proj = self.patch_proj(x) # B, P, n_embd
    x = input_proj + pos_embd
    zero_embd_tiled = self.zero_embd[None, None, :].expand(B, 1, -1)
    x = torch.concat([zero_embd_tiled, x], dim=1) # B, (P+1), n_embd
    x = self.blocks(x)# B, (P+1), n_embd
    x_0 = x[:, 0, :]

    x = self.classifier(x_0)

    return x

In [None]:
import numpy as np
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.datasets.mnist import MNIST
from torchvision.transforms import ToTensor
from tqdm import tqdm, trange

np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7cb29ee9ccf0>

In [None]:
train_set = MNIST("./datasets", train=True, download=True, transform=ToTensor())
test_set = MNIST("./datasets", train=False, download=True, transform=ToTensor())

B = 128
train_loader = DataLoader(train_set, shuffle=True, batch_size=B)
test_loader = DataLoader(test_set, shuffle=False, batch_size=B)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device {device}")

patch_side_length = 4
no_patches = 49

model = ViT(patch_size=patch_side_length**2, no_patches=no_patches, no_blocks=2, n_embd=8, no_heads=2, no_classes=10).to(device)

N_EPOCHS = 5
LR = 0.005

optimizer = Adam(model.parameters(), lr=LR)
criterion = CrossEntropyLoss()

for epoch in trange(N_EPOCHS, desc="Training"):
  train_loss = 0.0
  for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1} in training", leave=False):
    imgs, labels = batch
    imgs, labels = imgs.to(device), labels.to(device)
    imgs_patchified = patchify(imgs, patch_side_length)
    imgs_patchified = imgs_patchified.flatten(2, -1)

    preds = model(imgs_patchified)
    loss = F.cross_entropy(preds, labels)

    train_loss += loss.detach().cpu().item() / len(train_loader)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
  print(f"Epoch {epoch + 1}/{N_EPOCHS} loss: {train_loss:.2f}")


# Test loop
with torch.no_grad():
  correct, total = 0,0
  test_loss = 0.0
  for batch in tqdm(test_loader, desc="Testing"):
    x,y = batch
    x, y = x.to(device), y.to(device)
    y_hat = model(x)
    loss = F.cross_entropy(y_hat, y)
    test_loss += loss.detach().cpu().item() / len(test_loader)
    correct += torch.sum(torch.argmax(y_hat, dim=1) == y).detach().cpu().item()

    total += len(x)
  print(f"Test loss: {test_loss:.2f}")
  print(f"Test accuracy: {correct / total * 100:.2f}%")





Using device cuda


Training:   0%|          | 0/5 [00:00<?, ?it/s]
Epoch 1 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 1 in training:   0%|          | 2/469 [00:00<00:25, 18.38it/s][A
Epoch 1 in training:   1%|▏         | 7/469 [00:00<00:13, 33.54it/s][A
Epoch 1 in training:   3%|▎         | 12/469 [00:00<00:11, 39.78it/s][A
Epoch 1 in training:   4%|▎         | 17/469 [00:00<00:10, 43.29it/s][A
Epoch 1 in training:   5%|▍         | 22/469 [00:00<00:10, 41.20it/s][A
Epoch 1 in training:   6%|▌         | 27/469 [00:00<00:11, 39.13it/s][A
Epoch 1 in training:   7%|▋         | 32/469 [00:00<00:10, 40.33it/s][A
Epoch 1 in training:   8%|▊         | 37/469 [00:00<00:10, 42.27it/s][A
Epoch 1 in training:   9%|▉         | 42/469 [00:01<00:09, 43.49it/s][A
Epoch 1 in training:  10%|█         | 47/469 [00:01<00:09, 44.94it/s][A
Epoch 1 in training:  11%|█         | 52/469 [00:01<00:09, 44.13it/s][A
Epoch 1 in training:  12%|█▏        | 57/469 [00:01<00:09, 44.74it/s][A
Epoch 1 in tra

Epoch 1/5 loss: 1.20



Epoch 2 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 2 in training:   1%|          | 5/469 [00:00<00:09, 47.29it/s][A
Epoch 2 in training:   2%|▏         | 10/469 [00:00<00:10, 45.38it/s][A
Epoch 2 in training:   3%|▎         | 15/469 [00:00<00:11, 38.52it/s][A
Epoch 2 in training:   4%|▍         | 20/469 [00:00<00:10, 41.88it/s][A
Epoch 2 in training:   5%|▌         | 25/469 [00:00<00:10, 43.88it/s][A
Epoch 2 in training:   6%|▋         | 30/469 [00:00<00:09, 45.55it/s][A
Epoch 2 in training:   7%|▋         | 35/469 [00:00<00:09, 46.07it/s][A
Epoch 2 in training:   9%|▊         | 40/469 [00:00<00:09, 47.04it/s][A
Epoch 2 in training:  10%|▉         | 45/469 [00:00<00:09, 47.07it/s][A
Epoch 2 in training:  11%|█         | 50/469 [00:01<00:08, 46.73it/s][A
Epoch 2 in training:  12%|█▏        | 55/469 [00:01<00:08, 46.41it/s][A
Epoch 2 in training:  13%|█▎        | 60/469 [00:01<00:09, 42.63it/s][A
Epoch 2 in training:  14%|█▍        | 65/469 [00:01<00:09, 4

Epoch 2/5 loss: 0.51



Epoch 3 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 3 in training:   1%|          | 4/469 [00:00<00:12, 37.24it/s][A
Epoch 3 in training:   2%|▏         | 8/469 [00:00<00:12, 36.49it/s][A
Epoch 3 in training:   3%|▎         | 12/469 [00:00<00:13, 33.03it/s][A
Epoch 3 in training:   3%|▎         | 16/469 [00:00<00:13, 32.52it/s][A
Epoch 3 in training:   4%|▍         | 20/469 [00:00<00:14, 31.45it/s][A
Epoch 3 in training:   5%|▌         | 24/469 [00:00<00:14, 29.85it/s][A
Epoch 3 in training:   6%|▌         | 28/469 [00:00<00:15, 28.93it/s][A
Epoch 3 in training:   7%|▋         | 31/469 [00:01<00:15, 28.34it/s][A
Epoch 3 in training:   7%|▋         | 34/469 [00:01<00:15, 27.66it/s][A
Epoch 3 in training:   8%|▊         | 37/469 [00:01<00:15, 27.34it/s][A
Epoch 3 in training:   9%|▉         | 42/469 [00:01<00:13, 32.76it/s][A
Epoch 3 in training:  10%|█         | 47/469 [00:01<00:11, 36.79it/s][A
Epoch 3 in training:  11%|█         | 52/469 [00:01<00:10, 39

Epoch 3/5 loss: 0.38



Epoch 4 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 4 in training:   1%|          | 5/469 [00:00<00:09, 47.12it/s][A
Epoch 4 in training:   2%|▏         | 10/469 [00:00<00:09, 47.19it/s][A
Epoch 4 in training:   3%|▎         | 15/469 [00:00<00:09, 47.15it/s][A
Epoch 4 in training:   4%|▍         | 20/469 [00:00<00:09, 47.57it/s][A
Epoch 4 in training:   5%|▌         | 25/469 [00:00<00:10, 43.09it/s][A
Epoch 4 in training:   6%|▋         | 30/469 [00:00<00:11, 37.50it/s][A
Epoch 4 in training:   7%|▋         | 34/469 [00:00<00:12, 35.14it/s][A
Epoch 4 in training:   8%|▊         | 38/469 [00:01<00:13, 32.98it/s][A
Epoch 4 in training:   9%|▉         | 42/469 [00:01<00:12, 33.36it/s][A
Epoch 4 in training:  10%|▉         | 46/469 [00:01<00:12, 33.72it/s][A
Epoch 4 in training:  11%|█         | 50/469 [00:01<00:12, 33.49it/s][A
Epoch 4 in training:  12%|█▏        | 54/469 [00:01<00:12, 33.56it/s][A
Epoch 4 in training:  12%|█▏        | 58/469 [00:01<00:12, 3

Epoch 4/5 loss: 0.33



Epoch 5 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 5 in training:   1%|          | 5/469 [00:00<00:09, 46.45it/s][A
Epoch 5 in training:   2%|▏         | 10/469 [00:00<00:10, 45.35it/s][A
Epoch 5 in training:   3%|▎         | 15/469 [00:00<00:09, 46.19it/s][A
Epoch 5 in training:   4%|▍         | 20/469 [00:00<00:09, 46.56it/s][A
Epoch 5 in training:   5%|▌         | 25/469 [00:00<00:09, 46.69it/s][A
Epoch 5 in training:   6%|▋         | 30/469 [00:00<00:09, 45.52it/s][A
Epoch 5 in training:   7%|▋         | 35/469 [00:00<00:10, 42.96it/s][A
Epoch 5 in training:   9%|▊         | 40/469 [00:00<00:10, 41.98it/s][A
Epoch 5 in training:  10%|▉         | 45/469 [00:01<00:09, 43.55it/s][A
Epoch 5 in training:  11%|█         | 50/469 [00:01<00:09, 45.09it/s][A
Epoch 5 in training:  12%|█▏        | 55/469 [00:01<00:09, 45.98it/s][A
Epoch 5 in training:  13%|█▎        | 60/469 [00:01<00:08, 47.03it/s][A
Epoch 5 in training:  14%|█▍        | 65/469 [00:01<00:08, 4

Epoch 5/5 loss: 0.31





Testing: 100%|██████████| 79/79 [00:01<00:00, 63.14it/s]

Test loss: 0.25
Test accuracy: 92.01%



