In [40]:
import torch.nn.functional as F
import torch as t
import torch.nn as nn

def image2emb_naive(image, patch_size, weight):
    """
    image shape: [bs,c,h,w]
    weight: DNN weight
    """
    patch = F.unfold(image, kernel_size=patch_size, stride=patch_size).transpose(-1, -2)
    patch_embedding = patch @ weight

    print(patch.shape, "# patch.shape")
    print(patch_embedding.shape, "# patch_embedding.shape")

    return patch_embedding


def test_image2emb_naive():
    bs, ic, ih, iw = 1, 3, 8, 8
    patch_size = 4
    model_dim = 8
    patch_depth = patch_size * patch_size * ic
    image = t.randn(bs, ic, ih, iw)
    weight = t.randn(patch_depth, model_dim)
    image2emb_naive(image, patch_size, weight)


test_image2emb_naive()



torch.Size([1, 4, 48]) # patch.shape
torch.Size([1, 4, 8]) # patch_embedding.shape


In [None]:
def image2emb_conv(image, kernel, stride):
    # conv_output: [bs, oc, oh, ow]
    conv_output = F.conv2d(image, kernel, stride=stride)
    bs, oc, oh, ow = conv_output.shape
    embedding = conv_output.reshape((bs, oc, oh*ow)).transpose(-1, -2)
    return embedding


def test_image2emb_conv():
    bs, ic, ih, iw = 1, 3, 8, 8
    patch_size = 4
    model_dim = 8
    image = t.randn(bs, ic, ih, iw)
    kernel = t.randn(model_dim, ic, patch_size, patch_size)
    patch_embedding = image2emb_conv(image, kernel, stride=patch_size)
    print(patch_embedding.shape)
    return patch_embedding

test_image2emb_conv()

In [51]:
def append_cls_token(patch_embedding):
    bs, _, model_dim = patch_embedding.shape
    cls_token_embedding = t.randn(bs, 1, model_dim, requires_grad=True)
    # 把cls放到第一个位置上
    token_embedding = t.cat([cls_token_embedding, patch_embedding], dim=1)
    return token_embedding


def append_position_embedding(max_num_token, token_embedding):
    """
    max_num_token:序列最大长度
    """
    bs, seq_len, model_dim = token_embedding.shape
    # shape = [vocab_size, model_dim]
    position_embedding_table = t.randn(max_num_token, model_dim, requires_grad=True)
    position_embedding = t.tile(position_embedding_table[:seq_len], [bs, 1, 1])
    token_embedding += position_embedding
    return token_embedding


def pass_embedding_to_encoder(token_embedding):
    bs, seq_len, model_dim = token_embedding.shape
    encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim, nhead=8)
    encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
    encoder_output = encoder(token_embedding)
    return encoder_output


def do_classification(encoder_output, num_class, model_dim, label):
    # label = t.randint(10,(bs,))
    cls_token_output = encoder_output[:, 0, :]
    linear_layer = nn.Linear(model_dim, num_class)
    logits = linear_layer(cls_token_output)
    loss_fn = nn.CrossEntropyLoss()
    loss = loss_fn(logits, label)
    return loss

def test_full():
    bs, ic, ih, iw = 1, 3, 8, 8
    patch_size = 4
    model_dim = 8
    max_num_token = 16
    num_class = 10
    label = t.randint(10,(bs,))
    image = t.randn(bs, ic, ih, iw)
    kernel = t.randn(model_dim, ic, patch_size, patch_size)
    patch_embedding = image2emb_conv(image, kernel, stride=patch_size)
    token_embedding = append_position_embedding(max_num_token ,append_cls_token(patch_embedding))
    encoder_output = pass_embedding_to_encoder(token_embedding)
    loss = do_classification(encoder_output, num_class, model_dim, label)
    print(loss)


test_full()

tensor(2.7252, grad_fn=<NllLossBackward0>)
