# 模型架构

大部分序列到序列（seq2seq）模型都使用编码器-解码器结构 [(引用)](https://arxiv.org/abs/1409.0473)。编码器把一个输入序列$(x_{1},...x_{n})$映射到一个连续的表示$z=(z_{1},...z_{n})$中。解码器对z中的每个元素，生成输出序列$(y_{1},...y_{m})$。解码器一个时间步生成一个输出。在每一步中，模型都是自回归的[(引用)](https://arxiv.org/abs/1308.0850)，在生成下一个结果时，会将先前生成的结果加入输入序列来一起预测。现在我们先构建一个EncoderDecoder类来搭建一个seq2seq架构：

In [12]:
import torch
import torch.nn as nn
import math
import torch.nn.functional as F

d_model = 512

## layers

In [13]:
class LayerNorm(nn.Module):
    
    def __init__(self, d_model, eps=1e-6) -> None:
        super(LayerNorm, self).__init__()
        
        self.a = nn.Parameter(torch.ones((d_model)))
        self.b = nn.Parameter(torch.zeros((d_model)))
        self.eps = eps
        
    def forward(self, input:torch.tensor):
        mean = torch.mean(input, dim=-1, keepdim=True)
        std = torch.std(input, dim=-1, keepdim=True)
        output = (input - mean) / (std + self.eps)
        output = self.a * output + self.b
        return output
        

input = torch.ones((10,d_model))
net = LayerNorm(d_model)
output = net(input)
print(input.shape, output.shape)

torch.Size([10, 512]) torch.Size([10, 512])


In [14]:
class FeedForwardNet(nn.Module):
    
    """
    This is position-wise feed forward, because linear only works on 1D.
    """
    def __init__(self, d_model, d_mid, drop_prob=0.1) -> None:
        super(FeedForwardNet, self).__init__()
        
        # just fc layers of depth of 2
        self.net = nn.Sequential(
            nn.Linear(d_model, d_mid),
            nn.ReLU(),
            nn.Dropout(drop_prob),
            nn.Linear(d_mid, d_model)
        )
        
    def forward(self, input):
        return self.net(input)

In [15]:
class MultiHeadAttention(nn.Module):
    
    def __init__(self, 
                 d_model, 
                 n_head, 
                 prob_dropout
                 ) -> None:
        super(MultiHeadAttention, self).__init__()
        
        assert d_model % n_head == 0
        self.n_head = n_head
        self.d_head = d_model // n_head
        
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(prob_dropout)
        self.W_O = nn.Linear(d_model, d_model)
        
    def forward(self, 
                query:torch.tensor, 
                key:torch.tensor, 
                value:torch.tensor, 
                mask:torch.tensor=None
                ):
        
        # qkv (bs, sl, dm)
        # mask (bs, sl, sl)
        assert query.shape == key.shape == value.shape
        bs, sl, dm = query.shape
        
        # (bs, sl, dm)
        Q = self.W_Q(query)
        K = self.W_K(key)
        V = self.W_V(value)
        
        # (bs, sl, nh, dh) => (bs, nh, sl, dh)
        Q_heads = torch.reshape(Q, (bs, sl, self.n_head, self.d_head)).transpose(1,2)
        K_heads = torch.reshape(K, (bs, sl, self.n_head, self.d_head)).transpose(1,2)
        V_heads = torch.reshape(V, (bs, sl, self.n_head, self.d_head)).transpose(1,2)
        
        # (bs, nh, sl, sl)
        attention_score = torch.matmul(Q_heads, K_heads.transpose(-1,-2))/math.sqrt(self.d_head)
        
        if mask is not None:
            attention_score = torch.masked_fill(attention_score, mask==0, -1e10)
        
        attention_weight = torch.softmax(attention_score, dim=-1)
        attention_weight = self.dropout(attention_weight)
        
        # (bs, nh, sl, dh)
        weighted_value = torch.matmul(attention_weight, V_heads)
        
        # (bs, sl, nh, dh)
        weighted_value = weighted_value.transpose(1,2)
        
        # (bs, sl, dm)
        weighted_value = torch.reshape(weighted_value, (bs, sl, -1))
        
        output = self.W_O(weighted_value)
        
        return output
        
    
mha = MultiHeadAttention(128, 8, 0.1)   
input = torch.randn((10, 30, 128)) 
print(input.shape)
output = mha(input, input, input)
print(output.shape)


torch.Size([10, 30, 128])
torch.Size([10, 30, 128])


In [16]:
class Gnerator(nn.Module):
    
    def __init__(self, d_model, vocab_size) -> None:
        super(Gnerator, self).__init__()
        
        self.proj = nn.Linear(d_model, vocab_size)
        
    def forward(self, input):
        
        return self.proj(input)

## blocks

In [17]:
class EncoderBlock(nn.Module):
    
    def __init__(self, d_model, n_heads, prob_dropout, d_mid) -> None:
        super(EncoderBlock, self).__init__()
        
        self.attention = MultiHeadAttention(d_model, n_heads, prob_dropout)
        self.norm1 = LayerNorm(d_model)
        
        self.ffn = FeedForwardNet(d_model, d_mid, prob_dropout)
        self.norm2 = LayerNorm(d_model)
        
    def forward(self, input, mask):
        
        _input = input
        input = self.attention(input, input, input, mask)
        input = self.norm1(input + _input)
        
        
        _input = input
        input = self.ffn(input)
        input = self.norm1(input + _input)
        
        return input
    
class DecoderBlock(nn.Module):
    
    def __init__(self, d_model, n_heads, prob_dropout, d_mid) -> None:
        super(DecoderBlock, self).__init__()
        
        self.self_attention = MultiHeadAttention(d_model, n_heads, prob_dropout)
        self.norm1 = LayerNorm(d_model)
        
        self.cross_attention = MultiHeadAttention(d_model, n_heads, prob_dropout)
        self.norm2 = LayerNorm(d_model)
        
        self.ffn = FeedForwardNet(d_model, d_mid, prob_dropout)
        self.norm3 = LayerNorm(d_model)
        
    def forward(self, dec, enc, src_mask, trg_mask):
        
        _dec = dec
        dec = self.self_attention(query=dec, key=dec, value=dec, mask=trg_mask)
        dec = self.norm1(dec + _dec)
        
        
        _dec = dec
        dec = self.cross_attention(query=dec, key=enc, value=enc, mask=src_mask)
        dec = self.norm2(dec + _dec)
        
        _dec = dec
        dec = self.ffn(dec)
        dec = self.norm3(dec + _dec)
        
        return dec
        

## embedding

In [18]:
class TokenEmbedding(nn.Embedding):
    """
    class for token embedding without positional information. It is basically like normal embedding layer.
    :param vocab_size: size of vocabulary
    :param d_model: dimensions of model
    """
    def __init__(self, vocab_size, d_model, padding_idx) -> None:
        super().__init__(vocab_size, d_model, padding_idx)


In [19]:
class PositionalEmbedding(nn.Module):
    
    def __init__(self, d_model, max_len, device) -> None:
        # This class takes device due to it being just a matrix, rather than torch.nn layers
        super(PositionalEmbedding, self).__init__()
        
        self.encoding = torch.zeros(max_len, d_model, device=device)
        self.encoding.requires_grad = False  # we don't need to compute gradient

        pos = torch.arange(0, max_len, device=device)
        pos = pos.float().unsqueeze(dim=1)
        # 1D => 2D unsqueeze to represent word's position

        _2i = torch.arange(0, d_model, step=2, device=device).float()
        # 'i' means index of d_model (e.g. embedding size = 50, 'i' = [0,50])
        # "step=2" means 'i' multiplied with two (same with 2 * i)

        self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / d_model)))
        self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))
        
    def forward(self, input):
        # self.encoding
        # [max_len = 512, d_model = 512]

        batch_size, seq_len = input.size()
        # [batch_size = 128, seq_len = 30]

        return self.encoding[:seq_len, :]
        # [seq_len = 30, d_model = 512]
        # it will add with tok_emb : [128, 30, 512]
        
x = torch.ones((2,30))
net = PositionalEmbedding(d_model, 256, 'cpu')
y = net(x)
print(x.shape, y.shape)

torch.Size([2, 30]) torch.Size([30, 512])


In [20]:
class TransformerEmbedding(nn.Module):
    
    def __init__(self, vocab_size, padding_idx, d_model, max_len, device) -> None:
        
        super(TransformerEmbedding, self).__init__()
        self.tok_emb = TokenEmbedding(vocab_size, d_model, padding_idx)
        self.pos_emb = PositionalEmbedding(d_model, max_len, device)
        
    def forward(self, input):
        
        return self.tok_emb(input) + self.pos_emb(input)
    

x = torch.ones((2,30), dtype=int)
net = TransformerEmbedding(500, 0, d_model, 256, 'cpu')
y = net(x)
print(x.shape, y.shape)

torch.Size([2, 30]) torch.Size([2, 30, 512])


## transformer

In [21]:
class Transformer(nn.Module):
    
    def __init__(self, src_vocab_size, src_padding_idx, trg_vocab_size, trg_padding_idx, d_model, n_heads, prob_dropout, d_mid, max_len, device, block_num) -> None:
        super(Transformer, self).__init__()
        
        self.src_padding_idx = src_padding_idx
        self.trg_padding_idx = trg_padding_idx
        self.device = device
        self.n_heads = n_heads
        
        self.src_emb = TransformerEmbedding(src_vocab_size, src_padding_idx, d_model, max_len, device)
        self.trg_emb = TransformerEmbedding(trg_vocab_size, trg_padding_idx, d_model, max_len, device)
        
        self.enc_blocks = EncoderBlock(d_model, n_heads, prob_dropout, d_mid)
        
        self.dec_blocks = DecoderBlock(d_model, n_heads, prob_dropout, d_mid)
        
        self.gen = Gnerator(d_model, trg_vocab_size)
        
    def forward(self, src, trg):
        
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        
        src = self.src_emb(src)
        trg = self.trg_emb(trg)
        src = self.enc_blocks(src, src_mask)
        dec = self.dec_blocks(trg, src, src_mask, trg_mask)
        output = self.gen(dec)
        return output
    
    
    def make_src_mask(self, src):
        src_mask = (src != self.src_padding_idx).unsqueeze(1).unsqueeze(1).type(torch.ByteTensor)
        src_mask = src_mask.repeat(1, self.n_heads,src_mask.shape[-1],1)
        return src_mask
        
    
    def make_trg_mask(self, trg):
        trg_pad_mask = (trg != self.trg_padding_idx).unsqueeze(1).unsqueeze(3)
        trg_len = trg.shape[1]
        trg_sub_mask = torch.tril(torch.ones(trg_len, trg_len)).type(torch.ByteTensor).to(self.device)
        trg_mask = trg_pad_mask & trg_sub_mask
        trg_mask = trg_mask.repeat(1,self.n_heads, 1, 1)
        return trg_mask
    

model = Transformer(500, 0, 500, 0, 512, 8, 0.1, 1024, 30, 'cpu', 2)

x = torch.tensor([[1,2,3,0], [1,2,0,0]])
y = torch.tensor([[1,2,3,4], [1,2,4,0]])
y_ = model(x, y)

In [22]:
y_.shape

torch.Size([2, 4, 500])