<a href="https://colab.research.google.com/github/tusharvatsa32/VisTransformers/blob/main/Code/vit_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Visual Transformer with Linformer

Training Visual Transformer on *Dogs vs Cats Data*

* Dogs vs. Cats Redux: Kernels Edition - https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition
* Base Code - https://www.kaggle.com/reukki/pytorch-cnn-tutorial-with-cats-and-dogs/
* Effecient Attention Implementation - https://github.com/lucidrains/vit-pytorch#efficient-attention

## Import Libraries

In [None]:
!pip install -q einops

In [None]:
import glob
import os
import random
import zipfile
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, models, transforms
from __future__ import print_function
from itertools import chain
from torch import einsum
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm.notebook import tqdm
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

In [None]:
print(f"Torch: {torch.__version__}")

Torch: 1.8.1+cu101


In [None]:
# Training settings
batch_size = 64
epochs = 3
lr = 3e-5
gamma = 0.7
seed = 42

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(seed)

In [None]:
device = 'cuda'

In [None]:
import PIL
img_size = ((384, 384)) #For ResNet models
# img_size = ((256, 256)) # For ViT predefined weights

transforms_train = transforms.Compose([
    transforms.Resize(img_size),
    transforms.ColorJitter(hue=.05, saturation=.05),
    transforms.RandomHorizontalFlip(p=0.3),
    transforms.RandomVerticalFlip(p=0.3),
    transforms.RandomRotation(10, resample=PIL.Image.BILINEAR),
    transforms.ToTensor(),
    transforms.RandomCrop(img_size, fill=0),
    transforms.RandomAffine(10, translate=None, scale=(0.8, 1.2), shear=None, fill=0, fillcolor=None, resample=None),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    #transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transforms_val = transforms.Compose([
    transforms.Resize(img_size),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    #transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

  "Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead"


## Load Data

In [None]:
train_data = torchvision.datasets.CIFAR10(train=True,download=True,root= "./cifar10/train_data", transform=transforms_train)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=512,
                                          shuffle=True, num_workers=8)

valid_data = torchvision.datasets.CIFAR10(train=False,download=True,root= "./cifar10/test_data", transform=transforms_val)
valid_loader = torch.utils.data.DataLoader(testset, batch_size=256,
                                         shuffle=False, num_workers=8)

Files already downloaded and verified


  cpuset_checked))


Files already downloaded and verified


In [None]:
print(len(train_data), len(train_loader))

50000 98


In [None]:
print(len(valid_data), len(valid_loader))

10000 40


### Visual Transformer

In [None]:
import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = self.attend(dots)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches = (image_size // patch_size) ** 2
        patch_dim = channels * patch_size ** 2
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
            nn.Linear(patch_dim, dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)

In [None]:
model = ViT(
    image_size=384,
    patch_size=32,
    num_classes=10,
    dim=1024,
    depth=6,
    heads=8,
    mlp_dim=2048,
    dropout=0.1,
    emb_dropout=0.1
).to(device)

### Training

In [None]:
numEpochs = 100
in_features = 3 # RGB channels

learningRate = 0.03
weightDecay = 5e-5

num_classes = len(trainset.classes)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learningRate, weight_decay=weightDecay, momentum=0.9, nesterov=True)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.3, patience=3, threshold=0.002, verbose=True)

In [None]:
my_acc = []
my_loss = []

In [None]:
# Train!
numEpochs = 100
for epoch in range(numEpochs):
    
    # Train
    model.train()
    train_loss = 0.0
    train_acc = 0.0
    correct = 0

    for batch_num, (x, y) in enumerate(trainloader):
        optimizer.zero_grad()
        
        x, y = x.to(device), y.to(device)

        outputs = model(x)

        correct += (torch.argmax(outputs, axis=1) == y).sum().item()

        loss = criterion(outputs, y.long())
        loss.backward()
        optimizer.step()

        del(outputs)

        train_loss += loss.item()

        if batch_num % 100 == 0:
            print('Epoch: {}\tBatch: {}\tAvg-Loss: {:.4f}'.format(epoch, batch_num+1, train_loss/(batch_num+1)))

    train_accuracy = correct / len(trainset)

    # Validate
    model.eval()
    num_correct = 0
    for batch_num1, (x, y) in enumerate(testloader):
        x, y = x.to(device), y.to(device)
        outputs = model(x)

        num_correct += (torch.argmax(outputs, axis=1) == y).sum().item()

    val_accuracy = num_correct / len(testset)
    my_acc.append(val_accuracy)
    my_loss.append(train_loss/(batch_num+1))
    print('Epoch: {}\t Training Accuracy: {:.4f}\t Validation Accuracy: {:.4f}\t Avg-Loss: {:.4f}'.format(epoch, train_accuracy*100, val_accuracy * 100, train_loss/(batch_num+1)))
    scheduler.step(val_accuracy)

    #torch.save(network.state_dict(),'/content/drive/MyDrive/DL_CMU/HW2_P2/ResNet_Plateau_d3/Net_'+str(epoch)+'_'+str(val_accuracy)+'_checkpoint.t7')

  cpuset_checked))


Epoch: 0	Batch: 1	Avg-Loss: 2.4815
Epoch: 0	Batch: 101	Avg-Loss: 11.6012
Epoch: 0	Batch: 201	Avg-Loss: 7.0939
Epoch: 0	Batch: 301	Avg-Loss: 5.5329
Epoch: 0	Batch: 401	Avg-Loss: 4.7611
Epoch: 0	Batch: 501	Avg-Loss: 4.2867
Epoch: 0	Batch: 601	Avg-Loss: 3.9646
Epoch: 0	Batch: 701	Avg-Loss: 3.7337
Epoch: 0	Batch: 801	Avg-Loss: 3.5597
Epoch: 0	Batch: 901	Avg-Loss: 3.4184
Epoch: 0	Batch: 1001	Avg-Loss: 3.2955
Epoch: 0	Batch: 1101	Avg-Loss: 3.1959
Epoch: 0	Batch: 1201	Avg-Loss: 3.1108
Epoch: 0	Batch: 1301	Avg-Loss: 3.0399
Epoch: 0	Batch: 1401	Avg-Loss: 2.9780
Epoch: 0	Batch: 1501	Avg-Loss: 2.9220
Epoch: 0	Batch: 1601	Avg-Loss: 2.8748
Epoch: 0	Batch: 1701	Avg-Loss: 2.8317
Epoch: 0	Batch: 1801	Avg-Loss: 2.7937
Epoch: 0	Batch: 1901	Avg-Loss: 2.7593
Epoch: 0	Batch: 2001	Avg-Loss: 2.7266
Epoch: 0	Batch: 2101	Avg-Loss: 2.7012
Epoch: 0	Batch: 2201	Avg-Loss: 2.6758
Epoch: 0	Batch: 2301	Avg-Loss: 2.6541
Epoch: 0	Batch: 2401	Avg-Loss: 2.6335
Epoch: 0	Batch: 2501	Avg-Loss: 2.6139
Epoch: 0	Batch: 2601	Av