<a href="https://colab.research.google.com/github/zsombor-haasz/computer-vision/blob/main/HW_4_ViTs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ViT Assignment
Authors: Alexander Wan, Aryan Jain

### Assignment Goals


1. Familiarity with the Vision Transformer architecture
2. Familiarity with the self-attention algorithm
3. Practice with PyTorch matrix operations



### Tasks
1. Implement multi-head self-attention
2. Incorporate that into a ViT

### Runtime Acceleration
Colab limits GPU usage, so set `device` below as `'cpu'` and change your runtime to CPU as well (Runtime > Change runtime type) when you're developing, and only change it to `'cuda'` (and your runtime to GPU) when you're ready to train.

In [None]:
#device = 'cpu'
device = 'cuda'

### Multi-head self-attention
Begin by implementing multiheaded self-attention. Do **not** use any `for` loops, and instead put all of the calculations into [batch matrix multiplications](https://pytorch.org/docs/stable/generated/torch.bmm.html) or [Linear layers](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html).

Useful references include the lecture slides and the [illustrated transformer](https://jalammar.github.io/illustrated-transformer/).


In [None]:
import torch.nn.functional as F
from torch import nn
import torch

class MSA(nn.Module):
  def __init__(self, input_dim, embed_dim, num_heads):
    '''
    input_dim: Dimension of input token embeddings
    embed_dim: Dimension of internal key, query, and value embeddings
    num_heads: Number of self-attention heads
    '''

    super().__init__()

    self.input_dim = input_dim
    self.embed_dim = embed_dim
    self.num_heads = num_heads

    self.K_embed = nn.Linear(input_dim, embed_dim, bias=False)
    self.Q_embed = nn.Linear(input_dim, embed_dim, bias=False)
    self.V_embed = nn.Linear(input_dim, embed_dim, bias=False)
    self.out_embed = nn.Linear(embed_dim, embed_dim, bias=False)

  def forward(self, x):
    '''
    x: input of shape (batch_size, max_length, input_dim)
    return: output of shape (batch_size, max_length, embed_dim)
    '''

    batch_size, max_length, given_input_dim = x.shape
    assert given_input_dim == self.input_dim
    assert max_length % self.num_heads == 0

    # You shouldn't need to initialize any new modules. Everything you need is
    # already in __init__

    # HINT: If you're stuck on how to handle multiple heads without for loops, try to
    # reshape matrix such that the batch_size is num_heads * batch_size
    # e.g. if you have two heads, you'd be doing self-attention twice per instance
    # in the batch, so you essentially have batch_size * 2

    # HINT 2: Feel free to reference: https://d2l.ai/chapter_attention-mechanisms-and-transformers/multihead-attention.html
    # although make sure you understand what each command does

    # this implementation projects KQV before splitting into multiple heads
    # but you can also split into multiple heads first

    # compute KQV as a whole, embedding and
    x = x.reshape(batch_size * max_length, -1)
    K = self.K_embed(x).reshape(batch_size, max_length, self.embed_dim) # (batch_size, max_length, embed_dim)
    # TODO: Compute Q
    # TODO: Compute V

    # TODO: split each KQV into heads, by reshaping each into (batch_size, max_length, self.num_heads, indiv_dim)
    indiv_dim = self.embed_dim // self.num_heads
    K = # TODO
    Q = # TODO
    V = # TODO

    K = K.permute(0, 2, 1, 3) # (batch_size, num_heads, max_length, embed_dim / num_heads)
    Q = Q.permute(0, 2, 1, 3)
    V = V.permute(0, 2, 1, 3)

    K = K.reshape(batch_size * self.num_heads, max_length, indiv_dim)
    Q = Q.reshape(batch_size * self.num_heads, max_length, indiv_dim)
    V = V.reshape(batch_size * self.num_heads, max_length, indiv_dim)

    # transpose and batch matrix multiply
    K_T = K.permute(0, 2, 1) # This is our K transposed so we can do a simple batched matrix multiplication (see torch.bmm for more details and the quick solution)
    QK = # TODO: Compute the weights before dividing by square root of d (batch_size * num_heads, max_length, max_length)

    # calculate weights by dividing everything by the square root of d (self.embed_dim)
    weights = # TODO
    weights = # TODO Take the softmax over the last dimension (see torch.functional.Softmax) (batch_size * num_heads, max_length, max_length)

    # TODO get weighted average... see torch.bmm for a one line solution
    w_V = # weights is (batch_size * num_heads, max_length, max_length) and V is (batch_size * self.num_heads, max_length, indiv_dim), so we want the matrix multiple of weights @ V

    # rejoin heads
    w_V = w_V.reshape(batch_size, self.num_heads, max_length, indiv_dim)
    w_V = w_V.permute(0, 2, 1, 3) # (batch_size, max_length, num_heads, embed_dim / num_heads)
    w_V = w_V.reshape(batch_size, max_length, self.embed_dim)

    out = self.out_embed(w_V)

    return out
    # </SOL>

### Implement the ViT architecture
You will be implementing the ViT architecture based on the "An image is worth 16x16 words" paper.

Although the ViT and Transformer architecture are very similar, note a few differences:

1. Image patches instead of discrete tokens as input.
2. [GELU](https://pytorch.org/docs/stable/generated/torch.nn.GELU.html) is used for the linear layers in the transformer layer (instead of ReLU)
3. LayerNorm before the sublayer instead of after.
4. Dropout after every linear layer except for KQV projections and also directly after adding positional embeddings to the patch embeddings.
5. Learnable [CLS] token at the beginning of the input.

A useful reference is Figure 1 in the [paper](https://arxiv.org/pdf/2010.11929.pdf).

First, implement a single layer:

In [None]:
class ViTLayer(nn.Module):
  def __init__(self, num_heads, input_dim, embed_dim, mlp_hidden_dim, dropout=0.1):
    '''
    num_heads: Number of heads for multi-head self-attention
    embed_dim: Dimension of internal key, query, and value embeddings
    mlp_hidden_dim: Hidden dimension of the linear layer
    dropout: Dropout rate
    '''

    super().__init__()

    self.input_dim = input_dim
    self.msa = MSA(input_dim, embed_dim, num_heads)

    self.layernorm1 = nn.LayerNorm(embed_dim)
    self.w_o_dropout = nn.Dropout(dropout)
    self.layernorm2 = nn.LayerNorm(embed_dim)
    self.mlp = nn.Sequential(nn.Linear(embed_dim, mlp_hidden_dim),
                              nn.GELU(),
                              nn.Dropout(dropout),
                              nn.Linear(mlp_hidden_dim, embed_dim),
                              nn.Dropout(dropout))

  def forward(self, x):
    '''
    x: input embeddings (batch_size, max_length, input_dim)
    return: output embeddings (batch_size, max_length, embed_dim)
    '''

    # TODO: Fill in the code for the forward pass below
    # You shouldn't need to initialize any more modules, everything you need is already
    # in __init__
    # A forward function consists of:
    # 1) LayerNorm of x
    # 2) Self-Attention on output of 1)
    # 3) Dropout
    # 4) Residual w/ original x
    # 5) LayerNorm
    # 6) MLP
    # 7) Residual


A portion of the full network is already implemented for you. Your task is to implement the preprocessing code, converting raw images into patch embeddings + positional embeddings + dropout, with a learnable CLS token at the beginning of the input.

Note that patch embeddings are to be added to positional embeddings elementwise, so the input embedding dimensions is size embed_dim.

In [None]:
class ViT(nn.Module):
  def __init__(self, patch_dim, image_dim, num_layers, num_heads, embed_dim, mlp_hidden_dim, num_classes, dropout):
    '''
    patch_dim: patch length and width to split image by
    image_dim: image length and width
    num_layers: number of layers in network
    num_heads: number of heads for multi-head attention
    embed_dim: dimension to project images patches to and dimension to use for position embeddings
    mlp_hidden_dim: hidden dimension of linear layer
    num_classes: number of classes to classify in data
    dropout: dropout rate
    '''

    super().__init__()
    self.num_layers = num_layers
    self.patch_dim = patch_dim
    self.image_dim = image_dim
    self.input_dim = self.patch_dim * self.patch_dim * 3
    self.num_heads = num_heads

    self.patch_embedding = nn.Linear(self.input_dim, embed_dim)
    self.position_embedding = nn.Parameter(torch.zeros(1, (image_dim // patch_dim) ** 2 + 1, embed_dim))
    self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
    self.embedding_dropout = nn.Dropout(dropout)

    self.encoder_layers = nn.ModuleList([])
    for i in range(num_layers):
      self.encoder_layers.append(ViTLayer(num_heads, embed_dim, embed_dim, mlp_hidden_dim, dropout))

    self.mlp_head = nn.Linear(embed_dim, num_classes)
    self.layernorm = nn.LayerNorm(embed_dim)

  def forward(self, images):
    '''
    images: raw image data (batch_size, channels, rows, cols)
    '''

    # Don't hardcode dimensions (except for maybe channels = 3), use the variables in __init__.
    # You shouldn't need to add anything else to __init__, all of the embeddings,
    # dropout etc. are already initialized for you.

    # Put the preprocessed patches in variable "out" with shape (batch_size, length, embed_dim).

    # HINT: You can make image patches with .reshape
    # e.g.
    # x = torch.ones((100, 100))
    # x_patches = x.reshape(4, 25, 4, 25)
    # where you have 4 * 4 patches with each patch being 25 by 25

    h = w = self.image_dim // self.patch_dim
    N = images.size(0)
    images = images.reshape(N, 3, h, self.patch_dim, w, self.patch_dim)
    images = torch.einsum("nchpwq -> nhwpqc", images)
    patches = images.reshape(N, h * w, self.input_dim) # (batch, num_patches_per_image, patch_size_unrolled)

    patch_embeddings = # TODO: Pass through our patch embedding layer
    patch_embeddings = torch.cat([torch.tile(self.cls_token, (N, 1, 1)),
                                  patch_embeddings], dim=1)
    out = patch_embeddings + torch.tile(self.position_embedding, (N, 1, 1)) # We add positional embeddings to our tokens (not concatenated)
    out = # TODO: Pass through our embedding dropout layer

    # add padding s.t. input length is multiple of num_heads
    add_len = (self.num_heads - out.shape[1]) % self.num_heads
    out = torch.cat([out, torch.zeros(N, add_len, out.shape[2], device=device)], dim=1)

    # TODO: Pass through each one of our encoder layers

    # Pop off and read our classification token we added, see what the value is
    cls_head = self.layernorm(torch.squeeze(out[:, 0], dim=1))
    logits = self.mlp_head(cls_head)
    return logits

def get_vit_tiny(num_classes=10, patch_dim=4, image_dim=32):
    return ViT(patch_dim=patch_dim, image_dim=image_dim, num_layers=12, num_heads=3,
              embed_dim=192, mlp_hidden_dim=768, num_classes=num_classes, dropout=0.1)

def get_vit_small(num_classes=10, patch_dim=4, image_dim=32):
    return ViT(patch_dim=patch_dim, image_dim=image_dim, num_layers=12, num_heads=6,
               embed_dim=384, mlp_hidden_dim=1536, num_classes=num_classes, dropout=0.1)

Now let's train the model! You don't need to write any code for this - just run the cell.

Remember to change the device variable (in the cell at the beginning of the notebook) to 'cuda' and change your runtime to GPU (Runtime > Change runtime type) as well.

Try to get 60%+ accuracy after 30 epochs.

In [None]:
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import torchvision
import math
import torch.optim as optim
from tqdm.notebook import tqdm

cifar10_mean = torch.tensor([0.49139968, 0.48215827, 0.44653124])
cifar10_std = torch.tensor([0.24703233, 0.24348505, 0.26158768])

class Cifar10Dataset(Dataset):
    def __init__(self, train):
        self.transform = transforms.Compose([
                                                transforms.Resize(40),
                                                transforms.RandomCrop(32),
                                                transforms.RandomHorizontalFlip(),
                                                transforms.ToTensor(),
                                                transforms.Normalize(cifar10_mean, cifar10_std)
                                            ])
        self.dataset = torchvision.datasets.CIFAR10(root='./SSL-Vision/data',
                                                    train=train,
                                                    download=True)
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        img = self.transform(img)
        return img, label

batch_size = 512

trainset = Cifar10Dataset(True)
trainloader = # TODO: Pass our dataset trainset into a torch Dataloader object, with shuffle = True and the batch_size=batch_size, num_workers=2

testset = Cifar10Dataset(False)
testloader = # TODO: create a test dataset the same as the train loader but with shuffle=False and the test dataset

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

vit = get_vit_small().to(device)
vit = torch.nn.DataParallel(vit)

learning_rate = 5e-4 * batch_size / 256
num_epochs = 30
warmup_fraction = 0.1
weight_decay = 0.1

total_steps = math.ceil(len(trainset) / batch_size) * num_epochs
# total_steps = num_epochs
warmup_steps = total_steps * warmup_fraction
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(vit.parameters(), lr=learning_rate, betas=(0.9, 0.95), weight_decay=weight_decay)

train_losses = []
test_losses = []
for epoch in range(num_epochs):
    train_loss = 0.0
    train_acc = 0.0
    train_total = 0
    vit.train()
    for inputs, labels in tqdm(trainloader):
        """TODO:
        1. Set inputs and labels to be on device
        2. zero out our gradients
        3. pass our inputs through the ViT
        4. pass our outputs / labels into our loss / criterion
        5. backpropagate
        6. step our optimizeer
        """

        loss = # TODO

        train_loss += loss.item() * inputs.shape[0]
        train_acc += torch.sum((torch.argmax(outputs, dim=1) == labels)).item()
        train_total += inputs.shape[0]
    train_loss = train_loss / train_total
    train_acc = train_acc / train_total
    train_losses.append(train_loss)

    test_loss = 0.0
    test_acc = 0.0
    test_total = 0
    vit.eval()
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = vit(inputs)
            loss = criterion(outputs, labels.long())

            test_loss += loss.item() * inputs.shape[0]
            test_acc += torch.sum((torch.argmax(outputs, dim=1) == labels)).item()
            test_total += inputs.shape[0]
    test_loss = test_loss / test_total
    test_acc = test_acc / test_total
    test_losses.append(test_loss)

    print(f'[{epoch + 1:2d}] train loss: {train_loss:.3f} | train accuracy: {train_acc:.3f} | test_loss: {test_loss:.3f} | test_accuracy: {test_acc:.3f}')

print('Finished Training')

Files already downloaded and verified
Files already downloaded and verified


  0%|          | 0/98 [00:00<?, ?it/s]

[ 1] train loss: 2.041 | train accuracy: 0.236 | test_loss: 1.822 | test_accuracy: 0.295


  0%|          | 0/98 [00:00<?, ?it/s]

[ 2] train loss: 1.725 | train accuracy: 0.334 | test_loss: 1.660 | test_accuracy: 0.365


  0%|          | 0/98 [00:00<?, ?it/s]

[ 3] train loss: 1.581 | train accuracy: 0.406 | test_loss: 1.495 | test_accuracy: 0.450


  0%|          | 0/98 [00:00<?, ?it/s]

[ 4] train loss: 1.456 | train accuracy: 0.465 | test_loss: 1.471 | test_accuracy: 0.462


  0%|          | 0/98 [00:00<?, ?it/s]

[ 5] train loss: 1.380 | train accuracy: 0.498 | test_loss: 1.384 | test_accuracy: 0.498


  0%|          | 0/98 [00:00<?, ?it/s]

[ 6] train loss: 1.326 | train accuracy: 0.519 | test_loss: 1.319 | test_accuracy: 0.519


  0%|          | 0/98 [00:00<?, ?it/s]

[ 7] train loss: 1.278 | train accuracy: 0.536 | test_loss: 1.245 | test_accuracy: 0.550


  0%|          | 0/98 [00:00<?, ?it/s]

[ 8] train loss: 1.248 | train accuracy: 0.548 | test_loss: 1.234 | test_accuracy: 0.556


  0%|          | 0/98 [00:00<?, ?it/s]

[ 9] train loss: 1.221 | train accuracy: 0.558 | test_loss: 1.213 | test_accuracy: 0.562


  0%|          | 0/98 [00:00<?, ?it/s]

[10] train loss: 1.186 | train accuracy: 0.571 | test_loss: 1.209 | test_accuracy: 0.563


  0%|          | 0/98 [00:00<?, ?it/s]

[11] train loss: 1.161 | train accuracy: 0.586 | test_loss: 1.178 | test_accuracy: 0.578


  0%|          | 0/98 [00:00<?, ?it/s]

[12] train loss: 1.143 | train accuracy: 0.589 | test_loss: 1.149 | test_accuracy: 0.586


  0%|          | 0/98 [00:00<?, ?it/s]

[13] train loss: 1.120 | train accuracy: 0.594 | test_loss: 1.125 | test_accuracy: 0.596


  0%|          | 0/98 [00:00<?, ?it/s]

[14] train loss: 1.101 | train accuracy: 0.604 | test_loss: 1.102 | test_accuracy: 0.604


  0%|          | 0/98 [00:00<?, ?it/s]

[15] train loss: 1.090 | train accuracy: 0.607 | test_loss: 1.100 | test_accuracy: 0.603


  0%|          | 0/98 [00:00<?, ?it/s]

[16] train loss: 1.064 | train accuracy: 0.617 | test_loss: 1.083 | test_accuracy: 0.609


  0%|          | 0/98 [00:00<?, ?it/s]

[17] train loss: 1.050 | train accuracy: 0.624 | test_loss: 1.061 | test_accuracy: 0.618


  0%|          | 0/98 [00:00<?, ?it/s]

[18] train loss: 1.037 | train accuracy: 0.625 | test_loss: 1.080 | test_accuracy: 0.619


  0%|          | 0/98 [00:00<?, ?it/s]

[19] train loss: 1.017 | train accuracy: 0.637 | test_loss: 1.046 | test_accuracy: 0.627


  0%|          | 0/98 [00:00<?, ?it/s]

[20] train loss: 1.002 | train accuracy: 0.641 | test_loss: 1.018 | test_accuracy: 0.635


  0%|          | 0/98 [00:00<?, ?it/s]

[21] train loss: 0.995 | train accuracy: 0.644 | test_loss: 1.034 | test_accuracy: 0.630


  0%|          | 0/98 [00:00<?, ?it/s]

[22] train loss: 0.976 | train accuracy: 0.650 | test_loss: 1.017 | test_accuracy: 0.638


  0%|          | 0/98 [00:00<?, ?it/s]

[23] train loss: 0.964 | train accuracy: 0.652 | test_loss: 1.013 | test_accuracy: 0.644


  0%|          | 0/98 [00:00<?, ?it/s]

[24] train loss: 0.954 | train accuracy: 0.658 | test_loss: 1.004 | test_accuracy: 0.645


  0%|          | 0/98 [00:00<?, ?it/s]

[25] train loss: 0.935 | train accuracy: 0.663 | test_loss: 0.974 | test_accuracy: 0.651


  0%|          | 0/98 [00:00<?, ?it/s]

[26] train loss: 0.931 | train accuracy: 0.666 | test_loss: 0.980 | test_accuracy: 0.653


  0%|          | 0/98 [00:00<?, ?it/s]

[27] train loss: 0.919 | train accuracy: 0.672 | test_loss: 0.985 | test_accuracy: 0.650


  0%|          | 0/98 [00:00<?, ?it/s]

[28] train loss: 0.901 | train accuracy: 0.675 | test_loss: 0.979 | test_accuracy: 0.658


  0%|          | 0/98 [00:00<?, ?it/s]

[29] train loss: 0.884 | train accuracy: 0.682 | test_loss: 0.971 | test_accuracy: 0.655


  0%|          | 0/98 [00:00<?, ?it/s]

[30] train loss: 0.872 | train accuracy: 0.687 | test_loss: 0.956 | test_accuracy: 0.662
Finished Training
