In [71]:
! pip install lightning-bolts



In [72]:
!pip install timm



In [73]:
import torch
import torchvision
import torch.nn as nn
from math import sqrt
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import pandas as pd
import matplotlib.pyplot as plt
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
from torchvision.transforms.v2 import CutMix, MixUp, RandomChoice, RandAugment
from timm.loss import SoftTargetCrossEntropy

In [74]:

# image pre-processing helper class

class PatchEmbedding(nn.Module):

    def __init__(self, batches=32, in_channels=3, patch_size=16, size=128, embed_dim=768):
        super().__init__()

        assert size % patch_size == 0, "Image size must be divisible by patch size"

        self.batches = batches
        self.in_channels = in_channels # rgb ==> 3 channels
        self.patch_size = patch_size # size of each patch (like a token)
        self.embed_dim = embed_dim # the higher-dimensional space to project the patches to
        self.size = size # size of input image
        self.N = (self.size // self.patch_size) ** 2 # number of patches

        self.proj = nn.Conv2d(in_channels=self.in_channels, # B, C, H, W --> B, D, H_p, W_p
                              out_channels=self.embed_dim, # 3D space --> 768D space to extract more information
                              kernel_size=self.patch_size, # so that the patches don't overlap
                              stride=self.patch_size) # divides input image into patches

        self.cls_token = nn.Parameter(torch.randn(1, 1, self.embed_dim)) # token which captures the 'meaning' of the image in a vector
        self.pos_embeddings = nn.Parameter(torch.randn(1, self.N + 1, self.embed_dim)) # the positional embeddings which will be added later

    def forward(self, x):
        x = self.proj(x) # applying conv2d projection
        x = torch.flatten(x, 2) # B, D, N
        x = x.transpose(1, 2) # B, N, D
        B = x.shape[0]
        cls_token = self.cls_token.expand(B, -1, -1) # expanding the cls token along the batch dimension so it can be added later
        x = torch.cat((cls_token, x), dim=1) # adding the cls token to the input tensor
        x = x + self.pos_embeddings # now each vector is aware of the position of the word

        return x

In [75]:

class ManualMultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim=768, heads=12):
        super().__init__()

        assert embed_dim % heads == 0, "Embedding dimension must be divisible by heads"

        self.embed_dim = embed_dim
        self.heads = heads
        self.head_dim = embed_dim // heads

        # fully connected NN layers with # of input neurons = embed_dim = # output neurons
        self.Q_proj = nn.Linear(embed_dim, embed_dim)
        self.V_proj = nn.Linear(embed_dim, embed_dim)
        self.K_proj = nn.Linear(embed_dim, embed_dim)

        self.dropout = nn.Dropout(0.15)

        self.output = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        B, N, D = x.shape

        # sending the input tensor through proj NN layers
        Q = self.Q_proj(x) # batches, patches, embed_dim
        Q = Q.view(B, N, self.heads, self.head_dim).permute(0, 2, 1, 3) # single head --> multihead

        V = self.V_proj(x) # batches, patches, embed_dim
        V = V.view(B, N, self.heads, self.head_dim).permute(0, 2, 1, 3) # single head --> multihead

        K = self.K_proj(x) # batches, patches, embed_dim
        K = K.view(B, N, self.heads, self.head_dim).permute(0, 2, 1, 3) # single head --> multihead

        # computing attention
        x = self.compute_attention(Q, K, V).permute(0, 2, 1, 3).contiguous() # B, heads, N, head_dim --> B, N, heads, head_dim
        x = self.dropout(x)
        x = x.view(B, N, D)
        x = self.output(x)

        return x

    def compute_attention(self, Q, K, V):
        K_T = torch.transpose(K, -2, -1) # transpose so that multiplication is defined
        scaling = sqrt(self.head_dim)
        val = torch.matmul(Q, K_T) / scaling

        return torch.matmul(torch.softmax(val, dim=-1), V)


In [76]:
class TransformerEncoder(nn.Module):
    def __init__(self, embed_dim=768, heads=12):
        super().__init__()

        self.mhsa = ManualMultiHeadSelfAttention(embed_dim, heads)
        self.ln1 = nn.LayerNorm(embed_dim)

        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.GELU(),
            nn.Dropout(0.05),
            nn.Linear(embed_dim, embed_dim),
            nn.GELU(),
            nn.Dropout(0.05),
        )

        self.ln2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        x = self.ln1(x) # normalizing activation functions to prevent saturation. Pre-norm is expected to lead to faster convergence.
        attn_out = self.mhsa(x) # computing residual
        x = x + attn_out # skip connection
        x = self.ln2(x) # normalizing activation functions to prevent saturation
        ffn_out = self.ffn(x) # computing residual
        x = x + ffn_out # skip connection

        return x


In [77]:
class VisionTransformer(nn.Module):
    def __init__(self,  batches=32, in_channels=3, patch_size=16, size=128, embed_dim=768, heads=12, depth=8, num_classes=10):
        super().__init__()

        self.patch_embedding = PatchEmbedding(batches, in_channels, patch_size, size, embed_dim)
        self.transformer_stack = nn.ModuleList(TransformerEncoder(embed_dim, heads) for _ in range(depth))

        self.mlp_head = nn.Sequential(
            nn.Dropout(0.05),
            nn.Linear(in_features=embed_dim, out_features=num_classes),
        )

    def forward(self, x):
        x = self.patch_embedding(x)

        for t in self.transformer_stack:
            x = t(x)

        cls = x[:, 0]
        x = self.mlp_head(cls)

        return x

In [78]:
train_transform = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.RandomCrop(32, padding=4),
    torchvision.transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    RandAugment(num_ops=1, magnitude=5),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)) # normalizing mean helps optims like Adam, scaling ensures uniformity across the inputs
])

test_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768))
])


training_data = datasets.CIFAR10(root="data", train=True, download=True, transform=train_transform)
test_data = datasets.CIFAR10(root="data", train=False, download=False,  transform=test_transform)

In [79]:
batchsize = 256
channels = 3
patchsize = 4
imsize = 32
embeddim = 192
numheads = 12
encoders = 9
numclasses = 10
epochs = 100

In [80]:
train_dataloader = DataLoader(dataset=training_data, batch_size=batchsize, shuffle=True, pin_memory=True)
test_dataloader = DataLoader(dataset=test_data, batch_size=batchsize, shuffle=False)

cutmix = CutMix(num_classes=numclasses)
mixup = MixUp(num_classes=numclasses)

vit = VisionTransformer(
    batches=batchsize, in_channels=channels,
    patch_size=patchsize, embed_dim=embeddim,
    heads=numheads, depth=encoders, size=imsize,
    num_classes=numclasses
)


loss_fn = torch.nn.CrossEntropyLoss()
sft_loss_fn = SoftTargetCrossEntropy()
optim = torch.optim.AdamW(params=vit.parameters(), lr=0.002, weight_decay=0.05)
scheduler = LinearWarmupCosineAnnealingLR(optim, warmup_epochs=5, max_epochs=epochs, eta_min=1e-6, warmup_start_lr=1e-6)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

vit.to(device)

  scheduler = LinearWarmupCosineAnnealingLR(optim, warmup_epochs=5, max_epochs=epochs, eta_min=1e-6, warmup_start_lr=1e-6)


VisionTransformer(
  (patch_embedding): PatchEmbedding(
    (proj): Conv2d(3, 192, kernel_size=(4, 4), stride=(4, 4))
  )
  (transformer_stack): ModuleList(
    (0-8): 9 x TransformerEncoder(
      (mhsa): ManualMultiHeadSelfAttention(
        (Q_proj): Linear(in_features=192, out_features=192, bias=True)
        (V_proj): Linear(in_features=192, out_features=192, bias=True)
        (K_proj): Linear(in_features=192, out_features=192, bias=True)
        (dropout): Dropout(p=0.15, inplace=False)
        (output): Linear(in_features=192, out_features=192, bias=True)
      )
      (ln1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
      (ffn): Sequential(
        (0): Linear(in_features=192, out_features=192, bias=True)
        (1): GELU(approximate='none')
        (2): Dropout(p=0.05, inplace=False)
        (3): Linear(in_features=192, out_features=192, bias=True)
        (4): GELU(approximate='none')
        (5): Dropout(p=0.05, inplace=False)
      )
      (ln2): LayerNorm((19

In [81]:
scaler = torch.amp.GradScaler()

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

In [82]:
for i in range(epochs):
    correct, total = 0, 0
    vit.train(True)

    loop = tqdm(train_dataloader, desc=f"Epoch {i+1}/{epochs}", leave=True)

    for input, labels in loop:
        input, labels = input.to(device), labels.to(device)
        optim.zero_grad()

        if i + 1 >= 25 and torch.rand(1).item() < 0.5:
            cutmix_or_mixup = RandomChoice([cutmix, mixup], p=[0.7, 0.3])
            input, labels = cutmix_or_mixup(input, labels)
            output = vit(input)
            loss = sft_loss_fn(output, labels)
        elif i + 1 >= 10 and torch.rand(1).item() < 0.5:
            input, labels = mixup(input, labels)
            output = vit(input)
            loss = sft_loss_fn(output, labels)
        else:
            output = vit(input)
            loss = loss_fn(output, labels)

        scaler.scale(loss).backward()
        scaler.step(optim)

        scaler.update()

        loop.set_postfix(loss=loss.item())

        if labels.ndim == 2:
          labels = labels.argmax(dim=1)

        pred = torch.argmax(output, dim=1)
        total += labels.size(0)
        correct += (pred == labels).sum().item()

    vit.eval()
    val_total, val_correct = 0, 0

    with torch.no_grad():
        for image, label in test_dataloader:
            image, label = image.to(device), label.to(device)

            with torch.cuda.amp.autocast():
              output = vit(image)

            pred = torch.argmax(output, dim=1)
            val_total += label.size(0)
            val_correct += (pred == label).sum().item()

        val_accuracy = val_correct / val_total
        accuracy = correct / total


    scheduler.step()
    print(f"Epoch {i + 1}: \t Accuracy: {accuracy} \t Val accuracy: {val_accuracy}")

Epoch 1/100: 100%|██████████| 196/196 [01:02<00:00,  3.15it/s, loss=2.38]
  with torch.cuda.amp.autocast():


Epoch 1: 	 Accuracy: 0.09456 	 Val accuracy: 0.1123


Epoch 2/100:  24%|██▍       | 47/196 [00:15<00:47,  3.12it/s, loss=2.05]


KeyboardInterrupt: 

In [None]:
training_loss_6 = [
    762.87,
    654.60,
    610.51,
    580.71,
    560.44,
    543.26,
    526.28,
    513.83,
    499.53,
    442.34,
    435.16,
    426.84,
    419.30,
    413.18,
    407.64,
    398.95,
    392.66,
    385.60
]

training_loss_7 = [
    751.32,
    643.59,
    599.94,
    571.05,
    552.35,
    501.65,
    488.99,
    477.56,
    466.41,
    456.77,
    447.05,
    419.97,
    412.00,
    397.69,
    383.77,
    368.06,
    351.56,
    336.02
]

plt.plot(training_loss_6, label="Trial 6", color='r')
plt.plot(training_loss_7, label="Trial 7", color='b')

plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Per-epoch training Loss")
plt.legend()

In [None]:
vit.eval()
total, correct = 0, 0

with torch.no_grad():
    for image, label in test_dataloader:
        image, label = image.to(device), label.to(device)

        with torch.cuda.amp.autocast():
          output = vit(image)

        pred = torch.argmax(output, dim=1)
        total += label.size(0)
        correct += (pred == label).sum().item()

    accuracy = correct / total

    print(f"Accuracy: {accuracy}")

In [None]:
import pandas as pd

data = {
    'Trial #': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
    'Embedding dimension': [384, 384, 384, 192, 192, 192, 256, 192, None, None],
    'Number of heads': [12, 12, 12, 12, 12, 8, 8, 12, None, None],
    'Number of encoders': [8, 8, 8,  8, 8, 8, 6, 9, None, None],
    'Batch size': [128, 128, 128, 128, 128, 128, 128, 128, None, None],
    'Epochs': [20, 20, 50,  30, 30, 30, 30, 100, None, None],
    'Patch size': [4, 4, 4,  4, 4, 4, 4, 4, None, None],
    'Optimizer': ['SGD', 'Adam', 'SGD',  "AdamW", "AdamW", "AdamW", "AdamW", "AdamW", None, None],
    'Learning rate': [0.001, 0.001, 0.1,  3e-4, 3e-4, 3e-4, "3e-4 with 1e-5 weight decay", "0.002 with 0.05 weight decay", None, None],
    'Normalization': ['No', 'No', 'Yes',  'Yes', 'Yes', 'Yes', 'Yes', "Yes", None, None],
    'Loss function': ['Cross entropy', 'Cross entropy', 'Cross entropy',  'Cross entropy', 'Cross entropy', 'Cross entropy', 'Cross entropy', 'Cross entropy', None, None],
    'Dropout regularization value': [None, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.10, None, None],
    'Data Augmentation': [None, None, "Jittering + horizontal flip", "Jittering + random crop + horizontal flip", "Jittering + random crop + horizontal flip",  "Jittering + random crop + horizontal flip", "Jittering + random crop + horizontal flip", "Jittering + random crop + horizontal flip + RandAugment" , None, None],
    'Scheduler': [None, None, None, "Cosine Annealing LR", "Cosine Annealing LR", "Cosine Annealing LR", "Cosine Annealing LR", "Cosine Annealing LR + linear warmup", None, None],
    'Accuracy %': [23.24, 23.24, 33.36, 10.00, 37.53,  70.49, 71.44, None, None, None],
}

# Create DataFrame
df = pd.DataFrame(data)
pd.set_option('display.max_colwidth', None)
df