In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm

class PatchEmbedding(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_dim=256):
        super().__init__()
        # CIFAR image size is 32x32, we'll use 4x4 patches
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        # Linear projection of flattened patches
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        # Learnable classification token
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        # Learnable position embeddings
        self.pos_embed = nn.Parameter(torch.randn(1, self.n_patches + 1, embed_dim))

    def forward(self, x):
        # x shape: (batch_size, channels, height, width)
        batch_size = x.shape[0]
        # Project and flatten patches
        x = self.proj(x)  # (batch_size, embed_dim, h', w')
        x = x.flatten(2)  # (batch_size, embed_dim, n_patches)
        x = x.transpose(1, 2)  # (batch_size, n_patches, embed_dim)

        # Add classification token
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        # Add position embeddings
        x = x + self.pos_embed
        return x

class TransformerEncoder(nn.Module):
    def __init__(self, embed_dim=256, num_heads=8, mlp_ratio=4, drop_rate=0.1):
        super().__init__()
        # Multi-head Self-attention
        self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=drop_rate, batch_first=True)
        # MLP block
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_ratio * embed_dim),
            nn.GELU(),
            nn.Dropout(drop_rate),
            nn.Linear(mlp_ratio * embed_dim, embed_dim),
            nn.Dropout(drop_rate)
        )
        # Layer normalization
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        # Attention block with residual connection
        x = x + self.attention(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        # MLP block with residual connection
        x = x + self.mlp(self.norm2(x))
        return x

class VisionTransformer(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_channels=3,
                 embed_dim=256, depth=6, num_heads=8, mlp_ratio=4,
                 num_classes=10, drop_rate=0.1):
        super().__init__()

        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        self.transformer_blocks = nn.ModuleList([
            TransformerEncoder(embed_dim, num_heads, mlp_ratio, drop_rate)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        # Patch embedding
        x = self.patch_embed(x)

        # Transformer blocks
        for block in self.transformer_blocks:
            x = block(x)

        # Classification head
        x = self.norm(x)
        x = x[:, 0]  # Use only the [CLS] token
        x = self.head(x)
        return x

# Training setup
def get_data_loaders(batch_size=128, train_subset_size=400, test_subset_size=100):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                          download=True, transform=transform)

    # Create a subset of the set
    train_subset = Subset(trainset, torch.arange(train_subset_size))

    trainloader = DataLoader(train_subset, batch_size=batch_size,
                           shuffle=True, num_workers=2)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                         download=True, transform=transform)
    test_subset = Subset(testset, torch.arange(test_subset_size))
    testloader = DataLoader(test_subset, batch_size=batch_size,
                          shuffle=False, num_workers=2)

    return trainloader, testloader

# Training loop
def train_model(model, trainloader, epochs=10, device='cpu'):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.1)

    model = model.to(device)

    for epoch in tqdm(range(epochs), desc='Epochs'):
        model.train()
        running_loss = 0.0
        for i, (inputs, labels) in enumerate(trainloader):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % 100 == 99:
                print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 100:.3f}')
                running_loss = 0.0

def evaluate_model(model, testloader, device='cpu'):
    model.eval()
    correct_1 = 0
    correct_5 = 0
    total = 0
    loss_total = 0
    criterion = nn.CrossEntropyLoss()

    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss_total += loss.item()

            # Top-1 accuracy
            _, predicted = outputs.max(1)
            correct_1 += predicted.eq(labels).sum().item()

            # Top-5 accuracy
            _, top5_predicted = outputs.topk(5, 1)
            for i in range(labels.size(0)):
                if labels[i] in top5_predicted[i]:
                    correct_5 += 1

            total += labels.size(0)

    top1_accuracy = 100.0 * correct_1 / total
    top5_accuracy = 100.0 * correct_5 / total
    avg_loss = loss_total / len(testloader)

    print(f'Top-1 Accuracy: {top1_accuracy:.2f}%')
    print(f'Top-5 Accuracy: {top5_accuracy:.2f}%')
    print(f'Average Loss: {avg_loss:.4f}')

    return {
        'top1_accuracy': top1_accuracy,
        'top5_accuracy': top5_accuracy,
        'loss': avg_loss
    }


# Usage example
if __name__ == '__main__':
    device = torch.device('mps' if torch.cuda.is_available() else 'cpu')

    # Initialize model with smaller parameters for MacBook Pro
    model = VisionTransformer(
        img_size=32,        # CIFAR-10 image size
        patch_size=4,       # 4x4 patches
        in_channels=3,      # RGB images
        embed_dim=256,      # Smaller embedding dimension
        depth=6,            # Fewer transformer layers
        num_heads=8,        # Number of attention heads
        num_classes=10      # CIFAR-10 classes
    )
    print('Initialized model')
    trainloader, testloader = get_data_loaders(batch_size=32)
    print('Loaded data')
    train_model(model, trainloader, epochs=1, device=device)

Initialized model
Files already downloaded and verified
Files already downloaded and verified
Loaded data


Epochs: 100%|██████████| 1/1 [00:11<00:00, 11.78s/it]


In [None]:
# save model
torch.save(model.state_dict(), 'vit_cifar10.pth')

In [None]:
metrics = evaluate_model(model, testloader, device)
metrics

Top-1 Accuracy: 19.00%
Top-5 Accuracy: 69.00%
Average Loss: 2.1644


{'top1_accuracy': 19.0, 'top5_accuracy': 69.0, 'loss': 2.164409816265106}