In [1]:
!pip install -q jupyter-black neptune


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m24.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [2]:
# enable cloud logging
enable_neptune = False
embedding_vector_path = "/kaggle/input/l2a-my-model-without-sos/embedding_vector.pt"

if enable_neptune:
    # for kaggle
    from kaggle_secrets import UserSecretsClient
    
    user_secrets = UserSecretsClient()
    secret_value_0 = user_secrets.get_secret("neptune_key")
    project = "riyadhrazzaq/learning-to-ask"

In [3]:
%load_ext jupyter_black
%load_ext autoreload
%autoreload 2

# download data
!git clone https://github.com/xinyadu/nqg/

# download utiliy scripts and add to path
import sys

!git clone https://github.com/riyadhrazzaq/learntoask-paper-replication

sys.path.append("/kaggle/working/learntoask-paper-replication")


from pathlib import Path
import io
import os
from typing import List
from statistics import mean

# these are utility scripts
import datahandler as dh
from tokenization import Tokenizer

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset
import torchtext
from torchtext.vocab import build_vocab_from_iterator, GloVe

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from nltk.translate.bleu_score import corpus_bleu

import neptune

device = "cuda" if torch.cuda.is_available() else "cpu"

config = {"src_max_seq": 40, "tgt_max_seq": 10, "batch_size": 64, "lr": 1.0}

Cloning into 'nqg'...
remote: Enumerating objects: 213, done.[K
^Cceiving objects:  25% (54/213), 18.10 MiB | 2.61 MiB/s
Cloning into 'learntoask-paper-replication'...
remote: Enumerating objects: 79, done.[K
remote: Counting objects: 100% (79/79), done.[K
remote: Compressing objects: 100% (57/57), done.[K
remote: Total 79 (delta 38), reused 63 (delta 22), pack-reused 0[K
Receiving objects: 100% (79/79), 12.42 MiB | 2.55 MiB/s, done.
Resolving deltas: 100% (38/38), done.


  from .autonotebook import tqdm as notebook_tqdm


# Data

In [8]:
# load data and build vocab
data_root = Path("/kaggle/working/nqg/data/processed/")

with open(data_root / "src-train.txt") as f:
    src_train = [line.strip() for line in f]

with open(data_root / "tgt-train.txt") as f:
    tgt_train = [line.strip() for line in f]

with open(data_root / "src-dev.txt") as f:
    src_dev = [line.strip() for line in f]

with open(data_root / "tgt-dev.txt") as f:
    tgt_dev = [line.strip() for line in f]

with open(data_root / "src-test.txt") as f:
    src_test = [line.strip() for line in f]

with open(data_root / "tgt-test.txt") as f:
    tgt_test = [line.strip() for line in f]

print(len(src_train), len(src_dev), len(src_test))

# loads data from source and target files
# keeps top 45K and 28K from the files respectively
# then unifies them into a single `torchtext.vocab.Vocab` object
vocab = dh.load_and_build_vocab(
    data_root / "src-train.txt", data_root / "tgt-train.txt"
)

pad_index = vocab["<PAD>"]
sos_index = vocab["<SOS>"]
eos_index = vocab["<EOS>"]

70484 10570 11877


In [9]:
tokenizer = Tokenizer(vocab, pad_index, sos_index, eos_index)

src_train_tensor, src_train_mask = tokenizer.encode(
    src_train, max_seq=config["src_max_seq"]
)

tgt_train_tensor, tgt_train_mask = tokenizer.encode(
    tgt_train, add_eos=True, max_seq=config["tgt_max_seq"]
)

src_test_tensor, src_test_mask = tokenizer.encode(
    src_test, max_seq=config["src_max_seq"]
)
tgt_test_tensor, tgt_test_mask = tokenizer.encode(
    tgt_test, add_eos=True, max_seq=config["tgt_max_seq"]
)

src_dev_tensor, src_dev_mask = tokenizer.encode(src_dev, max_seq=config["src_max_seq"])
tgt_dev_tensor, tgt_dev_mask = tokenizer.encode(
    tgt_dev, add_eos=True, max_seq=config["tgt_max_seq"]
)


# visualize sample
idx = torch.randint(0, 1000, size=(1,)).item()
print(src_train[idx], "\n")
print(
    tokenizer.decode(src_train_tensor[idx].unsqueeze(dim=0), keep_specials=True), "\n"
)
print(tgt_train[idx], "\n")
print(tokenizer.decode(tgt_train_tensor[idx].unsqueeze(dim=0), keep_specials=True))

print(src_train_tensor.shape, tgt_train_tensor.shape)

the most important non-european donor state was japan -lrb- $ 36 m -rrb- . 

['the most important non-european donor state was japan -lrb- $ 36 m -rrb- . <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD>'] 

which non-european donor was most important to the unfpa in 2008 ? 

['which non-european donor was most important to the unfpa <EOS>']
torch.Size([70484, 40]) torch.Size([70484, 10])


In [10]:
class SentenceQuestionDataset(Dataset):
    def __init__(
        self,
        sentences: torch.Tensor,
        questions: torch.Tensor,
        sentences_mask=None,
        questions_mask=None,
    ):
        """
        Represents a dataset of text pairs for training or evaluating models that
        deal with relationships between text passages.

        Args:
            vocab (torchtext.vocab.Vocab): A pre-built vocabulary object
                containing word mappings from text to numerical representation.
            sentences (List[str]): A list of text passages (sentences, paragraphs, etc.).
            questions (List[str]): A list of corresponding questions related to the sentences.
            Ls (int, optional): The maximum length to which sentences will be
                truncated or padded during preprocessing (default: 150).
            Lq (int, optional): The maximum length to which questions will be
                truncated or padded during preprocessing (default: 50).
        """
        self.sentences = sentences
        self.questions = questions
        self.sentences_mask = sentences_mask
        self.questions_mask = questions_mask

    def __len__(self):
        return self.sentences.size(0)

    def __getitem__(self, index):
        return (
            self.sentences[index],
            self.questions[index],
            self.sentences_mask[index],
            self.questions_mask[index],
        )

In [11]:
train_ds = SentenceQuestionDataset(
    src_train_tensor, tgt_train_tensor, src_train_mask, tgt_train_mask
)
test_ds = SentenceQuestionDataset(
    src_test_tensor, tgt_train_tensor, src_test_mask, tgt_test_mask
)
dev_ds = SentenceQuestionDataset(
    src_dev_tensor, tgt_dev_tensor, src_dev_mask, tgt_dev_mask
)

train_dl = DataLoader(train_ds, batch_size=config["batch_size"], shuffle=True)

# keep shuffling false because pplx calculation uses both dev_dl and tgt_dev
# hence, an order is necessary
dev_dl = DataLoader(dev_ds, batch_size=8, shuffle=False)
test_dl = DataLoader(test_ds, batch_size=32, shuffle=False)

# Models
Defines attention, encoder, decoder models and finally, a seq2seq model which combines all of them

In [13]:
if Path(embedding_vector_path).is_file():
    embedding_vector = torch.load(embedding_vector_path)
else:
    embedding_vector = torch.zeros(size=(len(vocab), 300))
    glove = GloVe(name="840B", dim=300)
    for index in range(len(vocab)):
        embedding_vector[index] = glove[vocab.lookup_token(index)]
    torch.save(embedding_vector, "embedding_vector.pt")

In [14]:
class Attention(nn.Module):
    def __init__(self, encoder_hidden_size, decoder_hidden_size):
        super().__init__()

        self.projection_layer = nn.Linear(encoder_hidden_size, decoder_hidden_size)

    def forward(self, encoder_output, decoder_output):
        """
        Args:
            encoder_output (torch.Tensor): (N, L, encoder_hidden_size)
            decoder_output (torch.Tensor): (N, 1, decoder_hidden_size)


        Returns:
            score (torch.Tensor): (N, L, 1)
        """
        # => (N, L, decoder_hidden_size)
        projection = self.projection_layer(encoder_output)
        # => (N, L, 1)
        score = projection @ decoder_output.transpose(1, 2)
        score = F.softmax(score, dim=1)

        return score

In [15]:
class Encoder(nn.Module):
    def __init__(
        self,
        embedding,
        embedding_dim,
        hidden_dim=8,
        bidirectional=False,
        num_layers=1,
    ):
        super().__init__()
        self.embedding = embedding
        self.encoder = nn.Sequential(
            self.embedding,
            nn.LSTM(
                input_size=embedding_dim,
                hidden_size=hidden_dim,
                batch_first=True,
                bidirectional=bidirectional,
                num_layers=num_layers,
                dropout=0.3 if num_layers > 1 else 0.0,
            ),
        )

    def forward(self, src: torch.Tensor):
        """
        Args:
            src (torch.Tensor): (N, Ls) A batch of source sentences represented as tensors.
        """
        # encoder_representation (N, Ls, d), hT => cT => (#direction * #layer, N, d) : hidden states from the last timestep
        encoder_out, (last_hidden_state, last_cell_state) = self.encoder(src)
        return encoder_out, last_hidden_state

In [16]:
class Decoder(nn.Module):
    def __init__(
        self,
        vocab_size,
        embedding,
        embedding_dim,
        hidden_dim=8,
        encoder_bidirectional=False,
        num_layers=1,
    ):
        super().__init__()
        self.embedding = embedding
        self.lstm = nn.LSTM(
            input_size=embedding_dim,
            hidden_size=hidden_dim,
            batch_first=True,
            bidirectional=False,
            num_layers=num_layers,
            dropout=0.3 if num_layers > 1 else 0.0,
        )

        self.attention = Attention(
            hidden_dim * 2 if encoder_bidirectional else hidden_dim, hidden_dim
        )

        self.decoder_linear = nn.Sequential(
            # 3*hidden_dim because decoder_out and source context will be concatenated
            # this layer is Eq 5 in the Luong et. al. paper
            nn.Linear(
                hidden_dim * 3 if encoder_bidirectional else hidden_dim, hidden_dim
            ),
            nn.Tanh(),
            nn.Linear(hidden_dim, vocab_size),
        )

    def forward(self, encoder_out, target, last_hidden_state, last_cell_state):
        x = self.embedding(target)
        # => N, 1, d
        output, (ht, ct) = self.lstm(x, (last_hidden_state, last_cell_state))
        # => (N, Ls, 1)
        score = self.attention(encoder_out, output)
        # (N, Ls, 1) x (N, Ls, DH) => (N, Ls, DH) => (N, 1, DH)
        # Eq 4 from the Du et. al. paper (learning to ask)
        attn_based_ctx = (score * encoder_out).sum(dim=1).unsqueeze(dim=1)
        # => (N, 1, d ) & (N, 1, DH)
        concatenated = torch.cat((output, attn_based_ctx), dim=2).squeeze()
        # => (N, vocab_size)
        logit = self.decoder_linear(concatenated)

        return logit, (ht, ct)

In [79]:
class Seq2Seq(nn.Module):
    def __init__(
        self,
        vocab_size,
        embedding_vector,
        embedding_dim,
        pad_index,
        sos_index,
        eos_index,
        hidden_dim=8,
        bidirectional=True,
        num_layers=2,
    ):

        super().__init__()
        self.sos_index = sos_index
        self.pad_index = pad_index
        self.eos_index = eos_index

        self.num_layers = num_layers
        if embedding_vector is None:
            self.embedding = nn.Embedding(vocab_size, embedding_dim)
        else:
            self.embedding = nn.Embedding.from_pretrained(embedding_vector)
        self.encoder = Encoder(
            self.embedding,
            embedding_dim,
            hidden_dim,
            bidirectional,
            num_layers,
        )
        self.decoder = Decoder(
            vocab_size,
            self.embedding,
            embedding_dim,
            hidden_dim,
            bidirectional,
            num_layers,
        )

    def forward(self, source, target):
        encoder_out, h = self.encoder(source)
        h = h[: self.num_layers]
        c = torch.randn_like(h, device=device)
        max_seq = target.size(1)
        decoder_input = torch.full(
            (source.size(0), 1), self.sos_index, device=device, dtype=torch.long
        )
        logits = []

        for t in range(max_seq):
            logit, (h, c) = self.decoder(encoder_out, decoder_input, h, c)
            decoder_input = target[:, t].view(-1, 1)
            logits.append(logit)

        return torch.stack(logits, dim=1)

    def generate_batch(self, source, method="greedy", max_seq=15):
        encoder_out, h = self.encoder(source)
        h = h[: self.num_layers]
        c = torch.randn_like(h, device=device)

        decoder_input = torch.full(
            (source.size(0), 1), self.sos_index, device=device, dtype=torch.long
        )
        logits = []
        outputs = torch.full(
            (source.size(0), max_seq), self.pad_index, device=device, dtype=torch.long
        )

        for t in range(max_seq):
            logit, (h, c) = self.decoder(encoder_out, decoder_input, h, c)

            most_probable_tokens = torch.max(logit, dim=1)[1]
            outputs[:, t] = most_probable_tokens
            decoder_input = most_probable_tokens.view(-1, 1)

        return outputs

# Training Utils

In [18]:
def masked_cross_entropy(hypotheses, reference, mask):
    crossEntropy = torch.nn.functional.cross_entropy(
        hypotheses.transpose(1, 2), reference, reduction="none"
    )
    loss = crossEntropy.masked_select(mask).mean()
    loss = loss.to(device)
    return loss

In [92]:
def save_checkpoint(model, optimizer, epoch, lr_scheduler, checkpoint_dir):
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    path = Path(checkpoint_dir) / "model_best.pt"

    torch.save(
        {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "lr_scheduler_state_dict": lr_scheduler.state_dict(),
        },
        path,
    )


def pplx(logits, tgt, mask):
    loss = masked_loss(logits, tgt, mask)
    return torch.exp(loss)


def validation(model, valid_dl, max_step):
    model.eval()
    loss_across_batches = []
    pplx_across_batches = []

    with torch.no_grad():
        bar = tqdm(valid_dl, leave=False)
        for step_no, batch in enumerate(bar):
            batch = [data.to(device) for data in batch]

            # calculate loss on valid
            loss, logits = step(model, batch)
            loss_across_batches.append(loss.item())

            pplx_across_batches.append(pplx(logits, batch[1], batch[3]).item())

            del batch

            if step_no == max_step:
                break

        bar.close()

    return {"loss": mean(loss_across_batches), "pplx": mean(pplx_across_batches)}


def masked_loss(logits, tgt, mask):
    return masked_cross_entropy(logits, tgt, mask)


def step(model, batch):
    src, tgt, _, tgt_mask = batch

    logits = model(src, tgt)
    loss = masked_loss(logits, tgt, tgt_mask)
    return loss, logits


def fit(
    model: nn.Module,
    optimizer,
    train_dl: torch.utils.data.DataLoader,
    valid_dl: torch.utils.data.DataLoader,
    tokenizer,
    config: dict,
    lr_scheduler=None,
    checkpoint_dir="./checkpoint",
    max_step=-1,
    validation_data: List[str] = None,
    experiment_name=None,
):
    if enable_neptune:
        run = neptune.init_run(
            project="riyadhrazzaq/learning-to-ask",
            name=experiment_name,
            api_token=secret_value_0,
        )
        run["parameters"] = config

    best_pplx = float("-inf")

    history = {
        "loss/train": [],
        "pplx/valid": [],
        "loss/valid": [],
        "train/epoch/lr": [],
    }

    for epoch in range(1, config["max_epoch"] + 1):
        model.train()
        loss_across_batches = []
        bar = tqdm(train_dl, unit="batch")

        for step_no, batch in enumerate(bar):
            batch = [data.to(device) for data in batch]
            optimizer.zero_grad()
            loss, logits = step(model, batch)

            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5)
            optimizer.step()

            loss_across_batches.append(loss.item())

            # show each batch loss in tqdm bar
            bar.set_postfix(**{"loss": loss.item()})
            if enable_neptune:
                run["train/batch/loss"].append(loss.item())

            # skip training on the entire training dataset
            if step_no == max_step:
                break

        lr_scheduler.step()
        validation_metrics = validation(model, valid_dl, max_step)

        history["loss/train"].append(mean(loss_across_batches))
        history["loss/valid"].append(validation_metrics["loss"])
        history["pplx/valid"].append(validation_metrics["pplx"])
        history["train/epoch/lr"].append(lr_scheduler.get_last_lr()[0])

        if validation_metrics["pplx"] > best_pplx:
            best_pplx = validation_metrics["pplx"]
            save_checkpoint(model, optimizer, epoch, lr_scheduler, checkpoint_dir)
            print("🎉 best pplx reached, saved a checkpoint :)")

        log(epoch, history, run if enable_neptune else None)

    if enable_neptune:
        run.stop()


def log(epoch, history, run):
    if enable_neptune:
        run["train/epoch/loss"].append(history["loss/train"][-1])
        run["valid/epoch/pplx"].append(history["pplx/valid"][-1])
        run["valid/epoch/loss"].append(history["loss/valid"][-1])
        run["train/epoch/lr"].append(history["train/epoch/lr"][-1])

    print(
        f"Epoch: {epoch},\tTrain Loss: {history['loss/train'][-1]},\tVal Loss: {history['loss/valid'][-1]}\tVal Perplexity: {history['pplx/valid'][-1]}\tLR: {history['train/epoch/lr'][-1]}"
    )

In [75]:
config["lr"] = 1.0
config["embedding_dim"] = 300
config["hidden_dim"] = 600
config["num_layers"] = 2
config["max_epoch"] = 15

net = Seq2Seq(
    len(vocab),
    embedding_vector,
    config["embedding_dim"],
    pad_index,
    sos_index,
    eos_index,
    hidden_dim=config["hidden_dim"],
    bidirectional=True,
    num_layers=config["num_layers"],
)
net = net.to(device)
optimizer = torch.optim.SGD(net.parameters(), lr=config["lr"])
# halve the learning rate once reaching epoch 8 as the original paper
lr_scheduler = torch.optim.lr_scheduler.MultiplicativeLR(
    optimizer, lr_lambda=lambda epoch: 0.5 if epoch > 8 else 1.0
)

fit(
    model=net,
    optimizer=optimizer,
    train_dl=train_dl,
    valid_dl=dev_dl,
    tokenizer=tokenizer,
    config=config,
    lr_scheduler=lr_scheduler,
    validation_data=tgt_dev,
    max_step=4,
    experiment_name="paper-config",
)

  0%|          | 4/1102 [00:00<03:16,  5.58batch/s, loss=10.6]
                                        

Epoch: 1,	Train Loss: 10.721875953674317,	Val Loss: 10.49915771484375	Val Perplexity: 36291.303125	LR: 1.0


  0%|          | 4/1102 [00:00<03:12,  5.71batch/s, loss=9.14]
                                        

Epoch: 2,	Train Loss: 9.952142906188964,	Val Loss: 10.290363693237305	Val Perplexity: 29558.090234375	LR: 1.0


  0%|          | 4/1102 [00:00<03:12,  5.71batch/s, loss=8.77]
                                        

Epoch: 3,	Train Loss: 9.623550033569336,	Val Loss: 8.929525184631348	Val Perplexity: 7665.77412109375	LR: 1.0


  0%|          | 4/1102 [00:00<03:12,  5.71batch/s, loss=10.1]
                                        

Epoch: 4,	Train Loss: 8.899108123779296,	Val Loss: 9.899304962158203	Val Perplexity: 20165.2037109375	LR: 1.0


  0%|          | 4/1102 [00:00<03:12,  5.71batch/s, loss=8.45]
                                        

Epoch: 5,	Train Loss: 9.121878814697265,	Val Loss: 8.22751750946045	Val Perplexity: 3839.845556640625	LR: 1.0


  0%|          | 4/1102 [00:00<03:12,  5.71batch/s, loss=7.77]
                                        

Epoch: 6,	Train Loss: 8.02920036315918,	Val Loss: 7.520333671569825	Val Perplexity: 1999.2869262695312	LR: 1.0


  0%|          | 4/1102 [00:00<03:12,  5.71batch/s, loss=7.7] 
                                        

Epoch: 7,	Train Loss: 7.646542358398437,	Val Loss: 7.635389423370361	Val Perplexity: 2234.5686767578127	LR: 1.0


  0%|          | 4/1102 [00:00<03:12,  5.69batch/s, loss=8.07]
                                        

Epoch: 8,	Train Loss: 7.597735118865967,	Val Loss: 7.397297191619873	Val Perplexity: 1779.5961791992188	LR: 1.0


  0%|          | 4/1102 [00:00<03:15,  5.63batch/s, loss=7.41]
                                        

Epoch: 9,	Train Loss: 7.6055426597595215,	Val Loss: 7.297848033905029	Val Perplexity: 1606.0508666992187	LR: 0.5


  0%|          | 4/1102 [00:00<03:13,  5.67batch/s, loss=7.17]
                                        

Epoch: 10,	Train Loss: 7.3084314346313475,	Val Loss: 7.190912246704102	Val Perplexity: 1438.8784912109375	LR: 0.25


  0%|          | 4/1102 [00:00<03:13,  5.68batch/s, loss=6.97]
                                        

Epoch: 11,	Train Loss: 7.059575176239013,	Val Loss: 7.1637575149536135	Val Perplexity: 1407.5422485351562	LR: 0.125


  0%|          | 4/1102 [00:00<03:12,  5.71batch/s, loss=7.2] 
                                        

Epoch: 12,	Train Loss: 7.211551380157471,	Val Loss: 7.158767986297607	Val Perplexity: 1395.7316040039063	LR: 0.0625


  0%|          | 4/1102 [00:00<03:13,  5.67batch/s, loss=7.15]
                                        

Epoch: 13,	Train Loss: 7.138396167755127,	Val Loss: 7.15666446685791	Val Perplexity: 1394.2253295898438	LR: 0.03125


  0%|          | 4/1102 [00:00<03:20,  5.48batch/s, loss=6.99]
                                        

Epoch: 14,	Train Loss: 7.067730140686035,	Val Loss: 7.15096378326416	Val Perplexity: 1385.732958984375	LR: 0.015625


  0%|          | 4/1102 [00:00<03:13,  5.68batch/s, loss=6.97]
                                        

Epoch: 15,	Train Loss: 7.131023311614991,	Val Loss: 7.153767871856689	Val Perplexity: 1390.0201049804687	LR: 0.0078125




In [89]:
def bleu(token_ids, references):
    hypotheses = tokenizer.decode(token_ids, keep_specials=False)
    return corpus_bleu(references, hypotheses)


def evaluate(model, source: List[str] = None, src_dl=None, target: List[str] = None):
    assert source or src_dl, "provide at least one of source or src_dl"
    batch_size = 8
    if src_dl is None:
        src_tensors = tokenizer.encoder(text, max_seq=config["src_max_seq"])
        src_ds = TensorDataset(src_tensors)
        src_dl = DataLoader(src_ds, shuffle=False, batch_size=batch_size)

    outputs = []
    bleus = []
    i = 0
    for src_tensor in tqdm(src_dl):
        # when src_dl is built on SentenceQuestionDataset
        if isinstance(src_tensor, list):
            src_tensor = src_tensor[0]
        if target:
            batch_target = target[i : i + batch_size]

        src_tensor = src_tensor.to(device)

        logits = model.generate_batch(src_tensor)

        decoded_target = tokenizer.decode(logits, keep_specials=False)
        outputs.extend(decoded_target)

        if target:
            bleus.append(corpus_bleu(decoded_target, batch_target))
        i += batch_size

    return outputs, mean(bleus)


net_test = Seq2Seq(
    len(vocab),
    embedding_vector,
    config["embedding_dim"],
    pad_index,
    sos_index,
    eos_index,
    hidden_dim=config["hidden_dim"],
    bidirectional=True,
    num_layers=config["num_layers"],
)

net_test = net_test.to(device)
evaluate(net_test, src_dl=dev_dl, target=tgt_dev)

Corpus/Sentence contains 0 counts of 2-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().
100%|██████████| 1322/1322 [02:33<00:00,  8.63it/s]


(['andreou levantine levantine nishan-e-pakistan dalcom dalcom dalcom dalcom dalcom dalcom dalcom dalcom dalcom dalcom dalcom',
  'dictated dictated folk loosely loosely dalcom dalcom isle isle isle isle isle isle isle isle',
  'procopius mockingbirds hamadhan hamadhan carrillo carrillo isle isle isle isle isle isle isle isle isle',
  '1080i50 1080i50 1080i50 dalcom dalcom dalcom dalcom dalcom dalcom dalcom dalcom dalcom dalcom dalcom dalcom',
  'legislatures legislatures legislatures lammy lammy atristain atristain dalcom dalcom dalcom dalcom dalcom dalcom dalcom dalcom',
  'stamp nixon carrillo carrillo epidemic epidemic isle isle isle isle isle isle isle isle isle',
  'franco franco franco sadat dalcom dalcom dalcom dalcom dalcom dalcom dalcom dalcom dalcom dalcom dalcom',
  'belousava talmud contaminant contaminant causa dalcom dalcom dalcom dalcom dalcom dalcom dalcom dalcom dalcom dalcom',
  'academics academics honoured prize-winning mockingbirds mockingbirds dalcom dalcom dalco

In [90]:
def load_checkpoint(model, checkpoint_path, optimizer=None, lr_scheduler=None):
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint["model_state_dict"])

        if optimizer:
            optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        if lr_scheduler:
            lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"])

        print(f"loaded existing model. epoch: {checkpoint['epoch']}")

    else:
        raise Exception("No checkpoint found in the provided path")


load_checkpoint(
    net_test, "/kaggle/input/l2a-my-model-without-sos/checkpoint/model_best.pt"
)

loaded existing model. epoch: 1


In [93]:
questions_generated, bleu = evaluate(net_test, src_dl=dev_dl, target=tgt_dev)

Corpus/Sentence contains 0 counts of 2-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().
100%|██████████| 1322/1322 [01:12<00:00, 18.18it/s]


In [95]:
for i in range(5):
    idx = torch.randint(0, len(tgt_dev), (1,)).item()
    print(src_dev[idx], "\n")
    print(tgt_dev[idx], "\n")
    print(questions_generated[idx], "\n")
    print()

the scottish parliament has the power to pass laws and has limited tax-varying capability . 

who has the role of holding the scottish government to account ? 

what many the name of the first of the 


despite denver 's excellent field position , they could not get the ball into the end zone , so mcmanus kicked a 33-yard field goal that increased their lead to 13 -- 7 . 

how many yards was the field goal that made the score 13-7 in super bowl 50 ? 

what did the name of the first of the 


the congregation was founded in 1767 , meeting initially in a sail loft on dock street , and in 1769 it purchased the shell of a building which had been erected in 1763 by a german reformed congregation . 

where did the congregation at st. george 's initially meet in 1767 ? 

what was the name of the first of the 


this teaching by luther was clearly expressed in his 1525 publication on the bondage of the will , which was written in response to on free will by desiderius erasmus -lrb- 1524 -rrb- 

In [97]:
def beam_search(
    model: torch.nn.Module,
    vocab: torchtext.vocab.Vocab,
    prompt: torch.Tensor,
    start_token: str,
    end_token: str,
    beam_width: int,
    max_len: int,
):
    """
    A PyTorch based Beam Search Implementation

    Parameters
    ----------
    model: torch.nn.Module
        model to generate text
    vocab: torchtext.vocab.Vocab
        vocabulary.
    prompt: torch.Tensor, shape (1, prompt_len)
        prompt to start with
    start_token: str
        start token
    end_token: str
        end token. will terminate when this token is generated
    beam_width: int
        beam width
    max_len: int
        maximum length of the generated text

    Returns
    -------
    text: List[str]
        generated text
    prob: float
        probability of the generated text
    """
    model.to(device=device)
    # setup
    k = beam_width
    start_index = vocab[start_token]
    end_index = vocab[end_token]
    # at first, begin with k possible beams starting with `start_token`
    prefix = [[start_index] for _ in range(k)]
    # move prompt to device if not already
    prompt = prompt.to(device)
    if prompt.dim() == 1:
        prompt = prompt[:, None]
    prompt = prompt.tile((k, 1))
    vocab_size = len(vocab)
    # initial probability of the `start_token` is 1
    probs = torch.ones((k,), device=device, dtype=torch.float32)

    with torch.no_grad():
        for t in range(1, max_len + 1):
            # for t = 1, we will only examine 1 * vocab_size probabilities
            x = torch.tensor([prefix[0]] if t == 1 else prefix, device=device)
            output = torch.softmax(
                model(prompt if t > 1 else prompt[0, None], x)[0], dim=2
            )
            # only interested in the last step's prediction
            output = output[:, -1, :]  # (k or 1, vocab_size)

            # uses broadcasting to perform the following loop
            # ```
            # for i in range(k):
            #         output[i] = output[i] * probs[i] probs[i] is scalar
            # ```
            if t > 1:
                output = output * probs[:, None]

            # when flattened,
            # argmax % vocab_size will give the candidate token index and
            # argmax // vocab_size will give the prefix index,
            # i.e. prefix[argmax // vocab_size] + [argmax % vocab_size] is one
            # of the prefixes inside the beam
            next_token = torch.topk(output.flatten(), k)

            new_prefixes = []

            # iterate through the top k candidates. update the existing prefix and calculate the probability
            for i, (token_prob, token_index) in enumerate(
                zip(next_token.values, next_token.indices)
            ):
                # _k is the index in `prefix`, which gave the ith candidate token
                _k = token_index.item() // vocab_size
                # copying that prefix to append the new token
                previous_prefix = prefix[_k].copy()
                new_token_index = token_index.item() % vocab_size

                previous_prefix.append(new_token_index)
                new_prefixes.append(previous_prefix)
                # update the probability of the new prefix
                probs[i] = token_prob.item()

                # check if end token is generated
                if new_token_index == end_index:
                    return vocab.lookup_tokens(previous_prefix), token_prob.item()

            prefix = new_prefixes

    # prefix is always sorted because `new_token` is sorted when defined
    return vocab.lookup_tokens(prefix[0]), probs[0]

In [105]:
text = [
    "the scottish parliament has the power to pass laws and has limited tax-varying capability . "
]
src_token_ids = tokenizer.encode(text, max_seq=40)[0]
print(src_token_ids.shape)
output, prob = beam_search(net_test,
                           vocab, prompt=src_token_ids, 
                           start_token=vocab.lookup_token(sos_index),
                           end_token=vocab.lookup_token(eos_index),
                           beam_width=3, max_len=15)
print(prob, output)

torch.Size([1, 40])


IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 2)