In [17]:
import math
import torch
from torch import nn
from torch.nn import functional
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
from GPT import Block
from tqdm.notebook import tqdm
import numpy as np

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
epochs = 20
print_interval = 100
batch_size = 64

In [4]:
transform = transforms.Compose([
    transforms.ToTensor(),
])

In [5]:
trainset = torchvision.datasets.CIFAR10(root='./data/cifar-10-train', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

Files already downloaded and verified


In [6]:
testset = torchvision.datasets.CIFAR10(root='./data/cifar-10-test', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True)

Files already downloaded and verified


In [6]:
len(testset.data)

10000

In [7]:
class PatchEmbeddings(nn.Module):
    """
    Convert the image into patches and then project them into a vector space.
    """

    def __init__(self, image_size, patch_size, num_channels, embed_dim):
        super().__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_channels = num_channels
        self.embed_dim = embed_dim
        # Calculate the number of patches from the image size and patch size
        self.num_patches = (self.image_size // self.patch_size) ** 2
        # Create a projection layer to convert the image into patches
        # The layer projects each patch into a vector of size hidden_size
        self.projection = nn.Conv2d(self.num_channels, self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size)

    
    def forward(self, x):
        # (batch_size, num_channels, image_size, image_size) -> (batch_size, num_patches, hidden_size)
        x = self.projection(x)
        x = x.flatten(2).transpose(1, 2)
        return x


In [8]:
class VisionTransformer(nn.Module):

    
    def __init__(self, image_size: int, patch_size: int, num_channels: int, embed_dim: int, num_classes: int):
        super().__init__()
       
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_channels = num_channels
        self.embed_dim = embed_dim
        self.num_classes = num_classes
        self.num_patches = (self.image_size // self.patch_size) ** 2 # Num patches == Context
        self.mask = torch.zeros(size = (self.num_patches + 1, self.num_patches + 1)).bool().to(device=device)

        # Token embedding table is used for token identification encoding
        # Position embedding table is used for token position (in reference to the current context) encoding
        self.patch_embeddings = PatchEmbeddings(image_size, patch_size, num_channels, embed_dim).to(device=device)
        self.cls_token = nn.Parameter(data = torch.randn(size=(1,1, embed_dim), device=device))
        self.position_embedding_table = nn.Embedding(self.num_patches + 1, embed_dim, device=device)

        self.blocks = nn.ModuleList([
            Block(emb_dims=embed_dim, num_heads=4),
            Block(emb_dims=embed_dim, num_heads=4),
            Block(emb_dims=embed_dim, num_heads=4),
        ])

        # Final layer norm
        self.ln_f = nn.LayerNorm(embed_dim) 
        
        # Language model head used for output
        self.lm_head = nn.Linear(embed_dim, self.num_classes)

    def forward(self, x, targets=None):
        B, C, H, W = x.shape

        # x and targets are both (B,C) tensor of integers

        # Getting the Patch embeddings
        patch_emb: torch.Tensor = self.patch_embeddings(x) # (B,C,D)
        cls_token = self.cls_token.expand(B, -1, -1)

        # Added the Class token to the Patch embeddings
        x = torch.concat([cls_token, patch_emb], dim=1) # (B, C+1, D) Added Class token
        
        B, C, D = x.shape

        # Getting the position embedding for all the positions, starting from 0 -> context - 1
        pos_emb = self.position_embedding_table(torch.arange(C, device=device)) # (C,D)

        # Adding the position embedding to the patch embeddings 
        x = x + pos_emb

        for block in self.blocks:
            x = block(x, self.mask)
            
        x = self.ln_f(x) 
        logits = self.lm_head(x)
        cls_logits = logits[:, 0]

        if targets is None:
            loss = None
        else:
            loss = functional.cross_entropy(cls_logits, targets)

        return cls_logits, loss

    def predict(self, x):
        # Get the predictions
        cls_logits, loss = self.forward(x)
        probs = functional.softmax(cls_logits, dim=-1)
        predictions = probs.argmax(dim = -1)
        return predictions

In [9]:
model = VisionTransformer(
    image_size=32,
    patch_size=4,
    num_channels=3,
    embed_dim=128,
    num_classes=10
)

model.to(device)

# Print the number of parameters in the model
print(sum(param.numel() for param in model.parameters()) / 1e6, 'M parameters')



0.60993 M parameters


In [10]:
state_dict = torch.load('vit.pt')
model.load_state_dict(state_dict)

<All keys matched successfully>

In [16]:
# Test function
acc = []
for step, batch in enumerate(testloader):
  
  # every once in a while evaluate the loss on train and val sets
  if step % print_interval == 0 :
      print(f"step {step}: acc {np.array(acc).mean()}")  
  x, y = batch
  x = x.to(device)
  y = y.to(device)
  preds = model.predict(x)
  scores = torch.eq(y, preds).float()
  acc.append(scores.mean())


tensor(0.6250, device='cuda:0')


In [11]:
# Create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [9]:
# Training loop

progress_bar = tqdm(range(epochs * len(trainloader)))

for epoch in range(epochs):
    total_loss = 0
    for step, batch in enumerate(trainloader):
        # every once in a while evaluate the loss on train and val sets
        if step % print_interval == 0 :
            print(f"step {step}: train loss {total_loss / (step + 1)}")

        x, y = batch
        # evaluate the loss
        logits, loss = model.forward(x = x, targets =y)
        total_loss += loss.item()
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        progress_bar.update(1)

  0%|          | 0/6250 [00:00<?, ?it/s]

step 0: train loss 0.0
