<a href="https://colab.research.google.com/github/ra1ph2/Vision-Transformer/blob/main/VisionTransformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#### Libraries Imported

In [1]:
import torch
from torch import nn
from torch import functional as F
from torch import optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

In [2]:
if torch.cuda.is_available():    
    device = torch.device("cuda")
    print('GPU: ', torch.cuda.get_device_name(0))

else:
    device = torch.device("cpu")
    print('No GPU available')

GPU:  Tesla T4


#### Model Architecture

In [3]:
class Attention(nn.Module):
    def __init__(self, embed_dim, heads=8, activation=None, dropout=0.1):
        super(Attention, self).__init__()
        self.heads = heads
        self.embed_dim = embed_dim
        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)
        self.softmax = nn.Softmax(dim=-1)
        if activation == 'relu':
            self.activation = nn.ReLU()
        else:
            self.activation = nn.Identity()
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, inp):
        # inp: (batch_size, seq_len, embed_dim)
        batch_size, seq_len, embed_dim = inp.size()
        assert embed_dim == self.embed_dim

        query = self.activation(self.query(inp))
        key   = self.activation(self.key(inp))
        value = self.activation(self.value(inp))

        # output of _reshape_heads(): (batch_size * heads, seq_len, reduced_dim) | reduced_dim = embed_dim // heads
        query = self._reshape_heads(query)
        key   = self._reshape_heads(key)
        value = self._reshape_heads(value)

        # attention_scores: (batch_size * heads, seq_len, seq_len) | Softmaxed along the last dimension
        attention_scores = self.softmax(torch.matmul(query, key.transpose(1, 2)))

        # out: (batch_size * heads, seq_len, reduced_dim)
        out = torch.matmul(self.dropout(attention_scores), value)
        
        # output of _reshape_heads_back(): (batch_size, seq_len, embed_size)
        out = self._reshape_heads_back(out)

        return out, attention_scores

    def _reshape_heads(self, inp):
        # inp: (batch_size, seq_len, embed_dim)
        batch_size, seq_len, embed_dim = inp.size()

        reduced_dim = self.embed_dim // self.heads
        assert reduced_dim * self.heads == self.embed_dim
        out = inp.reshape(batch_size, seq_len, self.heads, reduced_dim)
        out = out.permute(0, 2, 1, 3)
        out = out.reshape(-1, seq_len, reduced_dim)

        # out: (batch_size * heads, seq_len, reduced_dim)
        return out

    def _reshape_heads_back(self, inp):
        # inp: (batch_size * heads, seq_len, reduced_dim) | reduced_dim = embed_dim // heads
        batch_size_mul_heads, seq_len, reduced_dim = inp.size()
        batch_size = batch_size_mul_heads // self.heads

        out = inp.reshape(batch_size, self.heads, seq_len, reduced_dim)
        out = out.permute(0, 2, 1, 3)
        out = out.reshape(batch_size, seq_len, self.embed_dim)

        # out: (batch_size, seq_len, embed_dim)
        return out

In [4]:
attention = Attention(embed_dim=4, heads=2, activation=None)
attention_lib = nn.MultiheadAttention(embed_dim=4, num_heads=2)
inp = torch.ones((1, 2, 4))
print(inp)
out, wts = attention(inp)
print(out)
print(wts)
out_lib, out_wts = attention_lib(inp.permute(1, 0, 2), inp.permute(1, 0, 2), inp.permute(1, 0, 2))
print(out_lib.permute(1, 0, 2))
print(out_wts)

tensor([[[1., 1., 1., 1.],
         [1., 1., 1., 1.]]])
tensor([[[-0.9729,  0.5344,  1.2441,  0.7006],
         [-0.9729,  0.5344,  1.2441,  0.7006]]], grad_fn=<UnsafeViewBackward>)
tensor([[[0.5000, 0.5000],
         [0.5000, 0.5000]],

        [[0.5000, 0.5000],
         [0.5000, 0.5000]]], grad_fn=<SoftmaxBackward>)
tensor([[[-0.6603,  0.1291, -0.0158,  0.9985],
         [-0.6603,  0.1291, -0.0158,  0.9985]]], grad_fn=<PermuteBackward>)
tensor([[[0.5000, 0.5000],
         [0.5000, 0.5000]]], grad_fn=<DivBackward0>)


In [5]:
# Check if Dropout should be used after second Linear Layer
class FeedForward(nn.Module):
    def __init__(self, embed_dim, forward_expansion=1, dropout=0.1):
        super(FeedForward, self).__init__()
        self.embed_dim = embed_dim
        self.fc1 = nn.Linear(embed_dim, embed_dim * forward_expansion)
        self.activation = nn.GELU()
        self.fc2 = nn.Linear(embed_dim * forward_expansion, embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, inp):
        # inp: (batch_size, seq_len, embed_dim)
        batch_size, seq_len, embed_dim = inp.size()
        assert embed_dim == self.embed_dim

        out = self.dropout(self.activation(self.fc1(inp)))
        # out = self.dropout(self.fc2(out))
        out = self.fc2(out)

        # out: (batch_size, seq_len, embed_dim)
        return out 

In [None]:
ff = FeedForward(8, 1)
inp = torch.ones((1, 2, 8))
print(inp)
out = ff(inp)
print(out)

tensor([[[1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1.]]])
tensor([[[-0.0409, -0.0025, -0.0943, -0.0105, -0.0891, -0.0134, -0.1431,
           0.3287],
         [-0.0265, -0.0158, -0.0895,  0.0139, -0.0681, -0.0181, -0.1407,
           0.3108]]], grad_fn=<AddBackward0>)


In [6]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, heads=8, activation=None, forward_expansion=1, dropout=0.1):
        super(TransformerBlock, self).__init__()
        self.embed_dim = embed_dim
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attention = Attention(embed_dim, heads, activation, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.feed_forward = FeedForward(embed_dim, forward_expansion, dropout)

    def forward(self, inp):
        # inp: (batch_size, seq_len, embed_dim)
        batch_size, seq_len, embed_dim = inp.size()
        assert embed_dim == self.embed_dim

        res = inp
        out = self.norm1(inp)
        out, _ = self.attention(out)
        out = out + res
        
        res = out
        out = self.norm2(out)
        out = self.feed_forward(out)
        out = out + res

        # out: (batch_size, seq_len, embed_dim)
        return out

In [7]:
tb = TransformerBlock(8, 2)
inp = torch.ones((1, 2, 8))
print(inp)
out = tb(inp)
print(out)

tensor([[[1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1.]]])
tensor([[[0.8778, 1.1134, 0.7952, 1.3798, 0.6171, 0.7217, 0.9616, 1.3443],
         [1.0044, 1.2758, 0.9557, 1.3417, 0.5727, 0.7423, 0.9724, 1.3402]]],
       grad_fn=<AddBackward0>)


In [8]:
class Transformer(nn.Module):
    def __init__(self, embed_dim, layers, heads=8, activation=None, forward_expansion=1, dropout=0.1):
        super(Transformer, self).__init__()
        self.embed_dim = embed_dim
        self.trans_blocks = nn.ModuleList(
            [TransformerBlock(embed_dim, heads, activation, forward_expansion, dropout) for i in range(layers)]
        )

    def forward(self, inp):
        # inp: (batch_size, seq_len, embed_dim)

        out = inp
        for block in self.trans_blocks:
            out = block(out)

        # out: (batch_size, seq_len, embed_dim)
        return out

In [9]:
tb = TransformerBlock(8, 2)
trans = Transformer(8, 1, 2)
inp = torch.ones((1, 2, 8))
print(inp)
out = tb(inp)
print(out)
out = trans(inp)
print(out)

tensor([[[1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1.]]])
tensor([[[ 1.1834,  0.9178,  0.3529,  1.4619,  0.8418,  1.1810, -0.0198,
           0.8822],
         [ 1.1457,  0.9306,  0.3712,  1.4903,  0.8227,  1.1471, -0.0410,
           0.9110]]], grad_fn=<AddBackward0>)
tensor([[[1.5750, 0.7187, 1.4488, 0.9184, 1.4199, 0.7167, 1.2511, 0.4901],
         [1.6162, 0.7048, 1.4036, 0.7740, 1.2926, 0.6715, 1.1664, 0.6135]]],
       grad_fn=<AddBackward0>)


In [10]:
# Not Exactly Same as Paper | Check if Dropout should be used here
class ClassificationHead(nn.Module):
    def __init__(self, embed_dim, classes, dropout=0.1):
        super(ClassificationHead, self).__init__()
        self.embed_dim = embed_dim
        self.classes = classes
        self.fc1 = nn.Linear(embed_dim, embed_dim // 2)
        self.activation = nn.GELU()
        self.fc2 = nn.Linear(embed_dim // 2, classes)
        self.softmax = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, inp):
        # inp: (batch_size, embed_dim)
        batch_size, embed_dim = inp.size()
        assert embed_dim == self.embed_dim

        out = self.dropout(self.activation(self.fc1(inp)))
        out = self.softmax(self.fc2(out))

        # out: (batch_size, embed_dim) | SoftMaxed along the last dimension
        return out

In [11]:
ch = ClassificationHead(8, 2)
inp = torch.ones((2, 8))
print(inp)
out = ch(inp)
print(out)

tensor([[1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])
tensor([[0.6247, 0.3753],
        [0.6247, 0.3753]], grad_fn=<SoftmaxBackward>)


In [13]:
#TODO: Remove to.device from class_token
class VisionTransformer(nn.Module):
    def __init__(self, patch_size, max_len, embed_dim, classes, layers, heads=8, activation=None, forward_expansion=1, dropout=0.1):
        super(VisionTransformer, self).__init__()
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.patch_to_embed = nn.Linear(patch_size * patch_size * 3, embed_dim)
        self.position_embed = nn.Parameter(torch.randn((max_len, embed_dim)))
        self.transformer = Transformer(embed_dim, layers, heads, activation, forward_expansion, dropout)
        self.classification_head = ClassificationHead(embed_dim, classes)

    def forward(self, inp):
        # inp: (batch_size, 3, width, height)
        batch_size, channels, width, height = inp.size()
        assert channels == 3

        out = inp.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size).contiguous()
        out = out.view(batch_size, channels, -1, self.patch_size, self.patch_size)
        out = out.permute(0, 2, 3, 4, 1)
        # out: (batch_size, seq_len, patch_size, patch_size, 3) | seq_len would be (width*height)/(patch_size**2)
        batch_size, seq_len, patch_size, _, channels = out.size()
        
        out = out.reshape(batch_size, seq_len, -1)
        out = self.patch_to_embed(out)
        # out: (batch_size, seq_len, embed_dim)

        class_token = torch.randn((batch_size, 1, self.embed_dim)).to(device)
        out = torch.cat([class_token, out], dim=1)
        # out: (batch_size, seq_len+1, embed_dim)

        position_embed = self.position_embed[:seq_len+1]
        position_embed = position_embed.unsqueeze(0).expand(batch_size, seq_len+1, self.embed_dim)
        out = out + position_embed
        # out: (batch_size, seq_len+1, embed_dim) | Added Positional Embeddings

        out = self.transformer(out)
        # out: (batch_size, seq_len+1, embed_dim) 
        class_token = out[:, 0]
        # class_token: (batch_size, embed_dim)

        class_out = self.classification_head(class_token)
        # class_out: (batch_size, classes)
        
        return class_out, out

In [15]:
vit = VisionTransformer(2, 3, 8, 2, 2).to(device)
inp = torch.ones((2, 3, 2, 2)).to(device)
print(inp)
class_out, out = vit(inp)
print(class_out)
print(class_out.shape)
print(out)
print(out.shape)

tensor([[[[1., 1.],
          [1., 1.]],

         [[1., 1.],
          [1., 1.]],

         [[1., 1.],
          [1., 1.]]],


        [[[1., 1.],
          [1., 1.]],

         [[1., 1.],
          [1., 1.]],

         [[1., 1.],
          [1., 1.]]]], device='cuda:0')
tensor([[0.3354, 0.6646],
        [0.3787, 0.6213]], device='cuda:0', grad_fn=<SoftmaxBackward>)
torch.Size([2, 2])
tensor([[[-1.6120,  0.9858, -1.3521,  1.2709,  0.7807,  0.7141, -1.5630,
           3.0290],
         [ 1.4925,  3.4565,  0.7653,  0.1952, -0.0691, -0.6835, -1.5203,
           1.6126]],

        [[ 0.9667,  0.9441,  1.0435, -0.2500, -0.4524,  0.7600, -2.3929,
           0.6099],
         [ 1.7517,  3.9360,  1.0152, -1.0907, -1.7946, -0.9177, -1.7162,
           0.7655]]], device='cuda:0', grad_fn=<AddBackward0>)
torch.Size([2, 2, 8])


#### Data Loading

In [21]:
def CIFAR100DataLoader(split, batch_size=8, num_workers=2, shuffle=True):

    CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
    CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)
    
    mean = CIFAR100_TRAIN_MEAN
    std = CIFAR100_TRAIN_STD

    if split == 'train':
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(15),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
        
        cifar100 = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=train_transform)
        dataloader = DataLoader(cifar100, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle)
    
    elif split == 'test':
        test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])

        cifar100 = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=test_transform)
        dataloader = DataLoader(cifar100, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle)

    return dataloader

#### Training

In [17]:
lr = 0.001
batch_size = 8
num_workers = 2
shuffle = True
patch_size = 4
max_len = ((32//patch_size) * (32//patch_size)) + 1 # +1 for the class token
embed_dim = 128
classes = 100
layers = 6
heads = 8
epochs = 100

In [24]:
dataloader = CIFAR100DataLoader('train', batch_size=batch_size, num_workers=num_workers, shuffle=shuffle)
model = VisionTransformer(patch_size=patch_size, max_len=max_len, embed_dim=embed_dim, classes=classes, layers=layers, heads=heads).to(device)

criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.AdamW(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, steps_per_epoch=len(dataloader), epochs=epochs)

for epoch in range(epochs):

    running_loss = 0.0
    running_accuracy = 0.0

    for data, target in tqdm(dataloader):
        data = data.to(device)
        target = target.to(device)

        output, _ = model(data)
        loss = criterion(output, target)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        acc = (output.argmax(dim=1) == target).float().mean()
        running_accuracy += acc / len(dataloader)
        running_loss += loss.item() / len(dataloader)
    
    print(f"Epoch : {epoch+1} - loss : {running_loss:.4f} - acc: {running_accuracy:.4f}\n")

Files already downloaded and verified


100%|██████████| 6250/6250 [03:13<00:00, 32.37it/s]
  0%|          | 0/6250 [00:00<?, ?it/s]

Epoch : 1 - loss : 4.5867 - acc: 0.0338



100%|██████████| 6250/6250 [03:09<00:00, 33.02it/s]
  0%|          | 0/6250 [00:00<?, ?it/s]

Epoch : 2 - loss : 4.5719 - acc: 0.0501



100%|██████████| 6250/6250 [03:10<00:00, 32.85it/s]
  0%|          | 0/6250 [00:00<?, ?it/s]

Epoch : 3 - loss : 4.5658 - acc: 0.0570



100%|██████████| 6250/6250 [03:10<00:00, 32.78it/s]
  0%|          | 0/6250 [00:00<?, ?it/s]

Epoch : 4 - loss : 4.5614 - acc: 0.0607



  6%|▌         | 361/6250 [00:10<02:58, 33.06it/s]

KeyboardInterrupt: ignored