## Code

In [10]:
%pip install einops
%load_ext tensorboard

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import default_collate
from torchvision import datasets
from torchvision import transforms
from torchvision.transforms import v2
import matplotlib.pyplot as plt
import einops
from einops.layers.torch import Rearrange
from tqdm import tqdm

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)

Note: you may need to restart the kernel to use updated packages.
The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [11]:
class MLP(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.0):
        super().__init__()
        self.seq = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.seq(x)

In [12]:
class MultiHeadAttention(nn.Module):
    def __init__(self, dim, heads, dropout=0.0):
        super().__init__()

        self.dim = dim # size of embeddings
        self.heads = heads

        self.head_len = int(self.dim / self.heads)
        self.norm_scale = self.head_len ** -0.5 # divide by sqrt of query size during normalization

        # pytorch automatically handles the remaining dimensions by expansion. It is equivalent of the neural network receiving one word embedding as input at a time, but vectorized.
        self.q_linear = nn.Linear(self.dim, self.dim, bias = False)
        self.k_linear = nn.Linear(self.dim , self.dim, bias = False)
        self.v_linear = nn.Linear(self.dim, self.dim, bias = False)

        # dim is seq (because will be applied after q_dot_k)
        self.softmax = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)

        # input and output is the same. Could be different if we increased the head_len as a parameter instead of being dim/num_heads.
        self.dense = nn.Sequential(
          nn.Linear(self.dim, self.dim),
          nn.Dropout(dropout)
        )


    def forward(self, key, query, value):
        # pass through linear layers
        q = self.q_linear(query)
        k = self.k_linear(key)
        v = self.v_linear(value)

        # split heads and reshape. seq, head_len is a 2d matrix that is going to be multiplied. The other two dimensions are static.
        q = einops.rearrange(q, 'b seq (head head_len) -> b head seq head_len', head = self.heads)
        v = einops.rearrange(v, 'b seq (head head_len) -> b head seq head_len', head = self.heads)
        # transpose k for matmul. Just change the last two dimensions around because we need seq,head_len * head_len,seq
        k = einops.rearrange(k, 'b seq (head head_len) -> b head head_len seq', head = self.heads)

        # matmul between q and k
        q_dot_k = torch.matmul(q,k)

        # divide result by sqrt of head len
        q_dot_k *= self.norm_scale

        # softmax of q_dot_k
        attention_scores = self.softmax(q_dot_k)

        # apply dropout
        attention_scores = self.dropout(attention_scores)

        # matmul by value to obtain final result
        result = torch.matmul(attention_scores, v)

        # concatenate all heads. We get back the entire dim size
        result_concat = einops.rearrange(result, 'b head seq head_len -> b seq (head head_len)')

        # pass through final dense layer
        out = self.dense(result_concat)

        return out


In [13]:
class TransformerEncoder(nn.Module):
    def __init__(self, dim, heads, depth, mlp_hidden_dim, dropout=0.0):
        super().__init__()

        self.list = nn.ModuleList([])
        for _ in range(depth):
            entry = nn.ModuleList([
                nn.LayerNorm(dim),
                MultiHeadAttention(dim, heads, dropout=dropout),
                nn.LayerNorm(dim),
                MLP(dim, mlp_hidden_dim, dropout=dropout)
            ])
            self.list.append(entry)

    def forward(self, x):
        for norm1, attention, norm2, mlp in self.list: # type: ignore
            x_res = x # deep copy shouldn't be need
            x = norm1(x)
            x = attention(x, x, x) + x_res

            x_res = x
            x = norm2(x)
            x = mlp(x) + x_res
        return x


In [14]:
class CreatePatches(nn.Module):
    def __init__(self, patch_size, num_channels, emb_dim):
        super().__init__()

        self.flatten_and_project = nn.Sequential(
            # from (b c h w) to (b num_patches patch_data)
            # the (num_patches_h p1) means that we get i patches of size p1 across the h axis, where num_patches_h takes the value i
            # the same goes for (num_patches_w p2), where num_patches_w takes the value j
            # therefore, (num_patches_h num_patches_w) is i*j = num_patches
            Rearrange('b c (num_patches_h p1) (num_patches_w p2) -> b (num_patches_h num_patches_w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
            # takes in patch_size*patch_size*num_channels, outputs a emb_dim vector
            nn.Linear(patch_size*patch_size*num_channels, emb_dim)
        )

    def forward(self, x):
        return self.flatten_and_project(x)

In [15]:
class ViT(nn.Module):
    def __init__(self, img_size, patch_size, num_channels, emb_dim, heads, depth, transformer_mlp_hidden_dim, num_classes, dropout=0.0):
        super().__init__()
        # assuming images are square
        assert img_size % patch_size == 0, 'img_size must be divisible by patch_size'

        self.num_patches = (img_size // patch_size) ** 2

        self.create_patches = CreatePatches(patch_size=patch_size, num_channels=num_channels, emb_dim=emb_dim)

        # positional encoding - learned parameters
        # the first dimension is 1 for the batches. Broadcasting makes the pos_encoding the same for all batches
        self.pos_encoding = nn.Parameter(torch.randn(1, self.num_patches, emb_dim))

        self.dropout = nn.Dropout(dropout)

        self.transformer_encoder = TransformerEncoder(dim=emb_dim, heads=heads, depth=depth, mlp_hidden_dim=transformer_mlp_hidden_dim, dropout=dropout)

        self.mlp_linear_head = nn.Sequential(
            nn.LayerNorm(emb_dim),
            nn.Linear(emb_dim, num_classes),
        )


    def forward(self, x):
        x = self.create_patches(x)

        # positional encoding. Broadcasting adds same pos_encoding to all batches
        x += self.pos_encoding

        x = self.dropout(x)

        x = self.transformer_encoder(x)

        # mean - average all embedding vectors to get a single vector
        # we don't use cls. Apparently average pool is be better
        # https://arxiv.org/pdf/2205.01580.pdf
        x = x.mean(dim=1)

        x = self.mlp_linear_head(x)

        return x


## Train

In [None]:
%tensorboard --logdir runs

In [17]:
# Hyperparams
batch_size = 256
patch_size = 4
emb_dim = 256
heads = 8
depth = 8
mlp_hidden_dim = 256
epochs = 100
dropout = 0.1
lr = 3e-4
weight_decay = 0.0001
mixup_p = 0.0
flip_lr_p = 0.5
rand_aug_num_ops = 2
rand_aug_magnitude = 10

In [18]:
train_transform = v2.Compose(
    [
        v2.PILToTensor(),
        v2.RandomCrop(32, padding=4),  # Adding RandomCrop here with padding
        v2.RandomHorizontalFlip(p=flip_lr_p),
        v2.RandAugment(num_ops=rand_aug_num_ops, magnitude=rand_aug_magnitude),
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)

training_data = datasets.CIFAR10(
    root="data",
    train=True,
    download=True,
    transform=train_transform,
)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:13<00:00, 12935200.60it/s]


Extracting data/cifar-10-python.tar.gz to data


In [20]:
num_channels, img_size, _ = training_data[0][0].shape  # get dimensions from the first image of the training dataset
num_classes = len(training_data.classes)

writer = SummaryWriter()

print(f"Using device: {device}")

def collate_fn(batch):
    mixup = v2.MixUp(num_classes=num_classes)
    mixup = v2.RandomApply(torch.nn.ModuleList([mixup]), p=mixup_p)
    return mixup(*default_collate(batch))

train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

vit = ViT(
    img_size=img_size,
    patch_size=patch_size,
    num_channels=num_channels,
    emb_dim=emb_dim,
    heads=heads,
    depth=depth,
    transformer_mlp_hidden_dim=mlp_hidden_dim,
    num_classes=num_classes,
    dropout=dropout
).to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(vit.parameters(), lr=lr, weight_decay=weight_decay)

size = len(train_dataloader.dataset)

vit.train()
for epoch in range(1, epochs+1):
    correct = 0
    total_samples = 0
    with tqdm(train_dataloader, unit="batches") as pbar:
        for batch, (X, y) in enumerate(pbar):
            pbar.set_description(f"Epoch {epoch} of {epochs}")

            X, y = X.to(device), y.to(device)

            pred = vit(X)

            # From the pytorch CrossEntropyLoss documentation:
            # "The performance of this criterion is generally better when target contains class indices, as this allows for optimized computation.
            # Consider providing target as class probabilities only when a single class label per minibatch item is too restrictive."
            # Therefore, we convert to one-hot only if mixup is used.
            if mixup_p > 0.:
                y = torch.nn.functional.one_hot(y, num_classes=num_classes)

            loss = loss_fn(pred, y)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            if mixup_p > 0.:
                correct += (pred.argmax(1) == y.argmax(1)).sum().item()
            else:
                correct += (pred.argmax(1) == y).sum().item()

            total_samples += y.size(0) # batch size. Accounts for last batch which may be smaller

            accuracy = correct / total_samples

            pbar.set_postfix(loss=f"{loss.item():>7f}", accuracy=f"{accuracy:.4f}")

            writer.add_scalar("Loss/train", loss, epoch)

    writer.add_scalar("Accuracy/train", accuracy, epoch)  # Log epoch accuracy

writer.close()

Using device: cuda


Epoch 1 of 100: 100%|██████████| 196/196 [00:57<00:00,  3.39batches/s, accuracy=0.2054, loss=2.107837]
Epoch 2 of 100: 100%|██████████| 196/196 [01:00<00:00,  3.25batches/s, accuracy=0.2858, loss=1.689123]
Epoch 3 of 100: 100%|██████████| 196/196 [01:01<00:00,  3.18batches/s, accuracy=0.3569, loss=1.730898]
Epoch 4 of 100: 100%|██████████| 196/196 [01:01<00:00,  3.18batches/s, accuracy=0.3915, loss=1.490305]
Epoch 5 of 100: 100%|██████████| 196/196 [01:05<00:00,  3.01batches/s, accuracy=0.4229, loss=1.292164]
Epoch 6 of 100: 100%|██████████| 196/196 [01:04<00:00,  3.04batches/s, accuracy=0.4413, loss=1.492599]
Epoch 7 of 100: 100%|██████████| 196/196 [01:02<00:00,  3.15batches/s, accuracy=0.4584, loss=1.370872]
Epoch 8 of 100: 100%|██████████| 196/196 [01:03<00:00,  3.06batches/s, accuracy=0.4753, loss=1.366273]
Epoch 9 of 100: 100%|██████████| 196/196 [01:03<00:00,  3.11batches/s, accuracy=0.4912, loss=1.229915]
Epoch 10 of 100: 100%|██████████| 196/196 [01:03<00:00,  3.07batches/s, a

## Test

In [21]:
test_transform = v2.Compose(
    [
        v2.PILToTensor(),
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)

test_data = datasets.CIFAR10(
    root="data",
    train=False,
    download=True,
    transform=test_transform,
)

test_dataloader = DataLoader(test_data, batch_size=batch_size)

size = len(test_dataloader.dataset) # type: ignore
num_batches = len(test_dataloader)
test_loss = 0
correct = 0

vit.eval()
with torch.no_grad():
    for X, y in tqdm(test_dataloader, unit="batches"):
        X, y = X.to(device), y.to(device)
        pred = vit(X)
        test_loss += loss_fn(pred, y).item()
        correct += (pred.argmax(1) == y).sum().item()

test_loss /= num_batches
correct /= size

print(f"Accuracy: {100*correct}%, Avg loss: {test_loss:>8f} \n")


Files already downloaded and verified


100%|██████████| 40/40 [00:03<00:00, 13.06batches/s]

Accuracy: 80.74%, Avg loss: 0.565995 






## Save model

In [22]:
torch.save(vit, 'model.pth')

## Load model

In [23]:
vit = torch.load('model.pth').to(device)