# 16-[ViT]AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE

In [2]:
import torch
import ml_collections

from torch import nn
from torch.nn.modules.utils import _pair

In [4]:
def get_b16_config():
    """Returns the ViT-B/16 configuration."""
    config = ml_collections.ConfigDict()
    config.patches = ml_collections.ConfigDict({'size': (16, 16)})
    config.hidden_size = 768
    config.transformer = ml_collections.ConfigDict()
    config.transformer.mlp_dim = 3072
    config.transformer.num_heads = 12
    config.transformer.num_layers = 12
    config.transformer.attention_dropout_rate = 0.0
    config.transformer.dropout_rate = 0.1
    config.classifier = 'token'
    config.representation_size = None
    
    return config

In [None]:
class Embeddings(nn.Module):
    def __init__(self, config, img_size, in_channels=3):
        super().__init__()
        img_size = _pair(img_size)  # img_size : (224, 224)
        patch_size = _pair(config.patches["size"])  # patch_size : (16, 16)
        
        # 전체 패치 수 계산: (224/16) * (224/16) = 196
        n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])

        # Conv2d 출력 : [batch_size, embedd_dim, 14, 14]
        # stride를 patch_size로 설정함으로써 n_patches개의 patch가 생성된다.
        self.patch_embeddings = nn.Conv2d(in_channels=in_channels, out_channels=config.hidden_size, kernel_size=patch_size, stride=patch_size)

        # position_embeddings : [1, 197, embedd_dim]
        self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches+1, config.hidden_size))
        
        # cls_token : [1, 1, embedd_dim], 클래스 분류를 위한 클래스 토큰.
        self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))

        self.dropout = nn.Dropout(config.transformer["dropout_rate"])

    def forward(self, x):
        batch_size = x.shape[0]  # x : [batch_size, 3, 224, 224]
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)  # expanded cls_tokens : [batch_size, 1, embedd_dim]
        
        x = self.patch_embeddings(x) # image -> patches : [batch_size, embedd_dim, 14, 14], 여기서 14는 224 // 16
        x = x.flatten(2)  # flatten : [batch_size, embedd_dim, 196]
        x = x.transpose(-1, -2)  # transpose : [batch_size, 196, embedd_dim]
        x = torch.cat((cls_tokens, x), dim=1)  # cat : [batch_size, 197, embedd_dim]

        embeddings = x + self.position_embeddings  # broad casting을 통한 position_embeddings 추가 : [batch_size, 197, embedd_dim]
        embeddings = self.dropout(embeddings)  # dropout : [batch_size, 197, embedd_dim]
        
        return embeddings
