### Importing Libraries and Modules

In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as T # to resize input images and convert to tensor, make image size divisible by patch size
from torch.optim import Adam
from torchvision.datasets.mnist import MNIST
from torch.utils.data import DataLoader, Subset

import numpy as np # for positional encodings using sine and cosine operation

from sklearn.model_selection import train_test_split

### Patch Embeddings

In [2]:
class PatchEmbedding(nn.Module):
    def __init__(self, model_dim, img_size, patch_size, num_channels):
        super().__init__()

        self.model_dim = model_dim
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_channels = num_channels

        self.linear_project = nn.Conv2d(
            self.num_channels, 
            self.model_dim, 
            kernel_size=self.patch_size, 
            stride=self.patch_size
        )

    def forward(self, x):
        x = self.linear_project(x) # (B, C, H, W) --> (B, model_dim, P_Row, P_Col)
        x = x.flatten(2) # (B, model_dim, P_Row, P_Col) --> (B, model_dim, P)
        x = x.transpose(1,2) # (B, model_dim, P) --> (B, P, model_dim)
        
        return x


### Class Token and Positional Encoding

In [3]:
class PositionalEncoding(nn.Module):
    def __init__(self, model_dim, max_seq_length):
        super().__init__()

        self.cls_token = nn.Parameter(torch.randn(1, 1, model_dim)) # classification token

        # positional encoding
        pe = torch.zeros(max_seq_length, model_dim)

        for pos in range(max_seq_length):
            for i in range(model_dim):
                if i % 2 == 0:
                    pe[pos][i] = np.sin(pos/(10000 ** (i/model_dim)))
                else:
                    pe[pos][i] = np.cos(pos/(10000 ** ((i-1)/model_dim)))
        
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        tokens_batch = self.cls_token.expand(x.size()[0], -1, -1) # class token for each image
        x = torch.cat((tokens_batch,x), dim=1) # class token + no. of patches for each image
        x = x + self.pe

        return x

### Attention Head

In [4]:
class AttentionHead(nn.Module):
    def __init__(self, model_dim, head_size):
        super().__init__()
        self.head_size = head_size

        self.query = nn.Linear(model_dim, head_size)
        self.key = nn.Linear(model_dim, head_size)
        self.value = nn.Linear(model_dim, head_size)

    def forward(self, x):

        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)

        attention = Q @ K.transpose(-2,-1)
        attention = attention / (self.head_size ** 0.5) # scale value to control variance at initialization
        attention = torch.softmax(attention, dim=-1)
        attention = attention @ V

        return attention

### Multi-Head Attention

In [5]:
class MultiHeadAttention(nn.Module):
    def __init__(self, model_dim, num_heads):
        super().__init__()
        self.head_size = model_dim // num_heads
        self.W_o = nn.Linear(model_dim, model_dim)
        self.heads = nn.ModuleList([AttentionHead(model_dim, self.head_size) for _ in range(num_heads)])

    def forward(self, x):
        out = torch.cat([head(x) for head in self.heads], dim=-1)
        out = self.W_o(out)

        return out

### Transformer Encoder

In [6]:
class TransformerEncoder(nn.Module):
    def __init__(self, model_dim, num_heads, r_mlp=4):
        super().__init__()
        self.model_dim = model_dim
        self.num_heads = num_heads

        self.ln1 = nn.LayerNorm(model_dim)
        self.mha = MultiHeadAttention(model_dim, num_heads)
        self.ln2 = nn.LayerNorm(model_dim)

        self.mlp = nn.Sequential(
            nn.Linear(model_dim, model_dim * r_mlp),
            nn.GELU(),
            nn.Linear(model_dim * r_mlp, model_dim)
        )

    def forward(self, x):
        # residual connections to prevent vanishing gradient problem
        out = x + self.mha(self.ln1(x))
        out = out + self.mlp(self.ln2(out))

        return out

### Vision Transformer

In [7]:
class VisionTransformer(nn.Module):
    def __init__(self, model_dim, num_classes, img_size, patch_size, num_channels, num_heads, num_layers):
        super().__init__()

        assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, "img_size dimension must be divisible by patch_size dimensions"
        assert model_dim % num_heads == 0, "model_dim must be divisible by num_heads"

        self.model_dim = model_dim
        self.num_classes = num_classes
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_channels = num_channels
        self.num_heads = num_heads

        self.num_patches = (self.img_size[0] * self.img_size[1]) // (self.patch_size[0] * self.patch_size[1])
        self.max_seq_length = self.num_patches + 1
        self.patch_embedding = PatchEmbedding(self.model_dim, self.img_size, self.patch_size, self.num_channels)
        self.positional_encoding = PositionalEncoding(self.model_dim, self.max_seq_length)
        self.transformer_encoder = nn.Sequential(*[TransformerEncoder(self.model_dim, self.num_heads) for _ in range(num_layers)])

        self.classifer = nn.Sequential(
            nn.Linear(self.model_dim, self.num_classes),
            nn.Softmax(dim=1)
        )

    def forward(self, images):
        x = self.patch_embedding(images)
        x = self.positional_encoding(x)
        x = self.transformer_encoder(x)
        x = self.classifer(x[:, 0])

        return x

### Training Parameters

In [8]:
model_dim = 9
num_classes = 10
img_size = (32, 32)
patch_size = (16, 16)
num_channels = 1
num_heads = 3
num_layers = 3
batch_size = 128
epochs = 10
alpha = 0.005

### Loading MNIST Dataset

In [9]:
# transform = T.Compose([
#     T.Resize(img_size),
#     T.ToTensor()
# ])

# train_set = MNIST(
#     root='./../datasets', train= True, download=True, transform=transform
# )

# test_set = MNIST(
#     root='./../datasets', train=False, download=True, transform=transform
# )

# train_loader = DataLoader(train_set, shuffle=True, batch_size=batch_size)
# test_loader = DataLoader(test_set, shuffle=False, batch_size=batch_size)

In [10]:
transform = T.Compose([
    T.Resize(img_size),
    T.ToTensor()
])

train_set = MNIST(
    root='./../datasets', train= True, download=True, transform=transform
)

test_set = MNIST(
    root='./../datasets', train=False, download=True, transform=transform
)

train_indices, valid_indices = train_test_split(list(range(len(train_set))), test_size=0.2, random_state=42)

train_split = Subset(train_set, train_indices)
valid_split = Subset(train_set, valid_indices)


train_loader = DataLoader(train_split, shuffle=True, batch_size=batch_size)
valid_loader = DataLoader(valid_split, shuffle=False, batch_size=batch_size)
test_loader = DataLoader(test_set, shuffle=False, batch_size=batch_size)

### Training

In [11]:
# device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
# print("Using device:", device)

# transformer = VisionTransformer(model_dim, num_classes, img_size, patch_size, num_channels, num_heads, num_layers).to(device)

# optimizer = Adam(transformer.parameters(), lr=alpha)
# criterion = nn.CrossEntropyLoss()

# for epoch in range(epochs):
    
#     training_loss = 0.0
#     correct = 0
#     total = 0
#     for i, data in enumerate(train_loader, 0):
#         inputs, labels = data
#         inputs, labels = inputs.to(device), labels.to(device)

#         optimizer.zero_grad()

#         outputs = transformer(inputs)
#         loss = criterion(outputs, labels)
#         loss.backward()
#         optimizer.step()

#         training_loss += loss.item()

#         _, predicted = torch.max(outputs, 1)
#         total += labels.size(0)
#         correct += (predicted == labels).sum().item()

#     print(f'Epoch {epoch + 1}/{epochs} accuracy: {correct / total * 100:.3f} % loss: {training_loss / len(train_loader) :.3f}')

In [13]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print("Using device:", device)

transformer = VisionTransformer(model_dim, num_classes, img_size, patch_size, num_channels, num_heads, num_layers).to(device)

optimizer = Adam(transformer.parameters(), lr=alpha)
criterion = nn.CrossEntropyLoss()

for epoch in range(epochs):
    
    transformer.train()
    training_loss = 0.0
    correct = 0
    total = 0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = transformer(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        training_loss += loss.item()

        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        transformer.eval()
        valid_loss = 0.0
        correct_valid = 0
        total_valid = 0

        with torch.no_grad():
            for inputs, labels in valid_loader:
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)

                outputs = transformer(inputs)
                loss = criterion(outputs, labels)

                valid_loss += loss.item()

                _, predicted = torch.max(outputs.data, 1)
                total_valid += labels.size(0)
                correct_valid += (predicted == labels).sum().item()

    print(f'Epoch {epoch + 1}/{epochs} train accuracy: {correct / total * 100:.3f} % train loss: {training_loss / len(train_loader) :.3f} valid accuracy: {correct_valid / total_valid * 100:.3f} % valid loss: {valid_loss / len(valid_loader):.3f}')

Using device: mps
Epoch 1/10 train accuracy: 54.612 % train loss: 1.917 valid accuracy: 71.875 % valid loss: 1.735
Epoch 2/10 train accuracy: 75.935 % train loss: 1.704 valid accuracy: 80.469 % valid loss: 1.655
Epoch 3/10 train accuracy: 77.892 % train loss: 1.683 valid accuracy: 80.469 % valid loss: 1.660
Epoch 4/10 train accuracy: 85.173 % train loss: 1.611 valid accuracy: 88.281 % valid loss: 1.582
Epoch 5/10 train accuracy: 87.740 % train loss: 1.584 valid accuracy: 85.156 % valid loss: 1.617
Epoch 6/10 train accuracy: 89.223 % train loss: 1.569 valid accuracy: 91.406 % valid loss: 1.561
Epoch 7/10 train accuracy: 90.062 % train loss: 1.561 valid accuracy: 89.844 % valid loss: 1.568
Epoch 8/10 train accuracy: 90.460 % train loss: 1.557 valid accuracy: 89.844 % valid loss: 1.567
Epoch 9/10 train accuracy: 90.135 % train loss: 1.560 valid accuracy: 91.406 % valid loss: 1.556
Epoch 10/10 train accuracy: 90.417 % train loss: 1.557 valid accuracy: 89.844 % valid loss: 1.562


### Testing

In [14]:
correct = 0
total = 0

with torch.no_grad():
    for data in test_loader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)

        outputs = transformer(images)

        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    print(f'\nModel Accuracy: {100 * correct // total} %')


Model Accuracy: 90 %
