In [73]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [74]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
from torchvision.datasets import CIFAR100
from einops import rearrange
from einops.layers.torch import Rearrange
import numpy as np

In [75]:
IMG_SIZE = 32
BATCH_SIZE = 128 
NUM_WORKERS = 4

In [76]:
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])
])

val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])
])

NUM_CLASSES = 100 

In [77]:
root = "./data"
train_ds = CIFAR100(root=root, train=True, download=True, transform=train_transform)
val_ds = CIFAR100(root=root, train=False, download=True, transform=val_transform)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

print(f"Train samples: {len(train_ds)}, Val samples: {len(val_ds)}, Classes: {NUM_CLASSES}")


Train samples: 50000, Val samples: 10000, Classes: 100


In [78]:
def get_positional_embeddings(sequence_length, d):
    result = torch.ones(sequence_length, d)
    for i in range(sequence_length):
        for j in range(d):
            result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / d)))
    return result


In [79]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=3, patch_size=16, emb_size=256, img_size=224):
        super().__init__()
        self.patch_size = patch_size
        self.img_size = img_size
        self.num_patches = (img_size // patch_size) ** 2
        
        self.projection = nn.Sequential(
            nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
            Rearrange('b e h w -> b (h w) e'),
        )
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
        self.position_embeddings = nn.Parameter(get_positional_embeddings(self.num_patches + 1, emb_size))

    def forward(self, x):
        b = x.shape[0]
        x = self.projection(x)
        cls_tokens = self.cls_token.expand(b, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)
        x = x + self.position_embeddings
        return x


In [80]:
class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size, num_heads, dropout=0.1):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        self.head_dim = emb_size // num_heads
        
        self.qkv = nn.Linear(emb_size, emb_size * 3, bias=False)
        self.attn_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)
        self.proj_drop = nn.Dropout(dropout)

    def forward(self, x):
        b, n, e = x.shape
        qkv = self.qkv(x).reshape(b, n, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
        attn = F.softmax(attn, dim=-1)
        attn = self.attn_drop(attn)
        
        out = (attn @ v).transpose(1, 2).reshape(b, n, e)
        out = self.projection(out)
        out = self.proj_drop(out)
        return out


In [81]:
class MLP(nn.Module):
    def __init__(self, emb_size, expansion=4, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(emb_size, expansion * emb_size)
        self.fc2 = nn.Linear(expansion * emb_size, emb_size)
        self.gelu = nn.GELU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.gelu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

In [82]:
class TransformerBlock(nn.Module):
    def __init__(self, emb_size, num_heads, dropout=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(emb_size)
        self.attn = MultiHeadAttention(emb_size, num_heads, dropout)
        self.ln2 = nn.LayerNorm(emb_size)
        self.mlp = MLP(emb_size, dropout=dropout)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x


In [83]:
class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, num_classes=37, 
                 emb_size=256, num_layers=6, num_heads=8, dropout=0.1):
        super().__init__()
        
        self.patch_embed = PatchEmbedding(
            patch_size=patch_size,
            emb_size=emb_size,
            img_size=img_size
        )
        
        self.transformer = nn.Sequential(*[
            TransformerBlock(emb_size, num_heads, dropout) for _ in range(num_layers)
        ])
        
        self.ln = nn.LayerNorm(emb_size)
        self.dropout = nn.Dropout(0.2)
        self.head = nn.Linear(emb_size, num_classes)

    def forward(self, x):
        x = self.patch_embed(x)
        x = self.transformer(x)
        cls_token = x[:, 0]
        cls_token = self.ln(cls_token)
        cls_token = self.dropout(cls_token)
        logits = self.head(cls_token)
        return logits

In [84]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = VisionTransformer(
    img_size=32,  
    patch_size=4,  
    num_classes=100, 
    emb_size=256,
    num_layers=6,
    num_heads=8,
    dropout=0.15
).to(device)

In [85]:
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.05)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)

In [86]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        
        # Check for NaN
        if torch.isnan(loss):
            print("WARNING: NaN loss detected!")
            continue
            
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        running_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    return running_loss / total, 100. * correct / total

In [87]:
def eval_epoch(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
    return running_loss / total, 100. * correct / total

In [88]:
print("\n" + "="*50)
print("Starting Training")
print("="*50)

best_val_acc = 0
patience = 5
patience_counter = 0

for epoch in range(1, 101):
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc = eval_epoch(model, val_loader, criterion, device)
    scheduler.step()
    
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        patience_counter = 0
        torch.save(model.state_dict(), 'best_vit_cifar100.pth')
        marker = " ✓ NEW BEST"
    else:
        patience_counter += 1
        marker = ""
    
    if epoch % 5 == 0 or epoch <= 10:
        print(f"Epoch {epoch:3d} | Train: {train_acc:5.2f}% | Val: {val_acc:5.2f}% | "
              f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}{marker}")
    
    if patience_counter >= patience:
        print(f"\n Early stopping triggered after {epoch} epochs")
        print(f"No improvement for {patience} consecutive epochs")
        print(f"Best validation accuracy: {best_val_acc:.2f}%")
        break

print(f"\n{'='*50}")
print(f"Training Complete! Best Val Accuracy: {best_val_acc:.2f}%")
print(f"{'='*50}")


Starting Training
Epoch   1 | Train:  8.27% | Val: 16.45% | Train Loss: 4.1481 | Val Loss: 3.7151 ✓ NEW BEST
Epoch   2 | Train: 17.55% | Val: 23.57% | Train Loss: 3.6556 | Val Loss: 3.4059 ✓ NEW BEST
Epoch   3 | Train: 22.70% | Val: 27.58% | Train Loss: 3.4313 | Val Loss: 3.2179 ✓ NEW BEST
Epoch   4 | Train: 26.01% | Val: 31.46% | Train Loss: 3.2878 | Val Loss: 3.0772 ✓ NEW BEST
Epoch   5 | Train: 29.08% | Val: 33.57% | Train Loss: 3.1635 | Val Loss: 2.9912 ✓ NEW BEST
Epoch   6 | Train: 31.35% | Val: 35.46% | Train Loss: 3.0659 | Val Loss: 2.8928 ✓ NEW BEST
Epoch   7 | Train: 33.45% | Val: 36.93% | Train Loss: 2.9671 | Val Loss: 2.8355 ✓ NEW BEST
Epoch   8 | Train: 35.92% | Val: 39.69% | Train Loss: 2.8815 | Val Loss: 2.7544 ✓ NEW BEST
Epoch   9 | Train: 38.21% | Val: 41.20% | Train Loss: 2.7960 | Val Loss: 2.6830 ✓ NEW BEST
Epoch  10 | Train: 39.78% | Val: 43.32% | Train Loss: 2.7224 | Val Loss: 2.6080 ✓ NEW BEST
Epoch  15 | Train: 48.49% | Val: 49.15% | Train Loss: 2.4103 | Val Loss