Машинный перевод. Seq2Seq + Attention

In [0]:
from google.colab import drive
drive.mount('/content/gdrive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/gdrive


In [0]:
from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import string
import re
import random

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as func

In [0]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [0]:
SOS = 0
EOS = 1

In [0]:
class Language:
  def __init__(self, name):
    self.name = name
    self.word2index = {'SOS' : 0, 'EOS' : 1}
    self.word2count = {}
    self.index2word = {0 : 'SOS', 1 : 'EOS'}
    self.words_num = 2
  
  def new_sent(self, sent):
    for word in sent.split(' '):
      self.new_word(word)
    
  def new_word(self, word):
    if word not in self.word2index:
      self.word2index[word] = self.words_num
      self.word2count[word] = 1
      self.index2word[self.words_num] = word
      self.words_num += 1
    else:
      self.word2count[word] += 1

In [0]:
def uni2asc(s):
    return ''.join(c for c in unicodedata.normalize('NFD', s) if unicodedata.category(c) != 'Mn')

In [0]:
def norm(s):
    s = uni2asc(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    return s

In [0]:
def read_lang(lang):
  lines = open('/content/gdrive/My Drive/Data/Translations/eng-%s.txt' % (lang), encoding='utf-8').read().strip().split('\n')
  pairs = []
  for line in lines:
    sents = line.split('\t')
    pairs.append([norm(sents[1]), norm(sents[0])])
  return Language(lang), Language('eng'), pairs

In [0]:
MAX_LENGTH = 10

In [0]:
eng_prefixes = ('i m ', 'i am ', 'you re ', 'you are ', 'he s ', 'he is ', 'she s ', 'she is ', 'it s ', 'it is ', 'we re ', 'we are ', 'they re ', 'they are ')

In [0]:
def is_accepted(pair):
  return len(pair[0].split(' ')) < MAX_LENGTH and len(pair[1].split(' ')) < MAX_LENGTH and pair[1].startswith(eng_prefixes)

In [0]:
def filter_pairs(pairs):
  return [pair for pair in pairs if is_accepted(pair)]

In [0]:
def prepare_data(lang):
  SRC, TRG, PAIRS = read_lang(lang)
  print('Before: ', len(PAIRS))
  PAIRS = filter_pairs(PAIRS)
  print('After: ', len(PAIRS))
  for pair in PAIRS:
    SRC.new_sent(pair[0])
    TRG.new_sent(pair[1])
  print('SRC: ', SRC.words_num)
  print('TRG: ', TRG.words_num)
  return SRC, TRG, PAIRS

In [0]:
SRC, TRG, PAIRS = prepare_data('deu')
print(random.choice(PAIRS))

Before:  208486
After:  13643
SRC:  5604
TRG:  3570
['du bist die expertin .', 'you re the expert .']


In [0]:
class Encoder(nn.Module):
    def __init__(self, inp_dim, hid_dim, num_lay):
        super(Encoder, self).__init__()
        self.hid_dim = hid_dim

        self.inp2emb = nn.Embedding(inp_dim, hid_dim)
        self.emb2hid = nn.GRU(hid_dim, hid_dim, num_lay)

    def forward(self, inp):
        emb = self.inp2emb(inp).view(1, 1, -1)
        hid, h = self.emb2hid(emb)
        return hid, h

In [0]:
class DecoderUPD(nn.Module):
    def __init__(self, hid_dim, out_dim, drop, num_lay):
        super(DecoderUPD, self).__init__()
        self.hid_dim = hid_dim
        self.out_dim = out_dim

        self.inp2emb = nn.Embedding(out_dim, hid_dim)
        self.ehc2wei = nn.Linear(hid_dim * 2, MAX_LENGTH)
        self.eow2emb = nn.Linear(hid_dim * 2, hid_dim)
        self.drop = nn.Dropout(drop)
        self.emb2hid = nn.GRU(hid_dim, hid_dim, num_lay)
        self.hid2out = nn.Linear(hid_dim, out_dim)

    def forward(self, inp, h, enc_outs):
        emb = self.drop(self.inp2emb(inp).view(1, 1, -1))

        wei = func.softmax(self.ehc2wei(torch.cat((emb[0], h[0]), 1)), dim = 1)
        ow = torch.bmm(wei.unsqueeze(0), enc_outs.unsqueeze(0))

        eow = torch.cat((emb[0], ow[0]), 1)
        emb = func.relu(self.eow2emb(eow).unsqueeze(0))
        hid, h = self.emb2hid(emb, h)

        out = func.log_softmax(self.hid2out(hid[0]), dim = 1)
        return out, h, wei

In [0]:
def sent2ind(lang, sent):
  return [lang.word2index[word] for word in sent.split(' ')]

In [0]:
def sent2ten(lang, sent):
  ind = sent2ind(lang, sent)
  ind.append(EOS)
  return torch.tensor(ind).type(torch.LongTensor).to(device).view(-1, 1)

In [0]:
def pair2ten(pair):
  src_ten = sent2ten(SRC, pair[0])
  trg_ten = sent2ten(TRG, pair[1])
  return (src_ten, trg_ten)

In [0]:
tfr = 0.5

In [0]:
def train(src_ten, trg_ten, enc, dec, enc_opt, dec_opt, crit):  
  enc_opt.zero_grad()
  dec_opt.zero_grad()

  src_len = src_ten.size(0)
  trg_len = trg_ten.size(0)

  enc_outs = torch.zeros(MAX_LENGTH, enc.hid_dim, device = device)

  loss = 0

  for i in range(src_len):
    enc_out, enc_state = enc(src_ten[i])
    enc_outs[i] = enc_out[0, 0]
  
  dec_inp = torch.tensor([[SOS]], device = device)
  dec_state = enc_state

  if random.random() < tfr:
    for i in range(trg_len):
      dec_out, dec_state, dec_wei = dec(dec_inp, dec_state, enc_outs)
      loss += crit(dec_out, trg_ten[i])
      dec_inp = trg_ten[i]
  else:
    for i in range(trg_len):
      dec_out, dec_state, dec_wei = dec(dec_inp, dec_state, enc_outs)
      _, top_ind = dec_out.topk(1)
      dec_inp = top_ind.squeeze().detach()
      loss += crit(dec_out, trg_ten[i])
      if dec_inp.item() == EOS:
        break
  
  loss.backward()
  enc_opt.step()
  dec_opt.step()

  return loss.item() / trg_len

In [0]:
import time
import math

In [0]:
def min_sec(s):
  m = math.floor(s / 60)
  s -= m * 60
  return '%dm %ds' % (m, s)

In [0]:
def count_time(start, percent):
  now = time.time()
  s = now - start
  es = s / (percent)
  rs = es - s
  return '%s (- %s)' % (min_sec(s), min_sec(rs))

In [0]:
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np

In [0]:
%matplotlib inline

In [0]:
plt.switch_backend('agg')

In [0]:
def show_plot(points):
  plt.figure()
  fig, axs = plt.subplots()
  loc = ticker.MultipleLocator(base = 0.2)
  axs.yaxis.set_major_locator(loc)
  plt.plot(points)

In [0]:
def train_epochs(enc, dec, epochs_num, best, print_period = 1000, plot_period = 100, lr = 0.001):
  start = time.time()
  losses = []
  print_loss_total = 0
  plot_loss_total = 0

  enc_opt = optim.SGD(enc.parameters(), lr = lr)
  dec_opt = optim.SGD(dec.parameters(), lr = lr)
  pairs = [pair2ten(random.choice(PAIRS)) for i in range(epochs_num)]
  crit = nn.NLLLoss()

  for i in range(1, epochs_num + 1):
    pair = pairs[i - 1]
    src_ten = pair[0]
    trg_ten = pair[1]

    loss = train(src_ten, trg_ten, enc, dec, enc_opt, dec_opt, crit)
    print_loss_total += loss
    plot_loss_total += loss

    if i % print_period == 0:
      period_loss = print_loss_total / print_period
      print('%s (%d %d%%) %.5f' % (count_time(start, i / epochs_num), i, i / epochs_num * 100, period_loss))
      if (period_loss < best):
        torch.save(enc.state_dict(), '/content/gdrive/My Drive/Models/deu_eng_enc_.pt')
        torch.save(dec.state_dict(), '/content/gdrive/My Drive/Models/deu_eng_dec_.pt')
        best = period_loss
      print_loss_total = 0
    
    if i % plot_period == 0:
      losses.append(plot_loss_total / plot_period)
      plot_loss_total = 0
    
  show_plot(losses)

In [0]:
torch.save(enc.state_dict(), '/content/gdrive/My Drive/Models/deu_eng_enc_.pt')
torch.save(dec.state_dict(), '/content/gdrive/My Drive/Models/deu_eng_dec_.pt')

In [0]:
def eval(enc, dec, sent):
  with torch.no_grad():
    src_ten = sent2ten(SRC, sent)
    src_len = src_ten.size(0)

    enc_outs = torch.zeros(MAX_LENGTH, enc.hid_dim, device = device)

    for i in range(src_len):
      enc_out, enc_state = enc(src_ten[i])
      enc_outs[i] = enc_out[0, 0]
    
    dec_inp = torch.tensor([[SOS]], device = device)
    dec_state = enc_state
    
    dec_words = []
    dec_weights = torch.zeros(MAX_LENGTH, MAX_LENGTH)

    for i in range(MAX_LENGTH):
      dec_out, dec_state, dec_wei = dec(dec_inp, dec_state, enc_outs)
      dec_weights[i] = dec_wei
      top_ind = dec_out.argmax()
      dec_inp = top_ind.squeeze().detach()
      if dec_inp.item() == EOS:
        dec_words.append('<eos>')
        break
      else:
        dec_words.append(TRG.index2word[top_ind.item()])
    
  return dec_words, dec_weights[:i + 1]

In [0]:
def eval_rand(enc, dec, count):
  for i in range(count):
    pair = random.choice(PAIRS)
    print('SRC: ', pair[0])
    words, weights = eval(enc, dec, pair[0])
    print('OUT: ', ' '.join(words))
    print('TRG: ', pair[1])
    print('')

In [0]:
hid_dim = 256
num_lay = 1
enc = Encoder(SRC.words_num, hid_dim, num_lay).to(device)
dec = DecoderUPD(hid_dim, TRG.words_num, 0.1, num_lay).to(device)

In [0]:
train_epochs(enc, dec, 100000, float('inf'))

0m 23s (- 39m 10s) (1000 1%) 3.75628
0m 42s (- 34m 28s) (2000 2%) 3.53028
1m 0s (- 32m 28s) (3000 3%) 3.41454
1m 18s (- 31m 28s) (4000 4%) 3.35670
1m 37s (- 30m 46s) (5000 5%) 3.10289
1m 56s (- 30m 17s) (6000 6%) 3.00377
2m 14s (- 29m 51s) (7000 7%) 2.96159
2m 33s (- 29m 29s) (8000 8%) 2.89761
2m 52s (- 29m 1s) (9000 9%) 2.79984
3m 10s (- 28m 38s) (10000 10%) 2.76576
3m 29s (- 28m 11s) (11000 11%) 2.78427
3m 47s (- 27m 45s) (12000 12%) 2.66237
4m 5s (- 27m 25s) (13000 13%) 2.68703
4m 24s (- 27m 4s) (14000 14%) 2.59511
4m 43s (- 26m 44s) (15000 15%) 2.57815
5m 1s (- 26m 24s) (16000 16%) 2.56081
5m 20s (- 26m 4s) (17000 17%) 2.55976
5m 39s (- 25m 45s) (18000 18%) 2.47639
5m 57s (- 25m 24s) (19000 19%) 2.44573
6m 16s (- 25m 5s) (20000 20%) 2.45166
6m 34s (- 24m 44s) (21000 21%) 2.38359
6m 52s (- 24m 23s) (22000 22%) 2.35862
7m 11s (- 24m 5s) (23000 23%) 2.38431
7m 30s (- 23m 45s) (24000 24%) 2.28355
7m 49s (- 23m 27s) (25000 25%) 2.26584
8m 7s (- 23m 7s) (26000 26%) 2.27346
8m 25s (- 22m 

In [0]:
enc.load_state_dict(torch.load('/content/gdrive/My Drive/Models/deu_eng_enc_.pt'))
dec.load_state_dict(torch.load('/content/gdrive/My Drive/Models/deu_eng_dec_.pt'))

<All keys matched successfully>

In [0]:
train_epochs(enc, dec, 100000, 1.55725)

0m 25s (- 41m 28s) (1000 1%) 1.58086
0m 44s (- 35m 57s) (2000 2%) 1.43522
1m 3s (- 33m 57s) (3000 3%) 1.49036
1m 22s (- 32m 49s) (4000 4%) 1.40083
1m 41s (- 32m 11s) (5000 5%) 1.40400
2m 0s (- 31m 22s) (6000 6%) 1.41532
2m 18s (- 30m 45s) (7000 7%) 1.47049
2m 37s (- 30m 15s) (8000 8%) 1.40211
2m 56s (- 29m 44s) (9000 9%) 1.36359
3m 14s (- 29m 14s) (10000 10%) 1.37641
3m 33s (- 28m 50s) (11000 11%) 1.32287
3m 53s (- 28m 32s) (12000 12%) 1.36881
4m 12s (- 28m 10s) (13000 13%) 1.38307
4m 31s (- 27m 49s) (14000 14%) 1.34041
4m 50s (- 27m 26s) (15000 15%) 1.29113
5m 9s (- 27m 6s) (16000 16%) 1.30273
5m 28s (- 26m 44s) (17000 17%) 1.22010
5m 47s (- 26m 23s) (18000 18%) 1.26143
6m 6s (- 26m 2s) (19000 19%) 1.27051
6m 24s (- 25m 39s) (20000 20%) 1.21303
6m 44s (- 25m 21s) (21000 21%) 1.23761
7m 3s (- 25m 1s) (22000 22%) 1.20935
7m 22s (- 24m 39s) (23000 23%) 1.29196
7m 40s (- 24m 19s) (24000 24%) 1.21438
7m 59s (- 23m 58s) (25000 25%) 1.23183
8m 18s (- 23m 38s) (26000 26%) 1.20304
8m 37s (- 23

In [0]:
train_epochs(enc, dec, 100000, 0.88328)

0m 25s (- 41m 33s) (1000 1%) 0.86438
0m 44s (- 36m 35s) (2000 2%) 0.87535
1m 3s (- 34m 24s) (3000 3%) 0.84006
1m 23s (- 33m 31s) (4000 4%) 0.83201
1m 43s (- 32m 37s) (5000 5%) 0.85102
2m 1s (- 31m 46s) (6000 6%) 0.86819
2m 21s (- 31m 13s) (7000 7%) 0.80191
2m 40s (- 30m 45s) (8000 8%) 0.79909
2m 59s (- 30m 19s) (9000 9%) 0.80616
3m 19s (- 29m 57s) (10000 10%) 0.83674
3m 38s (- 29m 31s) (11000 11%) 0.80275
3m 57s (- 29m 5s) (12000 12%) 0.77267
4m 17s (- 28m 40s) (13000 13%) 0.78656
4m 36s (- 28m 17s) (14000 14%) 0.75581
4m 55s (- 27m 56s) (15000 15%) 0.75207
5m 15s (- 27m 34s) (16000 16%) 0.77129
5m 34s (- 27m 14s) (17000 17%) 0.74884
5m 54s (- 26m 55s) (18000 18%) 0.78577
6m 13s (- 26m 32s) (19000 19%) 0.75293
6m 32s (- 26m 10s) (20000 20%) 0.73105
6m 52s (- 25m 50s) (21000 21%) 0.78710
7m 10s (- 25m 27s) (22000 22%) 0.76331
7m 30s (- 25m 6s) (23000 23%) 0.72165
7m 49s (- 24m 47s) (24000 24%) 0.77364
8m 8s (- 24m 26s) (25000 25%) 0.70905
8m 28s (- 24m 6s) (26000 26%) 0.74944
8m 48s (- 

In [0]:
train_epochs(enc, dec, 100000, 0.56326)

0m 23s (- 39m 10s) (1000 1%) 0.59731
0m 41s (- 33m 41s) (2000 2%) 0.59442
0m 58s (- 31m 44s) (3000 3%) 0.57918
1m 16s (- 30m 33s) (4000 4%) 0.66055
1m 33s (- 29m 43s) (5000 5%) 0.57044
1m 51s (- 29m 9s) (6000 6%) 0.60625
2m 9s (- 28m 36s) (7000 7%) 0.62469
2m 26s (- 28m 10s) (8000 8%) 0.58144
2m 44s (- 27m 42s) (9000 9%) 0.57898
3m 1s (- 27m 15s) (10000 10%) 0.57572
3m 18s (- 26m 48s) (11000 11%) 0.58311
3m 36s (- 26m 24s) (12000 12%) 0.59197
3m 53s (- 26m 2s) (13000 13%) 0.56996
4m 10s (- 25m 39s) (14000 14%) 0.57658
4m 28s (- 25m 20s) (15000 15%) 0.57479
4m 45s (- 25m 0s) (16000 16%) 0.57196
5m 3s (- 24m 41s) (17000 17%) 0.57109
5m 20s (- 24m 21s) (18000 18%) 0.58885
5m 37s (- 24m 0s) (19000 19%) 0.57396
5m 55s (- 23m 42s) (20000 20%) 0.55035
6m 13s (- 23m 24s) (21000 21%) 0.55442
6m 30s (- 23m 5s) (22000 22%) 0.57777
6m 48s (- 22m 46s) (23000 23%) 0.57229
7m 5s (- 22m 27s) (24000 24%) 0.56477
7m 23s (- 22m 9s) (25000 25%) 0.63045
7m 40s (- 21m 50s) (26000 26%) 0.58764
7m 57s (- 21m 

In [0]:
train_epochs(enc, dec, 200000, 0.45968)

0m 28s (- 96m 5s) (1000 0%) 0.47339
0m 46s (- 76m 29s) (2000 1%) 0.48040
1m 3s (- 69m 46s) (3000 1%) 0.48487
1m 21s (- 66m 26s) (4000 2%) 0.51009
1m 38s (- 64m 6s) (5000 2%) 0.48447
1m 56s (- 62m 33s) (6000 3%) 0.47427
2m 13s (- 61m 19s) (7000 3%) 0.47963
2m 30s (- 60m 18s) (8000 4%) 0.47664
2m 48s (- 59m 31s) (9000 4%) 0.47508
3m 6s (- 58m 55s) (10000 5%) 0.51010
3m 23s (- 58m 17s) (11000 5%) 0.49920
3m 41s (- 57m 42s) (12000 6%) 0.49354
3m 58s (- 57m 11s) (13000 6%) 0.44792
4m 16s (- 56m 47s) (14000 7%) 0.47994
4m 34s (- 56m 20s) (15000 7%) 0.47123
4m 51s (- 55m 50s) (16000 8%) 0.48848
5m 8s (- 55m 21s) (17000 8%) 0.47326
5m 25s (- 54m 54s) (18000 9%) 0.48211
5m 42s (- 54m 25s) (19000 9%) 0.48118
5m 59s (- 53m 59s) (20000 10%) 0.46940
6m 17s (- 53m 33s) (21000 10%) 0.47769
6m 34s (- 53m 8s) (22000 11%) 0.43974
6m 51s (- 52m 46s) (23000 11%) 0.46024
7m 8s (- 52m 23s) (24000 12%) 0.46298
7m 25s (- 51m 59s) (25000 12%) 0.48032
7m 42s (- 51m 36s) (26000 13%) 0.47258
7m 59s (- 51m 14s) (2

In [0]:
train_epochs(enc, dec, 200000, 0.34465)

0m 29s (- 97m 30s) (1000 0%) 0.36923
0m 47s (- 77m 47s) (2000 1%) 0.37132
1m 4s (- 70m 52s) (3000 1%) 0.40844
1m 22s (- 67m 26s) (4000 2%) 0.38611
1m 40s (- 65m 3s) (5000 2%) 0.35435
1m 57s (- 63m 35s) (6000 3%) 0.36613
2m 15s (- 62m 17s) (7000 3%) 0.35441
2m 33s (- 61m 22s) (8000 4%) 0.36860
2m 50s (- 60m 28s) (9000 4%) 0.36057
3m 8s (- 59m 43s) (10000 5%) 0.37612
3m 26s (- 59m 2s) (11000 5%) 0.35919
3m 43s (- 58m 26s) (12000 6%) 0.34916
4m 1s (- 57m 56s) (13000 6%) 0.36402
4m 19s (- 57m 24s) (14000 7%) 0.33778
4m 37s (- 56m 58s) (15000 7%) 0.37818
4m 54s (- 56m 29s) (16000 8%) 0.35992
5m 12s (- 56m 0s) (17000 8%) 0.36227
5m 29s (- 55m 34s) (18000 9%) 0.35119
5m 47s (- 55m 9s) (19000 9%) 0.34355
6m 5s (- 54m 47s) (20000 10%) 0.33918
6m 22s (- 54m 22s) (21000 10%) 0.36559
6m 40s (- 53m 58s) (22000 11%) 0.33877
6m 57s (- 53m 35s) (23000 11%) 0.35388
7m 15s (- 53m 13s) (24000 12%) 0.34939
7m 32s (- 52m 50s) (25000 12%) 0.35435
7m 50s (- 52m 28s) (26000 13%) 0.34117
8m 8s (- 52m 8s) (2700

In [0]:
train_epochs(enc, dec, 200000, 0.27773)

0m 30s (- 102m 20s) (1000 0%) 0.30555
0m 49s (- 81m 22s) (2000 1%) 0.28267
1m 7s (- 73m 59s) (3000 1%) 0.29914
1m 25s (- 70m 11s) (4000 2%) 0.29157
1m 44s (- 67m 45s) (5000 2%) 0.27160
2m 2s (- 66m 5s) (6000 3%) 0.28936
2m 20s (- 64m 28s) (7000 3%) 0.27328
2m 38s (- 63m 17s) (8000 4%) 0.28476
2m 56s (- 62m 17s) (9000 4%) 0.29976
3m 14s (- 61m 26s) (10000 5%) 0.29669
3m 32s (- 60m 46s) (11000 5%) 0.28032
3m 49s (- 60m 1s) (12000 6%) 0.27917
4m 7s (- 59m 23s) (13000 6%) 0.27335
4m 25s (- 58m 43s) (14000 7%) 0.27923
4m 42s (- 58m 10s) (15000 7%) 0.29027
5m 0s (- 57m 35s) (16000 8%) 0.29533
5m 18s (- 57m 4s) (17000 8%) 0.26983
5m 35s (- 56m 35s) (18000 9%) 0.31679
5m 53s (- 56m 3s) (19000 9%) 0.27203
6m 10s (- 55m 34s) (20000 10%) 0.27978
6m 27s (- 55m 5s) (21000 10%) 0.27602
6m 45s (- 54m 38s) (22000 11%) 0.28522
7m 2s (- 54m 12s) (23000 11%) 0.29221
7m 20s (- 53m 47s) (24000 12%) 0.27508
7m 37s (- 53m 20s) (25000 12%) 0.26085
7m 54s (- 52m 57s) (26000 13%) 0.28173
8m 12s (- 52m 34s) (270

In [0]:
train_epochs(enc, dec, 200000, 0.23411)

0m 29s (- 98m 50s) (1000 0%) 0.25308
0m 47s (- 78m 37s) (2000 1%) 0.25566
1m 5s (- 71m 37s) (3000 1%) 0.25301
1m 23s (- 68m 6s) (4000 2%) 0.25665
1m 41s (- 65m 49s) (5000 2%) 0.25405
1m 58s (- 64m 6s) (6000 3%) 0.26460
2m 16s (- 62m 42s) (7000 3%) 0.25120
2m 33s (- 61m 33s) (8000 4%) 0.25258
2m 51s (- 60m 38s) (9000 4%) 0.24631
3m 8s (- 59m 50s) (10000 5%) 0.26562
3m 26s (- 59m 5s) (11000 5%) 0.26296
3m 43s (- 58m 25s) (12000 6%) 0.24462
4m 1s (- 57m 50s) (13000 6%) 0.26253
4m 18s (- 57m 14s) (14000 7%) 0.25391
4m 36s (- 56m 45s) (15000 7%) 0.26174
4m 53s (- 56m 14s) (16000 8%) 0.25063
5m 10s (- 55m 42s) (17000 8%) 0.27635
5m 27s (- 55m 13s) (18000 9%) 0.25482
5m 45s (- 54m 47s) (19000 9%) 0.24786
6m 2s (- 54m 23s) (20000 10%) 0.27040
6m 20s (- 53m 59s) (21000 10%) 0.23830
6m 37s (- 53m 37s) (22000 11%) 0.24397
6m 54s (- 53m 13s) (23000 11%) 0.24418
7m 12s (- 52m 48s) (24000 12%) 0.24693
7m 29s (- 52m 25s) (25000 12%) 0.24866
7m 46s (- 52m 2s) (26000 13%) 0.24034
8m 3s (- 51m 40s) (270

In [0]:
train_epochs(enc, dec, 200000, 0.21074)

0m 28s (- 95m 7s) (1000 0%) 0.24000
0m 46s (- 76m 13s) (2000 1%) 0.23706
1m 3s (- 69m 7s) (3000 1%) 0.23669
1m 20s (- 65m 55s) (4000 2%) 0.23116
1m 37s (- 63m 35s) (5000 2%) 0.23971
1m 55s (- 62m 0s) (6000 3%) 0.21824
2m 12s (- 60m 46s) (7000 3%) 0.22390
2m 29s (- 59m 55s) (8000 4%) 0.22358
2m 47s (- 59m 7s) (9000 4%) 0.22953
3m 4s (- 58m 22s) (10000 5%) 0.20904
3m 21s (- 57m 45s) (11000 5%) 0.21604
3m 38s (- 57m 10s) (12000 6%) 0.22781
3m 56s (- 56m 40s) (13000 6%) 0.23034
4m 13s (- 56m 10s) (14000 7%) 0.24261
4m 31s (- 55m 43s) (15000 7%) 0.22819
4m 48s (- 55m 16s) (16000 8%) 0.21401
5m 5s (- 54m 49s) (17000 8%) 0.22347
5m 23s (- 54m 25s) (18000 9%) 0.23659
5m 40s (- 54m 2s) (19000 9%) 0.21507
5m 57s (- 53m 38s) (20000 10%) 0.22465
6m 14s (- 53m 15s) (21000 10%) 0.21073
6m 32s (- 52m 53s) (22000 11%) 0.22336
6m 49s (- 52m 32s) (23000 11%) 0.23618
7m 6s (- 52m 9s) (24000 12%) 0.21663
7m 24s (- 51m 48s) (25000 12%) 0.22035
7m 41s (- 51m 29s) (26000 13%) 0.23421
7m 58s (- 51m 8s) (27000

In [0]:
train_epochs(enc, dec, 200000, 0.20235)

0m 28s (- 95m 52s) (1000 0%) 0.20852
0m 46s (- 76m 21s) (2000 1%) 0.21270
1m 3s (- 69m 32s) (3000 1%) 0.21772
1m 20s (- 65m 57s) (4000 2%) 0.21761
1m 37s (- 63m 41s) (5000 2%) 0.21350
1m 55s (- 62m 8s) (6000 3%) 0.21420
2m 12s (- 61m 3s) (7000 3%) 0.20711
2m 30s (- 60m 11s) (8000 4%) 0.22467
2m 47s (- 59m 20s) (9000 4%) 0.21581
3m 5s (- 58m 43s) (10000 5%) 0.21270
3m 22s (- 58m 5s) (11000 5%) 0.20820
3m 40s (- 57m 31s) (12000 6%) 0.20885
3m 57s (- 56m 58s) (13000 6%) 0.20915
4m 14s (- 56m 24s) (14000 7%) 0.20990
4m 32s (- 55m 54s) (15000 7%) 0.21277
4m 49s (- 55m 27s) (16000 8%) 0.21728
5m 6s (- 55m 2s) (17000 8%) 0.21513
5m 23s (- 54m 34s) (18000 9%) 0.22083
5m 41s (- 54m 9s) (19000 9%) 0.21173
5m 58s (- 53m 44s) (20000 10%) 0.20664
6m 15s (- 53m 20s) (21000 10%) 0.21348
6m 32s (- 52m 57s) (22000 11%) 0.21492
6m 49s (- 52m 33s) (23000 11%) 0.20773
7m 6s (- 52m 10s) (24000 12%) 0.19513
7m 24s (- 51m 51s) (25000 12%) 0.20918
7m 42s (- 51m 32s) (26000 13%) 0.20005
7m 59s (- 51m 11s) (270

In [0]:
train_epochs(enc, dec, 200000, 0.18891)

0m 48s (- 160m 0s) (1000 0%) 0.21987
1m 15s (- 124m 43s) (2000 1%) 0.22032
1m 42s (- 112m 30s) (3000 1%) 0.23614
2m 9s (- 106m 9s) (4000 2%) 0.22675
2m 37s (- 102m 6s) (5000 2%) 0.21469
3m 4s (- 99m 22s) (6000 3%) 0.21985
3m 31s (- 97m 10s) (7000 3%) 0.21525
3m 58s (- 95m 16s) (8000 4%) 0.23409
4m 25s (- 93m 49s) (9000 4%) 0.23117
4m 52s (- 92m 33s) (10000 5%) 0.22096
5m 19s (- 91m 26s) (11000 5%) 0.21273
5m 46s (- 90m 27s) (12000 6%) 0.21685
6m 13s (- 89m 32s) (13000 6%) 0.23283
6m 40s (- 88m 39s) (14000 7%) 0.20924
7m 7s (- 87m 56s) (15000 7%) 0.21741
7m 34s (- 87m 9s) (16000 8%) 0.22869
8m 2s (- 86m 28s) (17000 8%) 0.22249
8m 29s (- 85m 53s) (18000 9%) 0.21929
8m 57s (- 85m 16s) (19000 9%) 0.23167
9m 24s (- 84m 38s) (20000 10%) 0.21279
9m 51s (- 84m 0s) (21000 10%) 0.21930
10m 18s (- 83m 24s) (22000 11%) 0.22503
10m 45s (- 82m 48s) (23000 11%) 0.22556
11m 12s (- 82m 13s) (24000 12%) 0.22570
11m 39s (- 81m 38s) (25000 12%) 0.21593
12m 7s (- 81m 6s) (26000 13%) 0.20731
12m 34s (- 80m 

KeyboardInterrupt: ignored

In [0]:
from nltk.translate.bleu_score import corpus_bleu

In [0]:
def compute_bleu(pairs):
    src_lines = []
    trg_lines = []
    
    for i in range(len(pairs)):
      src_lines.append(pairs[i][0])
      trg_lines.append(pairs[i][1])

    out_lines = []
    for src_line in src_lines:
      words, weights = eval(enc, dec, src_line)
      out_lines.append(words[:-1])
    
    trg_lines = [[trg_line.split(' ')] for trg_line in trg_lines]
    return corpus_bleu(trg_lines, out_lines) * 100

In [0]:
def get_pairs(count):
  test_pairs = []
  for i in range(count):
    test_pairs.append(random.choice(PAIRS))
  return test_pairs

In [0]:
for i in range(10):
  print(compute_bleu(get_pairs(1000)))

83.6415956472907
84.63266043837938
84.00313190608298
82.97844831994789
83.63998046574872
83.10524084461345
83.62374148487321
84.25132317398123
84.33963227725596
82.77122496734965


In [0]:
eval(enc, dec, 'ich werde ein neues auto kaufen')

(['i', 'm', 'going', 'to', 'buy', 'a', 'new', 'car', '.', '<eos>'],
 tensor([[1.0000e+00, 1.8351e-09, 1.7801e-06, 1.3274e-06, 8.2737e-11, 1.0232e-12,
          3.8055e-17, 6.7400e-12, 8.7739e-10, 6.3965e-10],
         [3.5904e-09, 2.2416e-08, 1.0000e+00, 1.9475e-07, 5.8276e-10, 3.6841e-20,
          5.3598e-10, 9.1207e-14, 1.9272e-10, 3.9855e-10],
         [1.5883e-10, 6.0547e-10, 4.1170e-08, 1.0000e+00, 3.2485e-10, 7.9030e-21,
          7.3822e-19, 9.4383e-09, 7.0199e-11, 7.2117e-11],
         [3.1034e-17, 7.6071e-19, 3.8510e-17, 2.4710e-18, 1.0000e+00, 2.4025e-13,
          1.1231e-25, 3.7547e-23, 5.3969e-19, 2.1145e-17],
         [7.0572e-36, 5.5212e-37, 8.9674e-38, 8.6681e-39, 1.9298e-34, 1.0000e+00,
          0.0000e+00, 4.2039e-45, 7.0063e-36, 3.2208e-36],
         [3.9360e-15, 2.9508e-14, 1.5115e-17, 1.6395e-14, 4.5821e-17, 1.6595e-14,
          1.0000e+00, 8.8756e-22, 1.8641e-14, 1.0125e-14],
         [3.5167e-11, 1.2432e-09, 3.2557e-12, 2.4237e-13, 7.5923e-17, 2.0192e-14,
    

In [0]:
hid_dim = 256
num_lay = 2
enc = Encoder(SRC.words_num, hid_dim, num_lay).to(device)
dec = DecoderUPD(hid_dim, TRG.words_num, 0.1, num_lay).to(device)

In [0]:
enc.load_state_dict(torch.load('/content/gdrive/My Drive/Models/deu_eng_enc.pt'))
dec.load_state_dict(torch.load('/content/gdrive/My Drive/Models/deu_eng_dec.pt'))

<All keys matched successfully>

In [0]:
for i in range(8):
  print(compute_bleu(get_pairs(1000)))

75.7212647939405
79.18854391281602
78.75971029851401
76.747394680288
77.47065249204427
77.79768318260368
76.02864709486282
78.98724254023288
