#### Transformer架构
<div align=center><img decoding="async" src=img/transformerStruct.webp width="30%"><img decoding="async" src=img/transformer.webp width="54.8%"> 

 [Transformer从零开始(一)](https://zhuanlan.zhihu.com/p/451150316)  
[Transformer从零开始(二)](https://zhuanlan.zhihu.com/p/451182425)

In [32]:
import torch
import torch.nn as nn
import numpy as np
import math

transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)

Encoder  
> 从图中可以看出，在Encoder 部分，inputs经过Embedding处理后，再接Positional Encoding，形成Encoder的输入部分。输入部分进入由Multi-Head Attention 和 Feed Forward 组成的Encoder Layer后，得到Encoder部分的输出。其中，Multi-Head Attention和Feed Forward层后，都会接残差层和Layer Norm层，Encoder Layer可能堆叠N次，形成深层次的网络结构。

In [34]:
# Encoder部分由Embedding, Positional Encoding以及N层的encoder layer组成
# 输入x的维度为[batch_size, seq_len]，输出的维度为[batch_size, seq_len, hidden_size]
# enc_mask 记录输入中的 padding 的位置，也可以不要。

class Encoder(nn.Module):
    def __init__(self, layer, vocab_size, hidden_size, N=6):
        super(Encoder, self).__init__() 
        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.pe = PositionalEncoding()
        self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(N)])
        
    def forward(self, x, enc_mask=None):
        # x = [batch_size, seq_len]
        x = self.embedding(x)
        x = self.pe(x)
        for layer in self.layers:
            x = layer(x, enc_mask)
        return x

class PositionalEncoding(nn.Module):
    def __init__(self, dropout=0.1):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        # x = [batch_size, seq_len, hidden_size]
        batch_size, seq_len, hidden_size = x.shape
        position = torch.arange(seq_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, hidden_size, 2) * math.log(10000) / hidden_size).unsqueeze(0)
        pe = torch.zeros((1, seq_len, hidden_size))
        pe[:,:,0::2] = torch.sin(position / div_term)
        pe[:,:,1::2] = torch.cos(position / div_term)
        x = x + pe
        return self.dropout(x)

# EncoderLayer由attention和feedforward两层沟通，每层后接残差链接和layernorm层
# 输入维度为[batch_size, seq_len, hidden_size]，输出维度和输入一致。

class EncoderLayer(nn.Module):
    def __init__(self, heads, hidden_size):
        super(EncoderLayer, self).__init__()
        self.attention = MultiHeadAttention(heads, hidden_size)
        self.norm = nn.LayerNorm(hidden_size)
        self.ffn = FeedForward(hidden_size)
        
    def forward(self, x, enc_mask):
        attn = self.norm(self.attention(x, x, x, enc_mask) + x)
        out = self.norm(self.ffn(x) + x)
        return out

# MultiHeadAttention是一个特征提取器
# 输入query, key, value三个向量，输出部分是融合了上下文语义信息的单词表示，输出维度和query相同
# 可以兼容transformer中的三类Attention：encoder self-attention，无mask，输入query = key = value
# decoder self-attention，有sequence mask，保证当前单词只能看到之前的单词，看不到之后的单词。输入query = key = value
# encoder-decoder attention，实现encoder和decoder的交互，query是decoder层的输入，key = value 为encoder的输出。

class MultiHeadAttention(nn.Module):
    def __init__(self, heads, hidden_size):
        super(MultiHeadAttention, self).__init__()
        assert hidden_size % heads == 0
        self.hidden_size = hidden_size
        self.heads = heads
        self.wq = nn.Linear(hidden_size, hidden_size)
        self.wk = nn.Linear(hidden_size, hidden_size)
        self.wv = nn.Linear(hidden_size, hidden_size)
        
    def forward(self, query, key, value, mask=None):
        # query, key, value = [batch_size, seq_len, hidden_size]
        batch_size, seq2_len, hidden_size = query.shape
        seq1_len = key.shape[1]
        q = self.wq(query).view(batch_size, seq2_len, self.heads, -1) # [batch_size, seq2_len, heads, d_k]
        k = self.wk(key).view(batch_size, seq1_len, self.heads, -1) # [batch_size, seq1_len, heads, d_k]
        v = self.wv(value).view(batch_size, seq1_len, self.heads, -1) # [batch_size, seq1_len, heads, d_k]
        q = q.permute(0, 2, 1, 3) #[batch_size, heads, seq2_len, d_k]
        k = k.permute(0, 2, 1, 3) #[batch_size, heads, seq1_len, d_k]
        v = v.permute(0, 2, 1, 3) #[batch_size, heads, seq1_len, d_k]
        d_k = self.hidden_size // self.heads
        attention = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(d_k) # [batch_size, heads, seq2_len, seq1_len]
        
        if mask is not None:
            if len(mask.shape) != len(attention.shape):
                mask.unsqueeze_(1) # [batch_size, 1, seq2_len, seq1_len]
            attention.masked_fill_(mask, float('-inf'))
        score = nn.functional.softmax(attention, dim=-1) 
        output = torch.matmul(score, v) # [batch_size, heads, seq2_len, d_k]
        output = output.permute(0, 2, 1, 3).reshape(batch_size, seq2_len, -1) # [batch_size, seq2_len, heads, d_k] -> [batch_size, seq2_len, hidden_size]
        return output

class FeedForward(nn.Module):
    def __init__(self, model_size, dropout=0.1):
        super(FeedForward, self).__init__()
        self.linear1 = nn.Linear(model_size, 4 * model_size)
        self.linear2 = nn.Linear(4 * model_size, model_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        x = self.relu(self.linear1(x))
        return self.linear2(self.dropout(x))


Decoder的构成与Encoder基本类似，不同的地方在于以下几个部分。
* masked_attn：在decoder部分的self-attention中，需要实现sequence mask，保证在encoder self-attention过程中，当前单词只能看到之前序列的单词，看不到之后序列的单词，以防止信息泄露。
* dec_attn：这是encoder和decoder之间进行交互的attention，query来自decoder的输入，key和value来自encoder的输出，从而能够保证decoder可以读取到encoder的相关信息。
* linear：即decoder的输出层，需要将大小为hidden_size的向量，映射为vocab_size大小的向量。这是因为encoder的输出，在经过softmax处理之后，需要形成每个单词生成的概率。


In [39]:
# 由于最终需要输出每个单词生成的概率，Decoder多了一个维度为vocab_size的linear层

class Decoder(nn.Module):
    def __init__(self, layer, vocab_size, hidden_size, N=6):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(N)])
        self.pe = PositionalEncoding()
        self.linear = nn.Linear(hidden_size, vocab_size)
        
    def forward(self, x, enc_output, dec_mask=None):
        x = self.embedding(x)
        x = self.pe(x)
        for layer in self.layers:
            x = layer(x, enc_output, dec_mask)
        return self.linear(x)

class DecoderLayer(nn.Module):
    def __init__(self, heads, hidden_size):
        super(DecoderLayer, self).__init__()
        self.heads = heads
        self.masked_attn = MultiHeadAttention(heads, hidden_size)
        self.dec_attn = MultiHeadAttention(heads, hidden_size)
        self.ffn = FeedForward(hidden_size)
        self.norm = nn.LayerNorm(hidden_size)
    
    def forward(self, x, enc_output, dec_mask):
        seq_len = x.shape[1]
        seq_mask = self.get_seq_mask(seq_len)
        x = self.norm(self.masked_attn(x, x, x, seq_mask) + x)
        x = self.norm(self.dec_attn(x, enc_output, enc_output, dec_mask) + x)
        x = self.norm(self.ffn(x) + x)
        return x
    
    # 实现一个上三角矩阵的sequence mask
    @staticmethod
    def get_seq_mask(seq_len):
        ones = np.ones((1, seq_len, seq_len))
        mask = np.triu(ones, k=1)
        return torch.from_numpy(mask) == 1

最后，将Encoder部分和Decoder部分整合起来，即为完整的Transformer模型架构。

In [None]:
class Transformer(nn.Module):
    def __init__(self, encoder, decoder, vocab_size, hidden_size):
        super(Transformer, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        
    def forward(self, x, y, enc_mask=None, dec_mask=None):
        enc_output = self.encoder(x, enc_mask)
        dec_output = self.decoder(y, enc_output, dec_mask)
        return dec_output