In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

以离散符号的分类任务为例，离散符号描述的是编码器的输入和解码器的输出都是对于离散值的预测。传统的seq2seq模型是将编码器的输入得到一个固定的表征，注意力机制会对编码器的输出加权求和得到。

## 编码器

编码器，得到输入序列的上下文表征。此处编码器就是将给定的输入序列转换成上下文相关的一个新的序列

In [2]:
class Seq2SeqEncoder(nn.Module):
    def __init__(self, embedding_dim, hidden_size, source_vocab_size):
        super(Seq2SeqEncoder, self).__init__()
        
        self.embedding_table = nn.Embedding(source_vocab_size, embedding_dim)
        
        self.lstm_layer = nn.LSTM(input_size = embedding_dim,
                                  hidden_size = hidden_size,
                                  batch_first = True)
    
    def forward(self, input_ids):
        input_sequence = self.embedding_table(input_ids)  # [bs, source_len, embedding_dim]
        
        output_states, (final_h, final_c) = self.lstm_layer(input_sequence)
        return output_states, final_h

## 注意力机制

这里的注意力机制采用`dot-product`的方式。也就是编码器的$t$时刻输入与解码器的输出序列两两之间做内积。

最后输出得到的`context`就是第$t$时刻，上下文输出得到的解码器向量。

In [3]:
class Seq2SeqAttentionMechanism(nn.Module):
    def __init__(self):
        super(Seq2SeqAttentionMechanism, self).__init__()
    
    def forward(self, decoder_state_t, encoder_states):
        bs, source_len, hidden_size = encoder_states.shape
        
        decoder_state_t = decoder_state_t.unsqueeze(1)  # [bs, 1, hidden_size]
        decoder_state_t = torch.tile(decoder_state_t, dims=(1, source_len, 1))  # [bs, source_len, hidden_size]
        
        score = torch.sum(decoder_state_t * encoder_states, dim=-1)  # [bs, source_len]
        attn_prob = F.softmax(score, dim=-1)  # [bs, source_len]
        
        context = torch.sum(attn_prob.unsqueeze(-1) * encoder_states, 1)  # [bs, hidden_size]
        return attn_prob, context

## 解码器

`self.proj_layer`中`Linear`的输入是`hidden_size * 2`维度的，接收的是编码器编码输出`context`和解码器在时刻$t$获取到的字符的编码拼接而得到的。参考文章[Effective Approaches to Attention-based Neural Machine Translation]()。

`start_id`和`end_id`用于处理训练过程中的标签偏移。

训练过程调用的是`forward`函数，需要传入两个参数`shifted_target_id`和`encoder_states`，`shifted_target_id`的第一个字符一般是`start_id`，他是一个特殊字符。

在训练阶段，我们知道目标长度是多少，所以可以用`for`循环去做，在推理阶段，就不清楚目标长度是多少，我们就可以用`while`循环去做。

在推理阶段，用的是解码器预测的输出作为下一个时刻的输入，用于之后的预测。

In [4]:
class Seq2SeqDecoder(nn.Module):
    def __init__(self, embedding_dim, hidden_size, num_classes, target_vocab_size, start_id, end_id):
        super(Seq2SeqDecoder, self).__init__()
        
        self.lstm_cell = nn.LSTMCell(embedding_dim, hidden_size)
        self.proj_layer = nn.Linear(hidden_size * 2, num_classes)
        self.attention_mechanism = Seq2SeqAttentionMechanism()
        self.num_classes = num_classes
        self.embedding_table = nn.Embedding(target_vocab_size, embedding_dim)
        self.start_id = start_id
        self.end_id = end_id
    
    def forward(self, shifted_target_id, encoder_states):
        # 训练阶段调用
        
        shifted_target = self.embedding_table(shifted_target_id)
        
        bs, target_len, embedding_dim = shifted_target.shape
        bs, source_len, hidden_size = encoder_states.shape
        
        logits = torch.zeros(bs, target_len, self.num_classes)
        probs = torch.zeros(bs, target_len, source_len)
        
        for t in range(target_len):
            
            decoder_input_t = shifted_target[:, t, :]  # [bs, embedding_dim]
            
            if t == 0:
                
                h_t, c_t = self.lstm_cell(decoder_input_t)
            else:
                h_t, c_t = self.lstm_cell(decoder_input_t, (h_t, c_t))
            
            attn_prob, context = self.attention_mechanism(h_t, encoder_states)
            
            decoder_output = torch.cat((context, h_t), -1)
            logits[:, t, :] = self.proj_layer(decoder_output)
            probs[:, t, :] = attn_prob
            
        return probs, logits
    
    def inference(self, encoder_states):
        # 推理阶段使用
        
        target_id = self.start_id
        h_t = None
        result = []
        
        while True:
            decoder_input_t = self.embedding_table(target_id)
            if h_t is None:
                h_t, c_t = self.lstm_cell(decoder_input_t)
            else:
                h_t, c_t = self.lstm_cell(decoder_input_t, (h_t, c_t))
            
            attn_prob, context = self.attention_mechanism(h_t, encoder_states)
            decoder_output = torch.cat((context, h_t), -1)
            logits = self.proj_layer(decoder_output)
            
            target_id = torch.argmax(logits, -1)
            result.append(target_id)
            
            if torch.any(target_id == self.end_id):
                print("stop decoding")
            
                break
            
        predited_ids = torch.cat(result, dim=0)
        return predited_ids
            

## 模型

In [5]:
class Model(nn.Module):
    def __init__(self, embedding_dim, hidden_size, num_classes, 
                source_vocab_size, target_vocab_size, start_id, end_id):
        super(Model, self).__init__()
        
        self.encoder = Seq2SeqEncoder(embedding_dim, hidden_size, source_vocab_size)
        
        self.decoder = Seq2SeqDecoder(embedding_dim, hidden_size, num_classes, target_vocab_size, start_id, end_id)
    
    def forward(self, input_seqence_ids, shifted_target_id):
        encoder_states, final_h = self.encoder(input_seqence_ids)
        
        probs, logits = self.decoder(shifted_target_id, encoder_states)
        return probs, logits
    
    def inference(self):
        pass

In [6]:
if __name__ == "__main__":
    source_length = 3
    target_length = 4
    embedding_dim = 8
    hidden_size = 16
    num_classes = 10
    bs = 2
    start_id = end_id = 0
    source_vocab_size = 100
    target_vocab_size = 100
    
    input_seqence_ids = torch.randint(source_vocab_size, size=(bs, source_length)).to(torch.int32)
    
    target_ids = torch.randint(target_vocab_size, size=(bs, target_length))
    target_ids = torch.cat((target_ids, end_id * torch.ones(bs, 1)), dim=1).to(torch.int32)
    
    shifted_target_ids = torch.cat((start_id * torch.ones(bs, 1),target_ids[:, 1:]), dim=1).to(torch.int32)
    
    model = Model(embedding_dim, hidden_size, num_classes, source_vocab_size, target_vocab_size, start_id, end_id)
    
    probs, logits = model(input_seqence_ids, shifted_target_ids)
    print(probs.shape)
    print(logits.shape)

torch.Size([2, 5, 3])
torch.Size([2, 5, 10])
