In [None]:
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, image_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2
        self.proj = nn.Linear(in_chans * patch_size ** 2, embed_dim)

    def forward(self, x):
        B, C, H, W = x.shape
        assert H % self.patch_size == 0 and W % self.patch_size == 0
        x = x.reshape(B, C, H // self.patch_size, self.patch_size, W // self.patch_size, self.patch_size)
        x = x.permute(0, 2, 4, 1, 3, 5)
        x = x.reshape(B, -1, C * self.patch_size * self.patch_size)
        return self.proj(x)

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim=768, num_heads=12, dropout=0.0):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=True)
        self.attn_drop = nn.Dropout(dropout)
        self.scale = self.head_dim ** -0.5

    def forward(self, x):
        B, N, D = x.shape
        qkv = self.qkv(x)
        q, k, v = qkv.chunk(3, dim=-1)
        H = self.num_heads
        Hd = self.head_dim
        q = q.view(B, N, H, Hd).transpose(1, 2)  # (B, H, N, Hd)
        k = k.view(B, N, H, Hd).transpose(1, 2)
        v = v.view(B, N, H, Hd).transpose(1, 2)

        attn = q @ k.transpose(-2, -1) * self.scale
        attn = attn.softmax(dim=-1)
        out = attn @ v  # (B, H, N, Hd)
        out = out.transpose(1, 2).reshape(B, N, D)  # (B, N, D)
        return out

In [None]:
class MLP(nn.Module):
    def __init__(self, embed_dim=768, mlp_ratio=4, dropout=0.0):
        super().__init__()
        self.fc1 = nn.Linear(embed_dim, embed_dim * mlp_ratio)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(embed_dim * mlp_ratio, embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim=768, num_heads=12, mlp_ratio=4, dropout=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttention(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(embed_dim=embed_dim, mlp_ratio=mlp_ratio, dropout=dropout)

    def forward(self, x):
        x = self.attn(self.norm1(x)) + x
        x = self.mlp(self.norm2(x)) + x
        return x

In [None]:
class ViT(nn.Module):
    def __init__(self, image_size=224, n_classes=1000, depth=12, embed_dim=768, patch_size=16, num_heads=12, mlp_ratio=4, dropout=0.1):
        super().__init__()

        self.patch_embed = PatchEmbedding(image_size=image_size, patch_size=patch_size, embed_dim=embed_dim)
        self.num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim) * 0.02)

        self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim) * 0.02)
        self.pos_dropout = nn.Dropout(dropout)

        self.blocks = nn.Sequential(*[TransformerBlock(embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, dropout=dropout) for _ in range(depth)])

        self.norm = nn.LayerNorm(embed_dim)

        self.head = nn.Linear(embed_dim, n_classes)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        cls = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_dropout(x)
        x = self.blocks(x)
        x = self.norm(x)
        cls_output = x[:, 0]
        logits = self.head(cls_output)
        return logits

In [None]:
model = ViT()
model.to(device)

In [None]:
# Count parameters
n_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {n_params:,}")

In [None]:
dummy_input = torch.randn(2, 3, 224, 224).to(device)
output = model(dummy_input)
print(f"Output shape: {output.shape}")

In [None]:
# ImageNet normalization stats
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

transform_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

transform_test = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

In [None]:
from datasets import load_dataset

# Stream ImageNet from Hugging Face (requires authentication for gated dataset)
# Run: huggingface-cli login
train_dataset = load_dataset("ILSVRC/imagenet-1k", split="train", streaming=True)
val_dataset = load_dataset("ILSVRC/imagenet-1k", split="validation", streaming=True)

def transform_example(example, transform):
    image = example["image"].convert("RGB")
    return {"image": transform(image), "label": example["label"]}

train_dataset = train_dataset.map(lambda x: transform_example(x, transform_train))
val_dataset = val_dataset.map(lambda x: transform_example(x, transform_test))

def collate_fn(batch):
    images = torch.stack([x["image"] for x in batch])
    labels = torch.tensor([x["label"] for x in batch])
    return images, labels

batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=4)
test_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_fn, num_workers=4)

print("Streaming ImageNet from Hugging Face")
print(f"Number of classes: 1000")

In [None]:
def train_epoch(model, train_loader, optimizer, device, steps_per_epoch=20018):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    step = 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = nn.CrossEntropyLoss()(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        step += 1
        
        if step >= steps_per_epoch:
            break

    return total_loss / step, 100 * correct / total

In [None]:
def evaluate(model, test_loader, device, steps=782):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    step = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)

            test_loss += nn.CrossEntropyLoss()(outputs, labels).item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            step += 1
            
            if step >= steps:
                break
    return test_loss / step, 100 * correct / total

In [None]:
import time

n_epochs = 90
train_losses, test_losses = [], []
train_accs, test_accs = [], []

# ImageNet: 1,281,167 train / 50,000 val images
# With batch_size=64: ~20018 train steps, ~782 val steps per epoch
TRAIN_STEPS = 20018
VAL_STEPS = 782

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.05)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)

for epoch in range(n_epochs):
    start_time = time.time()
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, device, steps_per_epoch=TRAIN_STEPS)
    test_loss, test_acc = evaluate(model, test_loader, device, steps=VAL_STEPS)
    elapsed_time = time.time() - start_time
    scheduler.step()

    train_losses.append(train_loss)
    train_accs.append(train_acc)
    test_losses.append(test_loss)
    test_accs.append(test_acc)
    print(f"Epoch {epoch+1:2d}/{n_epochs} | "
          f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | "
          f"Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2f}% | "
          f"Elapsed Time: {elapsed_time:.2f}s")

In [None]:
import matplotlib.pyplot as plt

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.plot(train_losses, label="Training Loss")
ax1.plot(test_losses, label="Validation Loss")
ax1.set_xlabel("Epochs")
ax1.set_ylabel("Loss")
ax1.legend()

ax2.plot(train_accs, label='Train')
ax2.plot(test_accs, label='Test')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('Training and Test Accuracy')
ax2.legend()

plt.tight_layout()
plt.show()