In [99]:
'''
  code by Tae Hwan Jung(Jeff Jung) @graykode, Derek Miller @dmmiller612, modify by wmathor
  Reference : https://github.com/jadore801120/attention-is-all-you-need-pytorch
              https://github.com/JayParks/transformer
'''
import math
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as Data

# S: Symbol that shows starting of decoding input
# E: Symbol that shows starting of decoding output
# P: Symbol that will fill in blank sequence if current batch data size is short than time steps
sentences = [
        # input                   target         
        ['ich mochte ein bier . P', 'i want a beer .'],
        ['ich mochte ein cola . P', 'i want a coke .']
]

# Padding Should be Zero
src_vocab = {'P' : 0, 'ich' : 1, 'mochte' : 2, 'ein' : 3, 'bier' : 4, 'cola' : 5, 'S' : 6, 'E' : 7, '.' : 8 }
src_vocab_size = len(src_vocab)

tgt_vocab = {'P' : 0, 'i' : 1, 'want' : 2, 'a' : 3, 'beer' : 4, 'coke' : 5, 'S' : 6, 'E' : 7, '.' : 8}
idx2word = {i: w for i, w in enumerate(tgt_vocab)}
tgt_vocab_size = len(tgt_vocab)

src_len = 5 # enc_input max sequence length
tgt_len = 6 # dec_input(=dec_output) max sequence length

In [100]:
def make_data(sentences):
    enc_inputs, dec_inputs, dec_outputs = [], [], []
    for i in range(len(sentences)):
      enc_input = [[src_vocab[n] for n in sentences[i][0].split()]] # [[1, 2, 3, 4, 0], [1, 2, 3, 5, 0]]
      dec_input = [[tgt_vocab['S']] + [tgt_vocab[n] for n in sentences[i][1].split()]] # [[1, 2, 3, 4, 8, 7], [1, 2, 3, 5, 8, 7]]
      dec_output = [[tgt_vocab[n] for n in sentences[i][1].split()] + [tgt_vocab['E']]] # [[1, 2, 3, 4, 8, 7], [1, 2, 3, 5, 8, 7]]

      enc_inputs.extend(enc_input)
      dec_inputs.extend(dec_input)
      dec_outputs.extend(dec_output)

    return torch.LongTensor(enc_inputs), torch.LongTensor(dec_inputs), torch.LongTensor(dec_outputs)

enc_inputs, dec_inputs, dec_outputs = make_data(sentences)

class MyDataSet(Data.Dataset):
  def __init__(self, enc_inputs, dec_inputs, dec_outputs):
    super(MyDataSet, self).__init__()
    self.enc_inputs = enc_inputs
    self.dec_inputs = dec_inputs
    self.dec_outputs = dec_outputs
  
  def __len__(self):
    return self.enc_inputs.shape[0]
  
  def __getitem__(self, idx):
    return self.enc_inputs[idx], self.dec_inputs[idx], self.dec_outputs[idx]

loader = Data.DataLoader(MyDataSet(enc_inputs, dec_inputs, dec_outputs), 2, True)

In [101]:
for batch in loader:
    print(batch[0][0], batch[1][0])
    print(batch[0].shape, batch[1].shape)
    break

tensor([1, 2, 3, 4, 8, 0]) tensor([6, 1, 2, 3, 4, 8])
torch.Size([2, 6]) torch.Size([2, 6])


In [102]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        '''
        x: [seq_len, batch_size, d_model]
        '''
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)


In [103]:
# max_len_ = 10
# d_model_ = 6
# pe = torch.zeros(max_len_, d_model_)
# position = torch.arange(0, max_len_, dtype=torch.float).unsqueeze(1)
# print('position', position.shape)
# div_term = torch.exp(torch.arange(0, d_model_, 2).float() * (-math.log(10000.0) / d_model_))
# print('div_term', div_term.shape)

# pe[:, 0::2] = torch.sin(position * div_term)
# pe[:, 1::2] = torch.cos(position * div_term)
# print('pe', pe.shape)
# print('pe', pe[:, 0])
# print('pe', pe[:, 1])
# x = torch.ones((2, 4, d_model_))
# print(x)
# pe_x = pe[:4, :]
# x = x + pe_x
# print(x)

In [104]:
def get_attn_pad_mask(seq_q, seq_k):
    '''
    seq_q: [batch_size, seq_len]
    seq_k: [batch_size, seq_len]
    seq_len could be src_len or it could be tgt_len
    seq_len in seq_q and seq_len in seq_k maybe not equal
    '''
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    # eq(zero) is PAD token
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)  # [batch_size, 1, len_k], False is masked
    return pad_attn_mask.expand(batch_size, len_q, len_k)  # [batch_size, len_q, len_k]

def get_attn_subsequence_mask(seq):
    '''
    seq: [batch_size, tgt_len]
    '''
    attn_shape = [seq.size(0), seq.size(1), seq.size(1)] # [batch_size, tgt_len, tgt_len]
    mask = np.ones(attn_shape) # [batch_size, tgt_len, tgt_len]
    subsequence_mask = np.triu(mask, k=1) # Upper triangular matrix
    subsequence_mask = torch.from_numpy(subsequence_mask).byte()
    return subsequence_mask # [batch_size, tgt_len, tgt_len]

In [105]:
# get_attn_subsequence_mask
# batch_size = 2
# seq = torch.ones((batch_size, 3))
# asm = get_attn_subsequence_mask(seq)
# print(asm)

# console:
# tensor([[[0, 1, 1],
#          [0, 0, 1],
#          [0, 0, 0]],

#         [[0, 1, 1],
#          [0, 0, 1],
#          [0, 0, 0]]], dtype=torch.uint8)

In [106]:
# batch_size = 4
# seq_q = torch.ones((batch_size, 3))
# seq_q[:, -1] = 0
# print('seq_q', seq_q.shape)
# print(seq_q)
# seq_k = torch.ones((batch_size, 4))
# seq_k[:, -1] = 0
# print('seq_k', seq_k.shape)
# print(seq_k)
# apm = get_attn_pad_mask(seq_q, seq_k)
# print('apm', apm.shape)
# print(apm)

# console:
# seq_q torch.Size([4, 3])
# tensor([[1., 1., 0.],
#         [1., 1., 0.],
#         [1., 1., 0.],
#         [1., 1., 0.]])
# seq_k torch.Size([4, 4])
# tensor([[1., 1., 1., 0.],
#         [1., 1., 1., 0.],
#         [1., 1., 1., 0.],
#         [1., 1., 1., 0.]])
# apm torch.Size([4, 3, 4])
# tensor([[[False, False, False,  True],
#          [False, False, False,  True],
#          [False, False, False,  True]],

#         [[False, False, False,  True],
#          [False, False, False,  True],
#          [False, False, False,  True]],

#         [[False, False, False,  True],
#          [False, False, False,  True],
#          [False, False, False,  True]],

#         [[False, False, False,  True],
#          [False, False, False,  True],
#          [False, False, False,  True]]])

In [107]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_k):
        super(ScaledDotProductAttention, self).__init__()
        self.d_k = d_k
    def forward(self, Q, K, V, attn_mask):
        '''
        Q: [batch_size, n_heads, len_q, d_k]
        K: [batch_size, n_heads, len_k, d_k]
        V: [batch_size, n_heads, len_v(=len_k), d_v]
        attn_mask: [batch_size, n_heads, seq_len, seq_len]
        '''
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(self.d_k) # scores : [batch_size, n_heads, len_q, len_k]
        scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is True.
        
        attn = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attn, V) # [batch_size, n_heads, len_q, d_v]
        return context, attn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, d_k, d_v, n_heads):
        super(MultiHeadAttention, self).__init__()
        self.n_heads = n_heads
        self.d_k = d_k
        self.d_v = d_v
        self.scaled_dot_attn = ScaledDotProductAttention(d_k)
        self.layer_norm = nn.LayerNorm(d_model)
        self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_V = nn.Linear(d_model, d_v * n_heads, bias=False)
        self.fc = nn.Linear(n_heads * d_v, d_model, bias=False)
    def forward(self, input_Q, input_K, input_V, attn_mask):
        '''
        input_Q: [batch_size, len_q, d_model]
        input_K: [batch_size, len_k, d_model]
        input_V: [batch_size, len_v(=len_k), d_model]
        attn_mask: [batch_size, seq_len, seq_len]
        '''
        residual, batch_size = input_Q, input_Q.size(0)
        # 下面的多头的参数矩阵是放在一起做线性变换的，然后再拆成多个头，这是工程实现的技巧
        # B: batch_size, S:seq_len, D: dim
        # (B, S, D) -proj-> (B, S, D_new) -split-> (B, S, Head, W) -trans-> (B, Head, S, W)
        #           线性变换               拆成多头

        Q = self.W_Q(input_Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)  # Q: [batch_size, n_heads, len_q, d_k]
        K = self.W_K(input_K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)  # K: [batch_size, n_heads, len_k, d_k]
        V = self.W_V(input_V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1,2)  # V: [batch_size, n_heads, len_v(=len_k), d_v]

        attn_mask = attn_mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1) # attn_mask : [batch_size, n_heads, seq_len, seq_len]

        # context: [batch_size, n_heads, len_q, d_v], attn: [batch_size, n_heads, len_q, len_k]
        context, attn = self.scaled_dot_attn(Q, K, V, attn_mask)
        context = context.transpose(1, 2).reshape(batch_size, -1, self.n_heads * self.d_k) # context: [batch_size, len_q, n_heads * d_v]
        output = self.fc(context) # [batch_size, len_q, d_model]
        return self.layer_norm(output + residual), attn # [batch_size, len_q, d_model]

In [108]:
def count_parameters(model):
    params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f'The model has {params:,} trainable parameters')

In [109]:
# batch_size, seq_len, d_model, d_k, d_v, n_heads = 4, 3, 6, 8, 8, 2
# enc_self_attn = MultiHeadAttention(d_model, d_k, d_v, n_heads)
# q = torch.randn(batch_size, seq_len, d_model)
# k = torch.randn(batch_size, seq_len, d_model)
# v = torch.randn(batch_size, seq_len, d_model)
# attn_mask = torch.ones(batch_size, seq_len, seq_len)
# attn_mask[:, :, -1] = 0
# outputs, attn = enc_self_attn(q, k, v, attn_mask)
# print(outputs.shape, attn.shape) # [batch_size, len_q, d_model] [batch_size, n_heads, len_q, len_k]


In [110]:
# batch_size, seq_len, d_model, d_k, d_v, n_heads = 4, 3, 512, 64, 64, 8
# enc_self_attn = MultiHeadAttention(d_model, d_k, d_v, n_heads)
# count_parameters(enc_self_attn)
# torch.save(enc_self_attn.state_dict(), 'enc_self_attn.pt')

In [111]:
# implement layer norm with numpy
# class LayerNorm:
#     def __init__(self, d_model, eps=1e-6):
#         self.a_2 = np.ones(d_model)
#         self.b_2 = np.zeros(d_model)
#         self.eps = eps
#     def __call__(self, x):
#         '''
#         x: [batch_size, seq_len, d_model]
#         '''
#         mean = np.mean(x, axis=-1, keepdims=True)
#         std = np.std(x, axis=-1, keepdims=True)
#         return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
    
# layer_norm = LayerNorm(3)
# x = np.random.randn(1, 2, 3)
# print(layer_norm(x))

# console:
# [[[ 0.68685879 -1.41401887  0.72716008]
#   [ 1.35468178 -0.32574514 -1.02893663]]]


In [112]:
class PoswiseFeedForwardNet(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(d_model, d_ff, bias=False),
            nn.ReLU(),
            nn.Linear(d_ff, d_model, bias=False)
        )
        self.layer_norm = nn.LayerNorm(d_model)
    def forward(self, inputs):
        '''
        inputs: [batch_size, seq_len, d_model]
        '''
        residual = inputs
        output = self.fc(inputs)
        return self.layer_norm(output + residual) # [batch_size, seq_len, d_model]


In [113]:

class EncoderLayer(nn.Module):
    def __init__(self, d_model, d_k, d_v, n_heads, d_ff):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention(d_model, d_k, d_v, n_heads)
        self.pos_ffn = PoswiseFeedForwardNet(d_model, d_ff)

    def forward(self, enc_inputs, enc_self_attn_mask):
        '''
        enc_inputs: [batch_size, src_len, d_model]
        enc_self_attn_mask: [batch_size, src_len, src_len]
        '''
        # enc_outputs: [batch_size, src_len, d_model], attn: [batch_size, n_heads, src_len, src_len]
        enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_inputs to same Q,K,V
        enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size, src_len, d_model]
        return enc_outputs, attn


class Encoder(nn.Module):
    def __init__(self, d_model, n_layers, src_vocab_size, d_k, d_v, n_heads, d_ff):
        super(Encoder, self).__init__()
        self.src_emb = nn.Embedding(src_vocab_size, d_model)
        self.pos_emb = PositionalEncoding(d_model)
        self.layers = nn.ModuleList([EncoderLayer(d_model, d_k, d_v, n_heads, d_ff) for _ in range(n_layers)])

    def forward(self, enc_inputs):
        '''
        enc_inputs: [batch_size, src_len]
        '''
        embeded = self.src_emb(enc_inputs) # [batch_size, src_len, d_model]
        posed_embeded = self.pos_emb(embeded.transpose(0, 1)).transpose(0, 1) # [batch_size, src_len, d_model]
        enc_self_attn_mask = get_attn_pad_mask(enc_inputs, enc_inputs) # [batch_size, src_len, src_len]
        
        enc_self_attns = []
        outputs = posed_embeded
        for layer in self.layers:
            # enc_outputs: [batch_size, src_len, d_model], enc_self_attn: [batch_size, n_heads, src_len, src_len]
            outputs, enc_self_attn = layer(outputs, enc_self_attn_mask)
            enc_self_attns.append(enc_self_attn)
        return outputs, enc_self_attns

In [114]:
# batch_size, seq_len, d_model, n_layers, src_vocab_size, d_k, d_v, n_heads, d_ff = 4, 3, 12, 2, 50, 6, 6, 2, 5
# encoder = Encoder(d_model, n_layers, src_vocab_size, d_k, d_v, n_heads, d_ff)

# inputs = torch.randint(1, 50, (batch_size, seq_len))
# inputs[:, -1] = 0 # mask last token
# outputs, enc_self_attns = encoder(inputs)
# print(outputs.shape, len(enc_self_attns)) # [batch_size, src_len, d_model] n_layers

In [115]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, d_k, d_v, n_heads, d_ff):
        super(DecoderLayer, self).__init__()
        self.dec_self_attn = MultiHeadAttention(d_model, d_k, d_v, n_heads)
        self.dec_enc_attn = MultiHeadAttention(d_model, d_k, d_v, n_heads)
        self.pos_ffn = PoswiseFeedForwardNet(d_model, d_ff)

    def forward(self, dec_inputs, encoder_outputs, dec_self_attn_mask, dec_enc_attn_mask):
        '''
        dec_inputs: [batch_size, tgt_len, d_model]
        enc_outputs: [batch_size, src_len, d_model]
        dec_self_attn_mask: [batch_size, tgt_len, tgt_len]
        dec_enc_attn_mask: [batch_size, tgt_len, src_len]
        '''
        # dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len]
        dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)
        # dec_outputs: [batch_size, tgt_len, d_model], dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]
        dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, encoder_outputs, encoder_outputs, dec_enc_attn_mask)
        dec_outputs = self.pos_ffn(dec_outputs) # [batch_size, tgt_len, d_model]
        return dec_outputs, dec_self_attn, dec_enc_attn

class Decoder(nn.Module):
    def __init__(self, d_model, n_layers, tgt_vocab_size, d_k, d_v, n_heads, d_ff):
        super(Decoder, self).__init__()
        self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_emb = PositionalEncoding(d_model)
        self.layers = nn.ModuleList([DecoderLayer(d_model, d_k, d_v, n_heads, d_ff) for _ in range(n_layers)])

    def forward(self, decoder_inputs, encoder_inputs, encoder_outputs):
        '''
        dec_inputs: [batch_size, tgt_len]
        enc_intpus: [batch_size, src_len]
        enc_outputs: [batsh_size, src_len, d_model]
        '''
        embeded = self.tgt_emb(decoder_inputs) # [batch_size, tgt_len, d_model]
        pos_embeded = self.pos_emb(embeded.transpose(0, 1)).transpose(0, 1) # [batch_size, tgt_len, d_model]
        dec_self_attn_pad_mask = get_attn_pad_mask(decoder_inputs, decoder_inputs) # [batch_size, tgt_len, tgt_len]
        dec_self_attn_subsequence_mask = get_attn_subsequence_mask(decoder_inputs) # [batch_size, tgt_len, tgt_len]
        
        # Decoder中把两种mask矩阵相加（既屏蔽了pad的信息，也屏蔽了未来时刻的信息）
        dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequence_mask), 0) # [batch_size, tgt_len, tgt_len]

        # 这个mask主要用于encoder-decoder attention层
        # get_attn_pad_mask主要是enc_inputs的pad mask矩阵(因为enc是处理K,V的，求Attention时是用v1,v2,..vm去加权的，要把pad对应的v_i的相关系数设为0，这样注意力就不会关注pad向量)
        #                       dec_inputs只是提供expand的size的
        dec_enc_attn_mask = get_attn_pad_mask(decoder_inputs, encoder_inputs) # [batc_size, tgt_len, src_len]

        dec_self_attns, dec_enc_attns = [], []
        decoder_outputs = pos_embeded
        for layer in self.layers:
            # dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len], dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]
            decoder_outputs, dec_self_attn, dec_enc_attn = layer(decoder_outputs, encoder_outputs, dec_self_attn_mask, dec_enc_attn_mask)
            dec_self_attns.append(dec_self_attn)
            dec_enc_attns.append(dec_enc_attn)
        return decoder_outputs, dec_self_attns, dec_enc_attns

In [116]:
class Transformer(nn.Module):
    def __init__(self, d_model, n_layers, src_vocab_size, tgt_vocab_size, d_k, d_v, n_heads, d_ff):
        super(Transformer, self).__init__()
        self.encoder = Encoder(d_model, n_layers, src_vocab_size, d_k, d_v, n_heads, d_ff)
        self.decoder = Decoder(d_model, n_layers, tgt_vocab_size, d_k, d_v, n_heads, d_ff)
        self.projection = nn.Linear(d_model, tgt_vocab_size, bias=False)
    def forward(self, enc_inputs, dec_inputs):
        '''
        enc_inputs: [batch_size, src_len]
        dec_inputs: [batch_size, tgt_len]
        '''
        # tensor to store decoder outputs
        # outputs = torch.zeros(batch_size, tgt_len, tgt_vocab_size).to(self.device)
        
        # enc_outputs: [batch_size, src_len, d_model], enc_self_attns: [n_layers, batch_size, n_heads, src_len, src_len]
        enc_outputs, enc_self_attns = self.encoder(enc_inputs)
        # dec_outpus: [batch_size, tgt_len, d_model], dec_self_attns: [n_layers, batch_size, n_heads, tgt_len, tgt_len], dec_enc_attn: [n_layers, batch_size, tgt_len, src_len]
        dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs)
        dec_logits = self.projection(dec_outputs) # dec_logits: [batch_size, tgt_len, tgt_vocab_size]
        return dec_logits.view(-1, dec_logits.size(-1)), enc_self_attns, dec_self_attns, dec_enc_attns
    
    def interface(self, enc_input, start_symbol, tgt_eos):
        enc_outputs, enc_self_attns = self.encoder(enc_input)
        dec_input = torch.zeros(1, 0).type_as(enc_input.data)
        terminal = False
        next_symbol = start_symbol
        while not terminal:         
            dec_input = torch.cat([dec_input.detach(),torch.tensor([[next_symbol]],dtype=enc_input.dtype)],-1)
            dec_outputs, _, _ = self.decoder(dec_input, enc_input, enc_outputs)
            projected = self.projection(dec_outputs)
            prob = projected.squeeze(0).max(dim=-1, keepdim=False)[1]

            # 增量更新（我们希望重复单词预测结果是一样的）
            # 我们在预测是会选择性忽略重复的预测的词，只摘取最新预测的单词拼接到输入序列中
            # 拿出当前预测的单词(数字)。我们用x'_t对应的输出z_t去预测下一个单词的概率，不用z_1,z_2..z_{t-1}
            next_word = prob.data[-1]
            next_symbol = next_word
            if next_symbol == tgt_eos:
                terminal = True        

        predict, _, _, _ = self.forward(enc_input, dec_input)
        predict = predict.data.max(1, keepdim=True)[1]
        return predict
            

In [117]:
# Transformer Parameters
# d_model = 512  # Embedding Size
# d_ff = 2048 # FeedForward dimension
# d_k = d_v = 64  # dimension of K(=Q), V
# n_layers = 6  # number of Encoder of Decoder Layer
# n_heads = 8  # number of heads in Multi-Head Attention

d_model = 12  # Embedding Size
d_ff = 24 # FeedForward dimension
d_k = d_v = 6  # dimension of K(=Q), V
n_layers = 2  # number of Encoder of Decoder Layer
n_heads = 3  # number of heads in Multi-Head Attention

model = Transformer(d_model, n_layers, src_vocab_size, tgt_vocab_size, d_k, d_v, n_heads, d_ff)
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.99)

In [118]:
count_parameters(model)

The model has 8,052 trainable parameters


In [119]:
for epoch in range(100):
    for enc_inputs, dec_inputs, dec_outputs in loader:
      '''
      enc_inputs: [batch_size, src_len]
      dec_inputs: [batch_size, tgt_len]
      dec_outputs: [batch_size, tgt_len]
      '''
      enc_inputs, dec_inputs, dec_outputs = enc_inputs, dec_inputs, dec_outputs
      # outputs: [batch_size * tgt_len, tgt_vocab_size]
      outputs, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs)
      loss = criterion(outputs, dec_outputs.view(-1))
      print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

Epoch: 0001 loss = 2.469160
Epoch: 0002 loss = 2.361213
Epoch: 0003 loss = 2.376862
Epoch: 0004 loss = 2.334359
Epoch: 0005 loss = 2.253593
Epoch: 0006 loss = 2.423046
Epoch: 0007 loss = 2.451132
Epoch: 0008 loss = 2.291907
Epoch: 0009 loss = 2.268042
Epoch: 0010 loss = 2.138998
Epoch: 0011 loss = 2.286157
Epoch: 0012 loss = 2.173348
Epoch: 0013 loss = 2.152604
Epoch: 0014 loss = 2.155665
Epoch: 0015 loss = 2.084164
Epoch: 0016 loss = 2.089861
Epoch: 0017 loss = 2.074740
Epoch: 0018 loss = 2.066415
Epoch: 0019 loss = 1.928258
Epoch: 0020 loss = 1.877596
Epoch: 0021 loss = 1.908153
Epoch: 0022 loss = 1.879880
Epoch: 0023 loss = 1.843954
Epoch: 0024 loss = 1.838381
Epoch: 0025 loss = 1.839962
Epoch: 0026 loss = 1.776155
Epoch: 0027 loss = 1.668398
Epoch: 0028 loss = 1.692537
Epoch: 0029 loss = 1.696647
Epoch: 0030 loss = 1.620262
Epoch: 0031 loss = 1.596744
Epoch: 0032 loss = 1.504871
Epoch: 0033 loss = 1.428884
Epoch: 0034 loss = 1.472486
Epoch: 0035 loss = 1.395396
Epoch: 0036 loss = 1

In [120]:
enc_inputs, _, _ = next(iter(loader))
enc_input = enc_inputs[0].view(1, -1)
predict = model.interface(enc_input, start_symbol=tgt_vocab["S"], tgt_eos=tgt_vocab["."])
print(enc_input, '->', [idx2word[n.item()] for n in predict.squeeze()])

tensor([[1, 2, 3, 4, 8, 0]]) -> ['i', 'want', 'a', 'coke', '.']
