## This notebook works
## I was able to train to 0.084 loss
## Train it again and try to get lowest loss possible

Here is the dataset as a pickle file

In [13]:
# subsetted dataset
# this is in the shared drive in data folder
filtered_dataset_path = '/content/drive/My Drive/Colab Notebooks/NLP/Machine Translation/FT_Files/pairs.pkl'

Paths to WEs

In [14]:
# ft vectors
# this is in shared drive in ft folder
ft_path = '/content/drive/My Drive/Colab Notebooks/NLP/Machine Translation/FT_Files/wiki-news-300d-1M.vec'

#frWac vectors
# this is in shared drive in glove_frwac folder
frWac200d_path = '/content/drive/My Drive/Colab Notebooks/NLP/Machine Translation/FT_Files/cc.fr.300_fasttext_french.vec'

# model weights
# you won't have these till you train the model
encoder_saved_model_weights = '/content/drive/My Drive/Colab Notebooks/NLP/Machine Translation/FT_Files/FTsimple_encoder_200E_0.084Loss.pth'
decoder_saved_model_weights ='/content/drive/My Drive/Colab Notebooks/NLP/Machine Translation/FT_Files/FTsimple_decoder_200E_0.084Loss.pth'

In [15]:
!pip install bcolz



In [16]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline


import itertools 
import os, re, pickle, collections, bcolz, string
import numpy as np
import math
import gensim
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch import optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm_notebook
from gensim.models import KeyedVectors
from keras.preprocessing.sequence import pad_sequences
from sklearn.model_selection import train_test_split

In [17]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


Load and separate sentences

In [18]:
sentences = pickle.load(open(filtered_dataset_path, 'rb'))

In [19]:
en_questions, fr_questions = zip(*sentences)

Define tokenizer.

In [20]:
re_apos = re.compile(r"(\w)'s\b")         # make 's a separate word
re_mw_punc = re.compile(r"(\w[’'])(\w)")  # other ' in a word creates 2 words
re_punc = re.compile("([\"().,;:/_?!—])") # add spaces around punctuation
re_mult_space = re.compile(r"  *")        # replace multiple spaces with just one

def tokenize(sent):
    sent = re_apos.sub(r"\1 's", sent)
    sent = re_mw_punc.sub(r"\1 \2", sent)
    sent = re_punc.sub(r" \1 ", sent).replace('-', ' ')
    sent = re_mult_space.sub(' ', sent)
    return sent.lower().split()

Tokenize english questions.

In [21]:
en_tokens = list(map(tokenize, en_questions))

In [22]:
en_tokens[:4]

[['i', 'm', '.'],
 ['i', 'm', 'ok', '.'],
 ['i', 'm', 'ok', '.'],
 ['i', 'm', 'fat', '.']]

Tokenize french questions.

In [23]:
fr_tokens = list(map(tokenize, fr_questions))

In [24]:
fr_tokens[:4]

[['j', 'ai', 'ans', '.'],
 ['je', 'vais', 'bien', '.'],
 ['ca', 'va', '.'],
 ['je', 'suis', 'gras', '.']]

For each language: 
<br/>- Get vocabulary counter.
<br/>- Get vocabulary.
<br/>- Get dictionary that maps each word to an index.
<br/>- Transform tokens to their corresponding ids.

In [25]:
PAD = 0; SOS = 1

def tokens2ids(sentences):
    vocab_counter = collections.Counter(word for sent in sentences for word in sent)
    vocab = sorted(vocab_counter, key=vocab_counter.get, reverse=True)
    vocab.insert(PAD, '<PAD>')
    vocab.insert(SOS, '<SOS')
    w2id = {word:i for i, word in enumerate(vocab)}
    ids = [[w2id[word] for word in sent] for sent in sentences]
    return vocab_counter, vocab, w2id, ids

In [26]:
en_vocab_counter, en_vocab, en_w2id, en_ids = tokens2ids(en_tokens)
fr_vocab_counter, fr_vocab, fr_w2id, fr_ids = tokens2ids(fr_tokens)

In [27]:
len(en_vocab), len(fr_vocab)

(2803, 4345)

## Word vectors

FT English word vectors

In [28]:
import gensim.models.wrappers.fasttext
model = gensim.models.KeyedVectors.load_word2vec_format(ft_path, binary=False, encoding='utf8')
ft_model = model.wv
del model

  This is separate from the ipykernel package so we can avoid doing imports until


In [29]:
ft_model.most_similar("dog") 

  if np.issubdtype(vec.dtype, np.int):


[('dogs', 0.856066107749939),
 ('puppy', 0.7839491963386536),
 ('Dog', 0.7767305374145508),
 ('canine', 0.7631831169128418),
 ('Mixed-breed', 0.7280029058456421),
 ('pet', 0.7213231325149536),
 ('terrier', 0.7139902114868164),
 ('labrador', 0.7112174034118652),
 ('puppies', 0.6918587684631348),
 ('non-dog', 0.6915143728256226)]

In [30]:
ft_model.vector_size

300

#### French word vectors

In [31]:
fr_w2v = KeyedVectors.load_word2vec_format(frWac200d_path)

In [32]:
fr_w2v.vector_size

300

Now we need to create embeddings matrices for english and french words of training corpus. If a word appears on GloVe or frWac then we load its pre-trained vector, otherwise we create a random vector.

In [33]:
def create_embedding(w2v, target_vocab, emb_dim):
    emb_len = len(target_vocab)
    embedding = np.zeros((emb_len, emb_dim))
    words_found = 0
    
    for i, w in enumerate(target_vocab):
        try: 
            embedding[i] = w2v[w]
            words_found += 1
        except KeyError:
            embedding[i] = np.random.normal(scale=0.6, size=(emb_dim, ))
    
    return embedding, words_found

In [34]:
en_emb, words_found = create_embedding(ft_model, en_vocab, 300)

In [35]:
en_emb.shape, words_found

((2803, 300), 2796)

In [36]:
fr_emb, words_found = create_embedding(fr_w2v, fr_vocab, 300)

In [37]:
fr_emb.shape, words_found

((4345, 300), 4101)

## Data preparation

Min, max and mean length of english sentences.

In [38]:
len_en_ids = [len(sentence) for sentence in en_ids]
min(len_en_ids), max(len_en_ids), np.mean(len_en_ids)

(3, 9, 6.035569393338994)

Min, max and mean length of french sentences.

In [39]:
len_fr_ids = [len(sentence) for sentence in fr_ids]
min(len_fr_ids), max(len_fr_ids), np.mean(len_fr_ids)

(2, 9, 6.196056231719973)

We set 30 as max length. In this example, we could use 10 since the max length of every sentence is 9. 30 allows us more flexibility in the future for other datasets.

In [40]:
maxlen = 30

In [41]:
en_train = pad_sequences(en_ids, maxlen, 'int64', 'post', 'post')
fr_train = pad_sequences(fr_ids, maxlen, 'int64', 'post', 'post')

In [42]:
fr_train.shape, en_train.shape, en_emb.shape, fr_emb.shape

((10599, 30), (10599, 30), (2803, 300), (4345, 300))

In [43]:
en_train[0]

array([3, 6, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0])

In [44]:
fr_train[0]

array([ 27,  30, 115,   2,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0])

## Model

### Seq2Seq

<br/>Encoder:
<br/>
<br/>Inputs: french words sentence and initial hidden state (all zeros).
<br/>1- Look up at an embedding layer to get word vector of each word of the input sentence. 
<br/>2- Pass the word vectors sequence through a RNN.
<br/>3- Return hidden state of last timestep (vector representation of input sentence).
<br/>
<br/>Decoder:
<br/>
<br/>Inputs: 'SOS' word (i.e. start of sentence, is always the first word) and vector representation created by encoder.
<br/>1- Load vector representation as initial hidden state.
<br/>2- Look up at an embedding layer to get word vector of 'SOS'.
<br/>3- Pass the word vector through a RNN.
<br/>4- Generate prediction of next word.
<br/>5- Repeat 2, 3 and 4 using always the previous translated word until finish sentence translation.

In [45]:
torch.cuda.is_available()

True

In [46]:
fr_emb_t = torch.FloatTensor(fr_emb).cuda()
en_emb_t = torch.FloatTensor(en_emb).cuda()

In [47]:
def long_t(arr):
    return Variable(torch.LongTensor(arr)).cuda()

Load pre-trained vectors into an embedding layer.

In [48]:
def create_emb(emb_matrix, non_trainable=False):
    num_embeddings, embedding_dim = emb_matrix.size()
    emb = nn.Embedding(num_embeddings, embedding_dim)
    emb.load_state_dict({'weight': emb_matrix})
    if non_trainable:
        #emb.weight.requires_grad = False
        for param in emb.parameters():
            param.requires_grad = False
    return emb, num_embeddings, embedding_dim

Encoding layer

In [49]:
class EncoderRNN(nn.Module):
    def __init__(self, emb_matrix, hidden_size, num_layers=2):
        super(EncoderRNN, self).__init__()
        # Create embedding layer.
        self.embedding, num_embeddings, embedding_dim = create_emb(emb_matrix, True)
        # Create RNN.
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.gru = nn.GRU(embedding_dim, hidden_size, num_layers, batch_first=True)
        
    def forward(self, inp, hidden):
        return self.gru(self.embedding(inp), hidden)
    
    def init_hidden(self, batch_size):
        return Variable(torch.zeros(self.num_layers, batch_size, self.hidden_size))

Decoding layer

In [50]:
class DecoderRNN(nn.Module):
    def __init__(self, emb_matrix, hidden_size, num_layers=2):
        super(DecoderRNN, self).__init__()
        # Create embedding layer.
        self.emb, num_embeddings, embedding_dim = create_emb(emb_matrix)
        # Create RNN.
        self.gru = nn.GRU(embedding_dim, hidden_size, num_layers, batch_first=True, bidirectional=False)
        self.out = nn.Linear(hidden_size, num_embeddings)
        
    def forward(self, inp, hidden):
        emb = self.emb(inp).unsqueeze(1)
        res, hidden = self.gru(emb, hidden)
        # Softmax layer, generates probs for each word vector of the embedding layer.
        res = F.log_softmax(self.out(res[:,0]), dim=1)
        return res, hidden

In [51]:
def encode(inp, encoder):
    batch_size, input_length = inp.size()
    hidden = encoder.init_hidden(batch_size).cuda()
    enc_outputs, hidden = encoder.forward(inp, hidden)
    return long_t([SOS]*batch_size), enc_outputs, hidden

Training

We use teaching-force as training approach. Rather than pass to decoder the previous translated word, we pass the real target.

In [52]:
def fit(encoder, decoder, train_dl, n_epochs, enc_optim, dec_optim, criterion):
    bar = tqdm_notebook(total=n_epochs)
    loss_tracker = []
    avg_mom = 0.98
    avg_loss = 0.
    batch_num = 0

    for epoch in range(n_epochs):
        bar2 = tqdm_notebook(total=train_dl.dataset.shape[0] / train_dl.batch_size, desc=f'Epoch {epoch}', leave=False)
        for i, batch in enumerate(train_dl):
            batch_num += 1
            loss = 0
            
            inp = long_t(batch[:, :maxlen])

            targ = long_t(batch[:, maxlen:])
       
            # Encoder creates a vector representation of input french sentence. 
            decoder_input, encoder_output, hidden = encode(inp, encoder)

            # Zero the gradients before running the backward pass.
            enc_optim.zero_grad()
            dec_optim.zero_grad()
            
            targ_length = targ.size()[1]
     
            for di in range(targ_length):
                decoder_output, hidden = decoder(decoder_input, hidden)
                # Teacher forcing: the decoder receives as input the real target instead of predicted word.
                decoder_input = targ[:, di]
                
                # Compute loss.
                loss += criterion(decoder_output, decoder_input)
          
            # Backward pass: compute gradient of the loss with respect to all the learnable parameters of the model.
            loss.backward()

            # Calling the step function on an Optimizer makes an update to its parameters.
            enc_optim.step()
            dec_optim.step()
           
            # Exponentially weighted moving average, to make the reported loss more stable.
            avg_loss = avg_loss * avg_mom + (loss.data.item() / targ_length)  * (1-avg_mom)
            
            # Compute bias-corrected loss estimate.
            debias_loss = avg_loss / (1 - avg_mom**batch_num)
            
            bar2.update()
            
        loss_tracker.append(np.round([epoch, debias_loss], 6))
        print(np.round([epoch, debias_loss], 6))    
        bar.update()
    return loss_tracker

In [53]:
def req_grad_params(o):
    return (param for param in o.parameters() if param.requires_grad)

Initialize models and set parameters

In [54]:
hidden_size = 64 #128
encoder = EncoderRNN(fr_emb_t, hidden_size).cuda()
decoder = DecoderRNN(en_emb_t, hidden_size).cuda()

In [55]:
lr = 0.0001

In [56]:
enc_opt = optim.Adam(req_grad_params(encoder), lr=lr)
dec_opt = optim.Adam(decoder.parameters(), lr=lr)
criterion = nn.NLLLoss().cuda()

In [57]:
batch_size = 64
#64

Create a dataloader

In [None]:
train_dl = DataLoader(np.concatenate([fr_train, en_train], 1), batch_size, shuffle=True, num_workers=1)

Train the model

In [None]:
loss_tracker = fit(encoder, decoder, train_dl, 200, enc_opt, dec_opt, criterion)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  


HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  if __name__ == '__main__':


HBox(children=(FloatProgress(value=0.0, description='Epoch 0', max=165.609375, style=ProgressStyle(description…

[0.       0.141475]


HBox(children=(FloatProgress(value=0.0, description='Epoch 1', max=165.609375, style=ProgressStyle(description…

[1.       0.134318]


HBox(children=(FloatProgress(value=0.0, description='Epoch 2', max=165.609375, style=ProgressStyle(description…

[2.      0.12782]


HBox(children=(FloatProgress(value=0.0, description='Epoch 3', max=165.609375, style=ProgressStyle(description…

[3.       0.123532]


HBox(children=(FloatProgress(value=0.0, description='Epoch 4', max=165.609375, style=ProgressStyle(description…

[4.       0.120557]


HBox(children=(FloatProgress(value=0.0, description='Epoch 5', max=165.609375, style=ProgressStyle(description…

[5.       0.120755]


HBox(children=(FloatProgress(value=0.0, description='Epoch 6', max=165.609375, style=ProgressStyle(description…

[6.       0.119089]


HBox(children=(FloatProgress(value=0.0, description='Epoch 7', max=165.609375, style=ProgressStyle(description…

[7.       0.116263]


HBox(children=(FloatProgress(value=0.0, description='Epoch 8', max=165.609375, style=ProgressStyle(description…

[8.       0.115754]


HBox(children=(FloatProgress(value=0.0, description='Epoch 9', max=165.609375, style=ProgressStyle(description…

[9.       0.114742]


HBox(children=(FloatProgress(value=0.0, description='Epoch 10', max=165.609375, style=ProgressStyle(descriptio…

[10.        0.113279]


HBox(children=(FloatProgress(value=0.0, description='Epoch 11', max=165.609375, style=ProgressStyle(descriptio…

[11.        0.112437]


HBox(children=(FloatProgress(value=0.0, description='Epoch 12', max=165.609375, style=ProgressStyle(descriptio…

[12.        0.109533]


HBox(children=(FloatProgress(value=0.0, description='Epoch 13', max=165.609375, style=ProgressStyle(descriptio…

[13.        0.110783]


HBox(children=(FloatProgress(value=0.0, description='Epoch 14', max=165.609375, style=ProgressStyle(descriptio…

[14.        0.108534]


HBox(children=(FloatProgress(value=0.0, description='Epoch 15', max=165.609375, style=ProgressStyle(descriptio…

[15.        0.110208]


HBox(children=(FloatProgress(value=0.0, description='Epoch 16', max=165.609375, style=ProgressStyle(descriptio…

[16.        0.108641]


HBox(children=(FloatProgress(value=0.0, description='Epoch 17', max=165.609375, style=ProgressStyle(descriptio…

[17.        0.108041]


HBox(children=(FloatProgress(value=0.0, description='Epoch 18', max=165.609375, style=ProgressStyle(descriptio…

[18.        0.109272]


HBox(children=(FloatProgress(value=0.0, description='Epoch 19', max=165.609375, style=ProgressStyle(descriptio…

[19.        0.107649]


HBox(children=(FloatProgress(value=0.0, description='Epoch 20', max=165.609375, style=ProgressStyle(descriptio…

[20.       0.10606]


HBox(children=(FloatProgress(value=0.0, description='Epoch 21', max=165.609375, style=ProgressStyle(descriptio…

[21.        0.106915]


HBox(children=(FloatProgress(value=0.0, description='Epoch 22', max=165.609375, style=ProgressStyle(descriptio…

[22.        0.106317]


HBox(children=(FloatProgress(value=0.0, description='Epoch 23', max=165.609375, style=ProgressStyle(descriptio…

[23.        0.104948]


HBox(children=(FloatProgress(value=0.0, description='Epoch 24', max=165.609375, style=ProgressStyle(descriptio…

[24.        0.104234]


HBox(children=(FloatProgress(value=0.0, description='Epoch 25', max=165.609375, style=ProgressStyle(descriptio…

[25.        0.104382]


HBox(children=(FloatProgress(value=0.0, description='Epoch 26', max=165.609375, style=ProgressStyle(descriptio…

[26.        0.103635]


HBox(children=(FloatProgress(value=0.0, description='Epoch 27', max=165.609375, style=ProgressStyle(descriptio…

[27.        0.102623]


HBox(children=(FloatProgress(value=0.0, description='Epoch 28', max=165.609375, style=ProgressStyle(descriptio…

[28.        0.102079]


HBox(children=(FloatProgress(value=0.0, description='Epoch 29', max=165.609375, style=ProgressStyle(descriptio…

[29.        0.104037]


HBox(children=(FloatProgress(value=0.0, description='Epoch 30', max=165.609375, style=ProgressStyle(descriptio…

[30.        0.102396]


HBox(children=(FloatProgress(value=0.0, description='Epoch 31', max=165.609375, style=ProgressStyle(descriptio…

[31.        0.102229]


HBox(children=(FloatProgress(value=0.0, description='Epoch 32', max=165.609375, style=ProgressStyle(descriptio…

[32.        0.101734]


HBox(children=(FloatProgress(value=0.0, description='Epoch 33', max=165.609375, style=ProgressStyle(descriptio…

[33.        0.102387]


HBox(children=(FloatProgress(value=0.0, description='Epoch 34', max=165.609375, style=ProgressStyle(descriptio…

[34.        0.100728]


HBox(children=(FloatProgress(value=0.0, description='Epoch 35', max=165.609375, style=ProgressStyle(descriptio…

[35.        0.101145]


HBox(children=(FloatProgress(value=0.0, description='Epoch 36', max=165.609375, style=ProgressStyle(descriptio…

[36.        0.099129]


HBox(children=(FloatProgress(value=0.0, description='Epoch 37', max=165.609375, style=ProgressStyle(descriptio…

[37.        0.100513]


HBox(children=(FloatProgress(value=0.0, description='Epoch 38', max=165.609375, style=ProgressStyle(descriptio…

[38.        0.099529]


HBox(children=(FloatProgress(value=0.0, description='Epoch 39', max=165.609375, style=ProgressStyle(descriptio…

[39.      0.0996]


HBox(children=(FloatProgress(value=0.0, description='Epoch 40', max=165.609375, style=ProgressStyle(descriptio…

[40.        0.099632]


HBox(children=(FloatProgress(value=0.0, description='Epoch 41', max=165.609375, style=ProgressStyle(descriptio…

[41.        0.100242]


HBox(children=(FloatProgress(value=0.0, description='Epoch 42', max=165.609375, style=ProgressStyle(descriptio…

[42.        0.098712]


HBox(children=(FloatProgress(value=0.0, description='Epoch 43', max=165.609375, style=ProgressStyle(descriptio…

[43.        0.098912]


HBox(children=(FloatProgress(value=0.0, description='Epoch 44', max=165.609375, style=ProgressStyle(descriptio…

[44.        0.097461]


HBox(children=(FloatProgress(value=0.0, description='Epoch 45', max=165.609375, style=ProgressStyle(descriptio…

[45.        0.099039]


HBox(children=(FloatProgress(value=0.0, description='Epoch 46', max=165.609375, style=ProgressStyle(descriptio…

[46.        0.098319]


HBox(children=(FloatProgress(value=0.0, description='Epoch 47', max=165.609375, style=ProgressStyle(descriptio…

[47.        0.098164]


HBox(children=(FloatProgress(value=0.0, description='Epoch 48', max=165.609375, style=ProgressStyle(descriptio…

[48.        0.098518]


HBox(children=(FloatProgress(value=0.0, description='Epoch 49', max=165.609375, style=ProgressStyle(descriptio…

[49.        0.097696]


HBox(children=(FloatProgress(value=0.0, description='Epoch 50', max=165.609375, style=ProgressStyle(descriptio…

[50.        0.096954]


HBox(children=(FloatProgress(value=0.0, description='Epoch 51', max=165.609375, style=ProgressStyle(descriptio…

[51.        0.096959]


HBox(children=(FloatProgress(value=0.0, description='Epoch 52', max=165.609375, style=ProgressStyle(descriptio…

[52.        0.095987]


HBox(children=(FloatProgress(value=0.0, description='Epoch 53', max=165.609375, style=ProgressStyle(descriptio…

[53.        0.096655]


HBox(children=(FloatProgress(value=0.0, description='Epoch 54', max=165.609375, style=ProgressStyle(descriptio…

[54.        0.096946]


HBox(children=(FloatProgress(value=0.0, description='Epoch 55', max=165.609375, style=ProgressStyle(descriptio…

[55.        0.095344]


HBox(children=(FloatProgress(value=0.0, description='Epoch 56', max=165.609375, style=ProgressStyle(descriptio…

[56.        0.095724]


HBox(children=(FloatProgress(value=0.0, description='Epoch 57', max=165.609375, style=ProgressStyle(descriptio…

[57.        0.097153]


HBox(children=(FloatProgress(value=0.0, description='Epoch 58', max=165.609375, style=ProgressStyle(descriptio…

[58.        0.095559]


HBox(children=(FloatProgress(value=0.0, description='Epoch 59', max=165.609375, style=ProgressStyle(descriptio…

[59.        0.097549]


HBox(children=(FloatProgress(value=0.0, description='Epoch 60', max=165.609375, style=ProgressStyle(descriptio…

[60.        0.095664]


HBox(children=(FloatProgress(value=0.0, description='Epoch 61', max=165.609375, style=ProgressStyle(descriptio…

[61.        0.094568]


HBox(children=(FloatProgress(value=0.0, description='Epoch 62', max=165.609375, style=ProgressStyle(descriptio…

[62.        0.093778]


HBox(children=(FloatProgress(value=0.0, description='Epoch 63', max=165.609375, style=ProgressStyle(descriptio…

[63.        0.095372]


HBox(children=(FloatProgress(value=0.0, description='Epoch 64', max=165.609375, style=ProgressStyle(descriptio…

[64.       0.09556]


HBox(children=(FloatProgress(value=0.0, description='Epoch 65', max=165.609375, style=ProgressStyle(descriptio…

[65.        0.095073]


HBox(children=(FloatProgress(value=0.0, description='Epoch 66', max=165.609375, style=ProgressStyle(descriptio…

[66.        0.094508]


HBox(children=(FloatProgress(value=0.0, description='Epoch 67', max=165.609375, style=ProgressStyle(descriptio…

[67.        0.094441]


HBox(children=(FloatProgress(value=0.0, description='Epoch 68', max=165.609375, style=ProgressStyle(descriptio…

[68.        0.093915]


HBox(children=(FloatProgress(value=0.0, description='Epoch 69', max=165.609375, style=ProgressStyle(descriptio…

[69.        0.095138]


HBox(children=(FloatProgress(value=0.0, description='Epoch 70', max=165.609375, style=ProgressStyle(descriptio…

[70.        0.094037]


HBox(children=(FloatProgress(value=0.0, description='Epoch 71', max=165.609375, style=ProgressStyle(descriptio…

[71.        0.095195]


HBox(children=(FloatProgress(value=0.0, description='Epoch 72', max=165.609375, style=ProgressStyle(descriptio…

[72.        0.093393]


HBox(children=(FloatProgress(value=0.0, description='Epoch 73', max=165.609375, style=ProgressStyle(descriptio…

[73.        0.093173]


HBox(children=(FloatProgress(value=0.0, description='Epoch 74', max=165.609375, style=ProgressStyle(descriptio…

[74.        0.093482]


HBox(children=(FloatProgress(value=0.0, description='Epoch 75', max=165.609375, style=ProgressStyle(descriptio…

[75.        0.092814]


HBox(children=(FloatProgress(value=0.0, description='Epoch 76', max=165.609375, style=ProgressStyle(descriptio…

[76.        0.093702]


HBox(children=(FloatProgress(value=0.0, description='Epoch 77', max=165.609375, style=ProgressStyle(descriptio…

[77.        0.093313]


HBox(children=(FloatProgress(value=0.0, description='Epoch 78', max=165.609375, style=ProgressStyle(descriptio…

[78.        0.092561]


HBox(children=(FloatProgress(value=0.0, description='Epoch 79', max=165.609375, style=ProgressStyle(descriptio…

[79.        0.093048]


HBox(children=(FloatProgress(value=0.0, description='Epoch 80', max=165.609375, style=ProgressStyle(descriptio…

[80.        0.093287]


HBox(children=(FloatProgress(value=0.0, description='Epoch 81', max=165.609375, style=ProgressStyle(descriptio…

[81.        0.092535]


HBox(children=(FloatProgress(value=0.0, description='Epoch 82', max=165.609375, style=ProgressStyle(descriptio…

[82.        0.092511]


HBox(children=(FloatProgress(value=0.0, description='Epoch 83', max=165.609375, style=ProgressStyle(descriptio…

[83.        0.092179]


HBox(children=(FloatProgress(value=0.0, description='Epoch 84', max=165.609375, style=ProgressStyle(descriptio…

[84.        0.091353]


HBox(children=(FloatProgress(value=0.0, description='Epoch 85', max=165.609375, style=ProgressStyle(descriptio…

[85.        0.092472]


HBox(children=(FloatProgress(value=0.0, description='Epoch 86', max=165.609375, style=ProgressStyle(descriptio…

[86.        0.092527]


HBox(children=(FloatProgress(value=0.0, description='Epoch 87', max=165.609375, style=ProgressStyle(descriptio…

[87.        0.092357]


HBox(children=(FloatProgress(value=0.0, description='Epoch 88', max=165.609375, style=ProgressStyle(descriptio…

[88.        0.090943]


HBox(children=(FloatProgress(value=0.0, description='Epoch 89', max=165.609375, style=ProgressStyle(descriptio…

[89.        0.092188]


HBox(children=(FloatProgress(value=0.0, description='Epoch 90', max=165.609375, style=ProgressStyle(descriptio…

[90.        0.091456]


HBox(children=(FloatProgress(value=0.0, description='Epoch 91', max=165.609375, style=ProgressStyle(descriptio…

[91.        0.091679]


HBox(children=(FloatProgress(value=0.0, description='Epoch 92', max=165.609375, style=ProgressStyle(descriptio…

[9.2000e+01 9.1467e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 93', max=165.609375, style=ProgressStyle(descriptio…

[9.3000e+01 9.0951e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 94', max=165.609375, style=ProgressStyle(descriptio…

[9.4000e+01 9.0822e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 95', max=165.609375, style=ProgressStyle(descriptio…

[9.5000e+01 9.2107e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 96', max=165.609375, style=ProgressStyle(descriptio…

[9.6000e+01 8.9677e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 97', max=165.609375, style=ProgressStyle(descriptio…

[9.7000e+01 9.1334e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 98', max=165.609375, style=ProgressStyle(descriptio…

[9.8000e+01 9.0677e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 99', max=165.609375, style=ProgressStyle(descriptio…

[9.9000e+01 9.0994e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 100', max=165.609375, style=ProgressStyle(descripti…

[1.0000e+02 9.0657e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 101', max=165.609375, style=ProgressStyle(descripti…

[1.0100e+02 9.1963e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 102', max=165.609375, style=ProgressStyle(descripti…

[1.0200e+02 9.0665e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 103', max=165.609375, style=ProgressStyle(descripti…

[1.0300e+02 9.1143e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 104', max=165.609375, style=ProgressStyle(descripti…

[1.0400e+02 8.9756e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 105', max=165.609375, style=ProgressStyle(descripti…

[1.050e+02 9.068e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 106', max=165.609375, style=ProgressStyle(descripti…

[1.0600e+02 9.0173e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 107', max=165.609375, style=ProgressStyle(descripti…

[1.070e+02 8.956e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 108', max=165.609375, style=ProgressStyle(descripti…

[1.0800e+02 8.9942e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 109', max=165.609375, style=ProgressStyle(descripti…

[1.0900e+02 8.9842e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 110', max=165.609375, style=ProgressStyle(descripti…

[1.1000e+02 8.9457e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 111', max=165.609375, style=ProgressStyle(descripti…

[1.110e+02 8.937e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 112', max=165.609375, style=ProgressStyle(descripti…

[1.1200e+02 8.8616e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 113', max=165.609375, style=ProgressStyle(descripti…

[1.1300e+02 8.8752e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 114', max=165.609375, style=ProgressStyle(descripti…

[1.1400e+02 8.9291e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 115', max=165.609375, style=ProgressStyle(descripti…

[1.1500e+02 8.8259e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 116', max=165.609375, style=ProgressStyle(descripti…

[1.1600e+02 8.9563e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 117', max=165.609375, style=ProgressStyle(descripti…

[1.170e+02 8.822e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 118', max=165.609375, style=ProgressStyle(descripti…

[1.1800e+02 9.0395e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 119', max=165.609375, style=ProgressStyle(descripti…

[1.1900e+02 8.9077e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 120', max=165.609375, style=ProgressStyle(descripti…

[1.2000e+02 8.8978e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 121', max=165.609375, style=ProgressStyle(descripti…

[1.210e+02 8.827e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 122', max=165.609375, style=ProgressStyle(descripti…

[1.2200e+02 8.8142e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 123', max=165.609375, style=ProgressStyle(descripti…

[1.2300e+02 8.8246e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 124', max=165.609375, style=ProgressStyle(descripti…

[1.2400e+02 8.9386e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 125', max=165.609375, style=ProgressStyle(descripti…

[1.2500e+02 8.8721e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 126', max=165.609375, style=ProgressStyle(descripti…

[1.2600e+02 8.7712e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 127', max=165.609375, style=ProgressStyle(descripti…

[1.2700e+02 8.7896e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 128', max=165.609375, style=ProgressStyle(descripti…

[1.2800e+02 8.8779e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 129', max=165.609375, style=ProgressStyle(descripti…

[1.2900e+02 8.8505e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 130', max=165.609375, style=ProgressStyle(descripti…

[1.3000e+02 8.8488e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 131', max=165.609375, style=ProgressStyle(descripti…

[1.3100e+02 8.7769e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 132', max=165.609375, style=ProgressStyle(descripti…

[1.3200e+02 8.7788e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 133', max=165.609375, style=ProgressStyle(descripti…

[1.3300e+02 8.6447e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 134', max=165.609375, style=ProgressStyle(descripti…

[1.3400e+02 8.6805e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 135', max=165.609375, style=ProgressStyle(descripti…

[1.3500e+02 8.8002e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 136', max=165.609375, style=ProgressStyle(descripti…

[1.360e+02 8.693e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 137', max=165.609375, style=ProgressStyle(descripti…

[1.3700e+02 8.7657e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 138', max=165.609375, style=ProgressStyle(descripti…

[1.3800e+02 8.7083e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 139', max=165.609375, style=ProgressStyle(descripti…

[1.3900e+02 8.7683e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 140', max=165.609375, style=ProgressStyle(descripti…

[1.400e+02 8.697e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 141', max=165.609375, style=ProgressStyle(descripti…

[1.4100e+02 8.6339e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 142', max=165.609375, style=ProgressStyle(descripti…

[1.4200e+02 8.7281e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 143', max=165.609375, style=ProgressStyle(descripti…

[1.4300e+02 8.6517e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 144', max=165.609375, style=ProgressStyle(descripti…

[1.4400e+02 8.7315e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 145', max=165.609375, style=ProgressStyle(descripti…

[1.4500e+02 8.6419e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 146', max=165.609375, style=ProgressStyle(descripti…

[1.4600e+02 8.6759e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 147', max=165.609375, style=ProgressStyle(descripti…

[1.470e+02 8.722e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 148', max=165.609375, style=ProgressStyle(descripti…

[1.4800e+02 8.6335e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 149', max=165.609375, style=ProgressStyle(descripti…

[1.4900e+02 8.7088e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 150', max=165.609375, style=ProgressStyle(descripti…

[1.5000e+02 8.5404e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 151', max=165.609375, style=ProgressStyle(descripti…

[1.5100e+02 8.6678e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 152', max=165.609375, style=ProgressStyle(descripti…

[1.520e+02 8.634e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 153', max=165.609375, style=ProgressStyle(descripti…

[1.5300e+02 8.6177e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 154', max=165.609375, style=ProgressStyle(descripti…

[1.5400e+02 8.5919e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 155', max=165.609375, style=ProgressStyle(descripti…

[1.5500e+02 8.6453e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 156', max=165.609375, style=ProgressStyle(descripti…

[1.5600e+02 8.5936e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 157', max=165.609375, style=ProgressStyle(descripti…

[1.5700e+02 8.6307e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 158', max=165.609375, style=ProgressStyle(descripti…

[1.5800e+02 8.4698e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 159', max=165.609375, style=ProgressStyle(descripti…

[1.5900e+02 8.5216e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 160', max=165.609375, style=ProgressStyle(descripti…

[1.6000e+02 8.5054e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 161', max=165.609375, style=ProgressStyle(descripti…

[1.6100e+02 8.4874e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 162', max=165.609375, style=ProgressStyle(descripti…

[1.620e+02 8.588e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 163', max=165.609375, style=ProgressStyle(descripti…

[1.6300e+02 8.5172e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 164', max=165.609375, style=ProgressStyle(descripti…

[1.6400e+02 8.5559e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 165', max=165.609375, style=ProgressStyle(descripti…

[1.6500e+02 8.5673e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 166', max=165.609375, style=ProgressStyle(descripti…

[1.6600e+02 8.4648e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 167', max=165.609375, style=ProgressStyle(descripti…

[1.6700e+02 8.5075e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 168', max=165.609375, style=ProgressStyle(descripti…

[1.6800e+02 8.4257e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 169', max=165.609375, style=ProgressStyle(descripti…

[1.6900e+02 8.4318e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 170', max=165.609375, style=ProgressStyle(descripti…

[1.7000e+02 8.5188e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 171', max=165.609375, style=ProgressStyle(descripti…

[1.7100e+02 8.4224e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 172', max=165.609375, style=ProgressStyle(descripti…

[1.7200e+02 8.4111e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 173', max=165.609375, style=ProgressStyle(descripti…

[1.7300e+02 8.4681e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 174', max=165.609375, style=ProgressStyle(descripti…

[1.7400e+02 8.4267e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 175', max=165.609375, style=ProgressStyle(descripti…

[1.7500e+02 8.4084e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 176', max=165.609375, style=ProgressStyle(descripti…

[1.7600e+02 8.4578e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 177', max=165.609375, style=ProgressStyle(descripti…

[1.7700e+02 8.3755e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 178', max=165.609375, style=ProgressStyle(descripti…

[1.780e+02 8.418e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 179', max=165.609375, style=ProgressStyle(descripti…

[1.7900e+02 8.3956e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 180', max=165.609375, style=ProgressStyle(descripti…

[1.8000e+02 8.3317e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 181', max=165.609375, style=ProgressStyle(descripti…

[1.810e+02 8.395e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 182', max=165.609375, style=ProgressStyle(descripti…

[1.8200e+02 8.3618e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 183', max=165.609375, style=ProgressStyle(descripti…

[1.8300e+02 8.4796e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 184', max=165.609375, style=ProgressStyle(descripti…

[1.8400e+02 8.4286e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 185', max=165.609375, style=ProgressStyle(descripti…

[1.8500e+02 8.3257e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 186', max=165.609375, style=ProgressStyle(descripti…

[1.8600e+02 8.3634e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 187', max=165.609375, style=ProgressStyle(descripti…

[1.8700e+02 8.3491e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 188', max=165.609375, style=ProgressStyle(descripti…

[1.8800e+02 8.4271e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 189', max=165.609375, style=ProgressStyle(descripti…

[1.8900e+02 8.4113e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 190', max=165.609375, style=ProgressStyle(descripti…

[1.90e+02 8.28e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 191', max=165.609375, style=ProgressStyle(descripti…

[1.9100e+02 8.2815e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 192', max=165.609375, style=ProgressStyle(descripti…

[1.9200e+02 8.4756e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 193', max=165.609375, style=ProgressStyle(descripti…

[1.9300e+02 8.4686e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 194', max=165.609375, style=ProgressStyle(descripti…

[1.9400e+02 8.2783e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 195', max=165.609375, style=ProgressStyle(descripti…

[1.9500e+02 8.3855e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 196', max=165.609375, style=ProgressStyle(descripti…

[1.9600e+02 8.4194e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 197', max=165.609375, style=ProgressStyle(descripti…

[1.9700e+02 8.2341e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 198', max=165.609375, style=ProgressStyle(descripti…

[1.9800e+02 8.3255e-02]


HBox(children=(FloatProgress(value=0.0, description='Epoch 199', max=165.609375, style=ProgressStyle(descripti…

[1.9900e+02 8.2474e-02]


In [None]:
loss_tracker

[array([0.      , 0.141475]),
 array([1.      , 0.134318]),
 array([2.     , 0.12782]),
 array([3.      , 0.123532]),
 array([4.      , 0.120557]),
 array([5.      , 0.120755]),
 array([6.      , 0.119089]),
 array([7.      , 0.116263]),
 array([8.      , 0.115754]),
 array([9.      , 0.114742]),
 array([10.      ,  0.113279]),
 array([11.      ,  0.112437]),
 array([12.      ,  0.109533]),
 array([13.      ,  0.110783]),
 array([14.      ,  0.108534]),
 array([15.      ,  0.110208]),
 array([16.      ,  0.108641]),
 array([17.      ,  0.108041]),
 array([18.      ,  0.109272]),
 array([19.      ,  0.107649]),
 array([20.     ,  0.10606]),
 array([21.      ,  0.106915]),
 array([22.      ,  0.106317]),
 array([23.      ,  0.104948]),
 array([24.      ,  0.104234]),
 array([25.      ,  0.104382]),
 array([26.      ,  0.103635]),
 array([27.      ,  0.102623]),
 array([28.      ,  0.102079]),
 array([29.      ,  0.104037]),
 array([30.      ,  0.102396]),
 array([31.      ,  0.102229]),


In [None]:
with open('/content/drive/My Drive/Colab Notebooks/NLP/Machine Translation/FT_Files/loss_tracker_FT_200E_.084Loss.pkl', 'wb') as f:
  pickle.dump(loss_tracker, f)

Result trackers:   
  
100 Epochs: 0.1384 Loss


In [None]:
torch.save(encoder.state_dict(), f'/content/drive/My Drive/Colab Notebooks/NLP/Machine Translation/FT_Files/FTsimple_encoder_200E_0.084Loss.pth')
torch.save(decoder.state_dict(), f'/content/drive/My Drive/Colab Notebooks/NLP/Machine Translation/FT_Files/FTsimple_decoder_200E_0.084Loss.pth')

# New Section

Load model weights

In [58]:
encoder.load_state_dict(torch.load(encoder_saved_model_weights))

<All keys matched successfully>

In [59]:
decoder.load_state_dict(torch.load(decoder_saved_model_weights))

<All keys matched successfully>

In order to generate predictions of a french sentence:
<br/>1- Tokenize.
<br/>2- Transform words to their ids.
<br/>3- Set sentence length = 30.
<br/>3- Encode.
<br/>4- Decode next translated word until the decoder generates a special word that means end of sentence or until reach the max length = 30.

In [60]:
def sent2ids(sent):
    ids = [fr_w2id[t] for t in tokenize(sent)]
    return pad_sequences([ids], maxlen, 'int64', 'post', 'post')
  

In [61]:
def evaluate(inp):
    decoder_input, encoder_outputs, hidden = encode(inp, encoder)
    target_length = maxlen
    
    decoded_words = []
    for di in range(target_length):
        decoder_output, hidden = decoder(decoder_input, hidden)
        topv, topi = decoder_output.data.topk(1)
        ni = topi[0][0]
        if ni==PAD:
            break
        decoded_words.append(en_vocab[ni])
        decoder_input = long_t([ni])
    
    return decoded_words

In [62]:
def fr2en(sent):
    ids = long_t(sent2ids(sent))
    translation = evaluate(ids)
    return ' '.join(translation)

Bleu

In [63]:
from nltk.translate.bleu_score import sentence_bleu
# returns the one gram bleu score 

def bleu(reference,candidate):
  one_gram = sentence_bleu([reference], candidate, weights=(1, 0, 0, 0))
  return(one_gram)

In [64]:
from nltk.translate.gleu_score import sentence_gleu

def gleu(reference, candidate):
  one_gram = sentence_gleu(reference, candidate)
  return (one_gram)

Evaluate n random pairs with bleu

In [65]:
pairs = [[fr_questions[idx],en_questions[idx]] for idx,i in enumerate(sentences) ]

In [66]:
import random
import warnings
warnings.filterwarnings("ignore")

# evaluate n random sentence pairs using one gram bleu

def evaluateRandomly(encoder, decoder, n=25):
    score_tracker = []
    for i in range(n):
        pair = random.choice(pairs)
        print(i+1)

        output_words = fr2en(pair[0])
        output_sentence = ''.join(output_words)
        
        #bleu
        ref = pair[1].split()[:-1]
        pred = output_sentence.split()[:-1]
        ref, pred = fix_contractions(ref, pred)

        print('>', pair[0])
        print('=', ref)
        print('<', pred,'\n')
        one_gram = bleu(ref,pred)
        score_tracker.append(one_gram)
        print(f'Bleu Score: {one_gram}')
        print('')

    print('Avg Bleu Score (based on one-gram): ',sum(score_tracker)/len(score_tracker))

In [67]:
# sometimes ending punctuation is filtered off prediction when no EOS token is predicted
# adding it back in to not trigger missed prediction
# input: two lists of words
# output: two lists of words

def fix_punctuation(ref, pred):
  ending_punc = [ref[-1]]
  if pred[-1] not in ending_punc:
    pred.append(ending_punc[0])
  return ref, pred

In [68]:
import gensim
w2v_model = gensim.models.KeyedVectors.load_word2vec_format('/content/drive/My Drive/Colab Notebooks/NLP/Machine Translation/FT_Files/word2vec.bin', binary=True)

In [69]:
# create a custom score by averaging bs and gs scores with weights
# adding a bonus if different words have a shared semantic meaning (cs score > 0.3)
# subtracting points if predicted sequence has duplicated words
# inputs: bleu score, gleu score, # of double words in pred
# outputs: custom score

def custom_score(bs,gs,cs,double_word_penalty, verbose=True):
  total = ((bs*.75)+(gs*.25)) # weighted avg
  cs_bonus = 0

  # calc cs bonus
  def get_bonus(cs,multiplier = .4):
    additional = 0
    for i in cs:
      additional += i * multiplier
    return additional

  # if we have similarities, compute bonus
  if cs: 
    cs_bonus = get_bonus(cs)


  # if perfect score, return 1
  if bs == 1:
    if verbose:
      print('\nSemantic similarity bonus : +', float(cs_bonus))
      print('Double word penalty:        -', double_word_penalty * .1,'\n')
    return 1.00

  else:
    if cs_bonus:
      grand_total = total + cs_bonus
      if grand_total < 1:
        if verbose:
          print('\nSemantic similarity bonus : +', float(cs_bonus))
          print('Double word penalty:        -', double_word_penalty * .1,'\n')
        return grand_total

      # bonus put score over 1  
      else:
        cs_bonus = get_bonus(cs, multiplier = .3)
        if verbose:
          print('\nTotal score > 1, adjusting weights...') #debug statement, delete at end
          print('\nSemantic similarity bonus : +', float(cs_bonus))
          print('Double word penalty:        -', double_word_penalty * .1,'\n')
        grand_total = total + cs_bonus
        if grand_total < 1:   
          return grand_total

        # bonus put score over 1   
        else:
          cs_bonus = get_bonus(cs, multiplier = .2)
          grand_total = total + cs_bonus
          if grand_total < 1:  
            return grand_total
          else:
            cs_bonus = get_bonus(cs, multiplier = .1)
            grand_total = total + cs_bonus
            if grand_total < 1:   
              return grand_total



    # if no cs bonus      
    else:
      if verbose:
        print('\nSemantic similarity bonus : +', float(cs_bonus))
        print('Double word penalty:        -', double_word_penalty * .1,'\n')
      return total - (double_word_penalty * .1)

Compare Ref to Pred

In [70]:
# calculates penalties for double words, filters sentences to relevant words to compare,
# calculateds cos. sim., and a score for the strength of the cos. sims.
# inputs: two lists of words
# outputs: relevant word cosine sim. scores over 0.3, number of double words in the prediction

def similarities(A,B,verbose=True):

  # does B have more double words than A? 
  doublesA = 0
  doublesB = 0
  basketA = []
  basketB = []

  for i in A:
    if i not in basketA:
      basketA.append(i)
    else:
      doublesA += 1

  for i in B: 
    if i not in basketB:
      basketB.append(i)
    else:
      doublesB += 1

  # calc penalty, keep only positive values
  double_word_penalty = np.clip(doublesB - doublesA, 0,3) 

  # get words not in the other sentence and not in punc/stopwords
  stop_words = ['a','an','of','the','to','on','t','in','as'] #,'not','no']
  punc = ['.','?','!',',']
  extraW = [] # all extra words
  extraA = []
  extraB = []

  for i in A:
    if i not in punc:
      if i not in stop_words:
        if (i not in B):
          extraA.append(i)
  for i in B:
    if i not in punc:
      if i not in stop_words:
        if (i not in A):
          extraB.append(i)

  extraW = extraA + extraB

  # if off by one word, exit
  if len(extraW) == 1:
    return [0, double_word_penalty]
  
  # calc cos sims and score
  sim_finn = []
  sim_w2v = []
  record = []

  for a, b in itertools.product(extraA,extraB):
    sim_w2v.append([a,b,w2v_model.similarity(a,b)])

  sorted_sim_w2v = sorted(sim_w2v, key = lambda x: x[2], reverse=True)

  cs_score = [0] #list of cs over 0.3

  # print cs scores
  if sorted_sim_w2v:
    if verbose:
      print('\nSemantic similarities using w2v:')
    for idx,i in enumerate(sorted_sim_w2v):
      if verbose:
        if i[2] > 0.3:
          print(bold, end="")
          print('  ',i,reset)
        else:
          print('  ',i)
    #print('\n')
    # record cs scores
    for i in sorted_sim_w2v:
      if i[2] > 0.3:
        cs_score.append(i[2])

  return [cs_score, double_word_penalty]

In [91]:
# evaluate n random pairs from dataset
# input: models, n
# output: none

def evaluateRandomly(encoder, decoder, n=100, all = False):
    print(bold+'Evaluation of Machine Translation Model'+reset)
    
    bleu_score_tracker = []
    gleu_score_tracker = []
    custom_tracker = []
    record_test = []
    ending_punc = ['.','?','!']
    glove_frwac_ref_pred = []

    if not all: # run on n randomly chosen pairs
      print(bold+'Evaluating ',n,' examples...'+reset)
      for i in range(n):
          pair = random.choice(pairs)

          output_words = fr2en(pair[0])
 
          # lists
          ref = pair[1].split()#[:-1]
          pred = output_words.split()[:-1]

          # fix contractions
          ref, pred = fix_contractions(ref, pred)

          # if missing ending punctuation
          if pred[-1] not in ending_punc:
            ref, pred = fix_punctuation(ref, pred) 

          glove_frwac_ref_pred.append([i,ref,pred])

          #print('Before Bleu: ',ref[:-1],pred[:-1])
          bleu_one_gram = bleu(ref[:-1],pred[:-1])


          # DO NOT DISPLAY PERFECT SCORES - USED FOR EASY DEBUGGING - DELETE AT END
          # CONVERT TO DISPLAYING PERFECT SCORES SOME FRACTION OF THE TIME (1/5TH?)
          if bleu_one_gram:
            print('\n')
            print(bold+'Input:\t'+reset, pair[0])
            print(bold+'Target:\t'+reset, ' '.join(ref))

            # GOOGLE TRANSLATE FUNCTION - ALLOTED LIMITED TRANSLATIONS PER DAY
            #gt = google_translate(pair[0])
            #gt = normalizeString(gt)
            #gt, _ = fix_contractions(gt.split(),' ')
            #gt = ' '.join(gt)
            #print('GT:\t',gt)
            #bleu_one_gram_gt = bleu([ref[:-1]],gt[:-1])
            #print('GT Bleu Score: ',bleu_one_gram_gt)
            #if bleu_one_gram > bleu_one_gram_gt:
              #print('Better than GT!')

            print(bold+'Pred:\t'+reset, ' '.join(pred),'\n')
            
            # requires ref to be a 2d list, pred 1d list
            bleu_score_tracker.append(bleu_one_gram)
            print(f'Bleu Score: {bleu_one_gram:.3f}')

            gleu_one_gram = gleu([ref[:-1]],pred[:-1])
            gleu_score_tracker.append(gleu_one_gram)
            print(f'Gleu Score: {gleu_one_gram:.3f}')
            print(f'Avg Score:  {(gleu_one_gram*.25+bleu_one_gram*.75):.3f}') #weighted
            
            cs_score = 0

            # if not perfect score: calc. bonuses and penalties
            if bleu_one_gram < 1:
              try: # sometimes sims returns none
                sim_returns = similarities(ref,pred)
                cs_score = sim_returns[0]
                double_word_penalty = sim_returns[1]
                cust_score = custom_score(bleu_one_gram,gleu_one_gram,cs_score,double_word_penalty)
                print(f'{bold_red_font_tag}Custom Score: {cust_score:.3f}{reset}')
                #print('\n')

              # if word not in WE
              except KeyError:
                print('Cosine similarities: Word not found in embedding vocabulary')
                continue
            else:
              cust_score = custom_score(bleu_one_gram,gleu_one_gram,0,0)
              print(f'{bold_red_font_tag}Custom Score: {cust_score:.3f}{reset}')

            custom_tracker.append(cust_score)

      print('\n')
      print(f'{bold_blue_font_tag}Avg Bleu Score  :{reset} {sum(bleu_score_tracker)/len(bleu_score_tracker):.3f}')
      print(f'{bold_blue_font_tag}Avg Gleu Score  :{reset} {sum(gleu_score_tracker)/len(gleu_score_tracker):.3f}')
      print(f'{bold_blue_font_tag}Avg Custom Score:{reset} {sum(custom_tracker)/len(custom_tracker):.3f}')

    else: # run on entire dataset
      print(bold+'Evaluating entire dataset...'+reset)
      for i in range(len(pairs)):

        
        pair = pairs[i]
        output_words = fr2en(pair[0])
        ref = pair[1].split()#[:-1]
        pred = output_words.split()[:-1]

        # fix contractions
        ref, pred = fix_contractions(ref, pred)

        # if missing ending punctuation
        if pred[-1] not in ending_punc:
          ref, pred = fix_punctuation(ref, pred) 

        glove_frwac_ref_pred.append([i,ref,pred])

        bleu_one_gram = bleu(ref[:-1],pred[:-1])
        bleu_score_tracker.append(bleu_one_gram)
        gleu_one_gram = gleu([ref[:-1]],pred[:-1])
        gleu_score_tracker.append(gleu_one_gram)

        cs_score = 0

        # if not perfect score: calc. bonuses and penalties
        if bleu_one_gram < 1:
          try: # sometimes sims returns none
            sim_returns = similarities(ref,pred, verbose=False)
            cs_score = sim_returns[0]
            double_word_penalty = sim_returns[1]
            cust_score = custom_score(bleu_one_gram,gleu_one_gram,cs_score,double_word_penalty, verbose=False)
            #print(bold_red_font_tag+'Custom Score: ',cust_score,reset)
            #print('\n')

          # if word not in WE
          except KeyError:
            #print('Cosine similarities: Word not found in embedding vocabulary')
            continue
        else:
          cust_score = custom_score(bleu_one_gram,gleu_one_gram,0,0, verbose=False)
          #print(bold_red_font_tag+'Custom Score: ',cust_score,reset)

        custom_tracker.append(cust_score)
        
      '''print(len(custom_tracker))
      
      print(type(custom_tracker))
      print(type(custom_tracker[1]))'''
      print('\n')
      custom_tracker = list(filter(None, custom_tracker)) 
      #custom_tracker=[float(i) for i in custom_tracker]
      print(f'{bold_blue_font_tag}Avg Bleu Score  :{reset} {sum(bleu_score_tracker)/len(bleu_score_tracker):.3f}')
      print(f'{bold_blue_font_tag}Avg Gleu Score  :{reset} {sum(gleu_score_tracker)/len(gleu_score_tracker):.3f}')
      print(f'{bold_blue_font_tag}Avg Custom Score:{reset} {sum(custom_tracker)/len(custom_tracker):.3f}')

    return glove_frwac_ref_pred

Evaluate bleu score on entire on all training data

In [72]:
# get bleu score for entire dataset

def bleuScore(encoder, decoder):
  score_tracker = []
  for i in range(len(pairs)):
      pair = pairs[i]
      #print('pair: ',pair)
      output_words = fr2en(pair[0])
      #print(output_words)
      output_sentence = output_words
      ref = pair[1].split()[:-1]
      pred = output_sentence.split()[:-1]
      ref, pred = fix_contractions(ref, pred)
      #print(ref)
      #print(pred)
      one_gram = bleu(ref,pred)
      score_tracker.append(one_gram)


  print('Avg Bleu Score: ',sum(score_tracker)/len(score_tracker))

    

In [73]:
'''
def fix_contractions(ref,pred):
  for idx, word in enumerate(pred):
    if word == 're':
      pred[idx] = 'are'
    elif word == 'm':
      pred[idx] = 'am' 
    elif word == 's':
      pred[idx] = 'is'   
    elif word == 'aren': 
      try:
        if pred[idx+1] == 't':
          pred[idx] = 'are' 
          pred[idx+1] = 'not'
      except IndexError:
        continue

  for idx, rword in enumerate(ref):
    if rword == 're':
      ref[idx] = 'are'
    elif rword == 'm':
      ref[idx] = 'am' 
    elif rword == 's':
      ref[idx] = 'is'  
    elif rword == 'aren': 
      try:
        if ref[idx+1] == 't':
          ref[idx] = 'are' 
          ref[idx+1] = 'not'      
      except IndexError:
        continue
        
  return ref, pred'''

"\ndef fix_contractions(ref,pred):\n  for idx, word in enumerate(pred):\n    if word == 're':\n      pred[idx] = 'are'\n    elif word == 'm':\n      pred[idx] = 'am' \n    elif word == 's':\n      pred[idx] = 'is'   \n    elif word == 'aren': \n      try:\n        if pred[idx+1] == 't':\n          pred[idx] = 'are' \n          pred[idx+1] = 'not'\n      except IndexError:\n        continue\n\n  for idx, rword in enumerate(ref):\n    if rword == 're':\n      ref[idx] = 'are'\n    elif rword == 'm':\n      ref[idx] = 'am' \n    elif rword == 's':\n      ref[idx] = 'is'  \n    elif rword == 'aren': \n      try:\n        if ref[idx+1] == 't':\n          ref[idx] = 'are' \n          ref[idx+1] = 'not'      \n      except IndexError:\n        continue\n        \n  return ref, pred"

In [74]:
# fix issues with contractions when displaying results.
# issues: 's' could represent possesion and not 'is.' Small fraction of the time though.
# inputs: two lists of words
# outputs: two lists of words

def fix_contractions(ref,pred):

  for idx, word in enumerate(pred):
    if word == 're':
      pred[idx] = 'are'
    elif word == 'm':
      pred[idx] = 'am' 
    elif word == 's':
      pred[idx] = 'is'   
    elif word == 'ok':
      pred[idx] = 'okay'  
    elif word == 'aren': 
      try:
        if pred[idx+1] == 't':
          pred[idx] = 'are' 
          pred[idx+1] = 'not'
      except IndexError:
        continue

    elif word == 'isn': 
      try:
        if pred[idx+1] == 't':
          pred[idx] = 'is' 
          pred[idx+1] = 'not'
      except IndexError:
        continue
    elif (word == 'don' and pred[idx+1] == 't'):
      pred[idx] = 'do' 
      pred[idx+1] = 'not'

  for idx, rword in enumerate(ref):
    if rword == 're':
      ref[idx] = 'are'
    elif rword == 'm':
      ref[idx] = 'am' 
    elif rword == 'ok':
      ref[idx] = 'okay'       
    elif rword == 's':
      ref[idx] = 'is'  
    elif rword == 'aren': 
      try:
        if ref[idx+1] == 't':
          ref[idx] = 'are' 
          ref[idx+1] = 'not'      
      except IndexError:
        continue    
    elif rword == 'isn': 
      try:
        if ref[idx+1] == 't':
          ref[idx] = 'is' 
          ref[idx+1] = 'not'      
      except IndexError:
        continue   
    elif (rword == 'don' and ref[idx+1] == 't'):
      ref[idx] = 'do' 
      ref[idx+1] = 'not' 

  return ref, pred

In [75]:
bold_blue_font_tag = '\x1b[1m\x1b[34m'
bold_red_font_tag = '\x1b[1m\x1b[31m'
bold_gree_font_tag = '\x1b[1m\x1b[32m'
magenta = '\033[35m'
bold = '\033[1m'
reset = '\033[0m'

Evaluate on a sample

In [76]:
ft_frwac_ref_pred = evaluateRandomly(encoder, decoder, all= False)

[1mEvaluation of Machine Translation Model[0m
[1mEvaluating  100  examples...[0m


[1mInput:	[0m nous essayons .
[1mTarget:	[0m we are trying .
[1mPred:	[0m we are trying . 

Bleu Score: 1.000
Gleu Score: 1.000
Avg Score:  1.000

Semantic similarity bonus : + 0.0
Double word penalty:        - 0.0 

[1m[31mCustom Score: 1.000[0m


[1mInput:	[0m je vais bien .
[1mTarget:	[0m i am okay .
[1mPred:	[0m i am fine . 

Bleu Score: 0.667
Gleu Score: 0.500
Avg Score:  0.625

Semantic similarities using w2v:
[1m   ['okay', 'fine', 0.38318834] [0m

Semantic similarity bonus : + 0.1532753348350525
Double word penalty:        - 0.0 

[1m[31mCustom Score: 0.778[0m


[1mInput:	[0m tu es une etudiante .
[1mTarget:	[0m you are a student .
[1mPred:	[0m you are a terrible person . 

Bleu Score: 0.600
Gleu Score: 0.429
Avg Score:  0.557

Semantic similarities using w2v:
[1m   ['student', 'person', 0.33660564] [0m
   ['student', 'terrible', 0.05019008]

Semantic similarity bo

In [92]:
ft_frwac_ref_pred = evaluateRandomly(encoder, decoder, all= True)

[1mEvaluation of Machine Translation Model[0m
[1mEvaluating entire dataset...[0m


[1m[34mAvg Bleu Score  :[0m 0.755
[1m[34mAvg Gleu Score  :[0m 0.670
[1m[34mAvg Custom Score:[0m 0.803


In [93]:
len(ft_frwac_ref_pred)

10599

Save results to pickle file for scratch notebook - Travis needs this file

In [94]:
with open('/content/drive/My Drive/Colab Notebooks/NLP/Machine Translation/FT_Files/ft_frwac_results_V1_200Epochs_.084Loss.pkl', 'wb') as f:
  pickle.dump(ft_frwac_ref_pred, f)