In [73]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
import torch.nn.functional as F
import torch.optim as optim

In [74]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
batch_size = 64

# CIFAR-10 dataset preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load CIFAR-10 dataset
train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)

# Data loaders
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)


Files already downloaded and verified
Files already downloaded and verified


In [75]:
class MultiheadAttentionEinsum(nn.Module):
    def __init__(self, embedding_dim, num_heads):
        super(MultiheadAttentionEinsum, self).__init__()
        self.embedding_dim = embedding_dim
        self.num_heads = num_heads
        self.head_dim = embedding_dim // num_heads

        assert self.head_dim * num_heads == embedding_dim, "embedding_dim must be divisible by num_heads"

        self.q_linear = nn.Linear(embedding_dim, embedding_dim)
        self.k_linear = nn.Linear(embedding_dim, embedding_dim)
        self.v_linear = nn.Linear(embedding_dim, embedding_dim)
        self.fc_out = nn.Linear(embedding_dim, embedding_dim)

    def forward(self, query, key, value):
        batch_size = query.size(0)

        # Linear projections
        Q = self.q_linear(query)
        K = self.k_linear(key)
        V = self.v_linear(value)

        # Reshape and permute for multi-head attention
        Q = Q.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        K = K.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        V = V.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        # Scaled dot-product attention
        energy = torch.einsum("bnqd,bnkd->bnqk", Q, K)
        scaling_factor = self.head_dim ** 0.5
        scaled_energy = energy / scaling_factor
        attention = F.softmax(scaled_energy, dim=-1)

        # Attention values
        attended_values = torch.einsum("bnqk,bnvd->bnqd", attention, V)

        # Concatenate heads and put through final linear layer
        attended_values = attended_values.permute(0, 2, 1, 3).contiguous()
        attended_values = attended_values.view(batch_size, -1, self.embedding_dim)
        out = self.fc_out(attended_values)

        return out
    
class TransformerEncoderLayer(nn.Module):
    def __init__(self, embedding_dim, num_heads):
        super(TransformerEncoderLayer, self).__init__()
        self.multihead_attention = MultiheadAttentionEinsum(embed_dim=embedding_dim, num_heads=num_heads)
        self.feed_forward = nn.Sequential(
            nn.Linear(embedding_dim, 2048),
            nn.ReLU(),
            nn.Linear(2048, embedding_dim)
        )
        self.layer_norm1 = nn.LayerNorm(embedding_dim)
        self.layer_norm2 = nn.LayerNorm(embedding_dim)

    def forward(self, x):
        residual = x
        x = self.layer_norm1(x)
        x = x.permute(1, 0, 2)  # (seq_len, batch_size, embedding_dim)
        attn_output = self.multihead_attention(x, x, x)[0]  # self-attention
        x = attn_output + residual
        x = x.permute(1, 0, 2)  # (batch_size, seq_len, embedding_dim)

        residual = x
        x = self.layer_norm2(x)
        x = self.feed_forward(x)
        x = x + residual

        return x
    
class VisionTransformer(nn.Module):
    def __init__(self, num_classes, patch_size, embedding_dim, num_heads, num_layers):
        super(VisionTransformer, self).__init__()
        self.patch_embedding = nn.Conv2d(3, embedding_dim, kernel_size=patch_size, stride=patch_size)
        self.positional_encoding = nn.Parameter(torch.randn(1, 14 * 14 + 1, embedding_dim))
        self.transformer_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=num_heads) for _ in range(num_layers)
        ])
        self.fc = nn.Linear(embedding_dim, num_classes)

    def forward(self, x):
        batch_size = x.size(0)
        x = self.patch_embedding(x)
        x = x.flatten(2).transpose(1, 2)
        x = torch.cat((x, self.positional_encoding.repeat(batch_size, 1, 1)), dim=1)
        for layer in self.transformer_layers:
            x = layer(x)
        x = x.mean(dim=1)
        x = self.fc(x)
        return x

In [76]:
# Initialize the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

num_epochs = 10
learning_rate = 0.001
num_classes = 10
patch_size = 16
embedding_dim = 128
num_heads = 8
num_layers = 3

model = VisionTransformer(num_classes, patch_size, embedding_dim, num_heads, num_layers).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

losses=[]
# Training loop
total_steps = len(train_loader)
for epoch in range(num_epochs):
    model.train()
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.item())

        if (i + 1) % 100 == 0:
            print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{total_steps}], Loss: {loss.item():.4f}')


Epoch [1/10], Step [100/782], Loss: 1.9912
Epoch [1/10], Step [200/782], Loss: 2.1571
Epoch [1/10], Step [300/782], Loss: 2.2335
Epoch [1/10], Step [400/782], Loss: 2.2162
Epoch [1/10], Step [500/782], Loss: 2.1070
Epoch [1/10], Step [600/782], Loss: 2.1525
Epoch [1/10], Step [700/782], Loss: 2.0390
Epoch [2/10], Step [100/782], Loss: 2.2897
Epoch [2/10], Step [200/782], Loss: 2.2710
Epoch [2/10], Step [300/782], Loss: 2.0789
Epoch [2/10], Step [400/782], Loss: 2.3112
Epoch [2/10], Step [500/782], Loss: 2.2923
Epoch [2/10], Step [600/782], Loss: 2.2128
Epoch [2/10], Step [700/782], Loss: 2.1280
Epoch [3/10], Step [100/782], Loss: 2.0410
Epoch [3/10], Step [200/782], Loss: 2.1790
Epoch [3/10], Step [300/782], Loss: 2.0591
Epoch [3/10], Step [400/782], Loss: 2.2117
Epoch [3/10], Step [500/782], Loss: 2.1006
Epoch [3/10], Step [600/782], Loss: 2.1199
Epoch [3/10], Step [700/782], Loss: 2.2086
Epoch [4/10], Step [100/782], Loss: 2.1662
Epoch [4/10], Step [200/782], Loss: 2.1268
Epoch [4/10

In [77]:

# Testing phase
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Test Accuracy of the model on the {total} test images: {accuracy:.2f}%')

Test Accuracy of the model on the 10000 test images: 19.18%
