# Neural Machine Translation with Transformers

In [7]:
# see http://www.manythings.org/anki

! wget http://www.manythings.org/anki/fra-eng.zip
    
! unzip fra-eng.zip

Archive:  fra-eng.zip
  inflating: _about.txt              
  inflating: fra.txt                 


In [1]:
# Read the data

with open('fra.txt', 'r') as fr:
    
    nmt_data = []
    
    for lines in fr.readlines():
        splits = lines.split('\t')
        
        i = {
            'src': splits[1],
            'tgt': splits[0]
        }
        
        nmt_data.append(i)

In [2]:
import os

folder = 'nmt_vocab'

if not os.path.exists(folder):
    
    os.mkdir(folder)

In [3]:
## Save vocab_file

with open(os.path.join(folder, 'src_vocab.txt'), 'w') as src:
    el = ' '.join([a['src'] for a in nmt_data])
    src.write(el)

with open(os.path.join(folder, 'tgt_vocab.txt'), 'w') as tgt:
    el = ' '.join([a['tgt'] for a in nmt_data])
    tgt.write(el)

In [4]:
nmt_data[10000]

{'src': "J'ai payé en espèce.", 'tgt': 'I paid in cash.'}

### Training tokenizers

In [5]:
import tokenizers
from tokenizers import Tokenizer
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.models import WordLevel
from tokenizers.processors import TemplateProcessing

from tokenizers import normalizers
from tokenizers.normalizers import Lowercase, NFD, StripAccents

In [6]:
def create_tokenizer(vocab_file, vocab_size=30000, single_format='[SOS] $A [EOS]'):
    
    # Instanciate a trainer
    trainer = tokenizers.trainers.WordLevelTrainer(vocab_size=vocab_size, special_tokens=['[PAD]', '[SOS]', '[EOS]', '[UNK]'])
    
    # Instanciate a tokenizer
    tokenizer = Tokenizer(WordLevel(unk_token='[UNK]'))
    
    # Adding pre-tokenizer
    tokenizer.pre_tokenizer = Whitespace()
    
    # Adding normalizers
    tokenizer.normalizer = normalizers.Sequence([Lowercase(), NFD(), StripAccents()])
    
    # Post-Processing
    tokenizer.post_processor = TemplateProcessing(
        single=single_format,
        special_tokens=[
            ("[SOS]", 1), 
            ("[EOS]", 2)
        ]
    )
    
    # Train
    tokenizer.train([vocab_file], trainer)
    
    return tokenizer

In [7]:
src_tokenizer = create_tokenizer(os.path.join(folder, 'src_vocab.txt'), single_format='$A')

tgt_tokenizer = create_tokenizer(os.path.join(folder, 'tgt_vocab.txt'), single_format='[SOS] $A [EOS]')

In [8]:
def switch_mode(tokenizer, max_len=50):
    tokenizer.enable_truncation(max_len)
    tokenizer.enable_padding()

In [9]:
src_tokenizer.encode(nmt_data[10000]['tgt']).tokens

['i', 'paid', 'in', 'cash', '.']

### Creating dataset and dataloaders

In [10]:
import torch
from torch.utils.data import Dataset, DataLoader, random_split

In [11]:
class NMTdata(Dataset):
    
    def __init__(self, data):
        
        self.data = data
        
    def __len__(self):
        
        return len(self.data)
    
    def __getitem__(self, idx):
        
        d = self.data[idx]
        
        # return source and target
        return d['src'], d['tgt']

In [12]:
def collate(batch):
    # collate for batch tokenization
    
    src = [item[0] for item in batch]
    tgt = [item[1] for item in batch]
    
    switch_mode(src_tokenizer)
    switch_mode(tgt_tokenizer)
    
    src = src_tokenizer.encode_batch(src)
    src = torch.LongTensor([i.ids for i in src])
    
    tgt = tgt_tokenizer.encode_batch(tgt)
    tgt = torch.LongTensor([i.ids for i in tgt])
                
    return [src, tgt]

In [13]:
all_dataset = NMTdata(nmt_data)

print(len(all_dataset))

train, val = random_split(all_dataset, [len(all_dataset)-5000, 5000])

train_loader = DataLoader(train, batch_size=128, shuffle=True, collate_fn=collate, num_workers=15)

val_loader = DataLoader(val, batch_size=512, shuffle=False, collate_fn=collate, num_workers=10)

185583


In [14]:
### import modules
import numpy as np
from torch import nn
import torch.nn.functional as F

from tqdm import tqdm

from transformer_utils import PositionEmbedding, get_masks, TransformerEncoder, TransformerDecoder

## Build model

In [27]:
class NMTmodel(nn.Module):
    
    def __init__(self, src_vocab, tgt_vocab, d_model=512, n_head=8, num_layers=2):
        super().__init__()
        
        self.tgt_vocab = tgt_vocab
        
        # dropout for regularization
        self.drop = nn.Dropout(0.25)
        
        # embedding for source sequences
        self.scr_embedding = nn.Embedding(src_vocab, d_model, padding_idx=0)
        
        # embedding for target sequences
        self.tgt_embedding = nn.Embedding(tgt_vocab, d_model, padding_idx=0)
        
        # positional embedding
        self.pos_embedding = PositionEmbedding(200, d_model)
        
        # transformer encoder
        self.encoder = TransformerEncoder(d_model, n_head, num_layers)
        
        # transformer decoder
        self.decoder = TransformerDecoder(d_model, n_head, num_layers)
        
        # fully connected network
        self.fc = nn.Sequential(
            nn.Dropout(0.25),
            nn.Linear(d_model, tgt_vocab)
        )
        
    def encode(self, src):
        
        # src: sourse sequence of shape: [batch_size, src_len]
        
        src_mask, _ = get_masks(src)
        
        src_emb = self.scr_embedding(src)
        
        src_emb = self.pos_embedding(src_emb)
        
        src_emb = self.drop(src_emb)
        
        memory = self.encoder(src_emb, src_mask)
        
        return memory  # [batch_size, src_len, d_model]
    
    def decode(self, tgt, memory):
        
        # tgt: decoder input for teacher forcing [batch_size, tgt_len]
        # memory: encode output [batch_size, src_len, d_model]
        
        tgt_mask, causal = get_masks(tgt)
        
        tgt_mask = tgt_mask * causal # combine causal and padding mask
        
        tgt_emb = self.tgt_embedding(tgt)
        
        tgt_emb = self.pos_embedding(tgt_emb)
        
        tgt_emb = self.drop(tgt_emb)
        
        out = self.decoder(y=tgt_emb, memory=memory, y_mask=tgt_mask)
        
        return self.fc(out) # [batch_size, tgt_len, tgt_vocab_size]
        
    def forward(self, src, tgt):
        
        memory = self.encode(src) # encode src
        
        out = self.decode(tgt, memory) # decoding with teacher forcing
        
        return out
        
    
    def compute_loss(self, x, y):
        
        pred = self.forward(x, y[:, :-1]) # sos, ...
        
        y = y[:, 1:] # ...eos
        
        y = y.reshape(-1)
        
        pred = pred.view(-1, self.tgt_vocab)
        
        loss = F.cross_entropy(pred, y, ignore_index=0)
        
        return loss

In [87]:
# # Dropout regularization

class Dropout(nn.Module):
    
    def __init__(self, p=0.5):
        
        super().__init__()
        
        self.p = p
        
    def forward(self, x):
        
        if self.training:
            
            mask = torch.bernoulli(torch.empty_like(x).uniform_(0, 1))
            
            x = (1/self.p) * x * mask
            
            return x
        
        else:
            
            return x
        
x = torch.rand(4,)

drop = Dropout()

print('x:\n', x)

print('training mode:\n', drop(x))

drop.eval()

print('eval mode:\n', drop(x))


x:
 tensor([0.3050, 0.9295, 0.8356, 0.8977])
training mode:
 tensor([0.0000, 0.0000, 1.6712, 0.0000])
eval mode:
 tensor([0.3050, 0.9295, 0.8356, 0.8977])


In [None]:


    drop = nn.Dropout(p=0.5)
    
    x = drop(x)
    
    for layer in self.transformer_layers:
        if np.random.rand() > 0.5:
            x = layer(x)
            
            
            

In [30]:
def train_one_epoch(net: nn.Module, opt: torch.optim, dataloader: torch.utils.data.DataLoader):
    
    net.train()
    
    for param in net.parameters():
        device = param.device
        break
    
    losses = []
    
    pbar = tqdm(dataloader)
    
    for x, y in pbar:

        net.zero_grad()

        x, y = x.to(device), y.to(device)

        loss = net.compute_loss(x, y)

        loss.backward()

        opt.step()
        
        loss_item = loss.item()
        
        losses.append(loss_item)
        
        pbar.set_description(f'train_loss = {np.array(losses).mean()}')
        
    return np.array(losses).mean()

@torch.no_grad()
def validate(net: nn.Module, dataloader: torch.utils.data.DataLoader):
    
    net.eval()
    
    for param in net.parameters():
        device = param.device
        break
     
    losses = []
    
    for x, y in dataloader:

        x, y = x.to(device), y.to(device)

        loss = net.compute_loss(x, y)
        
        losses.append(loss.item())
                    
    return np.array(losses).mean()

## Training

In [31]:
# model = NMTmodel(src_vocab=src_tokenizer.get_vocab_size(), tgt_vocab=tgt_tokenizer.get_vocab_size(), d_model=512, n_head=8, num_layers=2).cuda()

# opt = torch.optim.AdamW(model.parameters(), lr=1e-4) # original paper use warmup step + decay

In [32]:
model.cuda()

for i in range(7):
    
    if i==0:
        print(validate(model, val_loader))
        
    train_one_epoch(model, opt, train_loader)
    
    print(validate(model, val_loader))

  0%|          | 0/1411 [00:00<?, ?it/s]

9.659289836883545


train_loss = 3.122909759647199: 100%|██████████| 1411/1411 [01:52<00:00, 12.52it/s] 
  0%|          | 0/1411 [00:00<?, ?it/s]

2.1426881313323975


train_loss = 1.9407147846688284: 100%|██████████| 1411/1411 [03:40<00:00,  6.39it/s]
  0%|          | 0/1411 [00:00<?, ?it/s]

1.5852023363113403


train_loss = 1.529400000328886: 100%|██████████| 1411/1411 [04:00<00:00,  5.86it/s] 
  0%|          | 0/1411 [00:00<?, ?it/s]

1.3283339142799377


train_loss = 1.287834175548851: 100%|██████████| 1411/1411 [05:58<00:00,  3.94it/s] 
  0%|          | 0/1411 [00:00<?, ?it/s]

1.175497567653656


train_loss = 1.1204880964325135: 100%|██████████| 1411/1411 [05:45<00:00,  4.08it/s]
  0%|          | 0/1411 [00:00<?, ?it/s]

1.0687022507190704


train_loss = 0.9963250955649929: 100%|██████████| 1411/1411 [03:42<00:00,  6.34it/s]
  0%|          | 0/1411 [00:00<?, ?it/s]

0.9926928997039794


train_loss = 0.8961147483488416: 100%|██████████| 1411/1411 [07:20<00:00,  3.21it/s]


0.9388803243637085


In [None]:
# 30 min of training

In [33]:
model.cpu()
model.eval()

print('ok')

ok


## Prediction

In [61]:
def topk_sampling(logits, k):
    
    logits = logits.squeeze()
    
    topk = torch.topk(logits, k)
    
    probs, indices = torch.softmax(topk.values, dim=0).numpy(), topk.indices.numpy()   
    
    return np.random.choice(indices, p=probs)

In [35]:
@torch.no_grad()
def translate(french):
    
    src = torch.LongTensor(src_tokenizer.encode(french).ids).unsqueeze(0)
        
    y = torch.LongTensor([[1]]) # sos
    
    memory = model.encode(src)
    
    sequences = [1]
    
    while True:
                
        pred = model.decode(y, memory).squeeze(0)[-1]
                        
        pred = topk_sampling(pred, 3) # torch.argmax(pred, -1).item()
        
        if pred == 2: # eos 
            break
        
        sequences.append(pred)
        
        y = torch.LongTensor([sequences])
        
    return tgt_tokenizer.decode(sequences[1:])

In [62]:
example = nmt_data[val.indices[np.random.randint(0, 5000)]]

print('source:', example['src'])

print('target:', example['tgt'])

pred = translate(example['src'])

print('pred:', pred)

source: Tu ferais mieux de prendre conseil auprès d'un médecin.
target: You'd better ask the doctor for advice.
pred: you ' d better take some advice of a doctor ' s advice .


## Beam search decoding

In [63]:
from allennlp.nn.beam_search import BeamSearch, TopPSampler, TopKSampler

In [64]:
# Sample images from validation set

test_src = iter(val_loader).__next__()[0]

# Take 16
test_src = test_src[:64]

In [65]:
test_src[0]

tensor([ 20,  23,   5,  11,  45,  29, 426,   4,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0])

In [66]:
# create a batch of <sos> tokens

sos_tokens = torch.ones((test_src.size()[0], 1)).long()

In [67]:
# Encode the sequences

with torch.no_grad():
    memory = model.encode(test_src)

In [68]:
# Instanciate beam search

bs = BeamSearch(end_index=2, beam_size=10, max_steps=20, sampler=TopPSampler(p=0.9))

In [69]:
# the code is ugly but it works ...

def next_step(last_pred, states):
    
    # unsqueeze the second dimension
    if len(last_pred.size()) == 1:
        last_pred = last_pred.unsqueeze(1)
        
    # extract    
    memory = states['memory']
        
    y = states['sequences']
    
    y = torch.cat([y, last_pred], dim=1)
                
    # prediction for last token
    pred = model.decode(y, memory)[:, -1, :]
            
    states['sequences'] =  y
            
    return F.log_softmax(pred, dim=-1), states

In [70]:
states = {'memory': memory, 'sequences': torch.LongTensor([])}

out_beam, log_probs = bs.search(sos_tokens, states, next_step)

In [None]:
idx = np.random.randint(0, 64)

print('source:', src_tokenizer.decode(test_src[idx].numpy()))

print('prediction:')

for s, p in zip(out_beam[idx], log_probs[idx]):
    
    print(f' - {tgt_tokenizer.decode(s.numpy())} ==> prob = {torch.exp(p).item(): 0.3f}')

In [196]:
fr_emb = model.tgt_embedding.weight.data

In [197]:
fr_emb

tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.5883, -0.0920,  0.0062,  ..., -0.4923, -1.2821,  0.7908],
        [ 0.5635,  0.9206, -0.4681,  ..., -2.3200,  0.0091,  1.5956],
        ...,
        [-1.2475,  0.2879,  1.7130,  ...,  0.7948,  0.4980,  1.3692],
        [ 0.0163, -0.0325,  0.2305,  ..., -0.7854,  1.9704, -0.4423],
        [-0.1573, -0.9495, -0.3234,  ...,  1.0498, -0.6279,  0.5638]])

In [204]:
vocab = tgt_tokenizer.get_vocab()

In [205]:
word2id = dict(sorted(vocab.items(), key=lambda x: x[1]))
id2word = {v:k for k, v in word2id.items()}

In [206]:
id2word = {v:k for k, v in word2id.items()}

In [207]:
def get_word_emb(word):
    
    idx = word2id[word]
    
    return fr_emb[idx]

In [213]:
torch.sort(torch.cosine_similarity(get_word_emb('love').view(1, -1), fr_emb))

torch.return_types.sort(
values=tensor([-0.1581, -0.1551, -0.1487,  ...,  0.1705,  0.1904,  1.0000]),
indices=tensor([4697, 9216,  314,  ..., 7115, 2923,  144]))

In [212]:
id2word[7115]

'fasting'