Машинный перевод. Seq2Seq + Attention

In [0]:
from google.colab import drive
drive.mount('/content/gdrive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/gdrive


In [0]:
from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import string
import re
import random

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as func

In [0]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [0]:
SOS = 0
EOS = 1

In [0]:
class Language:
  def __init__(self, name):
    self.name = name
    self.word2index = {'SOS' : 0, 'EOS' : 1}
    self.word2count = {}
    self.index2word = {0 : 'SOS', 1 : 'EOS'}
    self.words_num = 2
  
  def new_sent(self, sent):
    for word in sent.split(' '):
      self.new_word(word)
    
  def new_word(self, word):
    if word not in self.word2index:
      self.word2index[word] = self.words_num
      self.word2count[word] = 1
      self.index2word[self.words_num] = word
      self.words_num += 1
    else:
      self.word2count[word] += 1

In [0]:
def uni2asc(s):
    return ''.join(c for c in unicodedata.normalize('NFD', s) if unicodedata.category(c) != 'Mn')

In [0]:
def norm(s):
    s = uni2asc(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    return s

In [0]:
def read_lang(lang):
  lines = open('/content/gdrive/My Drive/Data/Translations/eng-%s.txt' % (lang), encoding='utf-8').read().strip().split('\n')
  pairs = []
  for line in lines:
    sents = line.split('\t')
    pairs.append([norm(sents[1]), norm(sents[0])])
  return Language(lang), Language('eng'), pairs

In [0]:
MAX_LENGTH = 10

In [0]:
eng_prefixes = ('i m ', 'i am ', 'you re ', 'you are ', 'he s ', 'he is ', 'she s ', 'she is ', 'it s ', 'it is ', 'we re ', 'we are ', 'they re ', 'they are ')

In [0]:
def is_accepted(pair):
  return len(pair[0].split(' ')) < MAX_LENGTH and len(pair[1].split(' ')) < MAX_LENGTH and pair[1].startswith(eng_prefixes)

In [0]:
def filter_pairs(pairs):
  return [pair for pair in pairs if is_accepted(pair)]

In [0]:
def prepare_data(lang):
  SRC, TRG, PAIRS = read_lang(lang)
  print('Before: ', len(PAIRS))
  PAIRS = filter_pairs(PAIRS)
  print('After: ', len(PAIRS))
  for pair in PAIRS:
    SRC.new_sent(pair[0])
    TRG.new_sent(pair[1])
  print('SRC: ', SRC.words_num)
  print('TRG: ', TRG.words_num)
  return SRC, TRG, PAIRS

In [0]:
SRC, TRG, PAIRS = prepare_data('deu')
print(random.choice(PAIRS))

Before:  208486
After:  13643
SRC:  5604
TRG:  3570
['es geht ihm besser als zuvor .', 'he is better off than before .']


In [0]:
class Encoder(nn.Module):
    def __init__(self, inp_dim, hid_dim, num_lay):
        super(Encoder, self).__init__()
        self.hid_dim = hid_dim

        self.inp2emb = nn.Embedding(inp_dim, hid_dim)
        self.emb2hid = nn.GRU(hid_dim, hid_dim, num_lay)

    def forward(self, inp):
        emb = self.inp2emb(inp).view(1, 1, -1)
        hid, h = self.emb2hid(emb)
        return hid, h

In [0]:
class DecoderUPD(nn.Module):
    def __init__(self, hid_dim, out_dim, drop, num_lay):
        super(DecoderUPD, self).__init__()
        self.hid_dim = hid_dim
        self.out_dim = out_dim

        self.inp2emb = nn.Embedding(out_dim, hid_dim)
        self.ehc2wei = nn.Linear(hid_dim * 2, MAX_LENGTH)
        self.eow2emb = nn.Linear(hid_dim * 2, hid_dim)
        self.drop = nn.Dropout(drop)
        self.emb2hid = nn.GRU(hid_dim, hid_dim, num_lay)
        self.hid2out = nn.Linear(hid_dim, out_dim)

    def forward(self, inp, h, enc_outs):
        emb = self.drop(self.inp2emb(inp).view(1, 1, -1))

        wei = func.softmax(self.ehc2wei(torch.cat((emb[0], h[0]), 1)), dim = 1)
        ow = torch.bmm(wei.unsqueeze(0), enc_outs.unsqueeze(0))

        eow = torch.cat((emb[0], ow[0]), 1)
        emb = func.relu(self.eow2emb(eow).unsqueeze(0))
        hid, h = self.emb2hid(emb, h)

        out = func.log_softmax(self.hid2out(hid[0]), dim = 1)
        return out, h, wei

In [0]:
def sent2ind(lang, sent):
  return [lang.word2index[word] for word in sent.split(' ')]

In [0]:
def sent2ten(lang, sent):
  ind = sent2ind(lang, sent)
  ind.append(EOS)
  return torch.tensor(ind).type(torch.LongTensor).to(device).view(-1, 1)

In [0]:
def pair2ten(pair):
  src_ten = sent2ten(SRC, pair[0])
  trg_ten = sent2ten(TRG, pair[1])
  return (src_ten, trg_ten)

In [0]:
tfr = 0.5

In [0]:
def train(src_ten, trg_ten, enc, dec, enc_opt, dec_opt, crit):  
  enc_opt.zero_grad()
  dec_opt.zero_grad()

  src_len = src_ten.size(0)
  trg_len = trg_ten.size(0)

  enc_outs = torch.zeros(MAX_LENGTH, enc.hid_dim, device = device)

  loss = 0

  for i in range(src_len):
    enc_out, enc_state = enc(src_ten[i])
    enc_outs[i] = enc_out[0, 0]
  
  dec_inp = torch.tensor([[SOS]], device = device)
  dec_state = enc_state

  if random.random() < tfr:
    for i in range(trg_len):
      dec_out, dec_state, dec_wei = dec(dec_inp, dec_state, enc_outs)
      loss += crit(dec_out, trg_ten[i])
      dec_inp = trg_ten[i]
  else:
    for i in range(trg_len):
      dec_out, dec_state, dec_wei = dec(dec_inp, dec_state, enc_outs)
      _, top_ind = dec_out.topk(1)
      dec_inp = top_ind.squeeze().detach()
      loss += crit(dec_out, trg_ten[i])
      if dec_inp.item() == EOS:
        break
  
  loss.backward()
  enc_opt.step()
  dec_opt.step()

  return loss.item() / trg_len

In [0]:
import time
import math

In [0]:
def min_sec(s):
  m = math.floor(s / 60)
  s -= m * 60
  return '%dm %ds' % (m, s)

In [0]:
def count_time(start, percent):
  now = time.time()
  s = now - start
  es = s / (percent)
  rs = es - s
  return '%s (- %s)' % (min_sec(s), min_sec(rs))

In [0]:
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np

In [0]:
%matplotlib inline

In [0]:
plt.switch_backend('agg')

In [0]:
def show_plot(points):
  plt.figure()
  fig, axs = plt.subplots()
  loc = ticker.MultipleLocator(base = 0.2)
  axs.yaxis.set_major_locator(loc)
  plt.plot(points)

In [0]:
def train_epochs(enc, dec, epochs_num, best, print_period = 1000, plot_period = 100, lr = 0.0001):
  start = time.time()
  losses = []
  print_loss_total = 0
  plot_loss_total = 0

  enc_opt = optim.SGD(enc.parameters(), lr = lr)
  dec_opt = optim.SGD(dec.parameters(), lr = lr)
  pairs = [pair2ten(random.choice(PAIRS)) for i in range(epochs_num)]
  crit = nn.NLLLoss()

  for i in range(1, epochs_num + 1):
    pair = pairs[i - 1]
    src_ten = pair[0]
    trg_ten = pair[1]

    loss = train(src_ten, trg_ten, enc, dec, enc_opt, dec_opt, crit)
    print_loss_total += loss
    plot_loss_total += loss

    if i % print_period == 0:
      period_loss = print_loss_total / print_period
      print('%s (%d %d%%) %.5f' % (count_time(start, i / epochs_num), i, i / epochs_num * 100, period_loss))
      if (period_loss < best):
        torch.save(enc.state_dict(), '/content/gdrive/My Drive/Models/deu_eng_enc.pt')
        torch.save(dec.state_dict(), '/content/gdrive/My Drive/Models/deu_eng_dec.pt')
        best = period_loss
      print_loss_total = 0
    
    if i % plot_period == 0:
      losses.append(plot_loss_total / plot_period)
      plot_loss_total = 0
    
  show_plot(losses)

In [0]:
def eval(enc, dec, sent):
  with torch.no_grad():
    src_ten = sent2ten(SRC, sent)
    src_len = src_ten.size(0)

    enc_outs = torch.zeros(MAX_LENGTH, enc.hid_dim, device = device)

    for i in range(src_len):
      enc_out, enc_state = enc(src_ten[i])
      enc_outs[i] = enc_out[0, 0]
    
    dec_inp = torch.tensor([[SOS]], device = device)
    dec_state = enc_state
    
    dec_words = []
    dec_weights = torch.zeros(MAX_LENGTH, MAX_LENGTH)

    for i in range(MAX_LENGTH):
      dec_out, dec_state, dec_wei = dec(dec_inp, dec_state, enc_outs)
      dec_weights[i] = dec_wei
      top_ind = dec_out.argmax()
      dec_inp = top_ind.squeeze().detach()
      if dec_inp.item() == EOS:
        dec_words.append('<eos>')
        break
      else:
        dec_words.append(TRG.index2word[top_ind.item()])
    
  return dec_words, dec_weights[:i + 1]

In [0]:
def eval_rand(enc, dec, count):
  for i in range(count):
    pair = random.choice(PAIRS)
    print('SRC: ', pair[0])
    words, weights = eval(enc, dec, pair[0])
    print('OUT: ', ' '.join(words))
    print('TRG: ', pair[1])
    print('')

In [0]:
hid_dim = 256
num_lay = 2
enc = Encoder(SRC.words_num, hid_dim, num_lay).to(device)
dec = DecoderUPD(hid_dim, TRG.words_num, 0.1, num_lay).to(device)

In [0]:
train_epochs(enc, dec, 100000, float('inf'))

0m 24s (- 40m 49s) (1000 1%) 3.80520
0m 44s (- 36m 4s) (2000 2%) 3.45530
1m 3s (- 34m 26s) (3000 3%) 3.41157
1m 23s (- 33m 29s) (4000 4%) 3.31503
1m 43s (- 32m 51s) (5000 5%) 3.13268
2m 4s (- 32m 28s) (6000 6%) 3.17728
2m 24s (- 32m 0s) (7000 7%) 3.06148
2m 45s (- 31m 39s) (8000 8%) 3.02501
3m 5s (- 31m 17s) (9000 9%) 2.99415
3m 25s (- 30m 53s) (10000 10%) 2.90105
3m 46s (- 30m 31s) (11000 11%) 2.90870
4m 6s (- 30m 9s) (12000 12%) 2.83543
4m 27s (- 29m 47s) (13000 13%) 2.83973
4m 47s (- 29m 25s) (14000 14%) 2.82670
5m 7s (- 29m 5s) (15000 15%) 2.79322
5m 28s (- 28m 44s) (16000 16%) 2.75363
5m 49s (- 28m 24s) (17000 17%) 2.81456
6m 9s (- 28m 2s) (18000 18%) 2.71916
6m 30s (- 27m 44s) (19000 19%) 2.75213
6m 50s (- 27m 23s) (20000 20%) 2.65309
7m 11s (- 27m 5s) (21000 21%) 2.66240
7m 32s (- 26m 45s) (22000 22%) 2.68313
7m 54s (- 26m 28s) (23000 23%) 2.65616
8m 16s (- 26m 10s) (24000 24%) 2.56633
8m 38s (- 25m 55s) (25000 25%) 2.60494
9m 0s (- 25m 38s) (26000 26%) 2.61726
9m 22s (- 25m 19s

In [0]:
enc.load_state_dict(torch.load('/content/gdrive/My Drive/Models/deu_eng_enc.pt'))
dec.load_state_dict(torch.load('/content/gdrive/My Drive/Models/deu_eng_dec.pt'))

<All keys matched successfully>

In [0]:
train_epochs(enc, dec, 100000, 1.93173)

0m 28s (- 47m 26s) (1000 1%) 1.91874
0m 50s (- 41m 34s) (2000 2%) 1.84577
1m 13s (- 39m 23s) (3000 3%) 1.86908
1m 35s (- 38m 20s) (4000 4%) 1.83616
1m 58s (- 37m 26s) (5000 5%) 1.81642
2m 20s (- 36m 35s) (6000 6%) 1.77608
2m 42s (- 35m 59s) (7000 7%) 1.80557
3m 4s (- 35m 21s) (8000 8%) 1.79039
3m 26s (- 34m 50s) (9000 9%) 1.76927
3m 48s (- 34m 13s) (10000 10%) 1.76521
4m 9s (- 33m 37s) (11000 11%) 1.75026
4m 30s (- 33m 5s) (12000 12%) 1.76048
4m 51s (- 32m 32s) (13000 13%) 1.68004
5m 12s (- 32m 1s) (14000 14%) 1.75002
5m 33s (- 31m 28s) (15000 15%) 1.71912
5m 53s (- 30m 57s) (16000 16%) 1.72926
6m 14s (- 30m 29s) (17000 17%) 1.76951
6m 35s (- 30m 1s) (18000 18%) 1.66414
6m 56s (- 29m 33s) (19000 19%) 1.69683
7m 16s (- 29m 7s) (20000 20%) 1.71449
7m 37s (- 28m 41s) (21000 21%) 1.72810
7m 57s (- 28m 14s) (22000 22%) 1.66556
8m 18s (- 27m 48s) (23000 23%) 1.67013
8m 39s (- 27m 24s) (24000 24%) 1.68054
9m 0s (- 27m 0s) (25000 25%) 1.65166
9m 20s (- 26m 36s) (26000 26%) 1.64714
9m 41s (- 26

In [0]:
train_epochs(enc, dec, 100000, 1.40395)

0m 27s (- 44m 54s) (1000 1%) 1.47071
0m 48s (- 39m 42s) (2000 2%) 1.44767
1m 9s (- 37m 38s) (3000 3%) 1.41223
1m 31s (- 36m 29s) (4000 4%) 1.43195
1m 52s (- 35m 34s) (5000 5%) 1.46883
2m 13s (- 34m 51s) (6000 6%) 1.41637
2m 34s (- 34m 13s) (7000 7%) 1.38112
2m 55s (- 33m 43s) (8000 8%) 1.41899
3m 17s (- 33m 14s) (9000 9%) 1.42185
3m 38s (- 32m 46s) (10000 10%) 1.42438
3m 59s (- 32m 18s) (11000 11%) 1.44404
4m 20s (- 31m 51s) (12000 12%) 1.40722
4m 41s (- 31m 25s) (13000 13%) 1.38285
5m 2s (- 30m 59s) (14000 14%) 1.41707
5m 23s (- 30m 34s) (15000 15%) 1.39124
5m 44s (- 30m 10s) (16000 16%) 1.40825
6m 6s (- 29m 47s) (17000 17%) 1.40264
6m 26s (- 29m 22s) (18000 18%) 1.48640
6m 47s (- 28m 59s) (19000 19%) 1.45574
7m 8s (- 28m 35s) (20000 20%) 1.44922
7m 30s (- 28m 13s) (21000 21%) 1.40632
7m 50s (- 27m 49s) (22000 22%) 1.40859
8m 11s (- 27m 25s) (23000 23%) 1.39399
8m 33s (- 27m 4s) (24000 24%) 1.39929
8m 54s (- 26m 42s) (25000 25%) 1.41732
9m 15s (- 26m 20s) (26000 26%) 1.43944
9m 36s (-

In [0]:
train_epochs(enc, dec, 100000, 1.20255)

0m 27s (- 44m 52s) (1000 1%) 1.28101
0m 48s (- 39m 32s) (2000 2%) 1.28685
1m 9s (- 37m 28s) (3000 3%) 1.21020
1m 31s (- 36m 24s) (4000 4%) 1.26022
1m 52s (- 35m 31s) (5000 5%) 1.23712
2m 13s (- 34m 49s) (6000 6%) 1.19015
2m 34s (- 34m 12s) (7000 7%) 1.16533
2m 56s (- 33m 44s) (8000 8%) 1.23359
3m 17s (- 33m 16s) (9000 9%) 1.20874
3m 38s (- 32m 46s) (10000 10%) 1.26296
3m 59s (- 32m 19s) (11000 11%) 1.22647
4m 20s (- 31m 52s) (12000 12%) 1.21201
4m 41s (- 31m 26s) (13000 13%) 1.24554
5m 2s (- 31m 0s) (14000 14%) 1.22522
5m 23s (- 30m 35s) (15000 15%) 1.24092
5m 45s (- 30m 11s) (16000 16%) 1.16579
6m 6s (- 29m 48s) (17000 17%) 1.19143
6m 27s (- 29m 24s) (18000 18%) 1.18138
6m 48s (- 29m 0s) (19000 19%) 1.19541
7m 9s (- 28m 36s) (20000 20%) 1.16064
7m 30s (- 28m 15s) (21000 21%) 1.16425
7m 51s (- 27m 51s) (22000 22%) 1.17259
8m 13s (- 27m 30s) (23000 23%) 1.18537
8m 33s (- 27m 7s) (24000 24%) 1.19002
8m 54s (- 26m 44s) (25000 25%) 1.19491
9m 16s (- 26m 22s) (26000 26%) 1.17175
9m 37s (- 2

In [0]:
train_epochs(enc, dec, 100000, 0.98945)

0m 27s (- 45m 12s) (1000 1%) 1.11080
0m 48s (- 39m 34s) (2000 2%) 1.03112
1m 9s (- 37m 37s) (3000 3%) 1.09177
1m 31s (- 36m 29s) (4000 4%) 1.08245
1m 52s (- 35m 34s) (5000 5%) 1.05621
2m 13s (- 34m 51s) (6000 6%) 1.06879
2m 34s (- 34m 16s) (7000 7%) 1.07099
2m 55s (- 33m 43s) (8000 8%) 1.05982
3m 17s (- 33m 15s) (9000 9%) 1.03489
3m 38s (- 32m 46s) (10000 10%) 1.00770
3m 59s (- 32m 20s) (11000 11%) 1.02339
4m 20s (- 31m 53s) (12000 12%) 1.05794
4m 41s (- 31m 26s) (13000 13%) 1.10024
5m 3s (- 31m 2s) (14000 14%) 1.03199
5m 24s (- 30m 38s) (15000 15%) 1.06164
5m 45s (- 30m 13s) (16000 16%) 1.01905
6m 6s (- 29m 49s) (17000 17%) 1.01769
6m 27s (- 29m 25s) (18000 18%) 1.00520
6m 48s (- 29m 1s) (19000 19%) 1.02309
7m 9s (- 28m 36s) (20000 20%) 1.05421
7m 29s (- 28m 12s) (21000 21%) 1.01498
7m 50s (- 27m 49s) (22000 22%) 1.01183
8m 12s (- 27m 27s) (23000 23%) 1.00797
8m 33s (- 27m 7s) (24000 24%) 1.03646
8m 54s (- 26m 44s) (25000 25%) 1.04931
9m 16s (- 26m 22s) (26000 26%) 1.01043
9m 36s (- 2

In [0]:
train_epochs(enc, dec, 100000, 0.87582)

0m 27s (- 44m 51s) (1000 1%) 0.91421
0m 48s (- 39m 13s) (2000 2%) 0.91641
1m 9s (- 37m 12s) (3000 3%) 0.93143
1m 29s (- 35m 58s) (4000 4%) 0.90862
1m 50s (- 35m 3s) (5000 5%) 0.94263
2m 11s (- 34m 23s) (6000 6%) 0.88727
2m 32s (- 33m 46s) (7000 7%) 0.91837
2m 53s (- 33m 12s) (8000 8%) 0.93237
3m 14s (- 32m 44s) (9000 9%) 0.91636
3m 35s (- 32m 18s) (10000 10%) 0.90571
3m 56s (- 31m 49s) (11000 11%) 0.92531
4m 17s (- 31m 26s) (12000 12%) 0.89170
4m 38s (- 31m 1s) (13000 13%) 0.96076
4m 59s (- 30m 37s) (14000 14%) 0.93972
5m 19s (- 30m 13s) (15000 15%) 0.89444
5m 40s (- 29m 47s) (16000 16%) 0.88065
6m 1s (- 29m 24s) (17000 17%) 0.92729
6m 22s (- 29m 1s) (18000 18%) 0.88701
6m 43s (- 28m 39s) (19000 19%) 0.89370
7m 3s (- 28m 14s) (20000 20%) 0.86405
7m 24s (- 27m 53s) (21000 21%) 0.87810
7m 45s (- 27m 30s) (22000 22%) 0.85529
8m 6s (- 27m 8s) (23000 23%) 0.90556
8m 27s (- 26m 46s) (24000 24%) 0.90897
8m 48s (- 26m 24s) (25000 25%) 0.91178
9m 8s (- 26m 2s) (26000 26%) 0.87447
9m 29s (- 25m 

In [0]:
train_epochs(enc, dec, 200000, 0.75615)

0m 32s (- 109m 5s) (1000 0%) 0.80125
0m 53s (- 88m 47s) (2000 1%) 0.78616
1m 14s (- 81m 42s) (3000 1%) 0.82368
1m 35s (- 77m 51s) (4000 2%) 0.79315
1m 56s (- 75m 35s) (5000 2%) 0.78519
2m 17s (- 73m 59s) (6000 3%) 0.75895
2m 38s (- 72m 39s) (7000 3%) 0.74964
2m 59s (- 71m 45s) (8000 4%) 0.83758
3m 20s (- 70m 53s) (9000 4%) 0.81781
3m 41s (- 70m 6s) (10000 5%) 0.79220
4m 2s (- 69m 22s) (11000 5%) 0.78868
4m 23s (- 68m 46s) (12000 6%) 0.79152
4m 44s (- 68m 11s) (13000 6%) 0.80256
5m 5s (- 67m 41s) (14000 7%) 0.78434
5m 26s (- 67m 9s) (15000 7%) 0.77571
5m 47s (- 66m 39s) (16000 8%) 0.76052
6m 9s (- 66m 12s) (17000 8%) 0.82498
6m 30s (- 65m 43s) (18000 9%) 0.77836
6m 50s (- 65m 13s) (19000 9%) 0.80047
7m 12s (- 64m 49s) (20000 10%) 0.78667
7m 33s (- 64m 22s) (21000 10%) 0.78922
7m 53s (- 63m 54s) (22000 11%) 0.77124
8m 14s (- 63m 28s) (23000 11%) 0.80917
8m 36s (- 63m 4s) (24000 12%) 0.77080
8m 56s (- 62m 38s) (25000 12%) 0.75865
9m 17s (- 62m 13s) (26000 13%) 0.78253
9m 38s (- 61m 48s) (

In [0]:
train_epochs(enc, dec, 200000, 0.55668)

0m 34s (- 115m 8s) (1000 0%) 0.62668
0m 56s (- 94m 0s) (2000 1%) 0.57182
1m 18s (- 86m 12s) (3000 1%) 0.58071
1m 40s (- 82m 17s) (4000 2%) 0.58854
2m 2s (- 79m 44s) (5000 2%) 0.59886
2m 24s (- 78m 8s) (6000 3%) 0.58313
2m 46s (- 76m 44s) (7000 3%) 0.55296
3m 9s (- 75m 45s) (8000 4%) 0.55669
3m 31s (- 74m 52s) (9000 4%) 0.57543
3m 53s (- 74m 1s) (10000 5%) 0.54464
4m 15s (- 73m 14s) (11000 5%) 0.55910
4m 38s (- 72m 35s) (12000 6%) 0.51868
5m 0s (- 71m 59s) (13000 6%) 0.55695
5m 22s (- 71m 22s) (14000 7%) 0.56977
5m 44s (- 70m 49s) (15000 7%) 0.54783
6m 6s (- 70m 16s) (16000 8%) 0.55565
6m 28s (- 69m 43s) (17000 8%) 0.55349
6m 50s (- 69m 13s) (18000 9%) 0.55997
7m 13s (- 68m 45s) (19000 9%) 0.56525
7m 35s (- 68m 18s) (20000 10%) 0.53771
7m 57s (- 67m 52s) (21000 10%) 0.55488
8m 20s (- 67m 26s) (22000 11%) 0.52715
8m 42s (- 67m 0s) (23000 11%) 0.50099
9m 5s (- 66m 37s) (24000 12%) 0.56362
9m 26s (- 66m 8s) (25000 12%) 0.54314
9m 49s (- 65m 42s) (26000 13%) 0.54667
10m 11s (- 65m 16s) (270

In [0]:
train_epochs(enc, dec, 200000, 0.43892)

0m 34s (- 114m 12s) (1000 0%) 0.49630
0m 56s (- 93m 36s) (2000 1%) 0.47409
1m 18s (- 86m 23s) (3000 1%) 0.45383
1m 40s (- 82m 26s) (4000 2%) 0.44845
2m 3s (- 80m 3s) (5000 2%) 0.48863
2m 25s (- 78m 9s) (6000 3%) 0.44382
2m 47s (- 76m 50s) (7000 3%) 0.45122
3m 9s (- 75m 43s) (8000 4%) 0.44377
3m 31s (- 74m 43s) (9000 4%) 0.46365
3m 53s (- 73m 48s) (10000 5%) 0.49532
4m 15s (- 73m 1s) (11000 5%) 0.47216
4m 37s (- 72m 24s) (12000 6%) 0.45461
4m 59s (- 71m 43s) (13000 6%) 0.47858
5m 20s (- 71m 4s) (14000 7%) 0.46145
5m 43s (- 70m 34s) (15000 7%) 0.45665
6m 5s (- 70m 1s) (16000 8%) 0.49052
6m 27s (- 69m 29s) (17000 8%) 0.45030
6m 49s (- 68m 59s) (18000 9%) 0.42804
7m 11s (- 68m 33s) (19000 9%) 0.46836
7m 34s (- 68m 6s) (20000 10%) 0.45785
7m 56s (- 67m 37s) (21000 10%) 0.46946
8m 18s (- 67m 12s) (22000 11%) 0.45617
8m 40s (- 66m 45s) (23000 11%) 0.44248
9m 2s (- 66m 20s) (24000 12%) 0.43836
9m 24s (- 65m 51s) (25000 12%) 0.42861
9m 46s (- 65m 26s) (26000 13%) 0.47725
10m 9s (- 65m 2s) (2700

In [0]:
train_epochs(enc, dec, 200000, 0.37568)

0m 35s (- 118m 54s) (1000 0%) 0.37032
0m 59s (- 97m 39s) (2000 1%) 0.34775
1m 22s (- 90m 25s) (3000 1%) 0.38382
1m 45s (- 86m 11s) (4000 2%) 0.37384
2m 8s (- 83m 45s) (5000 2%) 0.36662
2m 32s (- 81m 56s) (6000 3%) 0.39523
2m 55s (- 80m 25s) (7000 3%) 0.38932
3m 18s (- 79m 23s) (8000 4%) 0.38684
3m 41s (- 78m 19s) (9000 4%) 0.37196
4m 4s (- 77m 30s) (10000 5%) 0.37703
4m 28s (- 76m 49s) (11000 5%) 0.37382
4m 51s (- 76m 5s) (12000 6%) 0.39782
5m 14s (- 75m 23s) (13000 6%) 0.36511
5m 37s (- 74m 43s) (14000 7%) 0.38207
6m 0s (- 74m 8s) (15000 7%) 0.37390
6m 23s (- 73m 34s) (16000 8%) 0.39535
6m 46s (- 72m 57s) (17000 8%) 0.38128
7m 9s (- 72m 22s) (18000 9%) 0.40613
7m 32s (- 71m 48s) (19000 9%) 0.37053
7m 55s (- 71m 18s) (20000 10%) 0.38506
8m 18s (- 70m 47s) (21000 10%) 0.39178
8m 41s (- 70m 18s) (22000 11%) 0.39152
9m 4s (- 69m 50s) (23000 11%) 0.37015
9m 27s (- 69m 20s) (24000 12%) 0.37786
9m 50s (- 68m 53s) (25000 12%) 0.38575
10m 13s (- 68m 26s) (26000 13%) 0.37621
10m 36s (- 67m 56s)

In [0]:
from nltk.translate.bleu_score import corpus_bleu

In [0]:
def compute_bleu(pairs):
    src_lines = []
    trg_lines = []
    
    for i in range(len(pairs)):
      src_lines.append(pairs[i][0])
      trg_lines.append(pairs[i][1])

    out_lines = []
    for src_line in src_lines:
      words, weights = eval(enc, dec, src_line)
      out_lines.append(words[:-1])
    
    trg_lines = [[trg_line.split(' ')] for trg_line in trg_lines]
    return corpus_bleu(trg_lines, out_lines) * 100

In [0]:
def get_pairs(count):
  test_pairs = []
  for i in range(count):
    test_pairs.append(random.choice(PAIRS))
  return test_pairs

In [0]:
compute_bleu(get_pairs(1000))

75.59217076664714

In [0]:
train_epochs(enc, dec, 25000, 0.48851)

0m 18s (- 7m 35s) (1000 4%) 0.55974
0m 36s (- 6m 59s) (2000 8%) 0.54826
0m 53s (- 6m 35s) (3000 12%) 0.51443
1m 11s (- 6m 15s) (4000 16%) 0.52424
1m 29s (- 5m 56s) (5000 20%) 0.54424
1m 46s (- 5m 37s) (6000 24%) 0.51501
2m 4s (- 5m 19s) (7000 28%) 0.49736
2m 21s (- 5m 1s) (8000 32%) 0.50156
2m 39s (- 4m 43s) (9000 36%) 0.53815
2m 57s (- 4m 25s) (10000 40%) 0.49222
3m 14s (- 4m 7s) (11000 44%) 0.48452
3m 32s (- 3m 50s) (12000 48%) 0.48856
3m 50s (- 3m 32s) (13000 52%) 0.48371
4m 8s (- 3m 14s) (14000 56%) 0.55039
4m 25s (- 2m 57s) (15000 60%) 0.52901
4m 43s (- 2m 39s) (16000 64%) 0.50421
5m 0s (- 2m 21s) (17000 68%) 0.49989
5m 18s (- 2m 3s) (18000 72%) 0.49680
5m 35s (- 1m 46s) (19000 76%) 0.49210
5m 53s (- 1m 28s) (20000 80%) 0.50463
6m 10s (- 1m 10s) (21000 84%) 0.52804
6m 28s (- 0m 52s) (22000 88%) 0.51576
6m 45s (- 0m 35s) (23000 92%) 0.51236
7m 3s (- 0m 17s) (24000 96%) 0.51266
7m 21s (- 0m 0s) (25000 100%) 0.45235


In [0]:
for i in range(10):
  print(compute_bleu(get_pairs(1000)))

71.82599390985509
70.56190347779587
73.04423575097974
71.3014090432591
70.81787290187133
72.43064881102971
71.48627048860348
70.00889131187147
74.2638257758353
73.45914816597858


In [0]:
train_epochs(enc, dec, 25000, 0.45235)

0m 18s (- 7m 27s) (1000 4%) 0.50186
0m 36s (- 6m 55s) (2000 8%) 0.52685
0m 53s (- 6m 35s) (3000 12%) 0.53316
1m 11s (- 6m 13s) (4000 16%) 0.51024
1m 28s (- 5m 53s) (5000 20%) 0.48836
1m 45s (- 5m 35s) (6000 24%) 0.48319
2m 3s (- 5m 17s) (7000 28%) 0.51311
2m 21s (- 4m 59s) (8000 32%) 0.48621
2m 38s (- 4m 41s) (9000 36%) 0.47334
2m 56s (- 4m 24s) (10000 40%) 0.47088
3m 13s (- 4m 6s) (11000 44%) 0.49281
3m 31s (- 3m 49s) (12000 48%) 0.51372
3m 49s (- 3m 31s) (13000 52%) 0.52059
4m 6s (- 3m 13s) (14000 56%) 0.46390
4m 23s (- 2m 55s) (15000 60%) 0.53249
4m 41s (- 2m 38s) (16000 64%) 0.51298
4m 58s (- 2m 20s) (17000 68%) 0.46091
5m 16s (- 2m 2s) (18000 72%) 0.52970
5m 33s (- 1m 45s) (19000 76%) 0.49868
5m 51s (- 1m 27s) (20000 80%) 0.49038
6m 8s (- 1m 10s) (21000 84%) 0.44387
6m 26s (- 0m 52s) (22000 88%) 0.47504
6m 43s (- 0m 35s) (23000 92%) 0.47849
7m 1s (- 0m 17s) (24000 96%) 0.46404
7m 18s (- 0m 0s) (25000 100%) 0.45659
