In [None]:
pip install vit-pytorch linformer

Collecting vit-pytorch
  Downloading vit_pytorch-0.27.0-py3-none-any.whl (52 kB)
[?25l[K     |██████▎                         | 10 kB 14.4 MB/s eta 0:00:01[K     |████████████▋                   | 20 kB 11.1 MB/s eta 0:00:01[K     |██████████████████▉             | 30 kB 8.6 MB/s eta 0:00:01[K     |█████████████████████████▏      | 40 kB 7.8 MB/s eta 0:00:01[K     |███████████████████████████████▍| 51 kB 5.2 MB/s eta 0:00:01[K     |████████████████████████████████| 52 kB 690 kB/s 
[?25hCollecting linformer
  Downloading linformer-0.2.1-py3-none-any.whl (6.1 kB)
Collecting einops>=0.3
  Downloading einops-0.4.1-py3-none-any.whl (28 kB)
Installing collected packages: einops, vit-pytorch, linformer
Successfully installed einops-0.4.1 linformer-0.2.1 vit-pytorch-0.27.0


In [None]:
pip install wandb

Collecting wandb
  Downloading wandb-0.12.11-py2.py3-none-any.whl (1.7 MB)
[K     |████████████████████████████████| 1.7 MB 5.3 MB/s 
[?25hCollecting GitPython>=1.0.0
  Downloading GitPython-3.1.27-py3-none-any.whl (181 kB)
[K     |████████████████████████████████| 181 kB 44.4 MB/s 
Collecting yaspin>=1.0.0
  Downloading yaspin-2.1.0-py3-none-any.whl (18 kB)
Collecting sentry-sdk>=1.0.0
  Downloading sentry_sdk-1.5.7-py2.py3-none-any.whl (144 kB)
[K     |████████████████████████████████| 144 kB 48.7 MB/s 
Collecting docker-pycreds>=0.4.0
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Collecting setproctitle
  Downloading setproctitle-1.2.2-cp37-cp37m-manylinux1_x86_64.whl (36 kB)
Collecting shortuuid>=0.5.0
  Downloading shortuuid-1.0.8-py3-none-any.whl (9.5 kB)
Collecting pathtools
  Downloading pathtools-0.1.2.tar.gz (11 kB)
Collecting gitdb<5,>=4.0.1
  Downloading gitdb-4.0.9-py3-none-any.whl (63 kB)
[K     |████████████████████████████████| 63 kB 1.6 MB/s 
[

In [None]:
from __future__ import print_function

import os
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# import Linformer

from linformer import Linformer
import glob
from PIL import Image
from itertools import chain
from vit_pytorch.efficient import ViT
from tqdm.notebook import tqdm
# import torch and related libraries

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset, SubsetRandomSampler
import wandb

from torchvision.datasets import CIFAR10



torch.backends.cudnn.benchmarks = True

torch.backends.cudnn.deterministic = True

wandb.login(key='46f2d6a5ffcc458fed2cca6cf446900f97e396e1')

########################################################################################
################################ HYPERPARAMETERS #######################################
########################################################################################

config=dict(
IMAGE_SIZE = 224,
BATCH_SIZE = 128,
EPOCHS = 100,
VAL_SPLIT = 0.2,
LEARNING_RATE = 3e-5,
GAMMA = 0.7,
STEP_SIZE = 1,
saved_path = './models/'
)



wandb.init(project='Vision-Transformer-CIFAR10',config=config)
config = wandb.config



data_transform = transforms.Compose([transforms.Resize((config.IMAGE_SIZE,config.IMAGE_SIZE)),transforms.RandomResizedCrop(config.IMAGE_SIZE),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()])
dataset = CIFAR10(root="./data",download=True,transform=data_transform)


def train_log(epoch,train_loss,train_acc,test_acc,test_loss):
    wandb.log({"epoch": epoch, "train_loss": train_loss, "train_acc":train_acc,"test_loss":test_loss,"test_acc":test_acc})


def Split_index(n,val_percent):
    n_val = int(val_percent*n)
    index = np.random.permutation(n)
    return index[n_val:],index[:n_val]



train_indices, val_indices = Split_index(len(dataset),val_percent=config.VAL_SPLIT)
print(len(train_indices),len(val_indices))
print(train_indices,val_indices)


train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)

train_dl = DataLoader(dataset,batch_size =config.BATCH_SIZE,sampler=train_sampler,num_workers=4,pin_memory=True)
val_dl = DataLoader(dataset,batch_size =config.BATCH_SIZE,sampler=val_sampler,num_workers=4,pin_memory=True)


efficient_transformer = Linformer(
    dim=128,
    seq_len=49+1,  # 7x7 patches + 1 cls-token
    depth=12,
    heads=8,
    k=64
)

model = ViT(
    dim=128,
    image_size=config.IMAGE_SIZE,
    patch_size=32,
    num_classes=10,
    transformer=efficient_transformer,
    channels=3,
).cuda()


criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=config.LEARNING_RATE)
scheduler = StepLR(optimizer, step_size=config.STEP_SIZE, gamma=config.GAMMA)


History = []
for epoch in range(config.EPOCHS):
    epoch_loss = 0
    epoch_accuracy = 0
    for data, label in tqdm(train_dl):
        data = data.cuda()
        label = label.cuda()
        output = model(data)
        #print(label.size(),output.size())
        #print(output)
        loss = criterion(output, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_dl)
        epoch_loss += loss / len(train_dl)
    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in val_dl:
            data = data.cuda()
            label = label.cuda()
            val_output = model(data)
            val_loss = criterion(val_output, label)
            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(val_dl)
            epoch_val_loss += val_loss / len(val_dl)
    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )
    train_log(epoch,epoch_loss,epoch_accuracy,epoch_val_accuracy,epoch_val_loss)
    history = {
        'train_loss':epoch_loss,
        'train_acc':epoch_accuracy,
        'val_loss':epoch_val_loss,
        'val_acc':epoch_val_accuracy
    }
    History.append(history)


train_acc = []
train_loss = []
val_acc = []
val_loss = []
for i in range(config.EPOCHS):
  temp1 = History[i]['train_acc'].to('cpu').detach().numpy()
  temp2 = History[i]['train_loss'].to('cpu').detach().numpy()
  temp3 = History[i]['val_acc'].to('cpu').detach().numpy()
  temp4 = History[i]['val_loss'].to('cpu').detach().numpy()

  train_acc.append(temp1)
  train_loss.append(temp2)
  val_acc.append(temp3)
  val_loss.append(temp4)


plt.plot(range(config.EPOCHS),train_acc,label='Train Acc')
plt.plot(range(config.EPOCHS),val_acc,label='Val Acc')
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.title("Training and Validation Accuracy")
plt.savefig("Training.png")
plt.show()


plt.plot(range(config.EPOCHS),train_loss,label='Train Loss')
plt.plot(range(config.EPOCHS),val_loss,label='Val Loss')
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title("Training and Validation Loss")
plt.savefig("Loss.png")
plt.show()


[34m[1mwandb[0m: W&B API key is configured (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mshashi7679[0m (use `wandb login --relogin` to force relogin)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data
40000 10000
[12967 37593 34521 ...  2799 19744 45141] [34934 30846 28387 ... 46323 18849 43781]


  cpuset_checked))


  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 1 - loss : 2.1714 - acc: 0.1798 - val_loss : 2.0442 - val_acc: 0.2312



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 2 - loss : 2.0024 - acc: 0.2578 - val_loss : 1.9563 - val_acc: 0.2781



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 3 - loss : 1.9100 - acc: 0.3003 - val_loss : 1.8784 - val_acc: 0.3140



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 4 - loss : 1.8625 - acc: 0.3213 - val_loss : 1.8448 - val_acc: 0.3237



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 5 - loss : 1.8273 - acc: 0.3358 - val_loss : 1.8092 - val_acc: 0.3482



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 6 - loss : 1.7937 - acc: 0.3477 - val_loss : 1.7751 - val_acc: 0.3603



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 7 - loss : 1.7583 - acc: 0.3607 - val_loss : 1.7699 - val_acc: 0.3578



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 8 - loss : 1.7279 - acc: 0.3730 - val_loss : 1.7153 - val_acc: 0.3809



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 9 - loss : 1.7056 - acc: 0.3802 - val_loss : 1.7065 - val_acc: 0.3790



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 10 - loss : 1.6863 - acc: 0.3864 - val_loss : 1.6640 - val_acc: 0.3953



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 11 - loss : 1.6625 - acc: 0.3935 - val_loss : 1.6614 - val_acc: 0.3958



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 12 - loss : 1.6427 - acc: 0.4028 - val_loss : 1.6369 - val_acc: 0.4117



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 13 - loss : 1.6341 - acc: 0.4075 - val_loss : 1.6230 - val_acc: 0.4121



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 14 - loss : 1.6189 - acc: 0.4143 - val_loss : 1.6352 - val_acc: 0.4069



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 15 - loss : 1.6086 - acc: 0.4187 - val_loss : 1.6173 - val_acc: 0.4239



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 16 - loss : 1.5987 - acc: 0.4233 - val_loss : 1.5877 - val_acc: 0.4196



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 17 - loss : 1.5917 - acc: 0.4244 - val_loss : 1.5855 - val_acc: 0.4317



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 18 - loss : 1.5794 - acc: 0.4324 - val_loss : 1.6052 - val_acc: 0.4257



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 19 - loss : 1.5706 - acc: 0.4349 - val_loss : 1.5801 - val_acc: 0.4334



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 20 - loss : 1.5566 - acc: 0.4414 - val_loss : 1.5747 - val_acc: 0.4329



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 21 - loss : 1.5526 - acc: 0.4395 - val_loss : 1.5572 - val_acc: 0.4462



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 22 - loss : 1.5382 - acc: 0.4443 - val_loss : 1.5421 - val_acc: 0.4444



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 23 - loss : 1.5293 - acc: 0.4498 - val_loss : 1.5517 - val_acc: 0.4435



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 24 - loss : 1.5235 - acc: 0.4500 - val_loss : 1.5408 - val_acc: 0.4500



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 25 - loss : 1.5232 - acc: 0.4554 - val_loss : 1.5164 - val_acc: 0.4588



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 26 - loss : 1.5096 - acc: 0.4584 - val_loss : 1.5146 - val_acc: 0.4570



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 27 - loss : 1.5032 - acc: 0.4613 - val_loss : 1.5162 - val_acc: 0.4553



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 28 - loss : 1.4927 - acc: 0.4659 - val_loss : 1.5022 - val_acc: 0.4661



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 29 - loss : 1.4851 - acc: 0.4690 - val_loss : 1.4968 - val_acc: 0.4603



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 30 - loss : 1.4748 - acc: 0.4702 - val_loss : 1.4911 - val_acc: 0.4632



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 31 - loss : 1.4687 - acc: 0.4763 - val_loss : 1.4793 - val_acc: 0.4672



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 32 - loss : 1.4585 - acc: 0.4771 - val_loss : 1.4849 - val_acc: 0.4644



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 33 - loss : 1.4536 - acc: 0.4788 - val_loss : 1.4468 - val_acc: 0.4787



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 34 - loss : 1.4494 - acc: 0.4836 - val_loss : 1.4587 - val_acc: 0.4747



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 35 - loss : 1.4400 - acc: 0.4862 - val_loss : 1.4565 - val_acc: 0.4780



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 36 - loss : 1.4274 - acc: 0.4935 - val_loss : 1.4587 - val_acc: 0.4733



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 37 - loss : 1.4277 - acc: 0.4935 - val_loss : 1.4480 - val_acc: 0.4839



  0%|          | 0/313 [00:00<?, ?it/s]

Epoch : 38 - loss : 1.4214 - acc: 0.4924 - val_loss : 1.4611 - val_acc: 0.4778



  0%|          | 0/313 [00:00<?, ?it/s]