## Transformer 难点理解与实现

In [1]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

### word embedding

以序列建模为例, 考虑source sentence 和 target sentence

参数设置

In [2]:
batch_size = 2

# 单词表大小
max_num_src_words = 8
max_num_tgt_words = 8
model_dim = 8  # 原文中是512

# 序列的最大长度
max_src_seq_len = 5
max_tgt_seq_len = 5
max_position_len = 5

### 构建序列, 序列的字符以其在词表中索引的形式表示

In [3]:
# src_len = torch.randint(2, 5, (batch_size, ))
# tgt_len = torch.randint(2, 5, (batch_size, ))
src_len = torch.Tensor([2, 4]).to(torch.int32)
tgt_len = torch.tensor([4, 3], dtype=torch.int32)

# 单词索引构成源句子和目标句子, 构建batch, 并且做了padding 默认值为0
src_seq = torch.stack([F.pad(torch.randint(1, max_num_src_words, (L, )), (0, max(src_len) - L)) for L in src_len])
tgt_seq = torch.stack([F.pad(torch.randint(1, max_num_tgt_words, (L, )), (0, max(tgt_len) - L)) for L in tgt_len])

print(src_seq)
print(tgt_seq)

tensor([[5, 1, 0, 0],
        [7, 3, 4, 3]])
tensor([[5, 7, 3, 7],
        [1, 2, 5, 0]])


### 构造 word embedding  由于padding中默认加了0 所以词表数量 + 1

In [4]:
src_embedding_table = nn.Embedding(max_num_src_words + 1, model_dim)
tgt_embedding_table = nn.Embedding(max_num_tgt_words + 1, model_dim)

src_embedding = src_embedding_table(src_seq)
tgt_embedding = tgt_embedding_table(tgt_seq)

print(src_embedding)
print(tgt_embedding)

tensor([[[-7.1059e-01, -1.4507e-01, -2.9596e-01,  7.5678e-01, -2.3224e-01,
          -1.5946e+00, -5.0547e-01, -3.3037e-01],
         [ 5.7515e-01, -1.5071e-01, -4.2411e-02, -3.9152e-01, -1.0968e+00,
           1.5585e+00,  5.6138e-01, -1.2438e-01],
         [ 2.0126e+00, -5.6317e-01, -6.5935e-01, -6.2743e-01, -1.8382e-01,
          -1.1710e+00,  8.7818e-01, -1.9528e+00],
         [ 2.0126e+00, -5.6317e-01, -6.5935e-01, -6.2743e-01, -1.8382e-01,
          -1.1710e+00,  8.7818e-01, -1.9528e+00]],

        [[ 7.3068e-01,  1.4668e-01, -4.6573e-01,  8.8666e-01,  9.1251e-02,
           3.1018e-01,  4.9053e-01, -3.5768e-01],
         [ 1.0594e+00,  1.7390e-01,  7.5945e-01,  4.3819e-01, -4.1232e-02,
           2.1988e-03, -1.7437e+00,  2.7523e-01],
         [ 1.1412e-01, -2.7924e+00, -5.5918e-01, -1.3004e+00, -3.2612e-01,
           7.7508e-01, -1.3587e+00,  1.4015e+00],
         [ 1.0594e+00,  1.7390e-01,  7.5945e-01,  4.3819e-01, -4.1232e-02,
           2.1988e-03, -1.7437e+00,  2.7523e-01]

### 构造 positional embedding

In [5]:
pos_mat = torch.arange(max_position_len).reshape(-1, 1)
i_mat = torch.pow(10000, torch.arange(0, 8, 2).reshape(1, -1) / model_dim)
pe_embedding_table = torch.zeros(max_position_len, model_dim)
pe_embedding_table[:, 0::2] = torch.sin(pos_mat / i_mat)
pe_embedding_table[:, 1::2] = torch.cos(pos_mat / i_mat)

pe_embedding = nn.Embedding(max_position_len, model_dim)
pe_embedding.weight = nn.Parameter(pe_embedding_table, requires_grad=False)

src_pos = torch.stack([torch.arange(max_position_len) for _ in src_len]).to(torch.int32)
tgt_pos = torch.stack([torch.arange(max_position_len) for _ in tgt_len]).to(torch.int32)

src_pe_embedding = pe_embedding(src_pos)
tgt_pe_embedding = pe_embedding(tgt_pos)
print(src_pe_embedding)
print(tgt_pe_embedding)

tensor([[[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
           1.0000e+00,  0.0000e+00,  1.0000e+00],
         [ 8.4147e-01,  5.4030e-01,  9.9833e-02,  9.9500e-01,  9.9998e-03,
           9.9995e-01,  1.0000e-03,  1.0000e+00],
         [ 9.0930e-01, -4.1615e-01,  1.9867e-01,  9.8007e-01,  1.9999e-02,
           9.9980e-01,  2.0000e-03,  1.0000e+00],
         [ 1.4112e-01, -9.8999e-01,  2.9552e-01,  9.5534e-01,  2.9995e-02,
           9.9955e-01,  3.0000e-03,  1.0000e+00],
         [-7.5680e-01, -6.5364e-01,  3.8942e-01,  9.2106e-01,  3.9989e-02,
           9.9920e-01,  4.0000e-03,  9.9999e-01]],

        [[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
           1.0000e+00,  0.0000e+00,  1.0000e+00],
         [ 8.4147e-01,  5.4030e-01,  9.9833e-02,  9.9500e-01,  9.9998e-03,
           9.9995e-01,  1.0000e-03,  1.0000e+00],
         [ 9.0930e-01, -4.1615e-01,  1.9867e-01,  9.8007e-01,  1.9999e-02,
           9.9980e-01,  2.0000e-03,  1.0000e+00]

### Scaled Dot-product Attention

$$Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V$$

进行scaled是为了概率方差不要太大, 雅可比矩阵不要为0, 避免梯度消失

softmax 演示

In [6]:
alpha1 = 0.1
alpha2 = 10
score = torch.randn(5)
prob1 = F.softmax(score*alpha1, -1)
prob2 = F.softmax(score*alpha2, -1)

print(prob1)
print(prob2)

tensor([0.1979, 0.2085, 0.2066, 0.2037, 0.1833])
tensor([3.7366e-03, 6.6135e-01, 2.7055e-01, 6.4361e-02, 1.7657e-06])


jacobian 演示

In [7]:
def softmax_func(score):
    return F.softmax(score, -1)

jaco_mat1 = torch.autograd.functional.jacobian(softmax_func, score * alpha1)
jaco_mat2 = torch.autograd.functional.jacobian(softmax_func, score * alpha2)

print(jaco_mat1)
print(jaco_mat2)

tensor([[ 0.1588, -0.0413, -0.0409, -0.0403, -0.0363],
        [-0.0413,  0.1650, -0.0431, -0.0425, -0.0382],
        [-0.0409, -0.0431,  0.1639, -0.0421, -0.0379],
        [-0.0403, -0.0425, -0.0421,  0.1622, -0.0373],
        [-0.0363, -0.0382, -0.0379, -0.0373,  0.1497]])
tensor([[ 3.7227e-03, -2.4712e-03, -1.0109e-03, -2.4049e-04, -6.5977e-09],
        [-2.4712e-03,  2.2397e-01, -1.7893e-01, -4.2565e-02, -1.1677e-06],
        [-1.0109e-03, -1.7893e-01,  1.9735e-01, -1.7413e-02, -4.7770e-07],
        [-2.4049e-04, -4.2565e-02, -1.7413e-02,  6.0218e-02, -1.1364e-07],
        [-6.5977e-09, -1.1677e-06, -4.7770e-07, -1.1364e-07,  1.7657e-06]])


### 构造encoder的self-attention mask

mask的shape: [batch_size, max_src_len, max_src_len], 值为1或者-inf

mask可以理解为一个邻接矩阵, 每一行反应的是一个单词对所有单词的是否有效性

In [8]:
valid_encoder_pos = torch.unsqueeze(torch.stack([F.pad(torch.ones(L), (0, max(src_len) - L)) for L in src_len]), 2)

valid_encoder_pos_matrix = torch.bmm(valid_encoder_pos, valid_encoder_pos.transpose(1, 2))
invalid_encoder_pos_matrix = 1 - valid_encoder_pos_matrix
mask_encoder_self_attention = invalid_encoder_pos_matrix.to(torch.bool)

print(valid_encoder_pos_matrix)
print(src_len)
print(mask_encoder_self_attention)

tensor([[[1., 1., 0., 0.],
         [1., 1., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]]])
tensor([2, 4], dtype=torch.int32)
tensor([[[False, False,  True,  True],
         [False, False,  True,  True],
         [ True,  True,  True,  True],
         [ True,  True,  True,  True]],

        [[False, False, False, False],
         [False, False, False, False],
         [False, False, False, False],
         [False, False, False, False]]])


构造score进行测试

In [9]:
score = torch.randn(batch_size, max(src_len), max(src_len))
masked_score = score.masked_fill(mask_encoder_self_attention, -1e9)
prob = F.softmax(masked_score, -1)

print(masked_score)
print(prob)

tensor([[[ 1.8227e+00, -1.6084e+00, -1.0000e+09, -1.0000e+09],
         [ 7.6147e-01, -4.9629e-01, -1.0000e+09, -1.0000e+09],
         [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
         [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09]],

        [[ 1.6609e+00,  1.9705e+00,  3.4850e-02,  1.0079e+00],
         [ 3.9220e-01, -2.8047e-01,  1.2071e+00,  1.1806e+00],
         [-6.2551e-01,  1.0623e+00, -1.4678e-01, -1.6200e+00],
         [-2.2957e-01, -1.9798e+00,  1.6098e-01, -3.6392e-01]]])
tensor([[[0.9687, 0.0313, 0.0000, 0.0000],
         [0.7786, 0.2214, 0.0000, 0.0000],
         [0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500]],

        [[0.3247, 0.4425, 0.0639, 0.1690],
         [0.1675, 0.0855, 0.3784, 0.3685],
         [0.1192, 0.6444, 0.1923, 0.0441],
         [0.2836, 0.0493, 0.4191, 0.2480]]])


### 构造intra-attention的mask

$$Q @ K^T shape:$$ [batch_size, tgt_seq_len, src_seq_len]

In [10]:
valid_decoder_pos = torch.unsqueeze(torch.stack([F.pad(torch.ones(L), (0, max(tgt_len) - L)) for L in tgt_len]), 2)
valid_cross_pos_matrix = torch.bmm(valid_decoder_pos, valid_encoder_pos.transpose(1, 2))
invalid_cross_pos_matrix = 1 - valid_cross_pos_matrix
mask_cross_attention = invalid_cross_pos_matrix.to(torch.bool)

print(valid_decoder_pos)
print(valid_encoder_pos.transpose(1, 2))
print(valid_cross_pos_matrix)
print(mask_cross_attention)

tensor([[[1.],
         [1.],
         [1.],
         [1.]],

        [[1.],
         [1.],
         [1.],
         [0.]]])
tensor([[[1., 1., 0., 0.]],

        [[1., 1., 1., 1.]]])
tensor([[[1., 1., 0., 0.],
         [1., 1., 0., 0.],
         [1., 1., 0., 0.],
         [1., 1., 0., 0.]],

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [0., 0., 0., 0.]]])
tensor([[[False, False,  True,  True],
         [False, False,  True,  True],
         [False, False,  True,  True],
         [False, False,  True,  True]],

        [[False, False, False, False],
         [False, False, False, False],
         [False, False, False, False],
         [ True,  True,  True,  True]]])


### decoder self-attention的mask

In [11]:
valid_decoder_tri_matrix = torch.stack([F.pad(torch.tril(torch.ones(L, L)), (0, max(tgt_len) - L, 0, max(tgt_len) - L)) for L in tgt_len])
print(valid_decoder_tri_matrix)

invalid_decoder_tri_matrix = (1 - valid_decoder_tri_matrix).to(torch.bool)
print(invalid_decoder_tri_matrix)

tensor([[[1., 0., 0., 0.],
         [1., 1., 0., 0.],
         [1., 1., 1., 0.],
         [1., 1., 1., 1.]],

        [[1., 0., 0., 0.],
         [1., 1., 0., 0.],
         [1., 1., 1., 0.],
         [0., 0., 0., 0.]]])
tensor([[[False,  True,  True,  True],
         [False, False,  True,  True],
         [False, False, False,  True],
         [False, False, False, False]],

        [[False,  True,  True,  True],
         [False, False,  True,  True],
         [False, False, False,  True],
         [ True,  True,  True,  True]]])


In [12]:
score = torch.randn(batch_size, max(tgt_len), max(tgt_len))
masked_score = score.masked_fill(invalid_decoder_tri_matrix, -1e9)
prob = F.softmax(masked_score, -1)

print(masked_score)
print(prob)

tensor([[[-4.4884e-01, -1.0000e+09, -1.0000e+09, -1.0000e+09],
         [-3.6396e-01, -2.9015e-01, -1.0000e+09, -1.0000e+09],
         [ 4.6684e-01,  4.3242e-01,  8.5617e-03, -1.0000e+09],
         [ 9.1782e-01, -1.1736e+00, -4.5122e-01, -4.6846e-01]],

        [[ 1.3073e+00, -1.0000e+09, -1.0000e+09, -1.0000e+09],
         [-8.6327e-01, -9.7686e-01, -1.0000e+09, -1.0000e+09],
         [ 4.1243e-02,  7.0956e-01, -9.7122e-01, -1.0000e+09],
         [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09]]])
tensor([[[1.0000, 0.0000, 0.0000, 0.0000],
         [0.4816, 0.5184, 0.0000, 0.0000],
         [0.3848, 0.3718, 0.2434, 0.0000],
         [0.6143, 0.0759, 0.1562, 0.1536]],

        [[1.0000, 0.0000, 0.0000, 0.0000],
         [0.5284, 0.4716, 0.0000, 0.0000],
         [0.3017, 0.5887, 0.1096, 0.0000],
         [0.2500, 0.2500, 0.2500, 0.2500]]])


### 构建self-attention

构建scaled self-attention

In [13]:
def scaled_dot_product_attention(Q, K, V, attn_mask):
    # Shape of Q, K, V: (batch_size * num_head, seq_len, model_dim / num_head)
    score = torch.bmm(Q, K.transpose(-2, -1)) / torch.sqrt(model_dim)
    masked_score = score.masked_fill(attn_mask, -1e9)
    prob = F.softmax(masked_score, -1)
    context = torch.bmm(prob, V)
    return context