In [1]:
import torch
import torch.nn as nn
import math
import numpy as np

import sys
sys.path.insert(0,'..')

### Load data

Here we load the test set for the Multi30k and WMT14 datasets.

In [None]:
from torchtext.datasets import Multi30k, WMT14
from torchtext.data import Field

In [None]:
de = Field(tokenize = "spacy",
            tokenizer_language="de_core_news_sm",
            init_token = '<sos>',
            eos_token = '<eos>',
            lower = True)

en = Field(tokenize = "spacy",
            tokenizer_language="en_core_web_sm",
            init_token = '<sos>',
            eos_token = '<eos>',
            lower = True)

In [None]:
_, _, multi30k = Multi30k.splits(exts = ('.de', '.en'), fields = (de, en))
_, _, wmt14 = WMT14.splits(exts = ('.de', '.en'), fields = (de, en))

In [None]:
de.build_vocab(train_data, min_freq = 2)
en.build_vocab(train_data, min_freq = 2)
print(len(de.vocab))
print(len(en.vocab))

In [None]:
PAD_IDX = en.vocab.stoi['<pad>']

### Load models

In [None]:
from models import lstm, seq2seq_attn, transformer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

In [None]:
state = torch.load("lstm.pt")
params = state["params"]
lstm = lstm.make_model(*params["args"], **params["kwargs"])
lstm.load_state_dict(state["state_dict"])
lstm_losses = state["loss"]

In [None]:
state = torch.load("torch_Seq2Seq.pt")
params = state["params"]
gru = seq2seq_attn.make_model(*params["args"], **params["kwargs"])
gru.load_state_dict(state["state_dict"])
gru_losses = state["loss"]

In [None]:
state = torch.load("harvard_transformer2_state.pt")
params = state["params"]
trans = transformer.make_model(*params["args"], **params["kwargs"])
trans.load_state_dict(state["state_dict"])
trans_losses = state["loss"]

In [None]:
bert2bert = AutoModelForSeq2SeqLM.from_pretrained("google/bert2bert_L-24_wmt_de_en").to(device)
bert2bert_tokenizer = AutoTokenizer.from_pretrained("google/bert2bert_L-24_wmt_de_en", pad_token="<pad>", eos_token="</s>", bos_token="<s>", unk_token="<unk>")

### Training plots

In [None]:
plot_loss_curves("LSTM training curve", lstm_losses["train"], lstm_losses["val"])

In [None]:
plot_loss_curves("GRU training curve", gru_losses["train"], gru_losses["val"])

In [None]:
plot_loss_curves("Transformer training curve", trans_losses["train"], trans_losses["val"])

### Test loss and perplexity

In [None]:
lstm_test_loss = evaluate(lstm, multi30k, nn.CrossEntropyLoss(ignore_index=PAD_IDX))
gru_test_loss = evaluate(gru, multi30k, nn.CrossEntropyLoss(ignore_index=PAD_IDX))
trans_test_loss = evaluate(trans, multi30k, nn.CrossEntropyLoss(ignore_index=PAD_IDX))
bert2bert_test_loss = evaluate(bert2bert, multi30k, nn.CrossEntropyLoss(ignore_index=PAD_IDX))

### BLEU Score

In [None]:
from torchtext.data.metrics import bleu_score

In [None]:
bleu_lstm = eval_bleu(lstm, multi30k)
bleu_gru = eval_bleu(gru, multi30k)
bleu_trans = eval_bleu(trans, multi30k)
bleu_bert2bert = eval_bleu(bert2bert, multi30k)

In [None]:
bleu_lstm = eval_bleu(lstm, wmt14)
bleu_gru = eval_bleu(gru, wmt14)
bleu_trans = eval_bleu(trans, wmt14)
bleu_bert2bert = eval_bleu(bert2bert, wmt14)

### Sample translations

#### Short

In [None]:
sentence = "eine gruppe von menschen steht vor einem iglu ."
real_translation = "a group of people stands in front of an igloo ."

#### Long

In [2]:
sentence = "ein mann mit kariertem hut in einer schwarzen jacke und einer schwarz-weiß gestreiften hose spielt auf einer bühne mit einem sänger und einem weiteren gitarristen im hintergrund auf einer e-gitarre ."
real_translation = "a man in a black jacket and checkered hat wearing black and white striped pants plays an electric guitar on a stage with a singer and another guitar player in the background ."

#### Attention and word probabilities

# Knowledge distillation

In [None]:
teacher = bert2bert
student = seq2seq_attn.make_model(7854,5893,256,256,512,512,0.5,0.5,64, device=device)