In [48]:
# import libraries
import torch
import torchvision
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torch.utils.data as dataloader
import torch.nn as nn

In [49]:
#transformation of PIL data into tensor format
transformation_operation=transforms.Compose([transforms.ToTensor()])

In [50]:
train_dataset=torchvision.datasets.MNIST(root='./data', train =True, download=True, transform=transformation_operation)

In [51]:
val_dataset=torchvision.datasets.MNIST(root='./data', train =False, download=True, transform=transformation_operation)

In [52]:
batch_size=64
num_classes=10
img_size=28
patch_size=7
patch_number = (img_size//patch_size) * (img_size//patch_size)
attention_heads= 4
embed_dim = 20
transformer_blocks = 4
mlp_nodes= 64
num_channels = 1 # black and white image
learning_rate= 0.001
epochs=5

In [53]:
# using dataloader to prep data for nueral n/w

train_data= dataloader.DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
val_data=dataloader.DataLoader(train_dataset, shuffle=True, batch_size=batch_size)

In [54]:
# class for PatchEmbedding - part 1
 # inherits from nn.Module
class PatchEmbedding(nn.Module):
  def __init__(self):
    super().__init__()
    # stride size is same as patch size as they are non overlapping
    # kernel size is also same as patch size
    self.patch_embed = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)


  def forward(self,x): # x is the input dataset -> variable generated from training data
    x=self.patch_embed(x)
    x=x.flatten(2)
    x=x.transpose(1,2)
    return x



In [55]:
# to print dims

images, labels = next(iter(train_data))
print(images.shape)

#torch.Size([64, 1, 28, 28])
            #batch size,no of channels, size of image

patch_embed = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
print("This is the shape of input image tensor", images.shape)
print("This is the shape of output patch embed image tensor", patch_embed(images).shape)

'''
This is the shape of input image tensor torch.Size([64, 1, 28, 28])
This is the shape of output patch embed image tensor torch.Size([64, 20, 4, 4])

'''
embedded_image= patch_embed(images)
print("This is the shape of flattened image tensor",embedded_image.flatten(2).shape)
print("This is the shape of  transposed image tensor",embedded_image.flatten(2).transpose(1,2).shape)


torch.Size([64, 1, 28, 28])
This is the shape of input image tensor torch.Size([64, 1, 28, 28])
This is the shape of output patch embed image tensor torch.Size([64, 20, 4, 4])
This is the shape of flattened image tensor torch.Size([64, 20, 16])
This is the shape of  transposed image tensor torch.Size([64, 16, 20])


In [56]:
# class for Transformer encoder - part 2
# Layer Norm
# Multi-Head Attention
# Layer Norm
# Residuals
# MLP - activation function
class TransformerEncoder(nn.Module):
  def __init__(self):
    super().__init__()
    self.layer_norm1=nn.LayerNorm(embed_dim)
    self.multi_head_attention=nn.MultiheadAttention(embed_dim,attention_heads, batch_first=True)
    self.layer_norm2=nn.LayerNorm(embed_dim)
    self.mlp=nn.Sequential(
        nn.Linear(embed_dim, mlp_nodes),
        nn.GELU(),
        nn.Linear(mlp_nodes, embed_dim),
        #nn.GELU(),
        #nn.Linear(embed_dim)
    )
  def forward(self,x):
    residual1=x
    x=self.layer_norm1(x)
    x=self.multi_head_attention(x, x, x)[0] # key, query and value
    x=x+residual1
    residual2=x
    x=self.layer_norm2(x)
    x=self.mlp(x)
    x=x+residual2
    return x





In [57]:
# class for MLP head for classification encoder - part 3
class MLP_Head(nn.Module):
  def __init__(self):
    super().__init__()
    # to prev overfitting
    self.layer_norm1 = nn.LayerNorm(embed_dim)
    self.mlphead=nn.Sequential(
        #nn.Linear(embed_dim),
        nn.Linear(embed_dim,num_classes)
    )

  def forward(self,x):
    # we need only vector associated with CLS for classificatio
    #x = x[:,0]
    x=self.layer_norm1(x)
    x=self.mlphead(x)

    return x


In [58]:
class VisionTransformer(nn.Module):
  def __init__(self):
    super().__init__()
    self.patch_embedding = PatchEmbedding()
    self.cls_token = nn.Parameter(torch.randn(1,1,embed_dim))
    self.position_embedding = nn.Parameter(torch.randn(1,patch_number+1, embed_dim))
    # for n number of transformer blocks
    self.transformer_blocks= nn.Sequential(*[TransformerEncoder() for _ in range(transformer_blocks)])
    self.mlp_head = MLP_Head()

  def forward(self,x):
    x = self.patch_embedding(x)
    B = x.size(0) # cannot be hardcoded as last batch may have 16 images or any number
    cls_tokens = self.cls_token.expand(B,-1,-1)
    x=torch.cat((cls_tokens,x),1)
    x= x+self.position_embedding
    x = x[:,0]
    x = self.mlp_head(x)

    return x


In [59]:
# optimizer
# cross entropy loss

# device

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = VisionTransformer().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()


In [61]:
for epoch in range(5):
    model.train()
    total_loss = 0
    correct_epoch = 0
    total_epoch = 0
    print(f"\nEpoch {epoch+1}")

    for batch_idx, (images, labels) in enumerate(train_data):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss+=loss.item()
        preds = outputs.argmax(dim=1)

        correct = (preds == labels).sum().item()
        accuracy = 100.0 * correct / labels.size(0)

        correct_epoch += correct
        total_epoch += labels.size(0)

        if batch_idx % 100 == 0:
            print(f"  Batch {batch_idx+1:3d}: Loss = {loss.item():.4f}, Accuracy = {accuracy:.2f}%")

    epoch_acc = 100.0 * correct_epoch / total_epoch
    print(f"==> Epoch {epoch+1} Summary: Total Loss = {total_loss:.4f}, Accuracy = {epoch_acc:.2f}%")


Epoch 1
  Batch   1: Loss = 2.3044, Accuracy = 15.62%
  Batch 101: Loss = 2.3000, Accuracy = 12.50%
  Batch 201: Loss = 2.3059, Accuracy = 7.81%
  Batch 301: Loss = 2.3026, Accuracy = 3.12%
  Batch 401: Loss = 2.3029, Accuracy = 10.94%
  Batch 501: Loss = 2.3228, Accuracy = 7.81%
  Batch 601: Loss = 2.3054, Accuracy = 12.50%
  Batch 701: Loss = 2.3005, Accuracy = 6.25%
  Batch 801: Loss = 2.3182, Accuracy = 9.38%
  Batch 901: Loss = 2.3055, Accuracy = 9.38%
==> Epoch 1 Summary: Total Loss = 2159.9498, Accuracy = 10.91%

Epoch 2
  Batch   1: Loss = 2.3088, Accuracy = 3.12%
  Batch 101: Loss = 2.3024, Accuracy = 10.94%
  Batch 201: Loss = 2.2938, Accuracy = 14.06%
  Batch 301: Loss = 2.3054, Accuracy = 7.81%
  Batch 401: Loss = 2.3095, Accuracy = 9.38%
  Batch 501: Loss = 2.2920, Accuracy = 21.88%
  Batch 601: Loss = 2.2893, Accuracy = 12.50%
  Batch 701: Loss = 2.3021, Accuracy = 7.81%
  Batch 801: Loss = 2.3100, Accuracy = 3.12%
  Batch 901: Loss = 2.3031, Accuracy = 7.81%
==> Epoch 2