In [1]:
# Libraries + Config

import time
import random
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as T
from tqdm.auto import tqdm


# -------------------------
# Config
# -------------------------

@dataclass
class Config:
    # Data
    img_size: int = 32
    patch_size: int = 4
    in_chans: int = 3
    num_classes: int = 10

    # ViT
    d_model: int = 64
    depth: int = 6
    num_heads: int = 4
    mlp_dim: int = 256
    dropout: float = 0.1

    # Performer
    performer_eps: float = 1e-6

    # Train
    batch_size: int = 128
    epochs: int = 10
    lr: float = 3e-4
    weight_decay: float = 0.05
    seed: int = 0
    num_workers: int = 0


cfg = Config()


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True


set_seed(cfg.seed)


# -------------------------
# Data + Device
# -------------------------

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

train_tfms = T.Compose([
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465),
                (0.2470, 0.2435, 0.2616)),
])

test_tfms = T.Compose([
    T.ToTensor(),
    T.Normalize((0.4914, 0.4822, 0.4465),
                (0.2470, 0.2435, 0.2616)),
])

train_ds = torchvision.datasets.CIFAR10(
    root="./data", train=True, download=True, transform=train_tfms
)
test_ds = torchvision.datasets.CIFAR10(
    root="./data", train=False, download=True, transform=test_tfms
)

train_loader = DataLoader(
    train_ds,
    batch_size=cfg.batch_size,
    shuffle=True,
    num_workers=cfg.num_workers,
)

test_loader = DataLoader(
    test_ds,
    batch_size=cfg.batch_size,
    shuffle=False,
    num_workers=cfg.num_workers,
)


# -------------------------
# Model Components
# -------------------------

class PatchEmbed(nn.Module):
    def __init__(self, img_size, patch_size, in_chans, embed_dim):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(
            in_chans, embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )

    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)
        return x


class MLP(nn.Module):
    def __init__(self, d_model, mlp_dim, dropout):
        super().__init__()
        self.fc1 = nn.Linear(d_model, mlp_dim)
        self.fc2 = nn.Linear(mlp_dim, d_model)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


# -------------------------
# Strict Performer-ReLU
# -------------------------

class PerformerReLUAttention(nn.Module):
    """
    Strict Performer-ReLU:
        phi(x) = ReLU(x)
    """
    def __init__(self, d_model, num_heads, dropout, eps=1e-6):
        super().__init__()
        assert d_model % num_heads == 0
        self.num_heads = num_heads
        self.d_head = d_model // num_heads
        self.eps = eps

        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        B, N, D = x.shape

        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        q = q.view(B, N, self.num_heads, self.d_head).transpose(1, 2)
        k = k.view(B, N, self.num_heads, self.d_head).transpose(1, 2)
        v = v.view(B, N, self.num_heads, self.d_head).transpose(1, 2)

        q_phi = F.relu(q)
        k_phi = F.relu(k)

        kv = torch.einsum("b h n d, b h n e -> b h d e", k_phi, v)
        num = torch.einsum("b h n d, b h d e -> b h n e", q_phi, kv)

        k_sum = k_phi.sum(dim=2)
        denom = torch.einsum("b h n d, b h d -> b h n", q_phi, k_sum)
        denom = denom.unsqueeze(-1) + self.eps

        out = num / denom
        out = out.transpose(1, 2).contiguous().view(B, N, D)
        out = self.out_proj(out)
        out = self.drop(out)
        return out


# -------------------------
# Encoder Block
# -------------------------

class EncoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, mlp_dim, dropout, attn_type):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)

        if attn_type == "regular":
            self.attn = nn.MultiheadAttention(
                d_model, num_heads, dropout=dropout, batch_first=True
            )
            self.is_regular = True
        else:
            self.attn = PerformerReLUAttention(
                d_model, num_heads, dropout
            )
            self.is_regular = False

        self.norm2 = nn.LayerNorm(d_model)
        self.mlp = MLP(d_model, mlp_dim, dropout)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        h = self.norm1(x)
        if self.is_regular:
            attn_out, _ = self.attn(h, h, h, need_weights=False)
        else:
            attn_out = self.attn(h)
        x = x + self.drop(attn_out)

        h = self.norm2(x)
        x = x + self.drop(self.mlp(h))
        return x


# -------------------------
# Vision Transformer
# -------------------------

class ViT(nn.Module):
    def __init__(self):
        super().__init__()
        self.patch_embed = PatchEmbed(
            cfg.img_size, cfg.patch_size, cfg.in_chans, cfg.d_model
        )
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, cfg.d_model))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, cfg.d_model))
        self.pos_drop = nn.Dropout(cfg.dropout)

        self.blocks = nn.ModuleList([
            EncoderBlock(
                cfg.d_model,
                cfg.num_heads,
                cfg.mlp_dim,
                cfg.dropout,
                attn_type = "performer_relu" if (i % 2 == 0) else "regular"
            )
            for i in range(cfg.depth)
        ])

        self.norm = nn.LayerNorm(cfg.d_model)
        self.head = nn.Linear(cfg.d_model, cfg.num_classes)

    def forward(self, x):
        x = self.patch_embed(x)
        B, N, _ = x.shape

        cls = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls, x], dim=1)
        x = x + self.pos_embed[:, :N+1]
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x)
        return self.head(x[:, 0])


# -------------------------
# Train & Eval
# -------------------------

@torch.no_grad()
def evaluate_accuracy(model):
    model.eval()
    correct = total = 0
    for x, y in test_loader:
        x, y = x.to(device), y.to(device)
        pred = model(x).argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.numel()
    return correct / total


def train():
    model = ViT().to(device)
    optimizer = torch.optim.AdamW(
        model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay
    )
    criterion = nn.CrossEntropyLoss()

    t0 = time.time()
    for epoch in range(cfg.epochs):
        model.train()
        for x, y in tqdm(train_loader, desc=f"Epoch {epoch+1}/{cfg.epochs}"):
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            loss = criterion(model(x), y)
            loss.backward()
            optimizer.step()

    train_time = time.time() - t0
    acc = evaluate_accuracy(model)

    print(f"Test Accuracy: {acc*100:.2f}%")
    print(f"Train Time: {train_time:.2f}s")


train()


Device: cuda


100%|██████████| 170M/170M [00:03<00:00, 45.5MB/s]


Epoch 1/10:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 2/10:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 3/10:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 4/10:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 5/10:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 6/10:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 7/10:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 8/10:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 9/10:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 10/10:   0%|          | 0/391 [00:00<?, ?it/s]

Test Accuracy: 62.16%
Train Time: 198.22s
