In [1]:
import torch
import tqdm
import einops
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, datasets

In [2]:
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x):
        return self.fn(x) + x

In [3]:
class LayerNormalize(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        return self.fn(self.norm(x))

In [4]:
class MLP(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.gelu = nn.GELU()
        self.dropout = nn.Dropout(0.1)

        self.fc1 = nn.Linear(dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, dim)
        self.init_parameters()
        
    def init_parameters(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.normal_(self.fc1.bias, std=1e-6)
        
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.normal_(self.fc2.bias, std=1e-6)


    def forward(self, x):
        x = self.gelu(self.fc1(x))
        x = self.dropout(x)

        x = self.fc2(x)
        x = self.dropout(x)

        return x

In [5]:
class MultiHeadAttention(nn.Module):
    def __init__(self, dim, heads):
        super().__init__()
        self.heads = heads
        self.scale = dim ** -0.5
        self.dropout = nn.Dropout(0.1)

        self.fc1 = nn.Linear(dim, dim*3)
        self.fc2 = nn.Linear(dim, dim)
        self.init_parameters()

    def init_parameters(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.zeros_(self.fc1.bias)
        
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.zeros_(self.fc2.bias)

    def forward(self, x):
        qkv = self.fc1(x)
        q, k, v = einops.rearrange(qkv, "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads)
        energy = torch.einsum("bhid,bhjd->bhij", q, k) * self.scale
        energy = energy.softmax(dim=-1)

        attn = torch.einsum("bhij,bhjd->bhid", energy, v)
        attn = einops.rearrange(attn, "b h n d -> b n (h d)")
        out = self.dropout(self.fc2(attn))

        return out

In [6]:
class Transformer(nn.Module):
    def __init__(self, depth, dim, hidden_dim, heads):
        super().__init__()
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            self.layers.append(nn.ModuleList([Residual(LayerNormalize(dim, MultiHeadAttention(dim, heads))), 
                                              Residual(LayerNormalize(dim, MLP(dim, hidden_dim)))]))
            
    def forward(self, x):
        for attention, mlp in self.layers:
            x = attention(x)
            x = mlp(x)

        return x

In [7]:
class VisionTransformer(nn.Module):
    def __init__(self, in_channels, image_size, patch_size, num_classes, dim, depth, heads, hidden_dim):
        super().__init__()
        assert image_size % patch_size == 0, "Image Size must be divisible by Patch Size!!"

        num_patches = (image_size // patch_size) ** 2
        self.dropout = nn.Dropout(0.1)
        self.identity = nn.Identity()
        # self.gelu = nn.GELU()

        self.patch_embedding = nn.Conv2d(in_channels, dim, patch_size, stride=patch_size)
        self.cls_tokens = nn.Parameter(torch.zeros(1, 1, dim))
        self.pos_embedding = nn.Parameter(torch.empty(1, (num_patches + 1), dim))
        
        self.transformer = Transformer(depth, dim, hidden_dim, heads)
        self.fc = nn.Linear(dim, num_classes) # self.fc = nn.Linear(dim, hidden_dim) -> self.fc_out

        # self.fc_out = nn.Linear(hidden_dim, num_classes) -> use only for large datasets
        self.init_parameters()

    def init_parameters(self):
        nn.init.normal_(self.pos_embedding, std=0.02)

        nn.init.xavier_uniform_(self.fc.weight)
        nn.init.normal_(self.fc.bias, std=1e-6)

        # nn.init.xavier_uniform_(self.fc_out.weight)
        # nn.init.normal_(self.fc_out.bias)

    def forward(self, x):
        x = self.patch_embedding(x)
        x = einops.rearrange(x, "b c h w -> b (h w) c")

        cls_tokens = self.cls_tokens.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding
        x = self.dropout(x)

        x = self.transformer(x)
        x = self.identity(x[:, 0])

        x = self.fc(x)
        # x = self.dropout(self.gelu(x))
        # x = self.dropout(self.fc_out(x))

        return x

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = 30
batch_size = 100
in_channels = 3
image_size = 32
patch_size = 4
num_classes = 10
dim = 64
depth = 6
heads = 8
hidden_dim = 128
lr = 3e-3

In [9]:
device

device(type='cuda')

In [10]:
transform = transforms.Compose([transforms.ToTensor(), 
                                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

In [11]:
train_data = datasets.CIFAR10("data/", train=True, download=True, transform=transform)
test_data = datasets.CIFAR10("data/", train=False, download=True, transform=transform)

Files already downloaded and verified
Files already downloaded and verified


In [12]:
train_batches = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_batches = DataLoader(test_data, batch_size=batch_size, shuffle=False)

In [13]:
len(train_batches), len(test_batches)

(500, 100)

In [14]:
net = VisionTransformer(in_channels, image_size, patch_size, num_classes, dim, depth, heads, hidden_dim).to(device)
net

VisionTransformer(
  (dropout): Dropout(p=0.1, inplace=False)
  (identity): Identity()
  (patch_embedding): Conv2d(3, 64, kernel_size=(4, 4), stride=(4, 4))
  (transformer): Transformer(
    (layers): ModuleList(
      (0): ModuleList(
        (0): Residual(
          (fn): LayerNormalize(
            (fn): MultiHeadAttention(
              (dropout): Dropout(p=0.1, inplace=False)
              (fc1): Linear(in_features=64, out_features=192, bias=True)
              (fc2): Linear(in_features=64, out_features=64, bias=True)
            )
            (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          )
        )
        (1): Residual(
          (fn): LayerNormalize(
            (fn): MLP(
              (gelu): GELU()
              (dropout): Dropout(p=0.1, inplace=False)
              (fc1): Linear(in_features=64, out_features=128, bias=True)
              (fc2): Linear(in_features=128, out_features=64, bias=True)
            )
            (norm): LayerNorm((64,), eps

In [15]:
opt = torch.optim.Adam(net.parameters(), lr)
loss_fn = nn.CrossEntropyLoss(reduction="sum")

In [16]:
input = torch.randn(100, 3, 32, 32).to(device)
output = net(input)
output.shape

torch.Size([100, 10])

In [17]:
def get_accuracy(preds, y):
    preds = preds.argmax(dim=1, keepdim=True)
    correct = preds.squeeze(1).eq(y)
    acc = correct.sum() / torch.FloatTensor([y.shape[0]]).to(device)

    return acc.item()

In [18]:
def loop(net, batches, train):
    batch_losses = []
    batch_accs = []

    if train:
        print("Train Loop:")
        print("")
        net.train()

        for X, y in tqdm.tqdm(batches, total=len(batches)):
            X = X.to(device)
            y = y.to(device)

            preds = net(X)
            loss = loss_fn(preds, y)
            acc = get_accuracy(preds, y)

            opt.zero_grad()
            loss.backward()
            opt.step()

            batch_losses.append(loss.item())
            batch_accs.append(acc)

    else:
        print("Validation Loop:")
        print("")
        net.eval()

        with torch.no_grad():
            for X, y in tqdm.tqdm(batches, total=len(batches)):
                X = X.to(device)
                y = y.to(device)

                preds = net(X)
                loss = loss_fn(preds, y)
                acc = get_accuracy(preds, y)

                batch_losses.append(loss.item())
                batch_accs.append(acc) 

    print("")
    print("")
    
    return sum(batch_losses) / len(batch_losses), sum(batch_accs) / len(batch_accs)

In [19]:
for epoch in range(epochs):
    train_loss, train_acc = loop(net, train_batches, True)
    val_loss, val_acc = loop(net, test_batches, False)
    
    print(f"epoch: {epoch} | train_loss: {train_loss:.4f} | train_acc: {train_acc:.4f} | val_loss: {val_loss:.4f} | val_acc: {val_acc:.4f}")
    print("")

  0%|          | 2/500 [00:00<00:39, 12.58it/s]

Train Loop:



100%|██████████| 500/500 [00:26<00:00, 19.14it/s]
  4%|▍         | 4/100 [00:00<00:02, 32.70it/s]



Validation Loop:



100%|██████████| 100/100 [00:03<00:00, 31.89it/s]
  0%|          | 2/500 [00:00<00:29, 16.67it/s]



epoch: 0 | train_loss: 189.6934 | train_acc: 0.3248 | val_loss: 147.6869 | val_acc: 0.4634

Train Loop:



100%|██████████| 500/500 [00:25<00:00, 19.45it/s]
  4%|▍         | 4/100 [00:00<00:02, 33.24it/s]



Validation Loop:



100%|██████████| 100/100 [00:03<00:00, 31.86it/s]
  0%|          | 2/500 [00:00<00:30, 16.26it/s]



epoch: 1 | train_loss: 143.9153 | train_acc: 0.4734 | val_loss: 129.7925 | val_acc: 0.5271

Train Loop:



100%|██████████| 500/500 [00:25<00:00, 19.36it/s]
  4%|▍         | 4/100 [00:00<00:02, 33.27it/s]



Validation Loop:



100%|██████████| 100/100 [00:03<00:00, 31.81it/s]
  0%|          | 2/500 [00:00<00:31, 15.83it/s]



epoch: 2 | train_loss: 128.6261 | train_acc: 0.5327 | val_loss: 125.6824 | val_acc: 0.5484

Train Loop:



100%|██████████| 500/500 [00:25<00:00, 19.24it/s]
  4%|▍         | 4/100 [00:00<00:02, 32.92it/s]



Validation Loop:



100%|██████████| 100/100 [00:03<00:00, 31.42it/s]
  0%|          | 2/500 [00:00<00:30, 16.21it/s]



epoch: 3 | train_loss: 119.1738 | train_acc: 0.5688 | val_loss: 112.9703 | val_acc: 0.5911

Train Loop:



100%|██████████| 500/500 [00:25<00:00, 19.27it/s]
  3%|▎         | 3/100 [00:00<00:03, 28.42it/s]



Validation Loop:



100%|██████████| 100/100 [00:03<00:00, 31.30it/s]
  0%|          | 2/500 [00:00<00:30, 16.22it/s]



epoch: 4 | train_loss: 112.2153 | train_acc: 0.5933 | val_loss: 111.0462 | val_acc: 0.5992

Train Loop:



100%|██████████| 500/500 [00:25<00:00, 19.24it/s]
  4%|▍         | 4/100 [00:00<00:02, 32.80it/s]



Validation Loop:



100%|██████████| 100/100 [00:03<00:00, 31.32it/s]
  0%|          | 2/500 [00:00<00:30, 16.56it/s]



epoch: 5 | train_loss: 106.2068 | train_acc: 0.6172 | val_loss: 106.8088 | val_acc: 0.6180

Train Loop:



100%|██████████| 500/500 [00:25<00:00, 19.37it/s]
  4%|▍         | 4/100 [00:00<00:02, 33.79it/s]



Validation Loop:



100%|██████████| 100/100 [00:03<00:00, 32.45it/s]
  0%|          | 2/500 [00:00<00:28, 17.29it/s]



epoch: 6 | train_loss: 102.1125 | train_acc: 0.6317 | val_loss: 100.4614 | val_acc: 0.6393

Train Loop:



100%|██████████| 500/500 [00:25<00:00, 19.69it/s]
  4%|▍         | 4/100 [00:00<00:02, 33.82it/s]



Validation Loop:



100%|██████████| 100/100 [00:03<00:00, 32.30it/s]
  0%|          | 2/500 [00:00<00:28, 17.25it/s]



epoch: 7 | train_loss: 97.1243 | train_acc: 0.6504 | val_loss: 100.3125 | val_acc: 0.6465

Train Loop:



100%|██████████| 500/500 [00:25<00:00, 19.64it/s]
  4%|▍         | 4/100 [00:00<00:02, 33.07it/s]



Validation Loop:



100%|██████████| 100/100 [00:03<00:00, 32.81it/s]
  0%|          | 2/500 [00:00<00:28, 17.37it/s]



epoch: 8 | train_loss: 93.9959 | train_acc: 0.6622 | val_loss: 95.3670 | val_acc: 0.6674

Train Loop:



100%|██████████| 500/500 [00:25<00:00, 19.70it/s]
  4%|▍         | 4/100 [00:00<00:02, 33.97it/s]



Validation Loop:



100%|██████████| 100/100 [00:03<00:00, 32.56it/s]
  0%|          | 2/500 [00:00<00:29, 16.92it/s]



epoch: 9 | train_loss: 90.4290 | train_acc: 0.6746 | val_loss: 96.2096 | val_acc: 0.6610

Train Loop:



100%|██████████| 500/500 [00:25<00:00, 19.84it/s]
  4%|▍         | 4/100 [00:00<00:02, 33.33it/s]



Validation Loop:



100%|██████████| 100/100 [00:03<00:00, 32.50it/s]
  0%|          | 2/500 [00:00<00:28, 17.68it/s]



epoch: 10 | train_loss: 87.6519 | train_acc: 0.6888 | val_loss: 90.0948 | val_acc: 0.6820

Train Loop:



100%|██████████| 500/500 [00:25<00:00, 19.46it/s]
  4%|▍         | 4/100 [00:00<00:02, 33.59it/s]



Validation Loop:



100%|██████████| 100/100 [00:03<00:00, 32.43it/s]
  0%|          | 2/500 [00:00<00:28, 17.27it/s]



epoch: 11 | train_loss: 85.0043 | train_acc: 0.6953 | val_loss: 90.9816 | val_acc: 0.6842

Train Loop:



100%|██████████| 500/500 [00:26<00:00, 18.88it/s]
  4%|▍         | 4/100 [00:00<00:03, 31.44it/s]



Validation Loop:



100%|██████████| 100/100 [00:03<00:00, 30.62it/s]
  0%|          | 2/500 [00:00<00:30, 16.36it/s]



epoch: 12 | train_loss: 82.1688 | train_acc: 0.7055 | val_loss: 88.2300 | val_acc: 0.6946

Train Loop:



100%|██████████| 500/500 [00:26<00:00, 18.98it/s]
  4%|▍         | 4/100 [00:00<00:02, 33.66it/s]



Validation Loop:



100%|██████████| 100/100 [00:03<00:00, 31.71it/s]
  0%|          | 2/500 [00:00<00:29, 17.12it/s]



epoch: 13 | train_loss: 79.6158 | train_acc: 0.7146 | val_loss: 87.0094 | val_acc: 0.6885

Train Loop:



100%|██████████| 500/500 [00:25<00:00, 19.26it/s]
  4%|▍         | 4/100 [00:00<00:02, 33.10it/s]



Validation Loop:



100%|██████████| 100/100 [00:03<00:00, 31.71it/s]
  0%|          | 2/500 [00:00<00:29, 17.06it/s]



epoch: 14 | train_loss: 77.6083 | train_acc: 0.7206 | val_loss: 84.0653 | val_acc: 0.7055

Train Loop:



100%|██████████| 500/500 [00:25<00:00, 19.34it/s]
  4%|▍         | 4/100 [00:00<00:02, 32.45it/s]



Validation Loop:



100%|██████████| 100/100 [00:03<00:00, 31.86it/s]
  0%|          | 2/500 [00:00<00:30, 16.12it/s]



epoch: 15 | train_loss: 75.4540 | train_acc: 0.7297 | val_loss: 87.3324 | val_acc: 0.6992

Train Loop:



100%|██████████| 500/500 [00:26<00:00, 19.20it/s]
  4%|▍         | 4/100 [00:00<00:03, 31.52it/s]



Validation Loop:



100%|██████████| 100/100 [00:03<00:00, 31.42it/s]
  0%|          | 2/500 [00:00<00:30, 16.41it/s]



epoch: 16 | train_loss: 73.0346 | train_acc: 0.7382 | val_loss: 90.8878 | val_acc: 0.6911

Train Loop:



100%|██████████| 500/500 [00:25<00:00, 19.33it/s]
  4%|▍         | 4/100 [00:00<00:03, 30.48it/s]



Validation Loop:



100%|██████████| 100/100 [00:03<00:00, 31.62it/s]
  0%|          | 2/500 [00:00<00:29, 16.75it/s]



epoch: 17 | train_loss: 70.9748 | train_acc: 0.7445 | val_loss: 87.3515 | val_acc: 0.7103

Train Loop:



100%|██████████| 500/500 [00:25<00:00, 19.23it/s]
  4%|▍         | 4/100 [00:00<00:03, 31.79it/s]



Validation Loop:



100%|██████████| 100/100 [00:03<00:00, 31.64it/s]
  0%|          | 2/500 [00:00<00:29, 16.73it/s]



epoch: 18 | train_loss: 70.2105 | train_acc: 0.7487 | val_loss: 86.3680 | val_acc: 0.7074

Train Loop:



100%|██████████| 500/500 [00:25<00:00, 19.31it/s]
  4%|▍         | 4/100 [00:00<00:03, 30.25it/s]



Validation Loop:



100%|██████████| 100/100 [00:03<00:00, 31.36it/s]
  0%|          | 2/500 [00:00<00:31, 15.59it/s]



epoch: 19 | train_loss: 67.9407 | train_acc: 0.7567 | val_loss: 85.3552 | val_acc: 0.7142

Train Loop:



100%|██████████| 500/500 [00:25<00:00, 19.28it/s]
  4%|▍         | 4/100 [00:00<00:02, 33.33it/s]



Validation Loop:



100%|██████████| 100/100 [00:03<00:00, 31.90it/s]
  0%|          | 2/500 [00:00<00:32, 15.47it/s]



epoch: 20 | train_loss: 65.9545 | train_acc: 0.7638 | val_loss: 85.4868 | val_acc: 0.7175

Train Loop:



100%|██████████| 500/500 [00:25<00:00, 19.25it/s]
  4%|▍         | 4/100 [00:00<00:02, 33.58it/s]



Validation Loop:



100%|██████████| 100/100 [00:03<00:00, 31.91it/s]
  0%|          | 2/500 [00:00<00:30, 16.50it/s]



epoch: 21 | train_loss: 65.0255 | train_acc: 0.7668 | val_loss: 87.2401 | val_acc: 0.7147

Train Loop:



100%|██████████| 500/500 [00:25<00:00, 19.31it/s]
  4%|▍         | 4/100 [00:00<00:03, 31.87it/s]



Validation Loop:



100%|██████████| 100/100 [00:03<00:00, 32.32it/s]
  0%|          | 2/500 [00:00<00:29, 16.80it/s]



epoch: 22 | train_loss: 63.5421 | train_acc: 0.7738 | val_loss: 84.9009 | val_acc: 0.7204

Train Loop:



100%|██████████| 500/500 [00:25<00:00, 19.39it/s]
  4%|▍         | 4/100 [00:00<00:02, 32.71it/s]



Validation Loop:



100%|██████████| 100/100 [00:03<00:00, 31.89it/s]
  0%|          | 2/500 [00:00<00:29, 16.64it/s]



epoch: 23 | train_loss: 62.1987 | train_acc: 0.7780 | val_loss: 84.1880 | val_acc: 0.7245

Train Loop:



100%|██████████| 500/500 [00:25<00:00, 19.35it/s]
  4%|▍         | 4/100 [00:00<00:03, 31.13it/s]



Validation Loop:



100%|██████████| 100/100 [00:03<00:00, 32.16it/s]
  0%|          | 2/500 [00:00<00:30, 16.57it/s]



epoch: 24 | train_loss: 61.2156 | train_acc: 0.7792 | val_loss: 81.7569 | val_acc: 0.7266

Train Loop:



100%|██████████| 500/500 [00:25<00:00, 19.41it/s]
  4%|▍         | 4/100 [00:00<00:03, 29.81it/s]



Validation Loop:



100%|██████████| 100/100 [00:03<00:00, 31.38it/s]
  0%|          | 2/500 [00:00<00:29, 17.10it/s]



epoch: 25 | train_loss: 59.7910 | train_acc: 0.7862 | val_loss: 87.5563 | val_acc: 0.7151

Train Loop:



100%|██████████| 500/500 [00:25<00:00, 19.38it/s]
  4%|▍         | 4/100 [00:00<00:02, 32.74it/s]



Validation Loop:



100%|██████████| 100/100 [00:03<00:00, 32.26it/s]
  0%|          | 2/500 [00:00<00:29, 17.16it/s]



epoch: 26 | train_loss: 58.4900 | train_acc: 0.7905 | val_loss: 85.4096 | val_acc: 0.7299

Train Loop:



100%|██████████| 500/500 [00:25<00:00, 19.36it/s]
  4%|▍         | 4/100 [00:00<00:02, 32.73it/s]



Validation Loop:



100%|██████████| 100/100 [00:03<00:00, 32.06it/s]
  0%|          | 2/500 [00:00<00:30, 16.56it/s]



epoch: 27 | train_loss: 58.0180 | train_acc: 0.7919 | val_loss: 95.1431 | val_acc: 0.7009

Train Loop:



100%|██████████| 500/500 [00:25<00:00, 19.54it/s]
  4%|▍         | 4/100 [00:00<00:02, 32.71it/s]



Validation Loop:



100%|██████████| 100/100 [00:03<00:00, 32.27it/s]
  0%|          | 2/500 [00:00<00:29, 17.08it/s]



epoch: 28 | train_loss: 56.6283 | train_acc: 0.7977 | val_loss: 89.6068 | val_acc: 0.7222

Train Loop:



100%|██████████| 500/500 [00:25<00:00, 19.66it/s]
  4%|▍         | 4/100 [00:00<00:02, 34.42it/s]



Validation Loop:



100%|██████████| 100/100 [00:03<00:00, 32.18it/s]



epoch: 29 | train_loss: 55.4899 | train_acc: 0.7998 | val_loss: 89.6653 | val_acc: 0.7192




