vision transformer只用了transformer的encoder部分，并且是做成一个分类任务。

将transformer直接应用到像素图片上有以下问题：

1. 计算量特别大。
2. 单个像素点不像单个单词那样包含的信息量那么大。

vision transformer将图像分成很多个块，将一个个块当作token送入transformer中去。对于这个块有两种方式处理：

1. 将图片切割成多个块(image2patch)，然后将这个快经过一个仿射变换得到embedding(patch2embedding)。
2. 也可以理解成图片得到patch的过程是卷积操作，并且kernel size等于stride。然后将卷积过后的结果拉直，得到token embedding。

为了去做分类任务，vision transformer借鉴了bert中的class token这样的占位符，可以理解为这个class token就是去做query的作用，收集一些这个模型能够做好分类任务的信息。也就是在序列的开头增加一个可训练的embedding。

在vit中同样使用了position embedding。


[An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929)


但是需要将vit这个模型在大量的数据集上做预训练。图片大小可能大一点小一点，但是patch的大小一致，反映在序列上就是序列长短的不一致。

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

## convert image to embedding vector sequence

第一步是需要将`image`变成一个`embedding`，但是有两种方式实现，一种是直接切割，另一种是通过卷积的方式实现。

在切割的时候，我们可以通过`torch.nn.functional.unfold`来去拿到卷积的区域。函数原型为:


```python
torch.nn.functional.unfold(input, kernel_size, dilation=1, padding=0, stride=1)
```

In [2]:
def image2emb_naive(image, patch_size, weight):
    # image shape: (batch_size, channel, h, w)
    patch = F.unfold(image, kernel_size=patch_size, stride=patch_size).transpose(-1, -2)
    patch_embedding = patch @ weight
    return patch_embedding

设置一些超参数:

In [3]:
bs, ic, image_h, image_w = 1, 3, 8, 8
patch_size = 4
model_dim = 8
patch_depth = patch_size * patch_size * ic  # 图像分割完之后每个patch包含的像素点。

In [4]:
max_num_token = 16

除此之外，我们还需要一个weight，也就是patch到embeddin的乘法矩阵，使得最终的输出维度为model_dim。

In [5]:
weight = torch.randn(patch_depth, model_dim)

In [6]:
image = torch.randn(bs, ic, image_h, image_w)
patch_embedding_naive = image2emb_naive(image, patch_size, weight)
print(patch_embedding_naive.shape)

torch.Size([1, 4, 8])


之后，我们再通过卷积来去实现裁剪的过程:

通道数为embedding的size，oh * ow为sequence的长度:

In [7]:
def image2emb_conv(image, kernel, stride):
    conv_output = F.conv2d(image, kernel, stride=stride) # bs * oc * oh * ow
    bs, oc, ow, oh = conv_output.shape
    patch_embedding = conv_output.reshape((bs, oc, oh * ow)).transpose(-1, -2)
    return patch_embedding

weight中的model dim是输出通道数目，patch_depth是卷积核的面积乘以输入通道数，所以我们可以将kernel reshape成oc * ic * kh * kw。

In [8]:
kernel = weight.transpose(0, 1).reshape((-1, ic, patch_size, patch_size))
patch_embedding_conv = image2emb_conv(image, kernel, patch_size)

print(patch_embedding_naive)
print(patch_embedding_conv)

tensor([[[ -0.5641,  -0.7897,   1.6222,  -5.9904,   6.5429,   0.4905,  -2.6371,
            9.2284],
         [-16.0236,  -9.0467,  -0.8464,  -2.0507,  -3.1349,  -3.3175,   5.4241,
           -0.7327],
         [ -2.3553,   1.5403,  -3.5653,   5.9545,   0.6324,  -1.4762,  -2.7189,
           -2.2232],
         [ 12.6431,   4.5958,  -1.0657,  -7.8709,   0.7856,  13.6237,  -0.6142,
            1.6147]]])
tensor([[[ -0.5641,  -0.7897,   1.6222,  -5.9904,   6.5429,   0.4905,  -2.6371,
            9.2284],
         [-16.0236,  -9.0467,  -0.8464,  -2.0507,  -3.1349,  -3.3175,   5.4241,
           -0.7327],
         [ -2.3553,   1.5403,  -3.5653,   5.9545,   0.6324,  -1.4762,  -2.7189,
           -2.2232],
         [ 12.6431,   4.5958,  -1.0657,  -7.8709,   0.7856,  13.6237,  -0.6142,
            1.6147]]])


## class token embedding

In [9]:
cls_token_embedding = torch.randn(bs, 1, model_dim, requires_grad=True)
token_embedding = torch.cat([cls_token_embedding, patch_embedding_conv], dim=1)

## add position embeding 

In [10]:
position_embedding_table = torch.randn(max_num_token, model_dim, requires_grad=True)

In [11]:
seq_len = token_embedding.shape[1]
position_embedding = torch.tile(position_embedding_table[:seq_len], [token_embedding.shape[0], 1, 1])

In [12]:
token_embedding += position_embedding

## pass embedding to transformer encoder

In [13]:
encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim, nhead=8)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
encoder_output = transformer_encoder(token_embedding)

## do classification

In [14]:
num_classes = 10
label = torch.randint(10, (bs,))

In [15]:
cls_token_output = encoder_output[:, 0, :]

linear_layer = nn.Linear(model_dim, num_classes)
logits = linear_layer(cls_token_output)

loss_fn = nn.CrossEntropyLoss()

loss = loss_fn(logits, label)