# A Pytorch Implementation of [Vision Transformer](https://arxiv.org/pdf/2010.11929).

Vision Transformer (ViT) extracts patches from images and feed them into a Transformer encoder to obtain a global representation, which will finally be transformed for classification.

In [None]:
import copy, math
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from typing import Tuple, Optional, Callable

In [None]:
def clones(module: nn.Module, N: int) -> nn.ModuleList:
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

# PatchEmbedding

The standard Transformer receives as input a 1D sequence of token embeddings. To handle 2D images, ViT reshapes the image $x \in \mathbb{R}^{H \times W \times C}$ into a sequence of falttened 2D patches $x_p \in \mathbb{R}^{N \times (P^2 \cdot C)}$, where $(H, W)$ is the resolution of the original image, $C$ is the number of channels, $(P, P)$ is the resolution of each image patch, and $N = HW / P^2$ is the resulting number of patches, which also serves as the effective input sequence length for the Transformer. The Transformer uses constant latent vector size $D$ through all of its layers, so ViT flattens the patches and map to $D$ dimensions with a trainable linear projection. The output of this projection is called the patch embeddings.

Spliting an image into patches and linearly projecting these flattened patches can be simplified as a single convolution operation, where both the kernel size and the stride size are set to the patch size.

ViT prepends a learnable embedding (cls_token in the following code snippet) to the sequence of embeded patches, whose state at the output of the Transformer encoder serves as the image representation.

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self,
                 patch_size: int = 16,
                 channels_in: int = 3,
                 d_model: int = 512) -> None:
        super(PatchEmbedding, self).__init__()
        # self.num_patches = (img_size // patch_size) ** 2
        self.conv = nn.Conv2d(channels_in, d_model, kernel_size=patch_size, stride=patch_size)
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))

    def forward(self,
                x: torch.Tensor) -> torch.Tensor:
        # shape of x: [b, c, h, w]
        # shape of output: [b, d, h // p, w // p] -> [b, d, n_p] -> [b, n_p, d] -> [b, n_p + 1, d]
        x = self.conv(x).flatten(2).transpose(1, 2)
        return torch.cat([self.cls_token.expand(x.size(0), -1, -1), x], dim=1)

Position embeddings are added to the patch embeddings to retain positional information. ViT uses standard learnable 1D position embeddings.

When feeding images of higher resolution after pre-training, the ViT keeps the patch size the same, which results in a larger effective sequence length. The ViT can handle arbitrary sequence lengths, however, the pre-trained position embeddings may no longer be meaningful. The ViT therefore preform 2D interpolation of the pre-trained position embeddings, according to their location in the original image.

In [None]:
class PositionalEmbedding(nn.Module):
    def __init__(self,
                 num_steps: int,
                 d_model: int = 512,
                 dropout: float = 0.1) -> None:
        super(PositionalEmbedding, self).__init__()
        self.pos_embedding = nn.Parameter(torch.randn(1, num_steps, d_model))
        self.dropout = nn.Dropout(dropout)

    def forward(self,
                x: torch.Tensor) -> torch.Tensor:
        return self.dropout(x + self.pos_embedding)

# Scale Dot Production Attention and Multi-Head Attention

In [None]:
def subsequent_mask(size: int) -> torch.Tensor:
    attn_shape = (1, size, size)
    return torch.triu(torch.ones(attn_shape), 1).type(torch.uint8) == 0

def attention(query: torch.Tensor,
              key: torch.Tensor,
              value: torch.Tensor,
              mask: Optional[torch.Tensor] = None,
              dropout: Optional[nn.Dropout] = None) -> Tuple[torch.Tensor]:
    # shape of query: [b, nq, d_k] or [b, h, nq, d_k]
    # shape of key: [b, n, d_k] or [b, h, n, d_k]
    # shape of value: [b, n, d_v] or [b, h, n, d_v]
    d_k = query.size(-1)
    # shape of scores: [b, nq, n] or [b, h, nq, n]
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    # shape of mask: [b, 1 or nq, n] or [b, 1, 1 or nq, n]
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = F.softmax(scores, dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    # shape of output: [b, nq, d_v] or [b, h, nq, d_v]
    return torch.matmul(p_attn, value), p_attn

class MultiHeadAttention(nn.Module):
    def __init__(self,
                 h: int,
                 d_model: int,
                 dropout: float = 0.1) -> None:
        super(MultiHeadAttention, self).__init__()
        assert d_model % h == 0
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(dropout)
        
    def forward(self,
                query: torch.Tensor,
                key: torch.Tensor,
                value: torch.Tensor,
                mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        # add a dimension for Multi-Head attention
        if mask is not None:
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)
        query, key, value = [lin(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
                             for lin, x in zip(self.linears, (query, key, value))]
        x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)
        x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
        return self.linears[-1](x)

# Add & Norm

In [None]:
class SublayerConnection(nn.Module):
    def __init__(self,
                 size: int,
                 dropout: float) -> None:
        super(SublayerConnection, self).__init__()
        self.norm = nn.LayerNorm(size)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self,
                x: torch.Tensor,
                sublayer: Callable) -> torch.Tensor:
        return x + self.dropout(sublayer(self.norm(x)))

# Positionwise FFN

In [None]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self,
                 d_model: int,
                 d_ff: int,
                 dropout: float=0.1) -> None:
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self,
                x: torch.Tensor) -> torch.Tensor:
        return self.w_2(self.dropout(F.gelu(self.w_1(x))))

# Encoder

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self,
                 size: int,
                 self_attn: MultiHeadAttention,
                 feed_forward: PositionwiseFeedForward,
                 dropout: float) -> None:
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2)
        self.size = size
        
    def forward(self,
                x: torch.Tensor,
                mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        return self.sublayer[1](x, self.feed_forward)

In [None]:
class Encoder(nn.Module):
    def __init__(self,
                 layer: EncoderLayer,
                 N: int) -> None:
        super(Encoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = nn.LayerNorm(layer.size)
        
    def forward(self,
                x: torch.Tensor,
                mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

# Vision Transformer

In [None]:
class Generator(nn.Module):
    def __init__(self,
                 d_model: int,
                 num_classes: int) -> None:
        super(Generator, self).__init__()
        self.proj = nn.Linear(d_model, num_classes)

    def forward(self,
                x: torch.Tensor) -> torch.Tensor:
        return self.proj(x)

In [None]:
class ViT(nn.Module):
    def __init__(self,
                 encoder: Encoder,
                 embed,
                 generator: Generator) -> None:
        super(ViT, self).__init__()
        self.embed = embed
        self.encoder = encoder
        self.generator = generator

    def forward(self,
                x: torch.Tensor) -> torch.Tensor:
        return self.generator(self.encode(x)[:, 0])

    def encode(self,
               x: torch.Tensor) -> torch.Tensor:
        # output shape: [b, n + 1, d]
        return self.encoder(self.embed(x), mask=None)

In [None]:
def make_model(img_size: int=96,
               patch_size: int=16,
               num_channels: int=3,
               d_model: int=512,
               num_classes: int=10,
               h: int=8,
               N: int=6,
               d_ff: int=2048,
               dropout: float=0.1) -> ViT:
    attn = MultiHeadAttention(h, d_model, dropout)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    p = PositionalEmbedding((img_size // patch_size) ** 2 + 1, d_model, dropout)
    model = ViT(Encoder(EncoderLayer(d_model, attn, ff, dropout), N),
                nn.Sequential(PatchEmbedding(patch_size, num_channels, d_model), p),
                Generator(d_model, num_classes))
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    return model

# Train

In [13]:
batch_size = 128
img_size = 96
transform = transforms.Compose([
    transforms.Resize(img_size),
    transforms.ToTensor()
])
data_train = datasets.FashionMNIST(root="../data", train=True, download=True, transform=transform)
data_val = datasets.FashionMNIST(root="../data", train=False, download=True, transform=transform)
loader_train = DataLoader(data_train, batch_size=batch_size, shuffle=True)
loader_val = DataLoader(data_val, batch_size=batch_size, shuffle=False)

In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
patch_size = 16
d_model, d_ff, h, N = 512, 2048, 8, 2
dropout = 0.1
num_classes = 10
model = make_model(img_size, patch_size, 1, d_model, num_classes, h, N, d_ff, dropout).to(device)
optim = torch.optim.Adam(model.parameters(), lr=1e-4)
loss = F.cross_entropy

In [16]:
max_epochs = 50
for epoch in range(max_epochs):
    model.train()
    train_loss = train_count = train_acc = 0
    for i, (x, y) in enumerate(loader_train):
        x, y = x.to(device), y.to(device)
        y_pred = model(x)
        l = loss(y_pred, y, reduction="sum")
        optim.zero_grad()
        l.backward()
        optim.step()
        with torch.no_grad():
            train_loss += l.item()
            train_count += y.size(0)
            train_acc += (y_pred.argmax(1) == y).sum().item()
    model.eval()
    val_loss = val_count = val_acc = 0
    with torch.no_grad():
        for x, y in loader_val:
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            l = loss(y_pred, y, reduction="sum")
            val_loss += l.item()
            val_count += y.size(0)
            val_acc += (y_pred.argmax(1) == y).sum().item()
    print(f"Epoch: {epoch + 1:03d}, Train Loss: {train_loss / train_count:.4f}, Train Acc: {train_acc / train_count:.4f}, Val Loss: {val_loss / val_count:.4f}, Val Acc: {val_acc / val_count:.4f}")

Epoch: 001, Train Loss: 0.0715, Train Acc: 0.9727, Val Loss: 0.5317, Val Acc: 0.8864
Epoch: 002, Train Loss: 0.0698, Train Acc: 0.9737, Val Loss: 0.4823, Val Acc: 0.8867
Epoch: 003, Train Loss: 0.0691, Train Acc: 0.9733, Val Loss: 0.4955, Val Acc: 0.8841
Epoch: 004, Train Loss: 0.0655, Train Acc: 0.9756, Val Loss: 0.5458, Val Acc: 0.8887
Epoch: 005, Train Loss: 0.0653, Train Acc: 0.9755, Val Loss: 0.5375, Val Acc: 0.8890
Epoch: 006, Train Loss: 0.0616, Train Acc: 0.9776, Val Loss: 0.5268, Val Acc: 0.8831
Epoch: 007, Train Loss: 0.0592, Train Acc: 0.9784, Val Loss: 0.5192, Val Acc: 0.8875
Epoch: 008, Train Loss: 0.0547, Train Acc: 0.9797, Val Loss: 0.5741, Val Acc: 0.8850
Epoch: 009, Train Loss: 0.0536, Train Acc: 0.9799, Val Loss: 0.5692, Val Acc: 0.8895
Epoch: 010, Train Loss: 0.0528, Train Acc: 0.9803, Val Loss: 0.5776, Val Acc: 0.8864
Epoch: 011, Train Loss: 0.0566, Train Acc: 0.9793, Val Loss: 0.5722, Val Acc: 0.8854
Epoch: 012, Train Loss: 0.0529, Train Acc: 0.9800, Val Loss: 0.56

In [None]:
torch.save(model.state_dict(), "model.pth")

## Reference
1. [Official implementation](https://github.com/google-research/vision_transformer)
2. [d2l](https://d2l.ai)