## Transformer 难点理解与实现

In [1]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

### word embedding

以序列建模为例, 考虑source sentence 和 target sentence

参数设置

In [2]:
batch_size = 2

# 单词表大小
max_num_src_words = 8
max_num_tgt_words = 8
model_dim = 8  # 原文中是512

# 序列的最大长度
max_src_seq_len = 5
max_tgt_seq_len = 5
max_position_len = 5

### 构建序列, 序列的字符以其在词表中索引的形式表示

In [3]:
# src_len = torch.randint(2, 5, (batch_size, ))
# tgt_len = torch.randint(2, 5, (batch_size, ))
src_len = torch.Tensor([2, 4]).to(torch.int32)
tgt_len = torch.tensor([4, 3], dtype=torch.int32)

# 单词索引构成源句子和目标句子, 构建batch, 并且做了padding 默认值为0
src_seq = torch.stack([F.pad(torch.randint(1, max_num_src_words, (L, )), (0, max(src_len) - L)) for L in src_len])
tgt_seq = torch.stack([F.pad(torch.randint(1, max_num_tgt_words, (L, )), (0, max(tgt_len) - L)) for L in tgt_len])

print(src_seq)
print(tgt_seq)

tensor([[4, 2, 0, 0],
        [7, 3, 6, 2]])
tensor([[3, 2, 5, 7],
        [2, 1, 1, 0]])


### 构造 word embedding  由于padding中默认加了0 所以词表数量 + 1

In [4]:
src_embedding_table = nn.Embedding(max_num_src_words + 1, model_dim)
tgt_embedding_table = nn.Embedding(max_num_tgt_words + 1, model_dim)

src_embedding = src_embedding_table(src_seq)
tgt_embedding = tgt_embedding_table(tgt_seq)

print(src_embedding)
print(tgt_embedding)

tensor([[[-0.3332, -0.9937,  1.5486,  0.2025,  0.8946, -0.5197, -1.2958,
          -0.3221],
         [ 0.5120, -0.7522, -1.6029,  0.4899, -2.6923,  1.3810,  0.7255,
           1.3636],
         [ 0.7471, -0.7118,  0.3205,  0.8930, -1.0299, -1.8320, -1.2730,
           1.7782],
         [ 0.7471, -0.7118,  0.3205,  0.8930, -1.0299, -1.8320, -1.2730,
           1.7782]],

        [[-1.4368, -1.1851,  0.4859,  0.1804,  0.8603, -1.5882,  0.0233,
           0.2402],
         [-1.2841, -0.2302, -1.6006, -0.0709,  0.8760, -0.1979,  0.5812,
          -0.7197],
         [ 0.2651,  0.6535, -1.0530, -0.9495,  0.3578, -0.4908, -0.4830,
           0.2835],
         [ 0.5120, -0.7522, -1.6029,  0.4899, -2.6923,  1.3810,  0.7255,
           1.3636]]], grad_fn=<EmbeddingBackward0>)
tensor([[[-0.3910,  0.1137, -2.5097,  0.2471,  1.3806,  0.4929, -0.2335,
           0.5202],
         [ 0.0641, -1.0894,  1.0505,  0.2188, -0.5664,  2.8832, -1.3882,
          -2.7417],
         [-0.2198,  0.1703, -0.6364,

### 构造 positional embedding

In [5]:
pos_mat = torch.arange(max_position_len).reshape(-1, 1)
i_mat = torch.pow(10000, torch.arange(0, 8, 2).reshape(1, -1) / model_dim)
pe_embedding_table = torch.zeros(max_position_len, model_dim)
pe_embedding_table[:, 0::2] = torch.sin(pos_mat / i_mat)
pe_embedding_table[:, 1::2] = torch.cos(pos_mat / i_mat)

pe_embedding = nn.Embedding(max_position_len, model_dim)
pe_embedding.weight = nn.Parameter(pe_embedding_table, requires_grad=False)

src_pos = torch.stack([torch.arange(max_position_len) for _ in src_len]).to(torch.int32)
tgt_pos = torch.stack([torch.arange(max_position_len) for _ in tgt_len]).to(torch.int32)

src_pe_embedding = pe_embedding(src_pos)
tgt_pe_embedding = pe_embedding(tgt_pos)
print(src_pe_embedding)
print(tgt_pe_embedding)

tensor([[[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
           1.0000e+00,  0.0000e+00,  1.0000e+00],
         [ 8.4147e-01,  5.4030e-01,  9.9833e-02,  9.9500e-01,  9.9998e-03,
           9.9995e-01,  1.0000e-03,  1.0000e+00],
         [ 9.0930e-01, -4.1615e-01,  1.9867e-01,  9.8007e-01,  1.9999e-02,
           9.9980e-01,  2.0000e-03,  1.0000e+00],
         [ 1.4112e-01, -9.8999e-01,  2.9552e-01,  9.5534e-01,  2.9995e-02,
           9.9955e-01,  3.0000e-03,  1.0000e+00],
         [-7.5680e-01, -6.5364e-01,  3.8942e-01,  9.2106e-01,  3.9989e-02,
           9.9920e-01,  4.0000e-03,  9.9999e-01]],

        [[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
           1.0000e+00,  0.0000e+00,  1.0000e+00],
         [ 8.4147e-01,  5.4030e-01,  9.9833e-02,  9.9500e-01,  9.9998e-03,
           9.9995e-01,  1.0000e-03,  1.0000e+00],
         [ 9.0930e-01, -4.1615e-01,  1.9867e-01,  9.8007e-01,  1.9999e-02,
           9.9980e-01,  2.0000e-03,  1.0000e+00]

### Scaled Dot-product Attention

$$Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V$$

进行scaled是为了概率方差不要太大, 雅可比矩阵不要为0, 避免梯度消失

softmax 演示

In [6]:
alpha1 = 0.1
alpha2 = 10
score = torch.randn(5)
prob1 = F.softmax(score*alpha1, -1)
prob2 = F.softmax(score*alpha2, -1)

print(prob1)
print(prob2)

tensor([0.1975, 0.2074, 0.1904, 0.2071, 0.1976])
tensor([3.9781e-03, 5.2531e-01, 1.0590e-04, 4.6640e-01, 4.2113e-03])


jacobian 演示

In [7]:
def softmax_func(score):
    return F.softmax(score, -1)

jaco_mat1 = torch.autograd.functional.jacobian(softmax_func, score * alpha1)
jaco_mat2 = torch.autograd.functional.jacobian(softmax_func, score * alpha2)

print(jaco_mat1)
print(jaco_mat2)

tensor([[ 0.1585, -0.0410, -0.0376, -0.0409, -0.0390],
        [-0.0410,  0.1644, -0.0395, -0.0429, -0.0410],
        [-0.0376, -0.0395,  0.1542, -0.0394, -0.0376],
        [-0.0409, -0.0429, -0.0394,  0.1642, -0.0409],
        [-0.0390, -0.0410, -0.0376, -0.0409,  0.1585]])
tensor([[ 3.9622e-03, -2.0897e-03, -4.2129e-07, -1.8554e-03, -1.6753e-05],
        [-2.0897e-03,  2.4936e-01, -5.5632e-05, -2.4500e-01, -2.2122e-03],
        [-4.2129e-07, -5.5632e-05,  1.0589e-04, -4.9393e-05, -4.4599e-07],
        [-1.8554e-03, -2.4500e-01, -4.9393e-05,  2.4887e-01, -1.9641e-03],
        [-1.6753e-05, -2.2122e-03, -4.4599e-07, -1.9641e-03,  4.1936e-03]])


### 构造encoder的self-attention mask

mask的shape: [batch_size, max_src_len, max_src_len], 值为1或者-inf

In [8]:
valid_encoder_pos = torch.unsqueeze(torch.stack([F.pad(torch.ones(L), (0, max(src_len) - L)) for L in src_len]), 2)

valid_encoder_pos_matrix = torch.bmm(valid_encoder_pos, valid_encoder_pos.transpose(1, 2))
invalid_encoder_pos_matrix = 1 - valid_encoder_pos_matrix
mask_encoder_self_attention = invalid_encoder_pos_matrix.to(torch.bool)

print(valid_encoder_pos_matrix)
print(src_len)
print(mask_encoder_self_attention)

tensor([[[1., 1., 0., 0.],
         [1., 1., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]]])
tensor([2, 4], dtype=torch.int32)
tensor([[[False, False,  True,  True],
         [False, False,  True,  True],
         [ True,  True,  True,  True],
         [ True,  True,  True,  True]],

        [[False, False, False, False],
         [False, False, False, False],
         [False, False, False, False],
         [False, False, False, False]]])


构造score进行测试

In [9]:
score = torch.randn(batch_size, max(src_len), max(src_len))
masked_score = score.masked_fill(mask_encoder_self_attention, -1e9)
prob = F.softmax(masked_score, -1)

print(masked_score)
print(prob)

tensor([[[-7.6442e-01,  2.8342e-01, -1.0000e+09, -1.0000e+09],
         [-2.9085e-01,  1.7306e+00, -1.0000e+09, -1.0000e+09],
         [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
         [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09]],

        [[ 7.6844e-01, -4.3427e-01,  9.5924e-01,  2.9930e-01],
         [-9.0837e-01, -1.7956e+00, -2.7881e-01,  2.2242e-01],
         [ 7.2883e-01,  8.1568e-01,  8.9879e-01, -1.0591e-01],
         [ 1.1777e+00,  1.3912e+00, -4.2889e-01, -6.4911e-01]]])
tensor([[[0.2596, 0.7404, 0.0000, 0.0000],
         [0.1170, 0.8830, 0.0000, 0.0000],
         [0.2500, 0.2500, 0.2500, 0.2500],
         [0.2500, 0.2500, 0.2500, 0.2500]],

        [[0.3189, 0.0958, 0.3859, 0.1995],
         [0.1566, 0.0645, 0.2939, 0.4851],
         [0.2695, 0.2940, 0.3195, 0.1170],
         [0.3847, 0.4762, 0.0772, 0.0619]]])
