## Transformer 难点理解与实现

In [1]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

### word embedding

以序列建模为例, 考虑source sentence 和 target sentence

参数设置

In [2]:
batch_size = 2

# 单词表大小
max_num_src_words = 8
max_num_tgt_words = 8
model_dim = 8  # 原文中是512

# 序列的最大长度
max_src_seq_len = 5
max_tgt_seq_len = 5
max_position_len = 5

### 构建序列, 序列的字符以其在词表中索引的形式表示

In [3]:
# src_len = torch.randint(2, 5, (batch_size, ))
# tgt_len = torch.randint(2, 5, (batch_size, ))
src_len = torch.Tensor([2, 4]).to(torch.int32)
tgt_len = torch.tensor([4, 3], dtype=torch.int32)

# 单词索引构成源句子和目标句子, 构建batch, 并且做了padding 默认值为0
src_seq = torch.stack([F.pad(torch.randint(1, max_num_src_words, (L, )), (0, max(src_len) - L)) for L in src_len])
tgt_seq = torch.stack([F.pad(torch.randint(1, max_num_tgt_words, (L, )), (0, max(tgt_len) - L)) for L in tgt_len])

print(src_seq)
print(tgt_seq)

tensor([[3, 2, 0, 0],
        [6, 5, 2, 4]])
tensor([[3, 3, 5, 7],
        [7, 6, 7, 0]])


### 构造 word embedding  由于padding中默认加了0 所以词表数量 + 1

In [4]:
src_embedding_table = nn.Embedding(max_num_src_words + 1, model_dim)
tgt_embedding_table = nn.Embedding(max_num_tgt_words + 1, model_dim)

src_embedding = src_embedding_table(src_seq)
tgt_embedding = tgt_embedding_table(tgt_seq)

print(src_embedding)
print(tgt_embedding)

tensor([[[-7.5659e-01,  6.6300e-01,  6.7041e-01, -1.0780e+00, -1.1394e+00,
           2.9642e-01, -3.6435e-02,  7.7421e-01],
         [-1.4598e-01,  1.2388e-01,  1.2290e+00,  1.7662e-01,  3.8443e-02,
           1.4883e+00,  3.7352e-01, -6.5173e-01],
         [-2.0152e-01,  1.6985e-01,  3.4408e-01,  2.7922e+00,  1.8004e+00,
           4.6192e-01, -4.6728e-01, -2.7255e-03],
         [-2.0152e-01,  1.6985e-01,  3.4408e-01,  2.7922e+00,  1.8004e+00,
           4.6192e-01, -4.6728e-01, -2.7255e-03]],

        [[ 1.6643e+00,  2.2295e-01, -1.2752e-01,  6.4693e-01,  1.7282e+00,
           5.4279e-01, -3.2137e-01,  1.0861e+00],
         [-3.7940e-01, -1.6459e+00, -1.7783e+00, -3.3335e-02, -3.2558e-01,
          -2.4711e-01,  1.2158e+00,  1.1820e-02],
         [-1.4598e-01,  1.2388e-01,  1.2290e+00,  1.7662e-01,  3.8443e-02,
           1.4883e+00,  3.7352e-01, -6.5173e-01],
         [-5.4304e-01, -4.2037e-01,  1.3804e-01, -1.3266e+00, -2.0213e-02,
           1.2087e+00, -3.0175e-01, -5.5616e-01]

### 构造 positional embedding

In [5]:
pos_mat = torch.arange(max_position_len).reshape(-1, 1)
i_mat = torch.pow(10000, torch.arange(0, 8, 2).reshape(1, -1) / model_dim)
pe_embedding_table = torch.zeros(max_position_len, model_dim)
pe_embedding_table[:, 0::2] = torch.sin(pos_mat / i_mat)
pe_embedding_table[:, 1::2] = torch.cos(pos_mat / i_mat)

pe_embedding = nn.Embedding(max_position_len, model_dim)
pe_embedding.weight = nn.Parameter(pe_embedding_table, requires_grad=False)

src_pos = torch.stack([torch.arange(max_position_len) for _ in src_len]).to(torch.int32)
tgt_pos = torch.stack([torch.arange(max_position_len) for _ in tgt_len]).to(torch.int32)

src_pe_embedding = pe_embedding(src_pos)
tgt_pe_embedding = pe_embedding(tgt_pos)
print(src_pe_embedding)
print(tgt_pe_embedding)

tensor([[[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
           1.0000e+00,  0.0000e+00,  1.0000e+00],
         [ 8.4147e-01,  5.4030e-01,  9.9833e-02,  9.9500e-01,  9.9998e-03,
           9.9995e-01,  1.0000e-03,  1.0000e+00],
         [ 9.0930e-01, -4.1615e-01,  1.9867e-01,  9.8007e-01,  1.9999e-02,
           9.9980e-01,  2.0000e-03,  1.0000e+00],
         [ 1.4112e-01, -9.8999e-01,  2.9552e-01,  9.5534e-01,  2.9995e-02,
           9.9955e-01,  3.0000e-03,  1.0000e+00],
         [-7.5680e-01, -6.5364e-01,  3.8942e-01,  9.2106e-01,  3.9989e-02,
           9.9920e-01,  4.0000e-03,  9.9999e-01]],

        [[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
           1.0000e+00,  0.0000e+00,  1.0000e+00],
         [ 8.4147e-01,  5.4030e-01,  9.9833e-02,  9.9500e-01,  9.9998e-03,
           9.9995e-01,  1.0000e-03,  1.0000e+00],
         [ 9.0930e-01, -4.1615e-01,  1.9867e-01,  9.8007e-01,  1.9999e-02,
           9.9980e-01,  2.0000e-03,  1.0000e+00]

### Scaled Dot-product Attention

$$Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V$$

进行scaled是为了概率方差不要太大, 雅可比矩阵不要为0, 避免梯度消失

softmax 演示

In [6]:
alpha1 = 0.1
alpha2 = 10
score = torch.randn(5)
prob1 = F.softmax(score*alpha1, -1)
prob2 = F.softmax(score*alpha2, -1)

print(prob1)
print(prob2)

tensor([0.2003, 0.1796, 0.1906, 0.2249, 0.2046])
tensor([9.1889e-06, 1.7265e-10, 6.3001e-08, 9.9991e-01, 7.6683e-05])


jacobian 演示

In [7]:
def softmax_func(score):
    return F.softmax(score, -1)

jaco_mat1 = torch.autograd.functional.jacobian(softmax_func, score * alpha1)
jaco_mat2 = torch.autograd.functional.jacobian(softmax_func, score * alpha2)

print(jaco_mat1)
print(jaco_mat2)

tensor([[ 0.1602, -0.0360, -0.0382, -0.0451, -0.0410],
        [-0.0360,  0.1474, -0.0342, -0.0404, -0.0368],
        [-0.0382, -0.0342,  0.1542, -0.0429, -0.0390],
        [-0.0451, -0.0404, -0.0429,  0.1743, -0.0460],
        [-0.0410, -0.0368, -0.0390, -0.0460,  0.1627]])
tensor([[ 9.1888e-06, -1.5865e-15, -5.7891e-13, -9.1881e-06, -7.0464e-10],
        [-1.5865e-15,  1.7265e-10, -1.0877e-17, -1.7263e-10, -1.3239e-14],
        [-5.7891e-13, -1.0877e-17,  6.3001e-08, -6.2995e-08, -4.8311e-12],
        [-9.1881e-06, -1.7263e-10, -6.2995e-08,  8.5943e-05, -7.6677e-05],
        [-7.0464e-10, -1.3239e-14, -4.8311e-12, -7.6677e-05,  7.6677e-05]])


### 构造encoder的self-attention mask

mask的shape: [batch_size, max_src_len, max_src_len], 值为1或者-inf

mask可以理解为一个邻接矩阵, 每一行反应的是一个单词对所有单词的是否有效性

In [8]:
valid_encoder_pos = torch.unsqueeze(torch.stack([F.pad(torch.ones(L), (0, max(src_len) - L)) for L in src_len]), 2)

valid_encoder_pos_matrix = torch.bmm(valid_encoder_pos, valid_encoder_pos.transpose(1, 2))
invalid_encoder_pos_matrix = 1 - valid_encoder_pos_matrix
mask_encoder_self_attention = invalid_encoder_pos_matrix.to(torch.bool)

print(valid_encoder_pos_matrix)
print(src_len)
print(mask_encoder_self_attention)

tensor([[[1., 1., 0., 0.],
         [1., 1., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]]])
tensor([2, 4], dtype=torch.int32)
tensor([[[False, False,  True,  True],
         [False, False,  True,  True],
         [ True,  True,  True,  True],
         [ True,  True,  True,  True]],

        [[False, False, False, False],
         [False, False, False, False],
         [False, False, False, False],
         [False, False, False, False]]])


构造score进行测试

In [9]:
score = torch.randn(batch_size, max(src_len), max(src_len))
masked_score = score.masked_fill(mask_encoder_self_attention, -1e9)
prob = F.softmax(masked_score, -1)

print(masked_score)
print(prob)

tensor([[[-3.3873e-01, -1.3879e+00, -1.0000e+09, -1.0000e+09],
         [ 8.6916e-01, -8.8602e-01, -1.0000e+09, -1.0000e+09],
         [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
         [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09]],

        [[-1.5867e+00,  2.5560e+00, -5.1003e-01, -1.4995e-01],
         [-4.3451e-01, -1.0473e+00, -2.0167e+00,  3.6452e-01],
         [-2.4086e+00, -1.5413e+00,  4.0948e-02,  3.8118e-01],
         [ 6.9932e-01, -6.4664e-02,  3.4607e-02, -3.5933e-01]]])
tensor([[[0.7406, 0.2594, 0.0000, 0.0000],
         [0.8526, 0.1474, 0.0000, 0.0000],
         [0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500]],

        [[0.0141, 0.8855, 0.0413, 0.0592],
         [0.2518, 0.1365, 0.0518, 0.5599],
         [0.0320, 0.0762, 0.3708, 0.5210],
         [0.4297, 0.2002, 0.2211, 0.1491]]])


### 构造intra-attention的mask

$$Q @ K^T shape:$$ [batch_size, tgt_seq_len, src_seq_len]

In [12]:
valid_decoder_pos = torch.unsqueeze(torch.stack([F.pad(torch.ones(L), (0, max(tgt_len) - L)) for L in tgt_len]), 2)
valid_cross_pos_matrix = torch.bmm(valid_decoder_pos, valid_encoder_pos.transpose(1, 2))
invalid_cross_pos_matrix = 1 - valid_cross_pos_matrix
mask_cross_attention = invalid_cross_pos_matrix.to(torch.bool)

print(valid_decoder_pos)
print(valid_encoder_pos.transpose(1, 2))
print(valid_cross_pos_matrix)
print(mask_cross_attention)

tensor([[[1.],
         [1.],
         [1.],
         [1.]],

        [[1.],
         [1.],
         [1.],
         [0.]]])
tensor([[[1., 1., 0., 0.]],

        [[1., 1., 1., 1.]]])
tensor([[[1., 1., 0., 0.],
         [1., 1., 0., 0.],
         [1., 1., 0., 0.],
         [1., 1., 0., 0.]],

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [0., 0., 0., 0.]]])
tensor([[[False, False,  True,  True],
         [False, False,  True,  True],
         [False, False,  True,  True],
         [False, False,  True,  True]],

        [[False, False, False, False],
         [False, False, False, False],
         [False, False, False, False],
         [ True,  True,  True,  True]]])


### decoder self-attention的mask

In [18]:
valid_decoder_tri_matrix = torch.stack([F.pad(torch.tril(torch.ones(L, L)), (0, max(tgt_len) - L, 0, max(tgt_len) - L)) for L in tgt_len])
print(valid_decoder_tri_matrix)

invalid_decoder_tri_matrix = (1 - valid_decoder_tri_matrix).to(torch.bool)
print(invalid_decoder_tri_matrix)

tensor([[[1., 0., 0., 0.],
         [1., 1., 0., 0.],
         [1., 1., 1., 0.],
         [1., 1., 1., 1.]],

        [[1., 0., 0., 0.],
         [1., 1., 0., 0.],
         [1., 1., 1., 0.],
         [0., 0., 0., 0.]]])
tensor([[[False,  True,  True,  True],
         [False, False,  True,  True],
         [False, False, False,  True],
         [False, False, False, False]],

        [[False,  True,  True,  True],
         [False, False,  True,  True],
         [False, False, False,  True],
         [ True,  True,  True,  True]]])


In [21]:
score = torch.randn(batch_size, max(tgt_len), max(tgt_len))
masked_score = score.masked_fill(invalid_decoder_tri_matrix, -1e9)
prob = F.softmax(masked_score, -1)

print(masked_score)
print(prob)

tensor([[[-5.4839e-01, -1.0000e+09, -1.0000e+09, -1.0000e+09],
         [-2.8266e+00, -1.9330e+00, -1.0000e+09, -1.0000e+09],
         [-1.2761e+00,  3.1302e-01, -1.2184e-02, -1.0000e+09],
         [-1.7509e+00, -2.6608e-01,  1.3023e+00,  2.7448e-01]],

        [[ 3.9567e-01, -1.0000e+09, -1.0000e+09, -1.0000e+09],
         [ 7.6016e-02,  2.2431e-01, -1.0000e+09, -1.0000e+09],
         [ 1.2955e-01, -1.4103e+00, -3.8267e-01, -1.0000e+09],
         [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09]]])
tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
         [0.2903, 0.7097, 0.0000, 0.0000],
         [0.1059, 0.5191, 0.3750, 0.0000],
         [0.0293, 0.1292, 0.6198, 0.2218]],

        [[1.0000, 0.0000, 0.0000, 0.0000],
         [0.4630, 0.5370, 0.0000, 0.0000],
         [0.5514, 0.1182, 0.3304, 0.0000],
         [0.2500, 0.2500, 0.2500, 0.2500]]])


### 构建self-attention

构建scaled self-attention

In [None]:
def scaled_dot_product_attention(Q, K, V, attn_mask):
    # Shape of Q, K, V: (batch_size * num_head, seq_len, model_dim / num_head)
    score = torch.bmm(Q, K.transpose(-2, -1)) / torch.sqrt(model_dim)
    masked_score = score.masked_fill(attn_mask, -1e9)
    prob = F.softmax(masked_score, -1)
    context = torch.bmm(prob, V)
    return context