# Attention

### Imports

In [30]:
%load_ext autoreload
%autoreload 2

from src.data.loader import load_multi30k, load_WMT14
from src.data.raw_to_proc import proc_WMT14, proc_multi30k, create_WMT14_samp
from src.models.lstm_rnn import SimpleEncoder,SimpleEncoderVLS,SimpleDecoder
from src.models import fit,translate
from src.models.RNN_search import GRU, AttnDecoder

from src.utils import pad_len_sort_both
from torch.nn.utils.rnn import pack_padded_sequence

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Data preprocessing

##### Accessing data

In [2]:
train_iter, valid_iter, SRC, TRG = load_multi30k(32)
src, trg = 'de','en'

In [3]:
n_words_src = len(SRC.vocab)
n_words_trg = len(TRG.vocab)

In [4]:
pad_src_id = SRC.vocab.stoi['<pad>']
pad_trg_id = TRG.vocab.stoi['<pad>']

In [5]:
trg_sos_id = TRG.vocab.stoi['<SOS>']
trg_eos_id = TRG.vocab.stoi['<EOS>']
src_eos_id = SRC.vocab.stoi['<EOS>']

In [6]:
train_len = len(train_iter)
val_len = len(valid_iter)

### Models

##### Attention

In [68]:
enc = AttnEncoder(n_words = 2, 
                  n_factors = 6, 
                  n_hidden = 5, 
                  n_layers = 4, 
                  bidirectional=False).cuda()

In [69]:
a = torch.ones(10,1, dtype=torch.long).cuda()

In [70]:
enc(a)

(tensor([[[ 0.0553, -0.0892, -0.1495, -0.0712,  0.0548]],
 
         [[ 0.0792, -0.1506, -0.2354, -0.1177,  0.0525]],
 
         [[ 0.0871, -0.1895, -0.2882, -0.1506,  0.0353]],
 
         [[ 0.0879, -0.2132, -0.3215, -0.1744,  0.0172]],
 
         [[ 0.0861, -0.2274, -0.3424, -0.1916,  0.0022]],
 
         [[ 0.0839, -0.2358, -0.3555, -0.2039, -0.0091]],
 
         [[ 0.0820, -0.2408, -0.3637, -0.2125, -0.0171]],
 
         [[ 0.0806, -0.2437, -0.3688, -0.2186, -0.0227]],
 
         [[ 0.0797, -0.2454, -0.3720, -0.2228, -0.0264]],
 
         [[ 0.0791, -0.2463, -0.3740, -0.2257, -0.0289]]],
        device='cuda:0', grad_fn=<CopySlices>),
 tensor([[[ 0.2512,  0.1398,  0.3291,  0.0613, -0.4135]],
 
         [[ 0.0841,  0.0448, -0.1951,  0.2640,  0.3240]],
 
         [[ 0.2383, -0.0241, -0.3763,  0.3492, -0.1408]],
 
         [[ 0.0791, -0.2463, -0.3740, -0.2257, -0.0289]]],
        device='cuda:0', grad_fn=<CopySlices>))

### Training 

##### LSTM RNN

Setting up some training parameters - optimizers, learning rate, loss function, number of epochs

In [15]:
opt_enc = optim.SGD(enc.parameters(),5e-2)
opt_dec = optim.SGD(dec.parameters(),5e-2)
loss_fn = F.nll_loss
epochs = 1

Depending on your choice of data and models you should choose one of those. In order to manipulate length of training you can change the end_train and end_val parameters. I suggest using print_every size 5 times smaller than end_train. It's also good idea to set teacher_forcing to zero in the later stages of training

In [11]:
fit.Multi30k(enc,dec,train_iter,valid_iter,epochs,opt_enc,opt_dec,loss_fn,n_words_trg, trg_sos_id,
             end_train=train_len,end_val=val_len,print_every=int(train_len/5))

1.0% done
2.0% done
3.0% done
4.0% done
5.0% done
6.0% done
7.0% done
8.0% done
9.0% done
10.0% done
11.0% done
12.0% done
13.0% done
14.0% done
15.0% done
16.0% done
17.0% done
18.0% done
19.0% done
20.0% done
Train: 5.350942218092066 
Valid: 5.33647784655381 

21.0% done
22.0% done
23.0% done
24.0% done
25.0% done
26.0% done
27.0% done
28.0% done
29.0% done
30.0% done
31.0% done
32.0% done
33.0% done
34.0% done
35.0% done
36.0% done
37.0% done
38.0% done
39.0% done
40.0% done
Train: 5.316898484907207 
Valid: 5.309218650738869 

41.0% done
42.0% done
43.0% done
44.0% done
45.0% done
46.0% done
47.0% done
48.0% done
49.0% done
50.0% done
51.0% done
52.0% done
53.0% done
54.0% done
55.0% done
56.0% done
57.0% done
58.0% done
59.0% done
60.0% done
Train: 5.200537258116216 
Valid: 5.200820240043324 

61.0% done
62.0% done
63.0% done
64.0% done
65.0% done
66.0% done
67.0% done
68.0% done
69.0% done
70.0% done
71.0% done
72.0% done
73.0% done
74.0% done
75.0% done
76.0% done
77.0% done
78.0

In [21]:
fit.Multi30k_VLS(enc,dec,train_iter,valid_iter,epochs,opt_enc,opt_dec,loss_fn,n_words_trg, trg_sos_id,
                 pad_src_id, pad_trg_id,end_train=int(train_len/5),end_val=int(val_len/5),
                 print_every=int(train_len/(5*5)))

0.6% done
1.1% done
1.7% done
2.2% done
2.8% done
3.3% done
3.9% done
4.4% done
5.0% done
5.5% done
6.1% done
6.6% done
7.2% done
7.7% done
8.3% done
8.8% done
9.4% done
9.9% done
10.5% done
11.0% done
11.6% done
12.2% done
12.7% done
13.3% done
13.8% done
14.4% done
14.9% done
15.5% done
16.0% done
16.6% done
17.1% done
17.7% done
18.2% done
18.8% done
19.3% done
19.9% done
Train: 5.073108215524693 
Valid: 5.1813329735187565 

20.4% done
21.0% done
21.5% done
22.1% done
22.7% done
23.2% done
23.8% done
24.3% done
24.9% done
25.4% done
26.0% done
26.5% done
27.1% done
27.6% done
28.2% done
28.7% done
29.3% done
29.8% done
30.4% done
30.9% done
31.5% done
32.0% done
32.6% done
33.1% done
33.7% done
34.3% done
34.8% done
35.4% done
35.9% done
36.5% done
37.0% done
37.6% done
38.1% done
38.7% done
39.2% done
39.8% done
Train: 5.2609340686990755 
Valid: 5.150072846749817 

40.3% done
40.9% done
41.4% done
42.0% done
42.5% done
43.1% done
43.6% done
44.2% done
44.8% done
45.3% done
45.9% do

In [16]:
fit.WMT14(enc,dec,train_iter,valid_iter,epochs,opt_enc,opt_dec,loss_fn,n_words_trg,trg_sos_id,
          end=int(train_len/10),print_every=int(train_len/(10*5)))

1.0% done
2.0% done
3.0% done
4.0% done
5.0% done
6.0% done
7.0% done
8.0% done
9.0% done
10.0% done
11.0% done
11.9% done
12.9% done
13.9% done
14.9% done
15.9% done
16.9% done
17.9% done
18.9% done
19.9% done
Train: 7.183032462994258 
Valid: 7.198314721385638 

20.9% done
21.9% done
22.9% done
23.9% done
24.9% done
25.9% done
26.9% done
27.9% done
28.9% done
29.9% done
30.9% done
31.9% done
32.9% done
33.9% done
34.9% done
35.8% done
36.8% done
37.8% done
38.8% done
39.8% done
Train: 6.854988818367322 
Valid: 6.9456145738561945 

40.8% done
41.8% done
42.8% done
43.8% done
44.8% done
45.8% done
46.8% done
47.8% done
48.8% done
49.8% done
50.8% done
51.8% done
52.8% done
53.8% done
54.8% done
55.8% done
56.8% done
57.8% done
58.7% done
59.7% done
Train: 7.104311967889468 
Valid: 6.87798385322094 

60.7% done
61.7% done
62.7% done
63.7% done
64.7% done
65.7% done
66.7% done
67.7% done
68.7% done
69.7% done
70.7% done
71.7% done
72.7% done
73.7% done
74.7% done
75.7% done
76.7% done
77.

Fit functions of Multi30k automatically save progress after each epoch, but these functions can be used to manually save / load models.

In [10]:
model_path_Multi30k = 'models/LSTM_RNN/Multi30k/'
model_path_WMT14 = 'models/LSTM_RNN/WMT14/'

Loading / saving Multi30k models

In [11]:
enc.load_state_dict(torch.load(f'{model_path_Multi30k}enc.pt'))
dec.load_state_dict(torch.load(f'{model_path_Multi30k}dec.pt'))

In [None]:
torch.save(enc.state_dict(), f'{model_path_Multi30k}enc.pt')
torch.save(dec.state_dict(), f'{model_path_Multi30k}dec.pt')

Loading / saving WMT14 models

In [18]:
enc.load_state_dict(torch.load(f'{model_path_WMT14}enc.pt'))
dec.load_state_dict(torch.load(f'{model_path_WMT14}dec.pt'))

In [17]:
torch.save(enc.state_dict(), f'{model_path_WMT14}enc.pt')
torch.save(dec.state_dict(), f'{model_path_WMT14}dec.pt')

### Analysis

We won't dig deep into the model's workings / performance. For now, we will only look at examples of the model's translations. This will give us some insight into it's capabilities and we will also be able to better interpret the results given by applying custom metric.

##### LSTM RNN

To get new example, simply reload the line below. 

In [12]:
rand_ex = next(iter(valid_iter))
ex_src,ex_trg = getattr(rand_ex,src),getattr(rand_ex,trg)

These two cells show what we are going to translate and it's human translation.

In [13]:
[SRC.vocab.itos[x] for x in ex_src[:,0]]

['mehrere',
 'menschen',
 'stehen',
 'in',
 'der',
 'dämmerung',
 'in',
 'der',
 'nähe',
 'einiger',
 'bäume',
 '<EOS>']

In [14]:
[TRG.vocab.itos[x] for x in ex_trg[:,0]]

['<SOS>',
 'several',
 'people',
 'are',
 'standing',
 'near',
 'trees',
 'at',
 'dusk',
 '<EOS>',
 '<pad>']

Here we use our model to translate the sentence. Choose appropriate translate function (Multi30k, Multi30k_VLS and WMT14 respectively).

In [42]:
sent_ids = translate.Multi30k(enc,dec,trg_sos_id,trg_eos_id,ex_src)

In [16]:
sent_ids = translate.Multi30k_VLS(enc,dec,trg_sos_id,trg_eos_id,pad_src_id,ex_src[:,0][:,None])

In [22]:
sent_ids = translate.WMT14(enc,dec,trg_sos_id,trg_eos_id,ex_src)

Now we can take a look at the sentence our model created.

In [17]:
[TRG.vocab.itos[x] for x in sent_ids]

['several', 'people', 'are', 'people', 'in', 'near', 'a', 'a', 'a', '<EOS>']