In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

step1: convert image to embedding vector sequence

In [2]:
def image2emb_naive(image, patch_size, weight):
    # image size: bs * channel * h * w
    patch = F.unfold(image, kernel_size=patch_size, stride=patch_size).transpose(-1, -2)
    patch_embedding = patch @ weight
    return patch_embedding

def image2emb_conv(image, kernel, stride):
    conv_output = F.conv2d(image, kernel, stride=stride)  # bs * oc * oh * ow
    bs, oc, oh, ow = conv_output.shape
    patch_embedding = conv_output.reshape(bs, oc, oh * ow).transpose(-1, -2)
    return patch_embedding

定义常量

In [3]:
bs, ic, image_h, image_w = 1, 3, 8, 8
patch_size = 4
model_dim = 8
max_num_token = 16
num_classes = 10
label = torch.randint(10, (bs, ))

In [4]:
# test code for image2emb
image = torch.randn(bs, ic, image_h, image_w)

patch_depth = patch_size * patch_size * ic
weight = torch.randn(patch_depth, model_dim)  # model_dim是输出通道数目 patch_size是卷积核的面积乘以输入通道数目
print(weight.shape)

image2emb_naive(image, patch_size, weight)

torch.Size([48, 8])


tensor([[[  4.0152,   1.2908,  -0.6719,   4.3084,   8.2290,   7.2539,  -0.7970,
            0.5109],
         [ -3.3614,  -5.3569,  -0.3916,   6.0683,   4.3965, -13.0564,   7.3109,
          -11.3383],
         [ -9.2427,  -5.5954,   2.7279,   4.6565,  -6.5556,   4.1334,   1.7786,
            4.3553],
         [  8.1741,   0.5207,  -2.3741,  11.8238,  -6.7830,   2.6585,  -0.2237,
           13.3919]]])

In [5]:
kernel = weight.transpose(0, 1).reshape(-1, ic, patch_size, patch_size) # oc * ic * kh * kw
patch_embedding_conv = image2emb_conv(image, kernel, patch_size)
patch_embedding_conv

tensor([[[  4.0152,   1.2908,  -0.6719,   4.3084,   8.2290,   7.2539,  -0.7970,
            0.5109],
         [ -3.3614,  -5.3569,  -0.3916,   6.0683,   4.3965, -13.0564,   7.3109,
          -11.3383],
         [ -9.2427,  -5.5954,   2.7279,   4.6565,  -6.5556,   4.1334,   1.7786,
            4.3553],
         [  8.1741,   0.5207,  -2.3741,  11.8238,  -6.7830,   2.6585,  -0.2237,
           13.3919]]])

step2: prepend CLS token embedding

In [6]:
cls_token_embedding = torch.randn(bs, 1, model_dim, requires_grad=True)
token_embedding = torch.cat([cls_token_embedding, patch_embedding_conv], dim=1)
token_embedding

tensor([[[ -0.2484,   1.4410,   0.1289,  -0.1829,   1.1918,  -0.3986,  -0.3584,
            1.8053],
         [  4.0152,   1.2908,  -0.6719,   4.3084,   8.2290,   7.2539,  -0.7970,
            0.5109],
         [ -3.3614,  -5.3569,  -0.3916,   6.0683,   4.3965, -13.0564,   7.3109,
          -11.3383],
         [ -9.2427,  -5.5954,   2.7279,   4.6565,  -6.5556,   4.1334,   1.7786,
            4.3553],
         [  8.1741,   0.5207,  -2.3741,  11.8238,  -6.7830,   2.6585,  -0.2237,
           13.3919]]], grad_fn=<CatBackward0>)

step3: add position embedding

In [7]:
position_embedding_table = torch.randn(max_num_token, model_dim, requires_grad=True)
seq_len = token_embedding.shape[1]
position_embedding = torch.tile(position_embedding_table[:seq_len], [token_embedding.shape[0], 1, 1])

token_embedding += position_embedding
token_embedding

tensor([[[  0.9585,   4.2727,  -0.4708,   0.5940,   1.1600,  -0.3619,  -1.0689,
            0.5173],
         [  6.4486,   1.2571,   0.5272,   4.1966,   8.2855,   6.6908,  -0.7029,
            1.8371],
         [ -3.8875,  -6.2062,  -0.1072,   6.3478,   2.7142, -11.3760,   8.5360,
          -10.9376],
         [ -9.0394,  -6.0852,   1.3953,   3.9682,  -8.1176,   3.6017,   1.8497,
            5.1600],
         [  9.4518,   0.7019,  -2.4703,  11.5458,  -7.5297,   2.9255,  -0.3044,
           13.1077]]], grad_fn=<AddBackward0>)

step4: pass embedding to transformer encoder

In [8]:
encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim, nhead=8)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
encoder_output = transformer_encoder(token_embedding)
encoder_output

tensor([[[-0.8060, -1.7401,  1.5659, -0.1192,  1.2076, -0.4879,  0.3993,
          -0.0196],
         [-1.0634, -1.1550, -0.4834, -0.2341,  2.2232, -0.0312,  0.2568,
           0.4870],
         [-0.5729,  0.5682, -0.2679,  1.7192, -1.7254, -0.8179,  0.3561,
           0.7406],
         [-0.7607,  0.3504,  1.0961,  1.4312, -1.3287, -1.4237,  0.1522,
           0.4831],
         [-0.6054,  0.4418,  0.0546,  1.5891, -1.8297, -0.6505, -0.0654,
           1.0653]]], grad_fn=<NativeLayerNormBackward0>)

step5: do classification

In [9]:
cls_token_output = encoder_output[:, 0, :]
linear_layer = nn.Linear(model_dim, num_classes)
logits = linear_layer(cls_token_output)

loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits, label)

print(logits)
print(label)
print(loss)

tensor([[-0.3001, -0.6782, -0.6632, -0.0236, -1.0029, -0.1697,  0.6739, -0.6440,
          0.3561,  0.6160]], grad_fn=<AddmmBackward0>)
tensor([6])
tensor(1.6000, grad_fn=<NllLossBackward0>)
