In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F

In [2]:
from einops.layers.torch import Rearrange

In [3]:
# Hyperparameters
batch_size = 64
learning_rate = 1e-3
num_epochs = 20

transform = transforms.Compose([
    transforms.RandomResizedCrop(32, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Load CIFAR10 Dataset
train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

Files already downloaded and verified
Files already downloaded and verified
cuda


<a id="5"></a> <br>
### Vision Transformer

In [25]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels = 3, patch_size = 4, emb_size = 128):
        self.patch_size = patch_size
        super().__init__()
        self.projection = nn.Sequential(
            # break-down the image in s1 x s2 patches and flat them
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size),
            nn.Linear(patch_size * patch_size * in_channels, emb_size)
        )

    def forward(self, x):
        x = self.projection(x)
        return x


class AttentionHead(nn.Module):
    def __init__(self, input_emb_size=128, latent_emb_size=64):
        super().__init__()

        self.wk = nn.Linear(input_emb_size, latent_emb_size)
        self.wq = nn.Linear(input_emb_size, latent_emb_size)
        self.wv = nn.Linear(input_emb_size, latent_emb_size)
    
    def forward(self, x):
        k = self.wk(x)
        q = self.wq(x)
        v = self.wv(x)

        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (q.shape[-1] ** 0.5)
        attn_weights = F.softmax(attn_scores, dim=-1)

        return torch.matmul(attn_weights, v)


class MultiAttentionHead(nn.Module):
    def __init__(self, n_head=3, input_emb_dim=128, latent_emb_size=64):
        super().__init__()

        self.n_head = n_head
        self.multihead = [AttentionHead(input_emb_dim, latent_emb_size) for i in range(n_head)]
        [attn_head.to(device) for attn_head in self.multihead]

        self.proj = nn.Linear(n_head * latent_emb_size, input_emb_dim)

    def forward(self, x):
        multi_head_output = [attn_head(x) for attn_head in self.multihead]
        multi_head_output = torch.cat(multi_head_output, dim=-1)  

        return self.proj(multi_head_output)


class TransformerBlock(nn.Module):
    def __init__(self, n_head=3, input_emb_dim=128, attn_head_latent_emb_dim=64, mlp_latent_emb_dim=128):
        super().__init__()

        self.multi_head_attn = MultiAttentionHead(n_head, input_emb_dim, attn_head_latent_emb_dim)
    
        self.mlp = nn.Sequential(
            nn.LayerNorm(input_emb_dim),
            nn.Linear(input_emb_dim, mlp_latent_emb_dim),
            nn.ReLU(),
            nn.Linear(mlp_latent_emb_dim, input_emb_dim)
        )

    def forward(self, x):
        x = self.multi_head_attn(x)
        x = self.mlp(x)
        return x


# 🚀 Fixed Vision Transformer Model for CIFAR-10
class ViT(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_channels=3, num_classes=10, 
                 emb_dim=128):
        super(ViT, self).__init__()

        assert img_size % patch_size == 0, "Image size must be divisible by patch size"
        
        self.patch_embedding = PatchEmbedding(in_channels=in_channels,
                                              patch_size=patch_size,
                                              emb_size=emb_dim)

        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2

        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_dim))        
        self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, emb_dim))

        self.ln = nn.LayerNorm(emb_dim)

        self.transformer_block_1 = TransformerBlock(n_head=6, input_emb_dim=emb_dim)
        
        self.transformer_block_2 = TransformerBlock(n_head=6, input_emb_dim=emb_dim)
        
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(emb_dim),
            nn.Linear(emb_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        B = x.shape[0]
        
        x = self.patch_embedding(x) 

        cls_tokens = self.cls_token.expand(B, -1, -1)  
        x = torch.cat((cls_tokens, x), dim=1)  
        x = x + self.pos_embed

        x = self.ln(x)

        dx = self.transformer_block_1(x)
        x = x + dx

        dx = self.transformer_block_2(x)
        x = x + dx

        x = self.mlp_head(x)
        
        x = x[:, 0]  

        return x

model = ViT().to(device)
print(model)

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-2)
criterion = nn.CrossEntropyLoss()

ViT(
  (patch_embedding): PatchEmbedding(
    (projection): Sequential(
      (0): Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=4, p2=4)
      (1): Linear(in_features=48, out_features=128, bias=True)
    )
  )
  (ln): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (transformer_block_1): TransformerBlock(
    (multi_head_attn): MultiAttentionHead(
      (proj): Linear(in_features=384, out_features=128, bias=True)
    )
    (mlp): Sequential(
      (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (1): Linear(in_features=128, out_features=128, bias=True)
      (2): ReLU()
      (3): Linear(in_features=128, out_features=128, bias=True)
    )
  )
  (transformer_block_2): TransformerBlock(
    (multi_head_attn): MultiAttentionHead(
      (proj): Linear(in_features=384, out_features=128, bias=True)
    )
    (mlp): Sequential(
      (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (1): Linear(in_features=128, out_features=128, bias=True)
    

In [26]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters())

# Example
total_params = count_parameters(model)
print(f"Total Parameters: {total_params}")

Total Parameters: 189258


In [27]:
num_epochs = 100

# Traning the Model
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for (images, labels) in train_loader:
        images = images.to(device).view(-1, 3, 32, 32)
        labels = labels.to(device)
        
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    running_loss += loss.item()
        
    # for name, param in model.named_parameters():
    #     if param.grad is not None:
    #         print(f'{name} grad mean: {param.grad.mean()}')
    # print(outputs)

    # Calculate Accuracy         
    correct = 0
    total = 0
    # Predict test dataset
    for images, labels in test_loader: 
        images = images.to(device).view(-1, 3, 32, 32)
        labels = labels.to(device)

        outputs = model(images)
        predicted = torch.max(outputs.data, 1)[1]
        total += len(labels)
        correct += (predicted == labels).sum()
    
    accuracy = 100 * correct / float(total)
    avg_loss = running_loss / len(train_loader)
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.4f}, Acc: {accuracy:.4f}%")

Epoch [1/100], Loss: 0.0022, Acc: 30.7300%
Epoch [2/100], Loss: 0.0026, Acc: 36.3600%
Epoch [3/100], Loss: 0.0024, Acc: 40.3100%
Epoch [4/100], Loss: 0.0016, Acc: 44.0200%
Epoch [5/100], Loss: 0.0028, Acc: 45.1200%
Epoch [6/100], Loss: 0.0014, Acc: 46.8300%
Epoch [7/100], Loss: 0.0019, Acc: 47.8700%
Epoch [8/100], Loss: 0.0022, Acc: 48.7600%
Epoch [9/100], Loss: 0.0019, Acc: 49.1000%
Epoch [10/100], Loss: 0.0016, Acc: 49.5400%
Epoch [11/100], Loss: 0.0011, Acc: 51.6500%
Epoch [12/100], Loss: 0.0011, Acc: 50.6500%
Epoch [13/100], Loss: 0.0016, Acc: 51.4000%
Epoch [14/100], Loss: 0.0016, Acc: 52.5900%
Epoch [15/100], Loss: 0.0019, Acc: 53.1200%
Epoch [16/100], Loss: 0.0022, Acc: 52.7600%
Epoch [17/100], Loss: 0.0012, Acc: 52.9600%
Epoch [18/100], Loss: 0.0018, Acc: 53.4700%
Epoch [19/100], Loss: 0.0020, Acc: 53.6000%
Epoch [20/100], Loss: 0.0022, Acc: 54.3300%
Epoch [21/100], Loss: 0.0016, Acc: 54.2800%
Epoch [22/100], Loss: 0.0021, Acc: 54.1300%
Epoch [23/100], Loss: 0.0017, Acc: 54.360

In [29]:
# # visualization loss 
# plt.plot(iteration_list,loss_list)
# plt.xlabel("Number of iteration")
# plt.ylabel("Loss")
# plt.title("CNN: Loss vs Number of iteration")
# plt.show()

# # visualization accuracy 
# plt.plot(iteration_list,accuracy_list,color = "red")
# plt.xlabel("Number of iteration")
# plt.ylabel("Accuracy")
# plt.title("CNN: Accuracy vs Number of iteration")
# plt.show()