In [128]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np

## transformer 实现

### Encoder  part

In [129]:
max_len = 512
d_model = 128
d_k, d_q = 36, 36
d_v = 36
n_head = 8
d_fc = 256
n_layers = 3 
vocab_size = 30

In [130]:
class PositionEncoder(nn.Module):
    def __init__(self,  d_model, dropout=0.0, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(max_len, d_model, dtype=torch.float)
        row = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # 这个unqueeze1将整个row变为了一列
        col = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0/d_model)))
        pe[:, 0::2] = torch.sin(row * col)
        pe[:, 1::2] = torch.cos(row * col)
        # 将pe转化为（len， 1， dmodel）形式
        pe = pe.unsqueeze(0).transpose(0,1)
        self.register_buffer('pe', pe)
    def forward(self, x):
        '''
        x [seq_len, batch_size, embedding_size] 
        '''
        seq_len, bs, _ = x.size()
        x = x + self.pe[x.size(0), :]
        return self.dropout(x)    

In [131]:
def get_attn_pad_mask(q, k):
    '''q: (batch_size, len_q) '''
    batch_size, len_q = q.size()
    batch_size, len_k = k.size()
    mask = k.data.eq(0).unsqueeze(1) # bs, 1, len_q
    return mask.expand(batch_size, len_q, len_k)    

In [132]:
# q = torch.tensor([[1,1,1,1,0,0], [1,1,0,0,0,0]])
# get_encoder_mask(q, q)

In [133]:
def get_attn_subsequence_mask(seq):
    subsequence_mask =  np.triu(np.ones([seq.size(0), seq.size(1), seq.size(1)]), k=1)
    subsequence_mask = torch.from_numpy(subsequence_mask)
    return subsequence_mask

In [134]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, q, k, v, mask):
        '''q: [bs, n_head, seq_len, d_q]'''
        attn_score = torch.matmul(q, k.transpose(-1, -2)) / np.sqrt(d_k)
        attn_score.masked_fill_(mask, -1e9)
        
        attn_score = nn.Softmax(dim=-1)(attn_score)
        context = torch.matmul(attn_score, v)
        return context, attn_score        

In [135]:
class MultiHeadAttention(nn.Module):
    # 输入的inputsize（bs, len_q, d_model)
    def __init__(self):
        super().__init__()
        self.w_q = nn.Linear(d_model, n_head*d_q, bias=False)
        self.w_k = nn.Linear(d_model, n_head*d_k, bias=False)
        self.w_v = nn.Linear(d_model, n_head*d_v, bias=False)
        self.w_fc = nn.Linear(n_head*d_v, d_model, bias=False)
    def forward(self, input_q, input_k, input_v, attn_mask):
        residual, bs = input_q, input_q.size(0)
        q = self.w_q(input_q).view(bs, -1, n_head, d_q).transpose(1, 2)
        k = self.w_k(input_k).view(bs, -1, n_head, d_k).transpose(1, 2)
        v = self.w_v(input_v).view(bs, -1, n_head, d_v).transpose(1, 2)
        attn_mask = attn_mask.unsqueeze(1).repeat(1, n_head, 1, 1)
        context, attn_score = ScaledDotProductAttention()(q, k, v, attn_mask)
        context = context.transpose(1, 2).reshape(bs, -1, n_head*d_v)
        output = self.w_fc(context)
        return nn.LayerNorm(d_model)(output+residual), attn_score

In [136]:
class PoswiseFeedForwardNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(d_model, d_fc, bias=False),
            nn.ReLU(),
            nn.Linear(d_fc, d_model, bias=False)
        )
    def forward(self, x):
        residual = x
        output = self.fc(x)
        return nn.LayerNorm(d_model)(output+residual)

In [137]:
class EncoderLayer(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc_attn = MultiHeadAttention()
        self.ffn = PoswiseFeedForwardNet()
    def forward(self, x, mask):
        output, attn_score = self.enc_attn(x, x, x, mask)
        output = self.ffn(output)
        return output, attn_score

In [138]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        # vocab size 在这里指的是全部词的总量
        self.word_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = PositionEncoder(d_model)
        self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])
        
    def forward(self, x):
        ''' x: [bs, seq_len]'''
        out = self.word_emb(x)
        out = self.pos_emb(out.transpose(0,1)).transpose(0,1)
        mask = get_attn_pad_mask(x, x)
        attn_scores = []
        for layer in self.layers:
            out, attn_score = layer(out, mask)
            attn_scores.append(attn_score)
        return out, attn_scores


In [139]:
# x = torch.rand(4, 24).type(torch.LongTensor)
# a = Encoder()
# print(a)
# print(a(x)[0].shape)
# print(len(a(x)[1]))

### Decoder part

In [140]:
class DecoderLayer(nn.Module):
    def __init__(self):
        super().__init__()
        self.dec_self_attn = MultiHeadAttention()
        self.dec_enc_attn = MultiHeadAttention()
        self.ffn = PoswiseFeedForwardNet()
    def forward(self, dec_input, enc_output, dec_mask, enc_mask):
        dec_output, dec_attn_score = self.dec_self_attn(dec_input, dec_input, dec_input, dec_mask)
        dec_output, enc_attn_score = self.dec_enc_attn(dec_output, enc_output, enc_output, enc_mask)
        dec_output = self.ffn(dec_output)
        return dec_output, dec_attn_score, enc_attn_score

In [141]:
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.word_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = PositionEncoder(d_model)
        self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)])
    def forward(self, dec_inputs, enc_inputs, enc_outputs):
        dec_outputs = self.word_emb(dec_inputs)
        dec_outputs = self.pos_emb(dec_outputs.transpose(0,1)).transposse(0,1)
        dec_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs)
        dec_seq_mask = get_attn_subsequence_mask(dec_inputs)
        dec_self_mask = torch.gt((dec_pad_mask+dec_seq_mask), 0)
        dec_enc_mask = get_attn_pad_mask(dec_inputs, enc_inputs)
        
        dec_attn_scores, enc_attn_scores = [], []
        for layer in self.layers:
            dec_outputs, dec_attn_score, enc_attn_score = layer(dec_outputs, enc_outputs, dec_self_mask, dec_enc_mask)
            dec_attn_scores.append(dec_attn_score)
            enc_attn_scores.append(enc_attn_score)
        return dec_outputs, dec_attn_scores, enc_attn_scores
            

In [142]:
class Transformer(nn.Module):
    def __init__(self):
        super(Transformer, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.projection = nn.Linear(d_model, vocab_size, bias=False)
    def forward(self, enc_inputs, dec_inputs):
        '''
        enc_inputs: [batch_size, src_len]
        dec_inputs: [batch_size, tgt_len]
        '''
      
        enc_outputs, enc_self_attns = self.encoder(enc_inputs)
        dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs)
        dec_logits = self.projection(dec_outputs) # dec_logits: [batch_size, tgt_len, tgt_vocab_size]
        return dec_logits.view(-1, dec_logits.size(-1)), enc_self_attns, dec_self_attns, dec_enc_attns

In [143]:
# a =Transformer()
# a