Here we should make sure that a single Transformer model can learn to translate both src->trg and trg->src.

In [1]:
import sys; sys.path += ['..', '../src']

In [2]:
import os
from src.vocab import Vocab

DATA_PATH = '../data/generated'
max_len = 50 # Processing long sentences is slow

train_src_path = os.path.join(DATA_PATH, 'train.en.tok.bpe')
train_trg_path = os.path.join(DATA_PATH, 'train.de.tok.bpe')
val_src_path = os.path.join(DATA_PATH, 'val.en.tok.bpe')
val_trg_path = os.path.join(DATA_PATH, 'val.de.tok.bpe')

train_src = open(train_src_path, 'r', encoding='utf-8').read().splitlines()
train_trg = open(train_trg_path, 'r', encoding='utf-8').read().splitlines()
val_src = open(val_src_path, 'r', encoding='utf-8').read().splitlines()
val_trg = open(val_trg_path, 'r', encoding='utf-8').read().splitlines()

train_src = [s.split()[:max_len-2] for s in train_src]
train_trg = [s.split()[:max_len-2] for s in train_trg]
val_src = [s.split()[:max_len-2] for s in val_src]
val_trg = [s.split()[:max_len-2] for s in val_trg]

vocab_src = Vocab.from_file(os.path.join(DATA_PATH, 'vocab.en'))
vocab_trg = Vocab.from_file(os.path.join(DATA_PATH, 'vocab.de'))

train_src_idx = [[vocab_src.token2id.get(t, vocab_src.unk) for t in s] for s in train_src]
train_trg_idx = [[vocab_trg.token2id.get(t, vocab_trg.unk) for t in s] for s in train_trg]
val_src_idx = [[vocab_src.token2id.get(t, vocab_src.unk) for t in s] for s in val_src]
val_trg_idx = [[vocab_trg.token2id.get(t, vocab_trg.unk) for t in s] for s in val_trg]

In [3]:
from src.transformer.models import Transformer

model = Transformer(
    len(vocab_src),
    len(vocab_trg),
    max_len,
    n_layers=6,
    n_head=8,
    d_word_vec=512,
    d_model=512,
    d_inner_hid=2048,
    d_k=64,
    d_v=64)

In [4]:
import numpy as np
from IPython.display import clear_output
import matplotlib.pyplot as plt
import pandas as pd
%matplotlib inline

import torch
import torch.nn as nn
from torch.optim import Adam
from tqdm import tqdm
tqdm.monitor_interval = 0

from src.utils.data_utils import Batcher
import src.transformer.constants as constants

use_cuda = torch.cuda.is_available()

def get_criterion(vocab_size):
    ''' With PAD token zero weight '''
    weight = torch.ones(vocab_size)
    weight[constants.PAD] = 0

    return nn.CrossEntropyLoss(weight)

optimizer = Adam(model.get_trainable_parameters(), lr=1e-4, betas=(0.9, 0.98))
criterion_src_to_trg = get_criterion(len(vocab_trg))
criterion_trg_to_src = get_criterion(len(vocab_src))
training_data = Batcher(train_src_idx, train_trg_idx, vocab_src.token2id,
                        vocab_trg.token2id, batch_size=16, shuffle=True)
val_data = Batcher(val_src_idx, val_trg_idx, vocab_src.token2id,
                   vocab_trg.token2id, batch_size=16, shuffle=True)

if use_cuda:
    model.cuda()
    criterion_src_to_trg.cuda()
    criterion_trg_to_src.cuda()
    
model.train()
loss_src_to_trg_history = []
loss_trg_to_src_history = []
val_loss_src_to_trg_history = []
val_loss_trg_to_src_history = []
val_loss_iters = []
num_iters_done = 0
max_num_epochs = 10

for i in range(max_num_epochs):
    for batch in tqdm(training_data, leave=False):
        # prepare data
        src, trg = batch

        # forward
        optimizer.zero_grad()
        pred_trg = model(src, trg)
        pred_src = model(trg, src, use_trg_embs_in_encoder=True, use_src_embs_in_decoder=True)

        # backward
        loss_src_to_trg = criterion_src_to_trg(pred_trg, trg[:, 1:].contiguous().view(-1))
        loss_trg_to_src = criterion_trg_to_src(pred_src, src[:, 1:].contiguous().view(-1))
        
        loss_src_to_trg.backward()
        loss_trg_to_src.backward()

        # update parameters
        optimizer.step()

        loss_src_to_trg_history.append(loss_src_to_trg.data[0])
        loss_trg_to_src_history.append(loss_trg_to_src.data[0])

        if num_iters_done % 50 == 0:
            clear_output(True)
            
            plt.figure(figsize=[16,8])
            
            plt.subplot(121)
            plt.title("[src -> trg] loss")
            plt.plot(loss_src_to_trg_history)
            plt.plot(pd.DataFrame(np.array(loss_src_to_trg_history)).ewm(span=50).mean())
            plt.plot(val_loss_iters, val_loss_trg_to_src_history)
            plt.grid()
            
            plt.subplot(122)
            plt.title("[trg -> src] loss")
            plt.plot(loss_trg_to_src_history)
            plt.plot(pd.DataFrame(np.array(loss_trg_to_src_history)).ewm(span=50).mean())
            plt.plot(val_loss_iters, val_loss_trg_to_src_history)
            plt.grid()
            
            plt.show()
            
        if num_iters_done % 100 == 0:
            val_losses_src_to_trg = []
            val_losses_trg_to_src = []
            
            for val_batch in val_data:
                val_src, val_trg = val_batch

                val_pred_trg = model(val_src, val_trg)
                val_pred_src = model(val_trg, val_src, use_trg_embs_in_encoder=True, use_src_embs_in_decoder=True)
                
                val_loss_src_to_trg = criterion_src_to_trg(pred_trg, trg[:, 1:].contiguous().view(-1))
                val_loss_trg_to_src = criterion_trg_to_src(pred_src, src[:, 1:].contiguous().view(-1))

                val_losses_src_to_trg.append(val_loss_src_to_trg.data[0])
                val_losses_trg_to_src.append(val_loss_trg_to_src.data[0])

            val_loss_src_to_trg_history.append(np.mean(val_losses_src_to_trg))
            val_loss_trg_to_src_history.append(np.mean(val_losses_trg_to_src))
            val_loss_iters.append(num_iters_done)
            
        num_iters_done += 1

                                                 

KeyboardInterrupt: 

In [10]:
import numpy as np
from src.utils.bleu import compute_bleu_for_sents

model.eval()

val_data = Batcher(val_src_idx, val_trg_idx, vocab_src.token2id,
                   vocab_trg.token2id, batch_size=16, shuffle=True)

bleus_src_to_trg = []
bleus_trg_to_src = []

for test_batch in val_data:
    translations_src_to_trg = model.translate_batch(test_batch[0], max_len=max_len, beam_size=4)
    translations_trg_to_src = model.translate_batch(test_batch[1], max_len=max_len,
                                                    beam_size=4, use_src_embs_in_decoder=True,
                                                    use_trg_embs_in_encoder=True)
    
    translations_src_to_trg = [vocab_trg.remove_bpe(vocab_trg.detokenize(t)) for t in translations_src_to_trg]
    translations_trg_to_src = [vocab_src.remove_bpe(vocab_src.detokenize(t)) for t in translations_trg_to_src]

    targets_src_to_trg = [vocab_trg.remove_bpe(vocab_trg.detokenize(s)) for s in test_batch[1].data]
    targets_trg_to_src = [vocab_src.remove_bpe(vocab_src.detokenize(s)) for s in test_batch[0].data]

    translations_src_to_trg = [' '.join(t.split()[:-1]) for t in translations_src_to_trg]
    translations_trg_to_src = [' '.join(t.split()[:-1]) for t in translations_trg_to_src]
    targets_src_to_trg = [' '.join(t.split()[1:-1]) for t in targets_src_to_trg]
    targets_trg_to_src = [' '.join(t.split()[1:-1]) for t in targets_trg_to_src]
    
    bleus_src_to_trg.append(compute_bleu_for_sents(translations_src_to_trg, targets_src_to_trg))
    bleus_trg_to_src.append(compute_bleu_for_sents(translations_trg_to_src, targets_trg_to_src))
    
print('BLEU [src->trg]:', np.mean(bleus_src_to_trg))
print('BLEU [trg->src]:', np.mean(bleus_trg_to_src))

 34%|███▍      | 17/50 [00:00<00:01, 26.76it/s]
 40%|████      | 20/50 [00:00<00:01, 26.56it/s]
 48%|████▊     | 24/50 [00:00<00:00, 29.55it/s]
 62%|██████▏   | 31/50 [00:01<00:00, 28.83it/s]
 40%|████      | 20/50 [00:00<00:01, 29.79it/s]
 32%|███▏      | 16/50 [00:00<00:01, 26.82it/s]
 34%|███▍      | 17/50 [00:00<00:01, 28.93it/s]
 40%|████      | 20/50 [00:00<00:01, 27.64it/s]
 46%|████▌     | 23/50 [00:00<00:00, 27.91it/s]
 40%|████      | 20/50 [00:00<00:01, 27.10it/s]
 34%|███▍      | 17/50 [00:00<00:01, 24.27it/s]
 32%|███▏      | 16/50 [00:00<00:01, 25.53it/s]
 34%|███▍      | 17/50 [00:00<00:01, 28.66it/s]
 32%|███▏      | 16/50 [00:00<00:01, 25.48it/s]
 34%|███▍      | 17/50 [00:00<00:01, 30.50it/s]
 40%|████      | 20/50 [00:00<00:01, 28.08it/s]
 38%|███▊      | 19/50 [00:00<00:00, 31.54it/s]
 34%|███▍      | 17/50 [00:00<00:01, 28.52it/s]
 48%|████▊     | 24/50 [00:00<00:01, 24.42it/s]
 40%|████      | 20/50 [00:00<00:01, 22.69it/s]
 30%|███       | 15/50 [00:00<00:01, 22.

BLEU [src->trg]: 0.24522601492894625
BLEU [trg->src]: 0.2726236788269359



