This was code written while following along [notes from a Harvard lecture](https://docs.google.com/document/u/0/d/1VnNYGEmVgvl5p8w2xzypGySajaRv6qvzqw7E7LEwQKI/mobilebasic).

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import pytorch_lightning as pl
import os
import matplotlib.pyplot as plt
import torchvision

from torchvision import transforms
from torchvision.datasets import CIFAR10


DATASET_PATH = os.environ.get("DATA_PATH", "data/")
"""
pl.seed_everything seeds random, np.random, torch, and cuda
"""
pl.seed_everything(42)

# Load Data

- After we import our libraries, we will use environment variables to get the path for our data. 
    - If the environment variable has not been defined, we will define it as "data/"

In [None]:
train_transforms = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(),
        transforms.RandomResizedCrop((32, 32), scale=(0.8, 1), ratio=(0.9, 1.1)),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.49139968, 0.48215841, 0.44653091], [0.24703223, 0.24348513, 0.26158784]
        ),
    ]
)

test_transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(
            [0.49139968, 0.48215841, 0.44653091], [0.24703223, 0.24348513, 0.26158784]
        ),
    ]
)

- There's a competition strategy called **test time augmentations** where one 
can boost model performance by averaging the predictions over 
multiple augmented versions of the image. 

In [None]:
train_dataset = CIFAR10(
    root=DATASET_PATH, train=True, transform=train_transforms, download=True
)
val_dataset = CIFAR10(
    root=DATASET_PATH, train=True, transform=test_transforms, download=True
)
test_dataset = CIFAR10(
    root=DATASET_PATH, train=False, transform=test_transforms, download=True
)

In [None]:
train_set, _ = torch.utils.data.random_split(train_dataset, [45000, 5000])
_, val_set = torch.utils.data.random_split(val_dataset, [45000, 5000])

In [None]:
NUM_IMGS = 4
# Stack inserts a dimension
# Select N images and stack them along their 0th dim
# Results in (B, C, H, W)
CIFAR_images = torch.stack(
    [val_set[idx][0] for idx in range(NUM_IMGS), dim=0]
)
img_grid = torchvision.utils.make_grid(
    CIFAR_images, nrow=NUM_IMGS, normalize=True, pad_value=0.9
).permute(1, 2, 0)

plt.figure(figsize=(8, 8))
plt.imshow(img_grid)
plt.axis('off')
plt.show()

# DataLoaders

- Combines a sampler and dataset. Iterates over the dataset and returns batches of a specified length.
- `drop_last=True` removes the last incomplete batch if the number of samples is not divisible by the batch size.
- `num_workers` determined by trial and error, stop when performance stops improving. Start with the number of cores available, may run out of memory.

In [None]:
train_dataloader = data.DataLoader(
    train_set, batch_size=128, drop_last=True, pin_memory=True, num_workers=4
)
val_dataloader = data.DataLoader(
    val_set, batch_size=128, shuffle=False, drop_last=False, num_workers=4
)
test_dataloader = data.DataLoader(
    test_dataset, batch_size=128, shuffle=False, drop_last=False, workers=4
)

# Tokenization

- ViT requires the image to be broken into a sequence of smaller patches.
    - For preprocessing, we split the original 32 x 32 image into 4 x 4 patches, resulting in 8 x 8 patches. 
        - Batch and channel dimensions are untouched, we're only transforming the height and width.

In [None]:
def img_to_patch(
    x: torch.Tensor, patch_size: int, flatten_channels: bool = True
) -> torch.Tensor:
    """img_to_patch _summary_

    Args:
        x (torch.Tensor[B, C, H, W]): Tensor representing image.
        patch_size (int): Height and Width of the patches
        flatten_channels (bool, optional): Whether to flatten the patches into a feature vector or return as image grid. Defaults to True.

    Returns:
        torch.Tensor: Patches as feature vector or image grid.
    """
    B, C, W, H = x.shape
    x = x.reshape(
        B,
        C,
        torch.div(H, patch_size, rounding_mode="trunc"),
        patch_size,
        torch.div(W, patch_size, rounding_mode="trunc"),
        patch_size,
    )
    x = x.permute(0, 2, 4, 1, 3, 5)  # B, H', W', C, p_H, p_W
    x = x.flatten(1, 2)  # B, H'*W', C, p_H, p_W
    if flatten_channels:
        x = x.flatten(2, 4)  # B, H'*W', C*p_H*p_W
    return x


img_patches = img_to_patch(CIFAR_images, patch_size=4, flatten_channels=False)

In [None]:
fig, ax = plt.subplots(CIFAR_images.shape[0], 1, figsize=(14, 3))
fig.suptitle("Images as sequences of patches")
for i in range(CIFAR_images.shape[0]):
    img_grid = torchvision.utils.make_grid(
        img_patches[i], nrow=64, normalize=True, pad_value=0.9
    )
    img_grid = img_grid.permute(1, 2, 0)
    ax[i].imshow(img_grid)
    ax[i].axis("off")
plt.show()
plt.close()

In [None]:
class AttentionBlock(nn.Module):
    def __init__(self, embed_dim, hidden_dim, num_heads, dropout=0.0):
        super().__init__()

        self.layer_norm_1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads)
        self.layer_norm_2 = nn.LayerNorm(embed_dim)
        self.linear = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, embed_dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        inp_x = self.layer_norm_1(x)
        x = x + self.attn(inp_x, inp_x, inp_x)[0]
        x = x + self.linear(self.layer_norm_2(x))
        return x

In [None]:
class VisionTransformer(nn.Module):
    def __init__(
        self,
        embed_dim,
        hidden_dim,
        num_channels,
        num_heads,
        num_layers,
        num_classes,
        patch_size,
        num_patches,
        dropout=0.0,
    ):
        super().__init__()
        self.patch_size = patch_size

        # Layers
        self.input_layer = nn.Linear(num_channels * (patch_size**2), embed_dim)
        self.transformer = nn.Sequential(
            *(
                AttentionBlock(embed_dim, hidden_dim, num_heads, dropout=dropout)
                for _ in range(num_layers)
            ),
        )
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embed_dim), nn.Linear(embed_dim, num_classes)
        )
        self.dropout = nn.Dropout(dropout)

        # Parameters/Embeddings
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.pos_embedding = nn.Parameter(torch.randn(1, 1 + num_patches, embed_dim))

    def forward(self, x):
        x = img_to_patch(x, self.patch_size)
        B, T, _ = x.shape
        x = self.input_layer(x)

        cls_token = self.cls_token.repeat(B, 1, 1)
        x = torch.cat([cls_token, x], dim=1)
        x = x + self.pos_embedding[:, : T + 1]

        x = self.dropout(x)
        x = x.transpose(0, 1)
        x = self.transformer(x)

        cls = x[0]
        out = self.mlp_head(cls)
        return out

In [None]:
from torch.optim import lr_scheduler


class ViT(pl.LightningModule):
    def __init__(self, model_kwargs, lr):
        super().__init__()
        self.save_hyperparameters()
        self.model = VisionTransformer(**model_kwargs)

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters, lr=self.hparams.lr)
        lr_scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[100, 150], gamma=0.1
        )
        return [optimizer], [lr_scheduler]

    def _calculate_loss(self, batch, mode="train"):
        imgs, labels = batch
        preds = self(imgs)
        loss = F.cross_entropy(preds, labels)
        acc = (preds.argmax(dim=-1) == labels).float.mean()
        self.log(f"{mode}_loss", loss, prog_bar=True)
        self.log(f"{mode}_acc", acc, prog_bar=True)
        return loss

    def training_step(self, batch, batch_idx):
        loss = self._calculate_loss(batch, mode="train")
        return loss

    def validation_step(self, batch, batch_idx):
        self._calculate_loss(batch, mode="val")

    def test_step(self, batch, batch_idx):
        self._calculate_loss(batch, mode="test")

In [None]:
CHECKPOINT_PATH = os.environ.get("PATH_CHECKPOINT", "saved_models/VisionTransformers/")


def train_model(**kwargs):
    trainer = pl.Trainer(
        default_root_dir=os.path.join(CHECKPOINT_PATH, "ViT"), fast_dev_run=5
    )

    pl.seed_everything(42)
    model = ViT(**kwargs)
    trainer.fit(model, train_dataloader, val_dataloader)
    test_result = trainer.test(model, dataloaders=test_dataloader, verbose=False)
    return model, test_result

In [None]:
model_kwargs = {
    "embed_dim": 256,
    "hidden_dim": 512,
    "num_heads": 8,
    "num_layers": 6,
    "patch_size": 4,
    "num_channels": 3,
    "num_patches": 64,
    "num_classes": 10,
    "dropout": 0.2,
}

model, results = train_model(model_kwargs=model_kwargs, lr=3e-4)