# Transformer Encoder精讲

三种Mask:
1. encoder里面的padding mask
2. masked multi-head attention 的序列因果mask
3. encoder memory和decoder交叉注意力 mask和encoder类似，但是涉及两个不同序列，长度可能不同

现在很多时候不一定用到完整的trm模型（只用编码器），但是这三种mask经常使用

以下代码主要包括4块内容：
1. 单词索引构成的句子, 构建patch, 并且做了padding, 默认值为0
2. 构造word embedding
3. 构建position embedding
4. 造encoder的self-attention padding mask

In [16]:
import torch
import numpy as np

import torch.nn as nn
import torch.nn.functional as F

#关于word embedding，以序列建模为例
#考虑source sentence 和 target sentence
#构建序列，序列的字符以其在词表中的索引的形式表示

batch_size = 2
max_num_src_words = 8  #原序列词表大小为8
max_num_tgt_words = 8

model_dim = 8 #原论文512

#最大序列长度
max_src_seq_len = 5
max_tgt_seq_len = 5
max_position_len = 5

src_len = torch.randint(2, 5, (batch_size,)) #长度在2，5之间,长度是batch_size
tgt_len = torch.randint(2, 5, (batch_size,))
src_len = torch.Tensor([2, 4]).to(torch.int32)
tgt_len = torch.Tensor([4, 3]).to(torch.int32)

#step1.单词索引构成的句子, 构建patch, 并且做了padding, 默认值为0
src_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_src_words , (L,)), \
    (0, max(src_len) - L))  ,0) for L in src_len])
tgt_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_tgt_words , (L,)), \
    (0, max(tgt_len) - L))  ,0) for L in tgt_len])
print(src_seq)
#step2. 构造word embedding

src_embedding_table= nn.Embedding(max_num_src_words+1,model_dim)
tgt_embedding_table= nn.Embedding(max_num_tgt_words+1,model_dim)

#这里对一个实例后面直接加括号()，等于调用call方法或者forward方法
src_embedding = src_embedding_table(src_seq) #单词到embedding table查询
tgt_embedding = tgt_embedding_table(tgt_seq) 
print(src_embedding_table.weight)
print(src_embedding)
#step3.构建position embedding
#pos代表行，i代表列，写成矩阵相乘的形式，一个矩阵反应pos变化，一个反应i变化，pytorch广播机制
pos_mat = torch.arange(max_position_len).reshape(-1, 1) #5行1列
i_mat = torch.pow(10000, torch.arange(0, 8, 2).reshape((1, -1)) / model_dim)#间隔是2,1行4列

pe_embedding_table = torch.zeros(max_position_len, model_dim)
pe_embedding_table[ :,0::2] = torch.sin(pos_mat / i_mat)  #偶数列
pe_embedding_table[ :,1::2] = torch.cos(pos_mat / i_mat) 

pe_embedding = nn.Embedding(max_position_len, model_dim)#改写weight
pe_embedding.weight = nn.Parameter(pe_embedding_table, requires_grad=False)

#这里padding是不能算位置嵌入的，因为要保证相加后，padding位置依然是0
src_pos = torch.cat([torch.unsqueeze(torch.arange(max(src_len)) , 0) for _ in src_len]).to(torch.int32)
tgt_pos = torch.cat([torch.unsqueeze(torch.arange(max(tgt_len)) , 0) for _ in tgt_len]).to(torch.int32)
src_pe_embedding = pe_embedding(src_pos)
tgt_pe_embedding = pe_embedding(tgt_pos)
# print(pe_embedding_table)
# print(src_pe_embedding)
# print(src_pe_embedding)

'''
#softmax演示,scale的重要性
alpha1 = 0.1
alpha2 = 10
score = torch.randn(5) #我们认为是QK的结果
prob = F.softmax(score, -1)
prob1 = F.softmax(score * alpha1, -1)
prob2 = F.softmax(score * alpha2, -1)

def softmax_func(score):
    return F.softmax(score)

jaco_mat1 = torch.autograd.functional.jacobian(softmax_func, score * alpha1)
jaco_mat2 = torch.autograd.functional.jacobian(softmax_func, score * alpha2)
print(score)
print(prob)
print(jaco_mat1)
print(jaco_mat2)
'''

#step4.构造encoder的self-attention mask
#mask的shape：[batch_size,max_src_len,mask_src_len]，值为1或-inf
valid_encoder_pos =torch.unsqueeze(torch.cat([torch.unsqueeze(F.pad(torch.ones(L), \
    (0, max(src_len) -L)), 0) for L in src_len]) ,2)

valid_encoder_pos_matrix = torch.bmm(valid_encoder_pos, valid_encoder_pos.transpose(1, 2))
invalid_encoder_pos_matrix = 1 - valid_encoder_pos_matrix
mask_encoder_self_attention = invalid_encoder_pos_matrix.to(torch.bool) #bool类型

score = torch.randn(batch_size, max(src_len), max(src_len))
masked_score = score.masked_fill(mask_encoder_self_attention, -np.inf)
prob = F.softmax(masked_score, -1)
# print(masked_score)
# print(prob)

# step5.intra attention的mask
# Q @ K^T shape:[batch_size. tgt_seq_len, src_seq_len]
valid_encoder_pos =torch.unsqueeze(torch.cat([torch.unsqueeze(F.pad(torch.ones(L), \
    (0, max(src_len) -L)), 0) for L in src_len]) ,2)

valid_decoder_pos =torch.unsqueeze(torch.cat([torch.unsqueeze(F.pad(torch.ones(L), \
    (0, max(tgt_len) -L)), 0) for L in tgt_len]) ,2)

valid_cross_pos_matrix = torch.bmm(valid_decoder_pos, valid_encoder_pos.transpose(1,2))
invalid_cross_pos_matrix = 1 - valid_cross_pos_matrix

mask_cross_attention = invalid_cross_pos_matrix.to(torch.bool)#padding的位置为True


# step6. decoder self-attention mask 三角矩阵
valid_decoder_tri_matrix  = torch.cat([torch.unsqueeze(F.pad(torch.tril(torch.ones((L, L))),\
    (0,max(tgt_len)-L,0,max(tgt_len)-L)),0)for L in tgt_len ]) #tril是lower下三角，triu是上三角

#valid_decoder_tri_matrix = torch.bmm(valid_decoder_pos, valid_encoder_pos.transpose(1,2))
invalid_decoder_tri_matrix = 1 - valid_decoder_tri_matrix
invalid_decoder_tri_matrix = invalid_decoder_tri_matrix.to(torch.bool)
score = torch.randn(batch_size, max(tgt_len), max(tgt_len))
masked_score = score.masked_fill(invalid_decoder_tri_matrix, -1e9)
prob = F.softmax(masked_score, -1)
# print(tgt_len)

# print(prob)


#step7.构建self-attention公式 
#shape of QKV [batch_size*num_head,seq_len,model_dim/num_head] QKV形状不一定一样
def scaled_dot_product_attentin(Q,K,V,attn_mask):
    score = torch.bmm(Q, K.transpose(-1, -2)) / torch.sqrt(model_dim)
    masked_score = score.masked_fill(attn_mask, -1e9)
    prob = F.softmax(masked_score, -1)
    context  = torch.bmm(prob,V)
    return context



tensor([[2, 5, 0, 0],
        [7, 7, 4, 7]])
Parameter containing:
tensor([[ 0.6581,  0.0056,  0.7941,  0.6703, -0.4859,  0.1202,  0.2014,  1.2839],
        [-0.1706,  0.9963,  1.3472, -0.7502,  0.5489, -0.3132,  0.6457,  0.2714],
        [-1.3726,  1.0089, -0.2086,  2.0319,  0.1979, -1.6870, -0.4533,  0.4642],
        [ 0.8020, -1.1535, -1.0685,  0.0804,  1.9402, -1.0805,  0.5183, -0.0307],
        [-0.1627, -1.3019,  0.0285, -0.5413,  0.7970,  1.1493, -1.4249, -0.9690],
        [-1.3795, -0.1544,  0.3975,  1.6061, -0.6370, -0.7719, -0.1954, -0.5534],
        [ 0.8148,  0.8484, -0.4045, -1.0128, -2.3077, -0.5313,  0.6219, -1.1346],
        [-0.9171,  0.5305, -0.8556,  0.8228,  0.1133,  2.2393,  0.1156, -0.0067],
        [-0.4423, -1.8498,  1.7460,  0.7961,  0.8232, -0.3799, -1.6039, -0.5970]],
       requires_grad=True)
tensor([[[-1.3726,  1.0089, -0.2086,  2.0319,  0.1979, -1.6870, -0.4533,
           0.4642],
         [-1.3795, -0.1544,  0.3975,  1.6061, -0.6370, -0.7719, -0.1954,
 

In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F 

#step8. masked loss
logits = torch.randn(2,3,4) #[batchsize,seqlen=3,vocab_size=4] trm预测的结果
label = torch.randint(0, 4, (2,3)) #每个位置上都有一个label,是词表的索引
#pytorch中的交叉熵，第二项需要是vocab_size，所以我们需要对logits转置一下
logits = logits.transpose(1,2)
F.cross_entropy(logits, label) #每个单词的交叉熵，把所有算出来的交叉熵求和或平均，默认是平均
F.cross_entropy(logits, label, reduction = 'none') #返回所有单词的交叉熵
tgt_len = torch.Tensor([2, 3]).to(torch.int32) #第一个是2，需要padding
mask =torch.cat([torch.unsqueeze(F.pad(torch.ones(L), (0, max(tgt_len)-L)) ,0)for L in tgt_len])

F.cross_entropy(logits, label, reduction = 'none') * mask

#在Pytorch中我们也可以不用手写mask,把pad的索引传入ignore_index的值就可以了，很NB
'''
可以看一下Pytorch crossentropy官方文档
torch.nn.CrossEntropyLoss(weight=None, size_average=None, 
ignore_index=- 100, reduce=None, reduction='mean', label_smoothing=0.0)
'''
label
label[0, 2] = -100
label
F.cross_entropy(logits, label, reduction='none') #这个过程先当与直接Mask掉

tensor([[2.3083, 2.2482, 0.0000],
        [1.2322, 1.4493, 1.8367]])