In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


In [9]:
#load CIFAR-10 data
from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_data_cifar = datasets.CIFAR10('data_cifar', train=True, download=True, transform=transform)
test_data_cifar = datasets.CIFAR10('data_cifar', train=False, download=True, transform=transform)



Files already downloaded and verified
Files already downloaded and verified


In [10]:
#load mnist dataset
from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_data_mnist = datasets.MNIST('data_mnist', train=True, download=True, transform=transform)
test_data_mnist = datasets.MNIST('data_mnist', train=False, download=True, transform=transform)


In [11]:
class General_Learnable_RPE(nn.Module):
    def __init__(self, embed_dim, num_heads, max_len=512):
        super(General_Learnable_RPE, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.max_len = max_len
        self.rpe = nn.Parameter(torch.randn(num_heads, embed_dim//num_heads))
    
    def forward(self, distances):
        batch_size, num_patches, _ = distances.size()
        distances = distances.unsqueeze(1).repeat(1, self.num_heads, 1, 1)
        distances = distances / self.max_len
        rpe = self.rpe.unsqueeze(1).unsqueeze(1)
        rpe = rpe.repeat(batch_size, num_patches, 1, 1)
        rpe = rpe * distances
        rpe = torch.sum(rpe, dim=-1)
        return rpe

In [4]:
class PatchEmbedding(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, embed_dim):
        super(PatchEmbedding, self).__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.projection = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.projection(x)  # (batch_size, embed_dim, num_patches_w, num_patches_h)
        x = x.permute(0, 2, 3, 1)  # (batch_size, num_patches_w, num_patches_h, embed_dim)
        return x

In [5]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_dim, dropout=0.1):
        super(TransformerEncoderLayer, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_dim),
            nn.GELU(),
            nn.Linear(mlp_dim, embed_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        # Self-attention
        x_att = self.attention(x, x, x)[0]
        x = x + x_att
        x = self.norm1(x)

        # Feedforward layer
        x_mlp = self.mlp(x)
        x = x + x_mlp
        x = self.norm2(x)
        return x

class TransformerEncoder(nn.Module):
    def __init__(self, num_layers, embed_dim, num_heads, mlp_dim, dropout=0.1):
        super(TransformerEncoder, self).__init__()
        self.layers = nn.ModuleList([
            TransformerEncoderLayer(embed_dim, num_heads, mlp_dim, dropout)
            for _ in range(num_layers)
        ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

In [6]:
class ClassificationHead(nn.Module):
    def __init__(self, embed_dim, num_classes):
        super(ClassificationHead, self).__init__()
        self.fc = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = x.mean(dim=1)  # Global average pooling
        x = self.fc(x)
        return x

In [7]:
class ViT(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, embed_dim, num_patches, num_classes, 
                 num_layers, num_heads, mlp_dim, dropout=0.1):
        super(ViT, self).__init__()
        self.patch_embed = PatchEmbedding(image_size, patch_size, in_channels, embed_dim)
        self.pos_encoding = PositionalEncoding(num_patches, embed_dim)
        self.transformer_encoder = TransformerEncoder(num_layers, embed_dim, num_heads, mlp_dim, dropout)
        self.classification_head = ClassificationHead(embed_dim, num_classes)

    def forward(self, x):
        x = self.patch_embed(x)
        x = self.pos_encoding(x)
        x = self.transformer_encoder(x)
        x = self.classification_head(x)
        return x

In [8]:
# Example usage
image_size = 224
patch_size = 16
in_channels = 3
embed_dim = 256
num_patches = (image_size // patch_size) ** 2
num_classes = 1000
num_layers = 6
num_heads = 8
mlp_dim = 512
dropout = 0.1


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_epochs = 10
batch_size = 64
learning_rate = 0.001

train_loader = DataLoader(train_data_cifar, batch_size=batch_size, shuffle=True)


# Initialize model, loss function, and optimizer
model = ViT(image_size=224, patch_size=16, in_channels=3, embed_dim=256, num_patches=196, num_classes=10,
            num_layers=6, num_heads=8, mlp_dim=512, dropout=0.1).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)


for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0
    
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Compute statistics
        running_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs, 1)
        correct_predictions += (predicted == labels).sum().item()
        total_samples += labels.size(0)
    
    epoch_loss = running_loss / total_samples
    epoch_accuracy = correct_predictions / total_samples
    
    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.4f}')
