In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import random

In [2]:
# device agnostic code
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
# set the seed
torch.manual_seed(42)
torch.cuda.manual_seed(42)
np.random.seed(42)
random.seed(42)

In [4]:
# hyperparameters
BATCH_SIZE = 128
EPOCHS = 10
LR = 3e-4
PATCH_SIZE = 4
NUM_CLASSES = 10
IMAGE_SIZE = 32
CHANNELS = 3
EMBED_DIM = 256
NUM_HEADS = 8
DEPTH = 6
MLP_DIM = 512
DROP_RATE = 0.1

In [5]:
# transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5),
    # cvg faster, numerical stability
])

In [6]:
# dataset
train_data = datasets.CIFAR10(root='data', train=True, download=True, transform=transform)
test_data = datasets.CIFAR10(root='data', train=False, download=True, transform=transform)

In [7]:
# convert dataset into dataloader
train_loader = DataLoader(batch_size=BATCH_SIZE, dataset=train_data, shuffle=True)
test_loader = DataLoader(batch_size=BATCH_SIZE, dataset=test_data, shuffle=False)

In [8]:
# vit 
class PatchEmbedding(nn.Module):
    def __init__(self, 
                img_size,
                patch_size,
                in_channels,
                embed_dim
    ):
        super().__init__()
        self.batch_size = BATCH_SIZE
        self.proj = nn.Conv2d(in_channels=in_channels, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size)
        num_patches = (img_size // patch_size) ** 2
        self.cls_token = nn.Parameter(torch.randn(1,1,embed_dim))
        self.pos_embed = nn.Parameter(torch.randn(1,1+num_patches, embed_dim))

    def forward(self, x: torch.Tensor):
        B = x.size(0)
        x = self.proj(x) # B x E x H/P x W/P
        x = x.flatten(2).transpose(1, 2)
        cls_token = self.cls_token.expand(B, -1, -1) # expand method does the follwoing: repeat the tensor along the specified dimensions without actually copying the data in memory.
        x = torch.concat((cls_token, x), dim=1)
        x += self.pos_embed
        return x


In [9]:
# visualizer
# x = torch.arange(120).reshape(2, 3, 4, 5)
# TODO: concat, (B,-1,-1)

In [10]:
class MLP(nn.Module):
    def __init__(self, 
                 in_features,
                 hidden_features,
                 drop_rate) -> None:
        super().__init__()
        self.fc1 = nn.Linear(in_features=in_features, out_features=hidden_features) 
        self.fc2 = nn.Linear(in_features=hidden_features, out_features=in_features) 
        self.dropout = nn.Dropout(p=drop_rate)

    def forward(self, x):
        x = self.dropout(F.gelu(self.fc1(x)))
        x = self.dropout(self.fc2(x))
        return x

In [11]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self,
                 embed_dim,
                 num_heads,
                 mlp_dim,
                 drop_rate):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attention = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, dropout=drop_rate, batch_first=True)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(embed_dim, mlp_dim, drop_rate)

    def forward(self, x):
        x = x + self.attention(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        x = x + self.mlp(self.norm2(x))
        return x

In [12]:
class VisionTransformer(nn.Module):
    def __init__(self, img_size, patch_size, in_channels, num_classes, embed_dim, depth, num_heads, mlp_dim, drop_rate):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        self.encoder = nn.Sequential(*[TransformerEncoderLayer(embed_dim, num_heads, mlp_dim, drop_rate) for _ in range(depth)])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = self.patch_embed(x)
        x = self.encoder(x)
        x = self.norm(x)
        cls_token = x[:, 0]
        return self.head(cls_token)

In [13]:
# instantiate the model
model = VisionTransformer(
    img_size=IMAGE_SIZE,
    patch_size=PATCH_SIZE,
    in_channels=CHANNELS,
    num_classes=NUM_CLASSES,
    embed_dim=EMBED_DIM,
    depth=DEPTH,
    num_heads=NUM_HEADS,
    mlp_dim=MLP_DIM,
    drop_rate=DROP_RATE
).to(device)



In [14]:
model

VisionTransformer(
  (patch_embed): PatchEmbedding(
    (proj): Conv2d(3, 256, kernel_size=(4, 4), stride=(4, 4))
  )
  (encoder): Sequential(
    (0): TransformerEncoderLayer(
      (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
      )
      (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (mlp): MLP(
        (fc1): Linear(in_features=256, out_features=512, bias=True)
        (fc2): Linear(in_features=512, out_features=256, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (1): TransformerEncoderLayer(
      (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
      )
      (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      

In [15]:
# deine loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LR)

In [16]:
def train(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    correct = 0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * images.size(0)
        correct += (outputs.argmax(1) == labels).sum().item()
    
    return total_loss / len(loader.dataset), correct / len(loader.dataset)

def evaluate(model, loader, criterion):
    model.eval()
    total_loss = 0
    correct = 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            total_loss += loss.item() * images.size(0)
            correct += (outputs.argmax(1) == labels).sum().item()
    
    return total_loss / len(loader.dataset), correct / len(loader.dataset)

In [17]:
from tqdm.auto import tqdm
train_accuracies = []
test_accuracies = []

for epoch in tqdm(range(EPOCHS)):
    train_loss, train_acc = train(model, train_loader, optimizer, criterion)
    test_loss, test_acc = evaluate(model, test_loader, criterion)

    train_accuracies.append(train_acc)
    test_accuracies.append(test_acc)

    print(f"Epoch {epoch+1}/{EPOCHS}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")

  from .autonotebook import tqdm as notebook_tqdm
 10%|█         | 1/10 [00:25<03:45, 25.03s/it]

Epoch 1/10, Train Loss: 1.7355, Train Acc: 0.3706, Test Loss: 1.4479, Test Acc: 0.4809


 20%|██        | 2/10 [00:50<03:22, 25.28s/it]

Epoch 2/10, Train Loss: 1.3781, Train Acc: 0.5059, Test Loss: 1.2935, Test Acc: 0.5383


 30%|███       | 3/10 [01:16<02:57, 25.43s/it]

Epoch 3/10, Train Loss: 1.2327, Train Acc: 0.5601, Test Loss: 1.2111, Test Acc: 0.5668


 40%|████      | 4/10 [01:41<02:33, 25.61s/it]

Epoch 4/10, Train Loss: 1.1261, Train Acc: 0.5976, Test Loss: 1.1607, Test Acc: 0.5833


 50%|█████     | 5/10 [02:08<02:09, 25.86s/it]

Epoch 5/10, Train Loss: 1.0400, Train Acc: 0.6298, Test Loss: 1.1177, Test Acc: 0.6049


 60%|██████    | 6/10 [02:34<01:44, 26.03s/it]

Epoch 6/10, Train Loss: 0.9649, Train Acc: 0.6572, Test Loss: 1.1133, Test Acc: 0.6043


 70%|███████   | 7/10 [03:01<01:18, 26.16s/it]

Epoch 7/10, Train Loss: 0.8882, Train Acc: 0.6852, Test Loss: 1.0738, Test Acc: 0.6246


 80%|████████  | 8/10 [03:27<00:52, 26.27s/it]

Epoch 8/10, Train Loss: 0.8157, Train Acc: 0.7102, Test Loss: 1.0829, Test Acc: 0.6207


 90%|█████████ | 9/10 [03:54<00:26, 26.34s/it]

Epoch 9/10, Train Loss: 0.7466, Train Acc: 0.7354, Test Loss: 1.1014, Test Acc: 0.6264


100%|██████████| 10/10 [04:20<00:00, 26.06s/it]

Epoch 10/10, Train Loss: 0.6740, Train Acc: 0.7589, Test Loss: 1.0896, Test Acc: 0.6446





In [19]:
test_accuracies

[0.4809,
 0.5383,
 0.5668,
 0.5833,
 0.6049,
 0.6043,
 0.6246,
 0.6207,
 0.6264,
 0.6446]