In [3]:
import torch
import torch.nn as nn
import math
import numpy as np
import torch.optim as optim
import torch.utils.data as data

In [5]:
#               enc_input         dec_input          dec_output
sentences = [ [ '我 是 学 生 P', 'S I am a student', 'I am a student E' ],
              [ '我 喜 欢 学 习', 'S I like learning P', 'I like learning P E' ],
              [ '我 是 女 生 P', 'S I am a girl', 'I am a girl E' ] ]
# S 开始符号，E 结束符号， P 占位符号（凑够一句话5个字）
src_vocab = { 'P': 0, '我': 1, '是': 2, '学': 3, '生': 4, '喜': 5, '欢': 6, '习': 7, '女': 8 }
src_i2w = { src_vocab[i] for i in src_vocab }
src_size = len( src_vocab )                                     # 源字典尺寸

tgt_vocab = { 'P': 0, 'S': 1, 'E': 2, 'I': 3, 'am': 4, 'a': 5, 'student': 6, 'like': 7, 'learning': 8, 'girl': 9 }
tgt_i2w = { tgt_vocab[i] for i in tgt_vocab }
tgt_size = len( tgt_vocab )                                     # 目标字典尺寸

src_len = len( sentences[0][0].split( ' ' ) )                   # Encoder 输出的最大长度
tgt_len = len( sentences[0][1].split( ' ' ) )                   # Decoder 输入输出的最大长度
# sentences[0][1].split( ' ' )

把sentence转换成字典索引

In [8]:
def make_dict():
    enc_inputs, dec_inputs, dec_outputs = [], [], []
    for i in range( len( sentences ) ):
        enc_input = [ [ src_vocab[n] for n in sentences[i][0].split() ] ]
        dec_input = [ [ tgt_vocab[n] for n in sentences[i][1].split() ] ]
        dec_output = [ [ tgt_vocab[n] for n in sentences[i][2].split() ] ]
        enc_inputs.append( enc_input )
        dec_inputs.append( dec_input )
        dec_outputs.append( dec_output )
    return torch.LongTensor( enc_inputs ), torch.LongTensor( dec_inputs ), torch.LongTensor( dec_outputs )
# enc_inputs, dec_inputs, dec_outputs

([[1, 2, 3, 4, 0], [1, 5, 6, 3, 7], [1, 2, 8, 4, 0]],
 [[1, 3, 4, 5, 6], [1, 3, 7, 8, 0], [1, 3, 4, 5, 9]],
 [[3, 4, 5, 6, 2], [3, 7, 8, 0, 2], [3, 4, 5, 9, 2]])

In [9]:
class MDataset( data.Dataset ):
    def __init__( self, enc_inputs, dec_inputs, dec_outputs ):
        super( MDataset, self ).__init__()
        self.enc_inputs = enc_inputs
        self.dec_inputs = dec_inputs
        self.dec_outputs = dec_outputs
        
    def __getitem__( self, idx ):
        return self.enc_inputs[idx], self.dec_inputs[idx], self.dec_outputs[idx]
        
    def __len__( self ):
        return self.enc_inputs.shape[0]

In [ ]:
d_model = 512               # 字符embedding的维度数
d_ff = 2048                 # 向前传播的隐藏层维度
d_q = 64
d_v = 64                    # Q、V的维度
n_layers = 6                # encoder, decoder 的个数
n_heads = 8                 # Multi-Head Attention设置为8

**位置编码**
pos_table[1:, 0::2] 表示从第1行开始（跳过第0行），每一行的偶数列（从第0列开始，每隔两列）
pos_table[1:, 1::2] 表示从第1行开始，每一行的奇数列（从第1列开始，每隔两列）
偶数位置：$PE_{pos, i} = sin( pos/10000^{2i/d_{model}})$
奇数位置：$PE(pos, i) = cos( pos/10000^{2i/d_{model}})$

enc_inputs -> 数据张量 -> [ batch_size, seq_len, d_model ]      seq_len -> 每个batch的长度
enc_inputs.size(1) -> 获得enc_inputs的序列长度
pos_bas -> [ max_len, d_model ]     最大序列长度，模型维度

In [ ]:
class PositionalEncoding( nn.Module ):                                                  # 位置编码
    def __init__( self, d_model, dropout = 0.1, max_len = 5000 ):
        super( PositionalEncoding, self ).__init__()
        for pos in range( max_len ):
            if pos != 0:
                pos_bas = np.array( [ pos / np.power( 10000, 2*i / d_model ) for i in range( d_model ) ] )
            else:
                np.zeros( d_model )
        pos_bas[ 1:, 0::2 ] = np.sin( pos_bas[ 1:, 0::2 ] )
        pos_bas[ 0, 1::2 ] = np.cos( pos_bas[ 0, 1::2 ] )
        self.pos_bas = torch.FloatTensor( pos_bas )
        
        self.dropout = nn.Dropout( p = dropout )
        
        
    def forward( self, enc_inputs ):
        enc_inputs += self.pos_bas[ :enc_inputs.size(1), : ]
        return self.dropout( enc_inputs )
    

input_Q/K/V -> 经过线性变换先变成W_Q/K/V -> view重塑张量[ batch_size, seq_len, n_heads, d_k ]，seq_len = -1表示自动推断重塑之后这个维度的大小 -> transpose 交换第一个维度和第二个维度 变为[ batch_size, n_heads, seq_len, d_k ]

将注意力掩码 attn_mask 的维度扩展一维，并在第二维度上重复 n_heads 次，以匹配注意力张量的维度 -> [batch_size, n_heads, seq_len, seq_len]

In [ ]:
class MultiHeadAttention(nn.Module):
    def __init__( self ):
        super( MultiHeadAttention, self ).__init__()
        self.W_Q = nn.Linear( d_model, d_q * n_heads, bias = False )
        self.W_K = nn.Linear( d_model, d_q * n_heads, bias = False )
        self.W_V = nn.Linear( d_model, d_v * n_heads, bias = False )
        self.fc = nn.Linear( d_v * n_heads, d_model, bias = False )
        
    def forward( self, input_Q, input_K, input_V, attn_mask ):
        residual, batch_size = input_Q, input_Q.size( 0 )
        Q = self.W_Q( input_Q ).view( batch_size, -1, n_heads, d_q ).transpose( 1, 2 )
        K = self.W_K( input_K ).view( batch_size, -1, n_heads, d_q ).transpose( 1, 2 )
        V = self.W_V( input_V ).view( batch_size, -1, n_heads, d_v ).transpose( 1, 2 )
        
        attn_mask = attn_mask.unsqueeze( 1 ).repeat( 1, n_heads, 1, 1 )

In [None]:
class Transformer( nn.Module ):
    def __init__( self ):
        super( Transformer, self ).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.projection = nn.Linear( d_model, tgt_size, bias = False )
        
    def forward( self, enc_input, dec_input ):
        enc_output, enc_self_attns = self.encoder( enc_input )
        dec_output, dec_self_attns, dec_enc_attns = self.Decoder( dec_input, enc_input, enc_output )
        dec_logits = self.projection( dec_output )
        out = dec_logits.view( -1, dec_logits.size( -1 ) ), enc_self_attns, dec_self_attns, dec_enc_attns
        return out

In [None]:
'''
class Encoder( nn.Module ):
    def __init__( self ):
        super( Encoder, self ).__init__()
        self.src_emb = nn.Embedding( src_size, d_model )                # 源语言的vocabulary size
        self.pos_emb = PositionalEmbedding( d_model )
        self.layers = nn.ModuleList( [ EncoderLayer() for i in range(len())])
'''

In [None]:
'''
class Decoder( nn.Module ):
    def __init__( self ):
        super( Decoder, self ).__init__()
        self.tgt_emb = nn.Embedding( tgt_size, d_model )                # 目标语言的vocabulary size
'''