In [2]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision.datasets import MNIST
from tqdm import tqdm

In [3]:
# Patch Embedding Layer
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=28, patch_size=7, in_channels=1, embed_dim=64):
        super(PatchEmbedding, self).__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        
    def forward(self, x):
        patches = self.proj(x).flatten(2).transpose(1, 2)  # Flatten patches
        return patches

In [4]:
# Vision Transformer (ViT) Model
class MyViT(nn.Module):
    def __init__(self, chw, n_patches=7, hidden_d=64, n_heads=8, out_d=10):
        super(MyViT, self).__init__()
        self.chw = chw
        self.n_patches = n_patches
        self.hidden_d = hidden_d
        self.n_heads = n_heads
        self.out_d = out_d
        
        # Patch Embedding Layer
        self.patch_embed = PatchEmbedding(img_size=chw[1], patch_size=chw[1] // n_patches, in_channels=chw[0], embed_dim=hidden_d)
        
        # Transformer Encoder Blocks
        self.encoder_blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model=hidden_d, nhead=n_heads) for _ in range(2)
        ])
        
        # Classification head
        self.classification_head = nn.Linear(hidden_d, out_d)
        
    def forward(self, images):
        patches = self.patch_embed(images)
        
        # Transformer Encoder Blocks
        for block in self.encoder_blocks:
            patches = block(patches)
        
        cls_token = patches[:, 0]  # Extract classification token
        output = self.classification_head(cls_token)
        
        return output

In [5]:
# Training
def train(model, train_loader, optimizer, criterion, device):
    model.train()
    train_loss = 0.0
    for images, labels in tqdm(train_loader, desc="Training", leave=False):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * images.size(0)
    
    return train_loss / len(train_loader.dataset)

In [6]:
# Prepare dataset and dataloaders
transform = ToTensor()
train_set = MNIST(root='./data', train=True, download=True, transform=transform)
test_set = MNIST(root='./data', train=False, download=True, transform=transform)

batch_size = 128
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Model Initialization
model = MyViT(chw=(1, 28, 28), n_patches=7).to(device)
optimizer = Adam(model.parameters(), lr=0.001)
criterion = CrossEntropyLoss()

num_epochs = 5

for epoch in range(num_epochs):
    train_loss = train(model, train_loader, optimizer, criterion, device)
    print(f"Epoch [{epoch + 1}/{num_epochs}] - Train Loss: {train_loss:.4f}")

                                                                                                                       

Epoch [1/5] - Train Loss: 2.3144


Training:   6%|████                                                                   | 27/469 [00:17<04:40,  1.58it/s]

In [None]:
# Test loop
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in tqdm(test_loader, desc="Testing", leave=False):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f"Accuracy on test set: {accuracy:.2f}%")
