# "NMT by Jointly Learning to Align and Translate" Implementation

original paper: https://arxiv.org/abs/1409.0473

references
* arichitecture picture: https://arxiv.org/pdf/1703.03906.pdf
* tutorial: https://github.com/spro/practical-pytorch/blob/master/seq2seq-translation/seq2seq-translation-batched.ipynb
* data source: http://www.statmt.org/wmt14/translation-task.html

In [1]:
# torch
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchtext.data import Field, Iterator, BucketIterator, TabularDataset
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
# others
import unicodedata
import re
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

In [2]:
USE_CUDA = torch.cuda.is_available()
DEVICE = torch.cuda.current_device()

## Prepare data

use 50000 samples for this project

In [3]:
train_path = 'en_de_50000_train.txt'
valid_path = 'en_de_50000_valid.txt'
test_path = 'en_de_50000_test.txt'
BATCH_SIZE = 32

In [4]:
SOURCE = Field(tokenize=str.split, use_vocab=True, init_token="<s>", eos_token="</s>", lower=True, 
               include_lengths=True, batch_first=True)
TARGET = Field(tokenize=str.split, use_vocab=True, init_token="<s>", eos_token="</s>", lower=True, 
               batch_first=True)

In [5]:
train_data, valid_data, test_data = \
    TabularDataset.splits(path='data/en_de/', train=train_path, validation=valid_path, test=test_path, 
                          format='tsv', fields=[('so',SOURCE), ('ta',TARGET)])

In [6]:
SOURCE.build_vocab(train_data)
TARGET.build_vocab(train_data)

In [7]:
train_loader = BucketIterator(train_data, batch_size=BATCH_SIZE, device=DEVICE,
                              sort_key=lambda x: len(x.so), sort_within_batch=True, repeat=False)
valid_loader = BucketIterator(valid_data, batch_size=BATCH_SIZE, device=DEVICE,
                              sort_key=lambda x: len(x.so), sort_within_batch=True, repeat=False)
test_loader = BucketIterator(test_data, batch_size=BATCH_SIZE, device=DEVICE,
                              sort_key=lambda x: len(x.so), sort_within_batch=True, repeat=False)

In [8]:
for batch in train_loader:
    break

In [9]:
SOURCE.vocab.itos[1]

'<pad>'

In [10]:
batch.so[0]

tensor([[     2,      5,   2729,  ...,    881,      6,      3],
        [     2,     12,     24,  ...,    636,      6,      3],
        [     2,     14,     56,  ...,     24,      6,      3],
        ...,
        [     2,     12,   2321,  ...,      6,      3,      1],
        [     2,     94,    147,  ...,      6,      3,      1],
        [     2,    166,    435,  ...,      6,      3,      1]], device='cuda:0')

## architecture

![](./pics/encoder_decoder_att.png)

## Parameters

In [11]:
V_so = len(SOURCE.vocab)
V_ta = len(TARGET.vocab)
E = 10
H = 7

In [12]:
V_so, V_ta

(57486, 26356)

In [13]:
embed = nn.Embedding(V_so, E).cuda()
rnn = nn.GRU(E, H, 3, batch_first=True, bidirectional=True).cuda()

In [14]:
inputs, lengths = batch.so

In [15]:
inputs.size()

torch.Size([32, 35])

In [16]:
embeded = embed(inputs)

In [17]:
embeded.size()

torch.Size([32, 35, 10])

In [18]:
packed = pack_padded_sequence(embeded, lengths.tolist(), batch_first=True)

In [19]:
packed.data.size()

torch.Size([1115, 10])

In [20]:
outputs, hiddens = rnn(packed)

In [21]:
outputs

PackedSequence(data=tensor([[ 0.2349,  0.0056, -0.1244,  ...,  0.0629,  0.6755,  0.1037],
        [ 0.1816, -0.0245, -0.0812,  ...,  0.1382,  0.6588,  0.1471],
        [ 0.2105,  0.0483, -0.1101,  ...,  0.1284,  0.6618,  0.1982],
        ...,
        [ 0.5814,  0.3795,  0.0413,  ...,  0.0361,  0.2213,  0.0752],
        [ 0.5160,  0.3953,  0.0727,  ..., -0.0105,  0.2171,  0.0390],
        [ 0.4180,  0.3196, -0.0016,  ..., -0.0265,  0.2129,  0.0106]], device='cuda:0'), batch_sizes=tensor([ 32,  32,  32,  32,  32,  32,  32,  32,  32,  32,  32,  32,
         32,  32,  32,  32,  32,  32,  32,  32,  32,  32,  32,  32,
         32,  32,  32,  32,  32,  32,  32,  32,  32,  32,  27]))

In [22]:
outputs, output_lengths = pad_packed_sequence(outputs, batch_first=True)

In [23]:
outputs.size()

torch.Size([32, 35, 14])

In [24]:
hiddens = torch.cat([h for h in hiddens[-2:]], 1).unsqueeze(1)
hiddens.size()

torch.Size([32, 1, 14])

---

In [26]:
embed2 = nn.Embedding(V_ta, E).cuda()
start = torch.LongTensor([0]*BATCH_SIZE).unsqueeze(1).cuda()

In [27]:
embeded2 = embed2(start)

In [28]:
embeded2.size(), hiddens.size(), outputs.size()

(torch.Size([32, 1, 10]), torch.Size([32, 1, 14]), torch.Size([32, 35, 14]))

In [29]:
H, E

(7, 10)

In [30]:
new_H = 8

In [75]:
# paper version
attn_cat = nn.Linear(2*2*H, new_H).cuda()
v_a = nn.Parameter(torch.FloatTensor(1, new_H)).cuda()

In [76]:
a = torch.cat([hiddens.repeat(1, outputs.size(1), 1), outputs], 2)
a.size()

torch.Size([32, 35, 28])

In [77]:
attn_cat(a).size(), v_a.size()

(torch.Size([32, 35, 8]), torch.Size([1, 8]))

In [78]:
new_v_a = v_a.repeat(outputs.size(0), 1).unsqueeze(1)
new_v_a.size()

torch.Size([32, 1, 8])

In [83]:
# result: B, 1, T_x
x = new_v_a.bmm(attn_cat(a).transpose(1, 2))
x.size()

torch.Size([32, 1, 35])

In [103]:
# masking
mask = x.data.new(outputs.size(0), hiddens.size(1), outputs.size(1))
mask.size()

torch.Size([32, 1, 35])

In [104]:
mask[-1]

tensor([[-0.4154, -0.2518, -0.1025, -0.0876,  0.1673,  0.6210,  0.2234,
          0.5909,  0.3567, -0.0082,  0.3077, -0.2060,  0.1237, -0.1775,
         -0.4154, -0.2518, -0.1025, -0.0876,  0.1673,  0.6210,  0.2234,
          0.5909,  0.3567, -0.0082,  0.3077, -0.2060,  0.1237, -0.1775,
         -0.4154, -0.2518, -0.1025, -0.0876,  0.1673,  0.6210,  0.2234]], device='cuda:0')

In [105]:
v_unmask = 0
v_mask = float('-inf')

In [106]:
sizes = output_lengths.tolist()

In [107]:
mask.fill_(v_unmask)
n_context = mask.size(2)
for i, size in enumerate(sizes):
    if size < n_context:
        mask[i,:,size:] = v_mask

In [108]:
mask[-1]

tensor([[  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
           0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
           0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
           0.,   0.,   0.,   0., -inf.]], device='cuda:0')

In [110]:
x += mask

In [112]:
x[-1]

tensor([[ 9.2128e+14,  8.6409e+14,  7.5001e+14,  6.7749e+14,  7.3942e+14,
          7.8272e+14,  7.5244e+14,  6.5769e+14,  5.5612e+14,  4.6153e+14,
          3.7618e+14,  3.1627e+14,  2.4388e+14,  2.7045e+14,  2.9572e+14,
          3.3994e+14,  4.4921e+14,  4.3587e+14,  4.5809e+14,  4.3460e+14,
          4.4689e+14,  4.6073e+14,  4.3762e+14,  3.5159e+14,  2.7960e+14,
          2.0945e+14,  1.2590e+14,  6.7760e+13,  1.2335e+14,  2.7290e+14,
          3.9272e+14,  4.5216e+14,  5.5895e+14,  5.8419e+14,        -inf]], device='cuda:0')

In [113]:
weights = F.softmax(x, 2)
weights.size()

torch.Size([32, 1, 35])

In [115]:
weights[-1]

tensor([[ 1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]], device='cuda:0')

In [117]:
weights.bmm(outputs).size()

torch.Size([32, 1, 14])

In [68]:
# cat version
attn = nn.Linear(2*2*H, 2*H).cuda()
v_a = nn.Parameter(torch.FloatTensor(1, 2*H)).cuda()

In [69]:
a = torch.cat([hiddens.repeat(1, outputs.size(1), 1), outputs], 2)
a.size()

torch.Size([32, 35, 28])

In [70]:
v_a.repeat(outputs.size(0), 1).unsqueeze(1).size()

torch.Size([32, 1, 14])

In [71]:
attn(a).size()

torch.Size([32, 35, 14])

In [73]:
F.tanh(attn(a)).size()

torch.Size([32, 35, 14])

In [55]:
v_a.repeat(outputs.size(0), 1).unsqueeze(1).bmm(attn(a).transpose(1,2)).size()

torch.Size([32, 1, 35])

In [60]:
# general version
attn = nn.Linear(2*H, 2*H).cuda()

In [61]:
outputs.size(), hiddens.size()

(torch.Size([32, 35, 14]), torch.Size([32, 1, 14]))

In [63]:
e = attn(outputs)
e.size()

torch.Size([32, 35, 14])

In [64]:
hiddens.bmm(outputs.transpose(1, 2)).size()

torch.Size([32, 1, 35])

In [48]:
# dot
hiddens.bmm(outputs.transpose(1, 2)).size()

torch.Size([32, 1, 35])

## Encoder

In [122]:
class Encoder(nn.Module):
    def __init__(self, V_e, m_e, n_e, num_layers=1, bidrec=False):
        super(Encoder, self).__init__()
        """
        vocab_size: V_e
        embed_size: m_e
        hidden_size: n_e
        """
        self.V_e = V_e
        self.m_e = m_e
        self.n_e = n_e
        self.num_layers = num_layers
        self.bidrec = bidrec
        self.n_direct = 2 if bidrec else 1
        
        self.embed = nn.Embedding(V_e, m_e) 
        self.gru = nn.GRU(m_e, n_e, num_layers, batch_first=True, bidirectional=bidrec)
        
    def forward(self, inputs, lengths):
        """
        input: 
        - inputs: B, T_x
        - lengths: actual max length of batches
        output:
        - outputs: B, T_x, n_e
        """
        # embeded: (B, T_x, n_e)
        embeded = self.embed(inputs) 
        # packed: (B*T_x, n_e)
        packed = pack_padded_sequence(embeded, lengths.tolist(), batch_first=True) 
        # packed outputs: (B*T_x, 2*n_e)
        # hidden: (num of layers*n_direct, B, 2*n_e)
        outputs, hidden = self.gru(packed)
        # unpacked outputs: (B, T_x, 2*n_e)
        outputs, output_lengths = pad_packed_sequence(outputs, batch_first=True)
        
        # hidden bidirection: (num of layers*n_direct(0,1,2...last one), B, n_e)
        # choosen last hidden: (B, 1, 2*n_e)
        hidden = torch.cat((h for h in hidden[-self.n_direct:]), 1).unsqueeze(1)
        
        return outputs, hidden

## Attention

In [119]:
class Attention(nn.Module):
    def __init__(self, hidden_size, hidden_size2=None, method='general'):
        super(Attn, self).__init__()
        """
        hidden_size: set hidden size same as decoder hidden size which is n_d (= 2*n_e)
        hidden_size2: only for concat method, if none then is same as hidden_size (n_d)
        (in paper notation is n', https://arxiv.org/abs/1409.0473)
        methods:
        - 'dot': dot product between hidden and encoder_outputs
        - 'general': encoder_outputs through a linear layer 
        - 'concat': concat (hidden, encoder_outputs)
        - 'paper': concat + tanh
        """
        self.method = method
        self.hidden_size = hidden_size 
        self.hidden_size2 = hidden_size2 if hidden_size2 else hidden_size
        
        if self.method == 'general':
            self.attn = nn.Linear(self.hidden_size, hidden_size) 
            # linear_weight shape: (out_f, in_f)
            # linear input: (B, *, in_f)
            # linear output: (B, *, out_f)

        elif self.method == 'concat':
            self.attn = nn.Linear(self.hidden_size*2, self.hidden_size2)
            self.v = nn.Parameter(torch.FloatTensor(1, self.hidden_size2))
        
        
    def forward(self, hidden, encoder_outputs, encoder_lengths=None, return_weight=False):
        """
        input:
        - hidden, previous hidden(= H): B, 1, n_d 
        - encoder_outputs, source context(= O): B, T_x, n_d
        - encoder_lengths: real lengths of encoder outputs
        - return_weight = return weights(alphas)
        output:
        - attentioned_hidden(= z): B, 1, n_d
        - weights(= w): B, 1, T_x
        """
        H, O = hidden, encoder_outputs
        # Batch(B), Seq_length(T), dimemsion(n)
        B_H, T_H, n_H = H.size()
        B_O, T_O, n_O = O.size()
        
        if B_H != B_O:
            msg = "Batch size is not correct, H: {} O: {}".format(H.size(), O.size())
            raise ValueError(msg)
        else:
            B = B_H
        
        # score: (B, 1, T_x)
        s = self.score(H, O) 
        
        # encoding masking
        if encoder_lengths is not None:
            mask = s.data.new(B, T_H, T_O) # (B, 1, T_x)
            mask = self.fill_context_mask(mask, sizes=encoder_lengths, v_mask=float('-inf'), v_unmask=0)
            s += mask
        
        # softmax: (B, 1, T_x)
        w = F.softmax(s, 2) 
        
        # attention: weight * encoder_hiddens, (B, 1, T_x) * (B, T_x, n_d) = (B, 1, n_d)
        z = w.bmm(c)
        if return_weight:
            return z, w
        return z
    
    def score(self, H, O):
        """
        inputs:
        - hiddden, previous hidden(= H): B, 1, n_d 
        - encoder_outputs, source context(= O): B, T_x, n_d
        """
        if self.method == 'dot':
            # bmm: (B, 1, n_d) * (B, n_d, T_x) = (B, 1, T_x)
            e = H.bmm(O.transpose(1, 2))
            return e
        
        elif self.method == 'general':
            # attn: (B, T_x, n_d) > (B, T_x, n_d)
            # bmm: (B, 1, n_d) * (B, n_d, T_x) = (B, 1, T_x)
            e = self.attn(O)
            e = H.bmm(e.transpose(1, 2))
            return e
        
        elif self.method == 'concat':
            # H repeat: (B, 1, n_d) > (B, T_x, n_d)
            # cat: (B, T_x, 2*n_d)
            # attn: (B, T_x, 2*n_d) > (B, T_x, n_d)
            # v repeat: (1, n_d) > (B, 1, n_d)
            # bmm: (B, 1, n_d) * (B, n_d, T_x) = (B, 1, T_x)
            cat = torch.cat((H.repeat(1, O.size(1), 1), O), 2)
            e = self.attn(cat)
            v = self.v.repeat(O.size(0), 1).unsqueeze(1)
            e = v.bmm(e.transpose(1, 2))
            return e
        
        elif self.method == 'paper':
            # add tanh after attention linear layer in 'concat' method
            cat = torch.cat((H.repeat(1, O.size(1), 1), O), 2)
            e = F.tanh(self.attn(cat))
            v = self.v.repeat(O.size(0), 1).unsqueeze(1)
            e = v.bmm(e.transpose(1, 2))
            return e
    
    def fill_context_mask(self, mask, sizes, v_mask, v_unmask):
        """Fill attention mask inplace for a variable length context.
        Args
        ----
        mask: Tensor of size (B, T, D)
            Tensor to fill with mask values. 
        sizes: list[int]
            List giving the size of the context for each item in
            the batch. Positions beyond each size will be masked.
        v_mask: float
            Value to use for masked positions.
        v_unmask: float
            Value to use for unmasked positions.
        Returns
        -------
        mask:
            Filled with values in {v_mask, v_unmask}
        """
        mask.fill_(v_unmask)
        n_context = mask.size(2)
        for i, size in enumerate(sizes):
            if size < n_context:
                mask[i,:,size:] = v_mask
        return mask

## Decoder

In [None]:
class Decoder(nn.Module):
    def __init__(self, V_d, m_d, n_d, sos_idx, num_layers=1, hidden_size2=None, method='general'):
        super(Decoder, self).__init__()
        """
        vocab_size: V_d
        embed_size: m_d
        hidden_size: n_d (set this value as 2*n_e)
        methods:
        - 'dot': dot product between hidden and encoder_outputs
        - 'general': encoder_outputs through a linear layer 
        - 'concat': concat (hidden, encoder_outputs)
        - 'paper': concat + tanh
        """
        self.V_d = V_d
        self.m_d = m_d
        self.n_d = n_d
        self.sos_idx = sos_idx
        self.num_layers = num_layers
        
        self.attention = Attention(hidden_size=n_d, hidden_size2=hidden_size2, method=method)
        
        self.embed = nn.Embedding(V_d, embed_size)
        # gru(W*[embed, context] + U*[hidden_prev])
        # gru: m+n
        self.gru = nn.GRU(m_d+n_d, n_d, num_layers, batch_first=True, bidirectional=False) 
        
    def start_token(self, batch_size):
        sos = torch.LongTensor([self.sos_idx]*batch_size).unsqueeze(1)
        if USE_CUDA: sos = sos.cuda()
        return sos        
    
    def forward(self, hidden, encoder_outputs, encoder_outputs_lengths, max_len):
        """
        input:
        - hidden, previous hidden: B, 1, n_d 
        - encoder_outputs, source context: B, T_x, n_d
        start_token: B, 1
        """
        inputs = self.start_token(hidden.size(0)) # (B, 1)
        embed = self.embed(inputs) # (B, 1, m_d)
        
        embed
        
            
        

In [None]:
# Turn a Unicode string to plain ASCII, thanks to http://stackoverflow.com/a/518232/2809427
def unicode_to_ascii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

# Lowercase, trim, and remove non-letter characters
def normalize_string(s):
    s = unicode_to_ascii(s.lower().strip())
    s = re.sub(r"([,.!?])", r" \1 ", s)
    s = re.sub(r"[^a-zA-Z,.!?]+", r" ", s)
    s = re.sub(r"\s+", r" ", s).strip()
    return s

def read_files(path, max_len=None, min_len=None, n_sentences=None):
    source = []
    target = []
    with open(path, 'r', encoding='utf-8') as file:
        for l in file.readlines():
            so, ta = l.split('\t')
            normed_so = normalize_string(so.strip()).split()
            normed_ta = normalize_string(ta.strip()).split() 
            if len(normed_so) >= min_len and len(normed_so) <= max_len and \
               len(normed_ta) >= min_len and len(normed_ta) <= max_len:
                source.append(normed_so)
                target.append(normed_ta)
        if n_sentences:
            source = source[:n_sentences]
            target = target[:n_sentences]
    return source, target