In [176]:

import math

import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import v2

from models import Patchify

batch_size = 128
model_dim = 64

transform = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])
data = datasets.MNIST(root="data", train=False, download=True, transform=transform)
dl = DataLoader(data, batch_size=batch_size, shuffle=False)
len(data)
    


10000

In [177]:
for batch, (images, labels) in enumerate(dl):
    # batch size of 128 yields 78 batches from 10_000 test images
    print(f"Batch {batch}: images.shape = {images.shape}, labels.shape = {labels.shape}")
    break

# get first batch only for play
x, y = next(iter(dl))
x.shape

Batch 0: images.shape = torch.Size([128, 1, 28, 28]), labels.shape = torch.Size([128])


torch.Size([128, 1, 28, 28])

In [178]:

patchify = Patchify(patch_size=7, model_dim=model_dim)
patched = patchify(x)
patched.shape  # 128x1x28x28 -> 128x16x64

torch.Size([128, 16, 64])

In [179]:
# build simplest/smallest positional encoding from scratch
max_len = 16
pe = torch.zeros(max_len, model_dim)
pe.shape
# note that pe.shape here is one image as 16 patches in 64 dims, i.e. same as one item from the patch batch


torch.Size([16, 64])

In [180]:
position = torch.arange(0, max_len, dtype=torch.float)
position  # 0, 1, 2, ..., 15

tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
        14., 15.])

In [181]:
# turn each position int into its own row
position = position.unsqueeze(1)
position.shape

torch.Size([16, 1])

In [182]:
magic = math.log(10_000.0) / model_dim  # 10_000 is a magic number worked out by Vaswani et al.
magic

0.14391156831212787

In [183]:
# we want to alternate between adding sine and cosine values to even/odd indices
even_positions = torch.arange(0, model_dim, 2)
even_positions.shape

torch.Size([32])

In [184]:
# apply magic scalar to div tensor
div_terms = even_positions * - magic
div_terms

tensor([-0.0000, -0.2878, -0.5756, -0.8635, -1.1513, -1.4391, -1.7269, -2.0148,
        -2.3026, -2.5904, -2.8782, -3.1661, -3.4539, -3.7417, -4.0295, -4.3173,
        -4.6052, -4.8930, -5.1808, -5.4686, -5.7565, -6.0443, -6.3321, -6.6199,
        -6.9078, -7.1956, -7.4834, -7.7712, -8.0590, -8.3469, -8.6347, -8.9225])

In [None]:
# here we essentially compute the denominator for the sine/cosine functions (10000^(2i/model_dim))
div_terms = torch.exp(div_terms)
div_terms  # shape: 32

tensor([1.0000, 1.0423, 1.0864, 1.1323, 1.1802, 1.2301, 1.2821, 1.3364, 1.3929,
        1.4518, 1.5132, 1.5772, 1.6439, 1.7134, 1.7859, 1.8614, 1.9401, 2.0221,
        2.1077, 2.1968, 2.2897, 2.3865, 2.4875, 2.5927, 2.7023, 2.8166, 2.9357,
        3.0599, 3.1893, 3.3241, 3.4647, 3.6112])

In [186]:
half_len = position * div_terms  # basically applies each position int as a scalar factor to div_terms
half_len  # 16x1 * 32 -> 16x32

tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 1.0000,  1.0423,  1.0864,  1.1323,  1.1802,  1.2301,  1.2821,  1.3364,
          1.3929,  1.4518,  1.5132,  1.5772,  1.6439,  1.7134,  1.7859,  1.8614,
          1.9401,  2.0221,  2.1077,  2.1968,  2.2897,  2.3865,  2.4875,  2.5927,
          2.7023,  2.8166,  2.9357,  3.0599,  3.1893,  3.3241,  3.4647,  3.6112],
        [ 2.0000,  2.0846,  2.1727,  2.2646,  2.3604,  2.4602,  2.5643,  2.6727,
          2.7857,  2.9036,  3.0264,  3.1543,  3.2877,  3.4268,  3.5717,  3.7228,
          3.8802,  4.0443,  4.2153,  4.3936,  4.5794,  4.7731,  4.9749,  5.1853,
          5.4046,  5.6332,  5.8714,  6.1197,  6.3785,  6.6483,  6.9294,  7.2225],
        [ 3.0000,  3.1269

In [187]:
# pe was formally a tensor of zeroes, now we fill it 
pe[:, 0::2] = torch.sin(position * div_tensor)
self.pe[:, 1::2] = torch.cos(position * div_term)
pe

NameError: name 'self' is not defined