In [51]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [52]:
# 步骤1：得到图像分块的embedding
def image2embedding_naive(image, patch_size, weight):
    """
    通过分块的方法对图片embedding，将图像按照patch大小进行分块，将每一块编码成一个model_dim长的向量
    :param image: Batch * C * H * W
    :param patch_size: 每一个块的大小
    :param weight: 线性变换的权重，将二维图像映射到embedding的维度
    :return: img_patch_embedding # bs * block_num * model_dim
    """
    # F.unfold 可以取出当前的区域块所有channel flatten后的结果，可以用来分块
    # shape: torch.Size([1, 4, 48])
    patch = F.unfold(image, patch_size, stride=patch_size).transpose(-1, -2)
    # 将patch与weight相乘，使用@
    # torch.Size([1, 4, 48]) @ torch.Size([48 * 8]) = torch.Size([1, 4, 8])
    img_patch_embedding = patch @ weight
    return img_patch_embedding


In [53]:
def image2embedding_conv(img, kernel, stride):
    """
    用卷积的方法实现image embedding，将图像按照patch大小进行分块，将每一块编码成一个model_dim长的向量
    :param img: shape: Batch * C * H * W
    :param kernel: shape: output_channel * input_channel * patch_size * patch_size
    :param stride: patch_size
    :return: img_patch_embedding: shape: # bs * block_num * model_dim
    """
    # 通过卷积得到每一个块的向量，每一个块的embedding矩阵的深度方向
    conv_output = F.conv2d(img, kernel, stride=stride)  # 输出的shape: batch * output_channel * oh * ow
    # 将每个块变成一个向量
    bs, oc, oh, ow = conv_output.shape  # oc就是model_dim
    img_patch_embedding = conv_output.reshape(bs, oc, oh*ow).transpose(-1, -2) # bs * block_num * model_dim
    return img_patch_embedding


In [54]:
# 生成一张图
bs, input_channel, img_h, img_w = 1, 3, 8, 8
img = torch.randn(bs, input_channel, img_h, img_w)
# 定义patch_size, model_dim
patch_size = 4  # 将4*4的区域看作一个块
model_dim = 8
# 定义weight 和 对应的卷积kernel形式
weight = torch.randn(patch_size*patch_size*3, 8)
kernel = weight.transpose(0,1).reshape((model_dim, input_channel, patch_size, patch_size))

# 得到输出
# 维度：bs * block_num * model_dim = 1 * 4 * 8
img_patch_embedding_naive = image2embedding_naive(img, patch_size, weight)
img_patch_embedding_conv = image2embedding_conv(img, kernel, patch_size)

print(img_patch_embedding_naive)
print(img_patch_embedding_conv)

# 输出结果完全一致，得到结论：这两种方式是完全等价的！ naive的实现更容易理解，卷积的实现方式更加优雅

tensor([[[  2.9146,   1.2449,  -8.3008,   1.1948,  -9.3845,   7.7470,   8.4259,
           -3.8728],
         [-11.5311,   0.0789,  15.8432, -11.2865,   7.4376,  -2.6591,  10.3798,
           10.2890],
         [ -4.9996,  -7.7250,  -0.1250,   7.6475,   1.9382,  -2.9877,  -4.3252,
           -2.0763],
         [ -6.4553,  -9.6855,  -1.0354,  -0.4439,  -8.5694,  -7.6878,  17.1478,
           -0.1359]],

        [[ -0.5331,   7.2156,  -7.2066,  -0.4019,   4.8610,   7.1247,  19.2128,
            5.3889],
         [  1.2114, -10.5133,  -6.6776,  -1.9796,  -0.0932,   4.4562,  -0.3438,
           -1.1885],
         [ -0.4752,  -5.7226,  -3.3776,  -5.4405,   1.2669,  -6.6287,   0.9310,
            0.2563],
         [ -0.0282,   3.2105,  -2.6039,   2.1494, -17.4719,   6.6827,   7.2661,
            5.3899]]])
tensor([[[  2.9146,   1.2449,  -8.3008,   1.1948,  -9.3845,   7.7470,   8.4259,
           -3.8728],
         [-11.5311,   0.0789,  15.8432, -11.2865,   7.4376,  -2.6591,  10.3798,
       

In [55]:
# 步骤2：在分块的embedding的向量前面加一个类别编码，随机初始化，可学习的
cls_embedding = torch.randn(bs, 1, model_dim,requires_grad=True)
token_embedding = torch.cat([cls_embedding, img_patch_embedding_naive], dim=1)
print(token_embedding)
# token_embedding shape: # bs * (block_num+1) * model_dim

tensor([[[  0.1408,  -0.4919,   0.9294,   0.9404,   0.5707,   0.9665,   0.3508,
            0.9826],
         [  2.9146,   1.2449,  -8.3008,   1.1948,  -9.3845,   7.7470,   8.4259,
           -3.8728],
         [-11.5311,   0.0789,  15.8432, -11.2865,   7.4376,  -2.6591,  10.3798,
           10.2890],
         [ -4.9996,  -7.7250,  -0.1250,   7.6475,   1.9382,  -2.9877,  -4.3252,
           -2.0763],
         [ -6.4553,  -9.6855,  -1.0354,  -0.4439,  -8.5694,  -7.6878,  17.1478,
           -0.1359]],

        [[  1.6528,   0.6923,  -0.2269,   0.1733,   0.3345,   2.0581,  -0.1078,
            1.5345],
         [ -0.5331,   7.2156,  -7.2066,  -0.4019,   4.8610,   7.1247,  19.2128,
            5.3889],
         [  1.2114, -10.5133,  -6.6776,  -1.9796,  -0.0932,   4.4562,  -0.3438,
           -1.1885],
         [ -0.4752,  -5.7226,  -3.3776,  -5.4405,   1.2669,  -6.6287,   0.9310,
            0.2563],
         [ -0.0282,   3.2105,  -2.6039,   2.1494, -17.4719,   6.6827,   7.2661,
         

In [56]:
# 步骤3：增加位置编码，随机初始化的向量，可学习的，与token_embedding利用广播机制相加
# 作者原文试用了比较多不同的位置编码，发现效果也都差不多
# position_embedding shape: (block_num+1) * model_dim
# token_embedding shape: bs * (block_num+1) * model_dim
position_embedding = torch.randn(token_embedding.shape[1], model_dim)
token_embedding = token_embedding + position_embedding
print(token_embedding)

tensor([[[  0.6907,   0.1893,   0.7013,   0.4093,  -0.2047,   2.6669,   1.0116,
            2.2497],
         [  3.3494,  -1.4127,  -8.1215,   2.8935,  -9.0807,   6.5611,   7.4750,
           -4.8476],
         [-10.9377,  -0.2273,  16.5251, -12.0980,   7.8397,  -3.7294,  10.8651,
           11.1782],
         [ -2.8673, -10.2491,  -0.6166,   8.4247,   1.5956,  -3.5139,  -3.0859,
           -3.6573],
         [ -6.0559,  -9.7175,  -2.1906,  -1.0624,  -7.7265,  -7.6621,  17.3032,
           -0.9626]],

        [[  2.2027,   1.3735,  -0.4550,  -0.3577,  -0.4409,   3.7584,   0.5530,
            2.8016],
         [ -0.0984,   4.5580,  -7.0272,   1.2967,   5.1649,   5.9389,  18.2619,
            4.4142],
         [  1.8047, -10.8194,  -5.9957,  -2.7912,   0.3089,   3.3859,   0.1415,
           -0.2993],
         [  1.6570,  -8.2466,  -3.8692,  -4.6633,   0.9242,  -7.1550,   2.1702,
           -1.3247],
         [  0.3712,   3.1785,  -3.7592,   1.5308, -16.6290,   6.7085,   7.4216,
         

In [57]:
# 步骤4：将token_embedding送入transformer的encoder api
# encoder_output shape: bs * (block_num+1) * model_dim
encoder_layer = nn.TransformerEncoderLayer(d_model=8, nhead=4)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
encoder_output = transformer_encoder(token_embedding)

print(encoder_output)

tensor([[[-0.6900, -0.0034,  1.2284,  1.9044, -0.6718, -1.0218, -0.9272,
           0.1814],
         [ 1.4902,  0.2536, -1.4676, -0.1845, -0.0246,  1.0804,  0.3484,
          -1.4958],
         [ 1.3073, -0.4110,  0.8933, -1.0905, -0.5548, -1.1668, -0.4744,
           1.4968],
         [-1.9412,  0.1280,  1.4771,  0.1171,  0.7911,  0.7049, -0.4364,
          -0.8406],
         [-0.5711,  0.1480, -0.1121,  2.5456, -0.3921, -0.8193, -0.3127,
          -0.4863]],

        [[-0.9976,  0.1132,  0.8136,  1.8301, -1.1369, -0.8182, -0.6089,
           0.8048],
         [ 1.0171,  0.5179, -1.8139, -1.4191,  0.1019,  0.0668,  1.0875,
           0.4419],
         [ 1.5718, -0.3661, -0.1785, -1.2371, -0.2869, -1.3946,  0.8534,
           1.0381],
         [ 1.1251,  0.2179, -1.1357, -1.1453,  0.0185, -1.2853,  1.0709,
           1.1340],
         [-1.6715,  0.4870,  0.4338,  1.8707, -0.9553,  0.4189, -0.3990,
          -0.1846]]], grad_fn=<NativeLayerNormBackward0>)


In [58]:
num_classes =10
gt = torch.empty(bs, dtype=torch.long).random_(10)  # shape: bs 值：0-9十个类别中的一个
# 步骤5：取出经过encoder后cls位置的向量，映射到类别，与标签做损失
encoder_output_cls = encoder_output[:, 0, :]
linear_layer = nn.Linear(model_dim, num_classes)
cls_logits = linear_layer(encoder_output_cls)  # shape: bs * num_classes
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(cls_logits, gt)
print(loss)


tensor(2.3554, grad_fn=<NllLossBackward0>)
