<a href="https://colab.research.google.com/github/tusharvatsa32/VisTransformers/blob/main/Code/ViT_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Import Libraries

In [1]:
!nvidia-smi

Thu Apr 29 20:08:00 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   41C    P0    27W / 300W |      0MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
!pip install -q einops

In [3]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch import einsum
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm.notebook import tqdm
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

In [4]:
print(f"Torch: {torch.__version__}")

Torch: 1.8.1+cu101


In [5]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(11785)

In [6]:
device = 'cuda'

In [8]:
img_size = ((64, 64)) 

transforms_train = transforms.Compose([
    transforms.Resize(img_size),
    transforms.ColorJitter(hue=.05, saturation=.05),
    transforms.RandomHorizontalFlip(p=0.3),
    transforms.RandomVerticalFlip(p=0.3),
    transforms.RandomRotation(10, interpolation=torchvision.transforms.InterpolationMode.BILINEAR),
    transforms.RandomCrop(img_size, fill=0),
    transforms.RandomAffine(10, scale=(0.8, 1.2)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

transforms_val = transforms.Compose([
    transforms.Resize(img_size),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

## Load Data

In [9]:
train_data = torchvision.datasets.CIFAR10(train=True,download=True,root= "./cifar10/train_data", transform=transforms_train)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=128,
                                          shuffle=True, num_workers=8)

valid_data = torchvision.datasets.CIFAR10(train=False,download=True,root= "./cifar10/test_data", transform=transforms_val)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=32,
                                         shuffle=False, num_workers=8)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar10/train_data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=0.0, max=170498071.0), HTML(value='')))


Extracting ./cifar10/train_data/cifar-10-python.tar.gz to ./cifar10/train_data


  cpuset_checked))


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar10/test_data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=0.0, max=170498071.0), HTML(value='')))


Extracting ./cifar10/test_data/cifar-10-python.tar.gz to ./cifar10/test_data


In [10]:
print(len(train_data), len(train_loader))

50000 391


In [11]:
print(len(valid_data), len(valid_loader))

10000 313


### Visual Transformer

In [12]:
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = self.attend(dots)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches = (image_size // patch_size) ** 2
        patch_dim = channels * patch_size ** 2
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
            nn.Linear(patch_dim, dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)

In [15]:
model = ViT(
    image_size=64,
    patch_size=8,
    num_classes=10,
    dim=512,
    depth=6,
    heads=8,
    mlp_dim=1024,
    dropout=0.2,
    emb_dropout=0.1
).to(device)

### Training

In [16]:
numEpochs = 100
in_features = 3 # RGB channels

learningRate = 0.03
weightDecay = 5e-5

num_classes = len(train_data.classes)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learningRate, weight_decay=weightDecay, momentum=0.9, nesterov=True)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.3, patience=3, threshold=0.002, verbose=True)

In [17]:
my_acc = []
my_loss = []

In [18]:
# Train!
for epoch in range(numEpochs):
    
    # Train
    model.train()
    train_loss = 0.0
    train_acc = 0.0
    correct = 0

    for batch_num, (x, y) in enumerate(train_loader):
        optimizer.zero_grad()
        
        x, y = x.to(device), y.to(device)

        outputs = model(x)

        correct += (torch.argmax(outputs, axis=1) == y).sum().item()

        loss = criterion(outputs, y.long())
        loss.backward()
        optimizer.step()

        del(outputs)

        train_loss += loss.item()

        if batch_num % 100 == 0:
            print('Epoch: {}\tBatch: {}\tAvg-Loss: {:.4f}'.format(epoch, batch_num+1, train_loss/(batch_num+1)))

    train_accuracy = correct / len(train_data)

    # Validate
    model.eval()
    num_correct = 0
    for batch_num1, (x, y) in enumerate(valid_loader):
        x, y = x.to(device), y.to(device)
        outputs = model(x)

        num_correct += (torch.argmax(outputs, axis=1) == y).sum().item()

    val_accuracy = num_correct / len(valid_data)
    my_acc.append(val_accuracy)
    my_loss.append(train_loss/(batch_num+1))
    print('Epoch: {}\t Training Accuracy: {:.4f}\t Validation Accuracy: {:.4f}\t Avg-Loss: {:.4f}'.format(epoch, train_accuracy*100, val_accuracy * 100, train_loss/(batch_num+1)))
    scheduler.step(val_accuracy)

    #torch.save(network.state_dict(),'/content/drive/MyDrive/DL_CMU/HW2_P2/ResNet_Plateau_d3/Net_'+str(epoch)+'_'+str(val_accuracy)+'_checkpoint.t7')

  cpuset_checked))


Epoch: 0	Batch: 1	Avg-Loss: 2.4225
Epoch: 0	Batch: 101	Avg-Loss: 3.2220
Epoch: 0	Batch: 201	Avg-Loss: 2.6791
Epoch: 0	Batch: 301	Avg-Loss: 2.4711
Epoch: 0	 Training Accuracy: 21.6080	 Validation Accuracy: 30.4100	 Avg-Loss: 2.3525
Epoch: 1	Batch: 1	Avg-Loss: 1.8711
Epoch: 1	Batch: 101	Avg-Loss: 1.8781
Epoch: 1	Batch: 201	Avg-Loss: 1.8455
Epoch: 1	Batch: 301	Avg-Loss: 1.8156
Epoch: 1	 Training Accuracy: 34.1600	 Validation Accuracy: 41.2900	 Avg-Loss: 1.7890
Epoch: 2	Batch: 1	Avg-Loss: 1.7566
Epoch: 2	Batch: 101	Avg-Loss: 1.6584
Epoch: 2	Batch: 201	Avg-Loss: 1.6391
Epoch: 2	Batch: 301	Avg-Loss: 1.6276
Epoch: 2	 Training Accuracy: 40.8920	 Validation Accuracy: 46.0900	 Avg-Loss: 1.6157
Epoch: 3	Batch: 1	Avg-Loss: 1.6257
Epoch: 3	Batch: 101	Avg-Loss: 1.5394
Epoch: 3	Batch: 201	Avg-Loss: 1.5356
Epoch: 3	Batch: 301	Avg-Loss: 1.5287
Epoch: 3	 Training Accuracy: 44.5060	 Validation Accuracy: 48.1000	 Avg-Loss: 1.5235
Epoch: 4	Batch: 1	Avg-Loss: 1.6012
Epoch: 4	Batch: 101	Avg-Loss: 1.4790
Epoc