In [None]:
#!pip install torch==1.11.0

In [1]:
import numpy as np
import pandas as pd
import os
import PIL
from PIL import Image
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from sklearn import model_selection, metrics
from torchvision.datasets.mnist import MNIST
from torchvision.transforms import ToTensor

In [2]:
np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7f3f39434db0>

## Data Download

In [3]:
transform = ToTensor()
train_set = MNIST(root='./../datasets', train=True, download=True, transform=transform)
test_set = MNIST(root='./../datasets', train=False, download=True, transform=transform)

train_loader = DataLoader(train_set, shuffle=True, batch_size=16)
test_loader = DataLoader(test_set, shuffle=False, batch_size=16)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./../datasets/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./../datasets/MNIST/raw/train-images-idx3-ubyte.gz to ./../datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./../datasets/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./../datasets/MNIST/raw/train-labels-idx1-ubyte.gz to ./../datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./../datasets/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./../datasets/MNIST/raw/t10k-images-idx3-ubyte.gz to ./../datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./../datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./../datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./../datasets/MNIST/raw



In [4]:
class ViT(nn.Module):
    def __init__(self, input_shape, n_patches=14, hidden=8, num_heads=2, n_classes=10):
        # Super constructor
        super(ViT, self).__init__()

        # Input and patches sizes
        self.input_shape = input_shape
        self.n_patches = n_patches
        self.num_heads = num_heads
        self.patch_size = (input_shape[1] / n_patches, input_shape[2] / n_patches)
        self.hidden = hidden
        #single patch input
        self.input = int(input_shape[0] * self.patch_size[0] * self.patch_size[1])
        self.linear = nn.Linear(self.input, self.hidden)

        # Classification token
        self.class_token = nn.Parameter(torch.rand(1, self.hidden))
        
        #Layer norms
        self.ln1 = nn.LayerNorm((self.n_patches ** 2 + 1, self.hidden))
        self.ln2 = nn.LayerNorm((self.n_patches ** 2 + 1, self.hidden))
        #Attention layer
        self.attn = Attention(self.hidden, num_heads)
        #MLP Encoder
        self.mlpenc = nn.Sequential(
            nn.Linear(self.hidden, self.hidden),
            nn.ReLU()
        )
        self.mlp = nn.Sequential(
            nn.Linear(self.hidden, n_classes),
            nn.Softmax(dim=-1)
        )

    def forward(self, images):
        # Dividing images into patches
        B, C, W, H = images.shape
        patches = images.reshape(B, self.n_patches ** 2, self.input)
        tokens = self.linear(patches)

        # Adding classification token to the tokens
        tokens = torch.stack([torch.vstack((self.class_token, tokens[i])) for i in range(len(tokens))])

        # Adding positional embedding
        tokens += get_positional_embeddings(self.n_patches ** 2 + 1, self.hidden).repeat(B, 1, 1)
        
        output = tokens + self.attn(self.ln1(tokens))
        output = output + self.mlpenc(self.ln2(output))

        # Getting the classification token only
        output = output[:, 0]

        return self.mlp(output)

In [5]:
class Attention(nn.Module):
    def __init__(self, d, num_heads=2):
        super(Attention, self).__init__()
        self.d = d
        self.num_heads = num_heads

        d_head = int(d / num_heads)
        self.q_mappings = [nn.Linear(d_head, d_head) for _ in range(self.num_heads)]
        self.k_mappings = [nn.Linear(d_head, d_head) for _ in range(self.num_heads)]
        self.v_mappings = [nn.Linear(d_head, d_head) for _ in range(self.num_heads)]
        self.d_head = d_head
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, sequences):
        result = []
        for sequence in sequences:
            seq_result = []
            for head in range(self.num_heads):
                q_mapping = self.q_mappings[head]
                k_mapping = self.k_mappings[head]
                v_mapping = self.v_mappings[head]

                seq = sequence[:, head * self.d_head: (head + 1) * self.d_head]
                q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)

                attention = self.softmax(q @ k.T / (self.d_head ** 0.5))
                seq_result.append(attention @ v)
            result.append(torch.hstack(seq_result))
        return torch.cat([torch.unsqueeze(r, dim=0) for r in result])

In [6]:
def get_positional_embeddings(sequence_length, d):
    result = torch.ones(sequence_length, d)
    for i in range(sequence_length):
        for j in range(d):
            result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / d)))
    return result

In [7]:
def main():

    # Defining model and training options
    n_channels = 1
    IMG_SIZE = 28
    model = ViT((n_channels, IMG_SIZE, IMG_SIZE), n_patches=14, hidden=20, num_heads=2, n_classes=10)
    N_EPOCHS = 5
    LR = 0.01

    # Training loop
    optimizer = Adam(model.parameters(), lr=LR)
    criterion = CrossEntropyLoss()
    for epoch in range(N_EPOCHS):
        train_loss = 0.0
        accuracy, total = 0, 0
        for batch in train_loader:
            image, label = batch
            pred = model(image)
            loss = criterion(pred, label) / len(image)

            train_loss += loss.item()
            accuracy += torch.sum(torch.argmax(pred, dim=1) == label).item()
            total += len(image)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch + 1}/{N_EPOCHS} loss: {train_loss:.2f}  Train Accuracy :{accuracy / total * 100:.2f}%")

    # Test loop
    accuracy, total = 0, 0
    test_loss = 0.0
    for batch in test_loader:
        image, label = batch
        pred = model(image)
        loss = criterion(pred, label) / len(image)
        test_loss += loss

        accuracy += torch.sum(torch.argmax(pred, dim=1) == label).item()
        total += len(image)
    print(f"Test loss: {test_loss:.2f}")
    print(f"Test accuracy: {accuracy / total * 100:.2f}%")

In [8]:
if __name__ == '__main__':
    main()

Epoch 1/5 loss: 412.42  Train Accuracy :70.32%
Epoch 2/5 loss: 383.57  Train Accuracy :82.50%
Epoch 3/5 loss: 371.42  Train Accuracy :87.74%
Epoch 4/5 loss: 368.81  Train Accuracy :88.80%
Epoch 5/5 loss: 366.77  Train Accuracy :89.61%
Test loss: 61.28
Test accuracy: 89.31%
