In [1]:
import sys; sys.path += ['..', '../src']

Now we should initialize our transformer with learnt embeddings, initialize discriminator and add adversarial loss.
When we are done with that — we are only left with training the thing!

In [None]:
import os

from src.vocab import Vocab
from src.transformer.models import Transformer
from src.models import FFN
from src.utils.data_utils import load_embeddings, init_emb_matrix

DATA_PATH = '../data/generated'
max_len = 200 # TODO: Dostoevsky has much longer sentences

vocab_src = Vocab.from_file(os.path.join(DATA_PATH, 'vocab.en'))
vocab_trg = Vocab.from_file(os.path.join(DATA_PATH, 'vocab.de'))

transformer = Transformer(len(vocab_src), len(vocab_trg), max_len)
discriminator = FFN(512, 3, 1024)

# Initializing transformer encoder and decoder with embeddings
embeddings_src = load_embeddings('../trained_models/wmt17.en.tok.bpe_cbow.vec')
embeddings_trg = load_embeddings('../trained_models/wmt17.de.tok.bpe_cbow.vec')

init_emb_matrix(transformer.encoder.src_word_emb.weight.data, embeddings_src, vocab_src.token2id)
init_emb_matrix(transformer.decoder.tgt_word_emb.weight.data, embeddings_trg, vocab_trg.token2id)

train_src_path = os.path.join(DATA_PATH, 'train.en.tok.bpe')
train_trg_path = os.path.join(DATA_PATH, 'train.de.tok.bpe')
val_src_path = os.path.join(DATA_PATH, 'val.en.tok.bpe')
val_trg_path = os.path.join(DATA_PATH, 'val.de.tok.bpe')

train_src = open(train_src_path, 'r', encoding='utf-8').read().splitlines()
train_trg = open(train_trg_path, 'r', encoding='utf-8').read().splitlines()
val_src = open(val_src_path, 'r', encoding='utf-8').read().splitlines()
val_trg = open(val_trg_path, 'r', encoding='utf-8').read().splitlines()

train_src = [s.split() for s in train_src]
train_trg = [s.split() for s in train_trg]
val_src = [s.split() for s in val_src]
val_trg = [s.split() for s in val_trg]

train_src_idx = [[vocab_src.token2id.get(t, vocab_src.unk) for t in s] for s in train_src]
train_trg_idx = [[vocab_trg.token2id.get(t, vocab_trg.unk) for t in s] for s in train_trg]

And now we should write a training procedure, including backtranslation and noising.
That's not so easy, as it may seem.
Also we should write loss functions and add training visualization.

In [31]:
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.optim import Adam, RMSprop
from tqdm import tqdm

from src.utils.umt_batcher import UMTBatcher
import src.transformer.constants as constants

use_cuda = torch.cuda.is_available()

def reconstruction_criterion(vocab_size):
    ''' With PAD token zero weight '''
    weight = torch.ones(vocab_size)
    weight[constants.PAD] = 0

    return nn.CrossEntropyLoss(weight)


ae_criterion_src = reconstruction_criterion(len(vocab_src))
ae_criterion_trg = reconstruction_criterion(len(vocab_trg))
translation_criterion_src_to_trg = reconstruction_criterion(len(vocab_trg))
translation_criterion_trg_to_src = reconstruction_criterion(len(vocab_src))
adv_criterion = nn.BCELoss()

transformer_optimizer = Adam(transformer.get_trainable_parameters(), lr=3e-4, betas=(0.5, 0.999))
discriminator_optimizer = RMSprop(discriminator.parameters(), lr=5e-4)

Skipping  __BOS__
Skipping  __EOS__
Skipping  __UNK__
Skipping  __PAD__
Skipping  __BOS__
Skipping  __EOS__
Skipping  __UNK__
Skipping  __PAD__


  result = self.forward(*input, **kwargs)


Training discriminator
Computing predictions
Computing losses
Computing gradients
Training translator
Computing back-translations



  

 10%|█         | 1/10 [00:00<00:06,  1.44it/s][A
 20%|██        | 2/10 [00:01<00:07,  1.08it/s][A
 30%|███       | 3/10 [00:03<00:07,  1.06s/it][A
 40%|████      | 4/10 [00:04<00:06,  1.07s/it][A
 50%|█████     | 5/10 [00:05<00:05,  1.11s/it][A
 60%|██████    | 6/10 [00:06<00:04,  1.11s/it][A
 70%|███████   | 7/10 [00:07<00:03,  1.14s/it][A
 80%|████████  | 8/10 [00:09<00:02,  1.17s/it][A
 90%|█████████ | 9/10 [00:10<00:01,  1.18s/it][A
100%|██████████| 10/10 [00:12<00:00,  1.21s/it][A
[A
  0%|          | 0/10 [00:00<?, ?it/s][A
 10%|█         | 1/10 [00:00<00:04,  1.87it/s][A
 20%|██        | 2/10 [00:01<00:05,  1.38it/s][A
 30%|███       | 3/10 [00:02<00:05,  1.30it/s][A
 40%|████      | 4/10 [00:03<00:04,  1.24it/s][A
 50%|█████     | 5/10 [00:04<00:04,  1.15it/s][A
 60%|██████    | 6/10 [00:05<00:03,  1.02it/s][A
 70%|███████   | 7/10 [00:07<00:03,  1.08s/it][A
 80%|████████  | 8/10 [00:10<00:02,  1.26s/it][A
 90%|█████████ | 9/10 [00:12<00:01,  1.36s/it][

Computing predictions (translations of back-translations)
Computing losses
Computing gradients
Updating weights
Training discriminator
Computing predictions
Computing losses
Computing gradients


  "Please ensure they have the same size.".format(target.size(), input.size()))
  "Please ensure they have the same size.".format(target.size(), input.size()))


Updating parameters
Training generator
Computing losses
Computing gradients
Updating parameters


  - (Training)   :   0%|          | 1/907 [01:06<16:41:19, 66.31s/it]

Losses: {'ae_loss_src': 8.957502365112305, 'ae_loss_trg': 9.309548377990723, 'loss_bt_src': 8.96019172668457, 'loss_bt_trg': 9.308553695678711, 'discr_loss_src': 0.6752368807792664, 'discr_loss_trg': 0.7115498781204224, 'gen_loss_src': 0.7113943099975586, 'gen_loss_trg': 0.6750862002372742}
Training discriminator
Computing predictions
Computing losses
Computing gradients
Training translator
Computing back-translations



  0%|          | 0/10 [00:00<?, ?it/s][A
 10%|█         | 1/10 [00:01<00:12,  1.44s/it][A
 20%|██        | 2/10 [00:04<00:16,  2.06s/it][A
 30%|███       | 3/10 [00:06<00:14,  2.08s/it][A
 40%|████      | 4/10 [00:07<00:11,  1.92s/it][A
 50%|█████     | 5/10 [00:09<00:09,  1.86s/it][A
 60%|██████    | 6/10 [00:10<00:06,  1.75s/it][A
 70%|███████   | 7/10 [00:12<00:05,  1.72s/it][A
 80%|████████  | 8/10 [00:14<00:03,  1.80s/it][A
 90%|█████████ | 9/10 [00:16<00:01,  1.89s/it][A
100%|██████████| 10/10 [00:18<00:00,  1.89s/it][A
[A
  0%|          | 0/10 [00:00<?, ?it/s][A
 10%|█         | 1/10 [00:00<00:05,  1.58it/s][A
 20%|██        | 2/10 [00:01<00:06,  1.23it/s][A
 30%|███       | 3/10 [00:02<00:06,  1.16it/s][A
 40%|████      | 4/10 [00:03<00:05,  1.09it/s][A
 50%|█████     | 5/10 [00:04<00:04,  1.00it/s][A
 60%|██████    | 6/10 [00:06<00:04,  1.11s/it][A
 70%|███████   | 7/10 [00:08<00:03,  1.27s/it][A
 80%|████████  | 8/10 [00:11<00:02,  1.42s/it][A
 90%|█████

Computing predictions (translations of back-translations)


KeyboardInterrupt: 

In [None]:
%matplotlib inline

def visualize_losses(losses_history, figsize=(15,15)):
    clear_output(True)
    plt.figure(figsize=[15,15])
    
    plt.subplot(121)
    plt.title("batch loss")
    plt.plot(loss_history)
    plt.plot(ewma(np.array(loss_history),span=50))
    plt.grid()
    
    plt.subplot(122)
    plt.title("disc loss")
    plt.plot(disc_loss_history)
    plt.plot(pd.DataFrame(np.array(disc_loss_history), span=50))
    plt.grid()
    plt.show()

In [None]:
training_data = UMTBatcher(train_src_idx, train_trg_idx, vocab_src, vocab_trg,
                           batch_size=32, shuffle=True)
max_num_epochs = 100
start_bt_from_epoch = 1
num_iters_done = 0
losses_history = []
    
for epoch in range(max_num_epochs):
    for batch in tqdm(training_data, leave=False):
        try:
            should_backtranslate = epoch > start_bt_from_epoch
            losses = training_step(batch, should_backtranslate=should_backtranslate)
            losses_history.append(losses)
            num_iters_done += 1
            
            # Let's visualize some things
            visualize_losses(losses_history)
        except KeyboardInterrupt:
            should_continue = False
            break