In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
!git clone https://github.com/susannapaoli/NLP-final-project.git

Cloning into 'NLP-final-project'...
remote: Enumerating objects: 3, done.[K
remote: Counting objects: 100% (3/3), done.[K
remote: Total 3 (delta 0), reused 0 (delta 0), pack-reused 0[K
Unpacking objects: 100% (3/3), 608 bytes | 608.00 KiB/s, done.


In [4]:
%cd /content/NLP-final-project

/content/NLP-final-project


## Import packages 

In [4]:
!python -m spacy download en_core_web_sm
!python -m spacy download fr_core_news_sm

2023-04-14 19:21:10.515599: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-04-14 19:21:12.865409: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:996] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2023-04-14 19:21:12.865857: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:996] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2023-04-

In [5]:
import math
import time
import io
import numpy as np
import csv
from IPython.display import Image

import torch
import torch.nn as nn
import torch.optim as optim

import torchtext
from torchtext.datasets import Multi30k
from torch.utils.data import DataLoader
from torchtext.data.utils import get_tokenizer
from collections import Counter
from torchtext.vocab import vocab
from torchtext.utils import download_from_url, extract_archive
from torch.nn.utils.rnn import pad_sequence

from tqdm import tqdm_notebook, tqdm


In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("You are using device: %s" % device)

You are using device: cuda


## Data preprocessing

In [7]:
MAX_LEN = 20
url_base = 'https://raw.githubusercontent.com/multi30k/dataset/master/data/task1/raw/'
train_urls = ('train.fr.gz', 'train.en.gz')
val_urls = ('val.fr.gz', 'val.en.gz')
test_urls = ('test_2016_flickr.fr.gz', 'test_2016_flickr.en.gz')

train_filepaths = [extract_archive(download_from_url(url_base + url))[0] for url in train_urls]
val_filepaths = [extract_archive(download_from_url(url_base + url))[0] for url in val_urls]
test_filepaths = [extract_archive(download_from_url(url_base + url))[0] for url in test_urls]

fr_tokenizer = get_tokenizer('spacy', language='fr_core_news_sm')
en_tokenizer = get_tokenizer('spacy', language='en_core_web_sm')

In [8]:
def build_vocab(filepath, tokenizer):
  counter = Counter()
  with io.open(filepath, encoding="utf8") as f:
    for string_ in f:
      counter.update(tokenizer(string_))
  return vocab(counter, specials=['<unk>', '<pad>', '<sos>', '<eos>'], min_freq=2)


fr_vocab = build_vocab(train_filepaths[0], fr_tokenizer)
en_vocab = build_vocab(train_filepaths[1], en_tokenizer)
fr_vocab.set_default_index(fr_vocab['<unk>'])
en_vocab.set_default_index(en_vocab['<unk>'])

In [9]:
def data_process(filepaths):
  raw_fr_iter = iter(io.open(filepaths[0], encoding="utf8"))
  raw_en_iter = iter(io.open(filepaths[1], encoding="utf8"))
  data = []
  for (raw_fr, raw_en) in zip(raw_fr_iter, raw_en_iter):
    raw_en_l=raw_en.lower()     #turn sentences to lower case 
    raw_fr_l=raw_fr.lower()
    fr_tensor = torch.tensor([fr_vocab[token] for token in fr_tokenizer(raw_fr_l)],
                            dtype=torch.long)
    en_tensor = torch.tensor([en_vocab[token] for token in en_tokenizer(raw_en_l)],
                            dtype=torch.long)
    if len(fr_tensor) <= MAX_LEN-2 and len(en_tensor) <= MAX_LEN-2:
        data.append((fr_tensor, en_tensor))
  return data

In [10]:
train_data = data_process(train_filepaths)
val_data = data_process(val_filepaths)
test_data = data_process(test_filepaths)

In [11]:
BATCH_SIZE = 128
PAD_IDX = fr_vocab['<pad>']
SOS_IDX = fr_vocab['<sos>']
EOS_IDX = fr_vocab['<eos>']

In [12]:
def generate_batch(data_batch):
  
    fr_batch, en_batch = [], []
    for (fr_item, en_item) in data_batch:
          en_batch.append(torch.cat([torch.tensor([SOS_IDX]), en_item, torch.tensor([EOS_IDX])], dim=0))
          fr_batch.append(torch.cat([torch.tensor([SOS_IDX]), fr_item, torch.tensor([EOS_IDX])], dim=0))
    en_batch = pad_sequence(en_batch, padding_value=PAD_IDX)
    fr_batch = pad_sequence(fr_batch, padding_value=PAD_IDX)
    fix=torch.ones(MAX_LEN,en_batch.shape[1])
    two= pad_sequence([fr_batch,en_batch, fix], padding_value=PAD_IDX)
    fr_batch=two[:,0,]
    en_batch=two[:,1,]
    return fr_batch, en_batch

In [13]:
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE,
                        shuffle=False, collate_fn=generate_batch)
valid_loader = DataLoader(val_data, batch_size=BATCH_SIZE,
                        shuffle=False, collate_fn=generate_batch)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE,
                       shuffle=False, collate_fn=generate_batch)

In [14]:
# Get the input and the output sizes for model
input_size = len(fr_vocab)
output_size = len(en_vocab)
print (input_size,output_size)

6556 6192


## Train and Evaluate functions

In [None]:
def train(model, dataloader, optimizer, criterion, scheduler=None, device='cpu'):
    model.train()

    # Record total loss
    total_loss = 0.

    # Get the progress bar for later modification
    progress_bar = tqdm_notebook(dataloader, ascii=True)

    # Mini-batch training
    for batch_idx, data in enumerate(progress_bar):
        source = data[0].transpose(1, 0).to(device)
        target = data[1].transpose(1, 0).to(device)

        translation = model(source)
        translation = translation.reshape(-1, translation.shape[-1])
        target = target.reshape(-1)

        optimizer.zero_grad()
        loss = criterion(translation, target)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()
        progress_bar.set_description_str(
            "Batch: %d, Loss: %.4f" % ((batch_idx + 1), loss.item()))

    return total_loss, total_loss / len(dataloader)

In [None]:
def evaluate(model, dataloader, criterion, device='cpu'):
    # Set the model to eval mode to avoid weights update
    model.eval()
    total_loss = 0.
    with torch.no_grad():
        # Get the progress bar
        progress_bar = tqdm_notebook(dataloader, ascii=True)
        for batch_idx, data in enumerate(progress_bar):
            source = data[0].transpose(1, 0).to(device)
            target = data[1].transpose(1, 0).to(device)

            translation = model(source)
            translation = translation.reshape(-1, translation.shape[-1])
            target = target.reshape(-1)

            loss = criterion(translation, target)
            total_loss += loss.item()
            progress_bar.set_description_str(
                "Batch: %d, Loss: %.4f" % ((batch_idx + 1), loss.item()))

    avg_loss = total_loss / len(dataloader)
    return total_loss, avg_loss


## Train the LSTM seq2seq Model

In [None]:
from seq2seq import Encoder, Decoder, Seq2Seq

In [None]:
# Hyperparameters. You are welcome to modify these
encoder_emb_size = 128
encoder_hidden_size = 128
encoder_dropout = 0.2

decoder_emb_size = 128
decoder_hidden_size = 128
decoder_dropout = 0.2

learning_rate = 1e-3
model_type = "LSTM"

EPOCHS = 40

#input size and output size
input_size = len(fr_vocab)
output_size = len(en_vocab)

In [None]:
encoder = Encoder(input_size, encoder_emb_size, encoder_hidden_size, decoder_hidden_size, dropout = encoder_dropout, model_type = model_type)
decoder = Decoder(decoder_emb_size, encoder_hidden_size, encoder_hidden_size, output_size, dropout = decoder_dropout, model_type = model_type)
seq2seq_model = Seq2Seq(encoder, decoder, device, attention=True)
optimizer = optim.Adam(seq2seq_model.parameters(), lr = learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

In [None]:
for epoch_idx in range(EPOCHS):
    print("-----------------------------------")
    print("Epoch %d" % (epoch_idx+1))
    print("-----------------------------------")
    
    train_loss, avg_train_loss = train(seq2seq_model, train_loader, optimizer, criterion, device=device)
    scheduler.step(train_loss)

    val_loss, avg_val_loss = evaluate(seq2seq_model, valid_loader, criterion, device=device)
    
    print("Training Loss: %.4f. Validation Loss: %.4f. " % (avg_train_loss, avg_val_loss))
    print("Training Perplexity: %.4f. Validation Perplexity: %.4f. " % (np.exp(avg_train_loss), np.exp(avg_val_loss)))

-----------------------------------
Epoch 1
-----------------------------------


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  progress_bar = tqdm_notebook(dataloader, ascii=True)


  0%|          | 0/176 [00:00<?, ?it/s]

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  progress_bar = tqdm_notebook(dataloader, ascii=True)


  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 5.4295. Validation Loss: 5.1126. 
Training Perplexity: 228.0442. Validation Perplexity: 166.0984. 
-----------------------------------
Epoch 2
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 4.8098. Validation Loss: 4.6197. 
Training Perplexity: 122.7045. Validation Perplexity: 101.4669. 
-----------------------------------
Epoch 3
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 4.5742. Validation Loss: 4.4078. 
Training Perplexity: 96.9545. Validation Perplexity: 82.0901. 
-----------------------------------
Epoch 4
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 4.3585. Validation Loss: 4.2253. 
Training Perplexity: 78.1406. Validation Perplexity: 68.3959. 
-----------------------------------
Epoch 5
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 4.1974. Validation Loss: 4.1067. 
Training Perplexity: 66.5101. Validation Perplexity: 60.7460. 
-----------------------------------
Epoch 6
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 4.0703. Validation Loss: 3.9895. 
Training Perplexity: 58.5725. Validation Perplexity: 54.0262. 
-----------------------------------
Epoch 7
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 3.9581. Validation Loss: 3.8878. 
Training Perplexity: 52.3559. Validation Perplexity: 48.8026. 
-----------------------------------
Epoch 8
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 3.8495. Validation Loss: 3.7923. 
Training Perplexity: 46.9717. Validation Perplexity: 44.3573. 
-----------------------------------
Epoch 9
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 3.7053. Validation Loss: 3.6293. 
Training Perplexity: 40.6624. Validation Perplexity: 37.6871. 
-----------------------------------
Epoch 10
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 3.5474. Validation Loss: 3.4959. 
Training Perplexity: 34.7226. Validation Perplexity: 32.9784. 
-----------------------------------
Epoch 11
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 3.4139. Validation Loss: 3.3653. 
Training Perplexity: 30.3825. Validation Perplexity: 28.9417. 
-----------------------------------
Epoch 12
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 3.2861. Validation Loss: 3.2694. 
Training Perplexity: 26.7396. Validation Perplexity: 26.2966. 
-----------------------------------
Epoch 13
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 3.1731. Validation Loss: 3.1699. 
Training Perplexity: 23.8825. Validation Perplexity: 23.8050. 
-----------------------------------
Epoch 14
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 3.0844. Validation Loss: 3.1109. 
Training Perplexity: 21.8542. Validation Perplexity: 22.4421. 
-----------------------------------
Epoch 15
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.9980. Validation Loss: 3.0580. 
Training Perplexity: 20.0453. Validation Perplexity: 21.2860. 
-----------------------------------
Epoch 16
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.9263. Validation Loss: 3.0199. 
Training Perplexity: 18.6578. Validation Perplexity: 20.4900. 
-----------------------------------
Epoch 17
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.8704. Validation Loss: 2.9707. 
Training Perplexity: 17.6435. Validation Perplexity: 19.5050. 
-----------------------------------
Epoch 18
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.8094. Validation Loss: 2.9314. 
Training Perplexity: 16.5996. Validation Perplexity: 18.7538. 
-----------------------------------
Epoch 19
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.7573. Validation Loss: 2.9118. 
Training Perplexity: 15.7567. Validation Perplexity: 18.3900. 
-----------------------------------
Epoch 20
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.7099. Validation Loss: 2.8679. 
Training Perplexity: 15.0279. Validation Perplexity: 17.6005. 
-----------------------------------
Epoch 21
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.6664. Validation Loss: 2.8510. 
Training Perplexity: 14.3887. Validation Perplexity: 17.3050. 
-----------------------------------
Epoch 22
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.6246. Validation Loss: 2.8285. 
Training Perplexity: 13.7994. Validation Perplexity: 16.9193. 
-----------------------------------
Epoch 23
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.5811. Validation Loss: 2.8048. 
Training Perplexity: 13.2122. Validation Perplexity: 16.5239. 
-----------------------------------
Epoch 24
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.5444. Validation Loss: 2.7832. 
Training Perplexity: 12.7351. Validation Perplexity: 16.1714. 
-----------------------------------
Epoch 25
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.5070. Validation Loss: 2.7748. 
Training Perplexity: 12.2681. Validation Perplexity: 16.0361. 
-----------------------------------
Epoch 26
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.4755. Validation Loss: 2.7348. 
Training Perplexity: 11.8874. Validation Perplexity: 15.4061. 
-----------------------------------
Epoch 27
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.4477. Validation Loss: 2.7304. 
Training Perplexity: 11.5613. Validation Perplexity: 15.3387. 
-----------------------------------
Epoch 28
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.4146. Validation Loss: 2.7156. 
Training Perplexity: 11.1851. Validation Perplexity: 15.1141. 
-----------------------------------
Epoch 29
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.3830. Validation Loss: 2.7100. 
Training Perplexity: 10.8378. Validation Perplexity: 15.0289. 
-----------------------------------
Epoch 30
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.3549. Validation Loss: 2.6873. 
Training Perplexity: 10.5370. Validation Perplexity: 14.6921. 
-----------------------------------
Epoch 31
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.3308. Validation Loss: 2.6646. 
Training Perplexity: 10.2867. Validation Perplexity: 14.3617. 
-----------------------------------
Epoch 32
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.2989. Validation Loss: 2.6597. 
Training Perplexity: 9.9635. Validation Perplexity: 14.2918. 
-----------------------------------
Epoch 33
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.2776. Validation Loss: 2.6463. 
Training Perplexity: 9.7536. Validation Perplexity: 14.1012. 
-----------------------------------
Epoch 34
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.2523. Validation Loss: 2.6332. 
Training Perplexity: 9.5096. Validation Perplexity: 13.9178. 
-----------------------------------
Epoch 35
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.2324. Validation Loss: 2.6356. 
Training Perplexity: 9.3224. Validation Perplexity: 13.9518. 
-----------------------------------
Epoch 36
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.2065. Validation Loss: 2.6300. 
Training Perplexity: 9.0841. Validation Perplexity: 13.8743. 
-----------------------------------
Epoch 37
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.1912. Validation Loss: 2.6180. 
Training Perplexity: 8.9463. Validation Perplexity: 13.7077. 
-----------------------------------
Epoch 38
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.1671. Validation Loss: 2.6209. 
Training Perplexity: 8.7333. Validation Perplexity: 13.7487. 
-----------------------------------
Epoch 39
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.1451. Validation Loss: 2.6023. 
Training Perplexity: 8.5427. Validation Perplexity: 13.4946. 
-----------------------------------
Epoch 40
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.1294. Validation Loss: 2.5962. 
Training Perplexity: 8.4102. Validation Perplexity: 13.4133. 


## Train the transformer model

In [None]:
from Transformer import TransformerTranslator

In [None]:
learning_rate = 1e-4
EPOCHS = 40
hidden_dim=400
num_heads=10
dim_feedforward=2048
dim_k=96
dim_v=96
dim_q=96
max_length=50
#max_length = 43 

trans_model = TransformerTranslator(input_size, output_size, device,num_heads = num_heads, max_length = max_length, hidden_dim = hidden_dim ,dim_feedforward = dim_feedforward).to(device)

# optimizer = optim.Adam(model.parameters(), lr = learning_rate)
optimizer = torch.optim.Adam(trans_model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

In [None]:
for epoch_idx in range(EPOCHS):
    print("-----------------------------------")
    print("Epoch %d" % (epoch_idx+1))
    print("-----------------------------------")
    
    train_loss, avg_train_loss = train(trans_model, train_loader, optimizer, criterion, device=device)
    scheduler.step(train_loss)

    val_loss, avg_val_loss = evaluate(trans_model, valid_loader, criterion, device=device)

    print("Training Loss: %.4f. Validation Loss: %.4f. " % (avg_train_loss, avg_val_loss))
    print("Training Perplexity: %.4f. Validation Perplexity: %.4f. " % (np.exp(avg_train_loss), np.exp(avg_val_loss)))

-----------------------------------
Epoch 1
-----------------------------------


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  progress_bar = tqdm_notebook(dataloader, ascii=True)


  0%|          | 0/176 [00:00<?, ?it/s]

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  progress_bar = tqdm_notebook(dataloader, ascii=True)


  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 4.7045. Validation Loss: 3.7538. 
Training Perplexity: 110.4476. Validation Perplexity: 42.6827. 
-----------------------------------
Epoch 2
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 3.6512. Validation Loss: 3.5368. 
Training Perplexity: 38.5208. Validation Perplexity: 34.3570. 
-----------------------------------
Epoch 3
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 3.4625. Validation Loss: 3.4214. 
Training Perplexity: 31.8965. Validation Perplexity: 30.6130. 
-----------------------------------
Epoch 4
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 3.3343. Validation Loss: 3.3400. 
Training Perplexity: 28.0585. Validation Perplexity: 28.2199. 
-----------------------------------
Epoch 5
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 3.2322. Validation Loss: 3.2760. 
Training Perplexity: 25.3362. Validation Perplexity: 26.4705. 
-----------------------------------
Epoch 6
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 3.1455. Validation Loss: 3.2229. 
Training Perplexity: 23.2308. Validation Perplexity: 25.1016. 
-----------------------------------
Epoch 7
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 3.0693. Validation Loss: 3.1781. 
Training Perplexity: 21.5265. Validation Perplexity: 24.0019. 
-----------------------------------
Epoch 8
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 3.0008. Validation Loss: 3.1390. 
Training Perplexity: 20.1017. Validation Perplexity: 23.0814. 
-----------------------------------
Epoch 9
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.9381. Validation Loss: 3.1045. 
Training Perplexity: 18.8800. Validation Perplexity: 22.2986. 
-----------------------------------
Epoch 10
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.8799. Validation Loss: 3.0743. 
Training Perplexity: 17.8127. Validation Perplexity: 21.6341. 
-----------------------------------
Epoch 11
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.8254. Validation Loss: 3.0475. 
Training Perplexity: 16.8671. Validation Perplexity: 21.0631. 
-----------------------------------
Epoch 12
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.7738. Validation Loss: 3.0234. 
Training Perplexity: 16.0198. Validation Perplexity: 20.5615. 
-----------------------------------
Epoch 13
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.7249. Validation Loss: 3.0015. 
Training Perplexity: 15.2546. Validation Perplexity: 20.1151. 
-----------------------------------
Epoch 14
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.6782. Validation Loss: 2.9814. 
Training Perplexity: 14.5588. Validation Perplexity: 19.7160. 
-----------------------------------
Epoch 15
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.6336. Validation Loss: 2.9633. 
Training Perplexity: 13.9232. Validation Perplexity: 19.3611. 
-----------------------------------
Epoch 16
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.5907. Validation Loss: 2.9470. 
Training Perplexity: 13.3388. Validation Perplexity: 19.0479. 
-----------------------------------
Epoch 17
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.5493. Validation Loss: 2.9313. 
Training Perplexity: 12.7982. Validation Perplexity: 18.7528. 
-----------------------------------
Epoch 18
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.5092. Validation Loss: 2.9152. 
Training Perplexity: 12.2949. Validation Perplexity: 18.4534. 
-----------------------------------
Epoch 19
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.4702. Validation Loss: 2.9010. 
Training Perplexity: 11.8245. Validation Perplexity: 18.1919. 
-----------------------------------
Epoch 20
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.4323. Validation Loss: 2.8886. 
Training Perplexity: 11.3851. Validation Perplexity: 17.9674. 
-----------------------------------
Epoch 21
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.3956. Validation Loss: 2.8775. 
Training Perplexity: 10.9742. Validation Perplexity: 17.7700. 
-----------------------------------
Epoch 22
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.3598. Validation Loss: 2.8676. 
Training Perplexity: 10.5887. Validation Perplexity: 17.5945. 
-----------------------------------
Epoch 23
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.3249. Validation Loss: 2.8585. 
Training Perplexity: 10.2258. Validation Perplexity: 17.4358. 
-----------------------------------
Epoch 24
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.2909. Validation Loss: 2.8506. 
Training Perplexity: 9.8834. Validation Perplexity: 17.2986. 
-----------------------------------
Epoch 25
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.2576. Validation Loss: 2.8435. 
Training Perplexity: 9.5597. Validation Perplexity: 17.1764. 
-----------------------------------
Epoch 26
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.2249. Validation Loss: 2.8371. 
Training Perplexity: 9.2526. Validation Perplexity: 17.0668. 
-----------------------------------
Epoch 27
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.1929. Validation Loss: 2.8315. 
Training Perplexity: 8.9613. Validation Perplexity: 16.9713. 
-----------------------------------
Epoch 28
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.1615. Validation Loss: 2.8266. 
Training Perplexity: 8.6843. Validation Perplexity: 16.8876. 
-----------------------------------
Epoch 29
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.1307. Validation Loss: 2.8223. 
Training Perplexity: 8.4206. Validation Perplexity: 16.8151. 
-----------------------------------
Epoch 30
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.1004. Validation Loss: 2.8185. 
Training Perplexity: 8.1691. Validation Perplexity: 16.7520. 
-----------------------------------
Epoch 31
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.0705. Validation Loss: 2.8153. 
Training Perplexity: 7.9290. Validation Perplexity: 16.6974. 
-----------------------------------
Epoch 32
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.0411. Validation Loss: 2.8124. 
Training Perplexity: 7.6992. Validation Perplexity: 16.6497. 
-----------------------------------
Epoch 33
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 2.0121. Validation Loss: 2.8100. 
Training Perplexity: 7.4794. Validation Perplexity: 16.6105. 
-----------------------------------
Epoch 34
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 1.9836. Validation Loss: 2.8080. 
Training Perplexity: 7.2687. Validation Perplexity: 16.5771. 
-----------------------------------
Epoch 35
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 1.9554. Validation Loss: 2.8064. 
Training Perplexity: 7.0669. Validation Perplexity: 16.5498. 
-----------------------------------
Epoch 36
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 1.9277. Validation Loss: 2.8052. 
Training Perplexity: 6.8735. Validation Perplexity: 16.5300. 
-----------------------------------
Epoch 37
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 1.9004. Validation Loss: 2.8045. 
Training Perplexity: 6.6883. Validation Perplexity: 16.5187. 
-----------------------------------
Epoch 38
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 1.8734. Validation Loss: 2.8042. 
Training Perplexity: 6.5106. Validation Perplexity: 16.5141. 
-----------------------------------
Epoch 39
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 1.8469. Validation Loss: 2.8045. 
Training Perplexity: 6.3401. Validation Perplexity: 16.5185. 
-----------------------------------
Epoch 40
-----------------------------------


  0%|          | 0/176 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Training Loss: 1.8207. Validation Loss: 2.8050. 
Training Perplexity: 6.1764. Validation Perplexity: 16.5273. 


## Get some translations

In [None]:
def translate(model, dataloader):
    model.eval()
    with torch.no_grad():
        # Get the progress bar 
        progress_bar = tqdm(dataloader, ascii = True)
        for batch_idx, data in enumerate(progress_bar):
            source = data[0].transpose(1,0).to(device)
            target = data[1].transpose(1,0).to(device)

            translation = model(source)
            return target, translation

In [None]:
#model = seq2seq_model
model = trans_model

In [None]:
target, translation = translate(model, test_loader)

  0%|          | 0/7 [00:00<?, ?it/s]


In [None]:
raw = np.array([list(map(lambda x: en_vocab.get_itos()[x], target[i])) for i in range(target.shape[0])])

In [None]:
raw[10:19]

array([['<sos>', 'three', 'people', 'sit', 'in', 'a', 'cave', '.', '\n',
        '<eos>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>',
        '<pad>', '<pad>', '<pad>', '<pad>'],
       ['<sos>', 'a', 'girl', 'in', 'a', 'jean', 'dress', 'is',
        'walking', 'along', 'a', 'raised', 'balance', 'beam', '.', '\n',
        '<eos>', '<pad>', '<pad>', '<pad>'],
       ['<sos>', 'a', 'blond', 'holding', 'hands', 'with', 'a', 'guy',
        'in', 'the', 'sand', '.', '\n', '<eos>', '<pad>', '<pad>',
        '<pad>', '<pad>', '<pad>', '<pad>'],
       ['<sos>', 'the', 'person', 'in', 'the', 'striped', 'shirt', 'is',
        'mountain', 'climbing', '.', '\n', '<eos>', '<pad>', '<pad>',
        '<pad>', '<pad>', '<pad>', '<pad>', '<pad>'],
       ['<sos>', 'two', 'men', 'pretend', 'to', 'be', '<unk>', 'while',
        'women', 'look', 'on', '.', '\n', '<eos>', '<pad>', '<pad>',
        '<pad>', '<pad>', '<pad>', '<pad>'],
       ['<sos>', 'people', 'standing', 'outside', 'of', 'a', 'b

In [None]:
token_trans = np.argmax(translation.cpu().numpy(), axis = 2)
translated = np.array([list(map(lambda x: en_vocab.get_itos()[x], token_trans[i])) for i in range(token_trans.shape[0])])

In [None]:
translated[10:19]

array([['<sos>', 'three', 'people', 'are', 'sitting', 'in', 'a', 'cave',
        '.', '\n', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>',
        '<eos>', '<eos>', '<eos>', '<eos>'],
       ['<sos>', 'a', 'girl', 'in', 'a', 'a', 'dress', 'is', 'is', 'a',
        'on', 'beam', 'expanse', 'a', '\n', '\n', '\n', '<eos>', '<eos>',
        '<eos>'],
       ['<sos>', 'a', 'blond', 'giving', 'the', 'hand', 'a', 'in', 'guy',
        '.', '\n', '\n', '\n', '<eos>', '<eos>', '<eos>', '<eos>',
        '<eos>', '<eos>', '<eos>'],
       ['<sos>', 'the', 'person', 'in', 'a', 'striped', 'is', '.', '.',
        '.', '.', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>', '<eos>',
        '<eos>', '<eos>', '<eos>'],
       ['<sos>', 'two', 'men', 'are', 'pretending', 'look', 'be', 'look',
        'connected', 'are', 'while', 'watch', '.', '.', '<eos>', '<eos>',
        '<eos>', '<eos>', '<eos>', '<eos>'],
       ['<sos>', 'people', 'standing', 'standing', 'in', 'of', 'of',
        'building', '\n', 