In [1]:
import random
%reload_ext autoreload
%autoreload 2
# %load_ext jupyter_black

import sys

sys.path.append("../src")
sys.path.append("..")

from src.trainutil import *

import yaml

data_root = "../data"
device = "cuda" if torch.cuda.is_available() else "cpu"



In [2]:
experiment_dir = Path("../checkpoints/paper")
with open(experiment_dir / "history/config.yaml", "r") as stream:
    cfg = yaml.safe_load(stream)

In [3]:
src_tokenizer = torch.load(experiment_dir / "src_tokenizer.pt")
tgt_tokenizer = torch.load(experiment_dir / "tgt_tokenizer.pt")
src_vocab = src_tokenizer.vocab
tgt_vocab = tgt_tokenizer.vocab

In [4]:
model = Seq2Seq(
    src_vocab_size=len(src_vocab),
    tgt_vocab_size=len(tgt_vocab),
    src_embedding_vector=None,
    tgt_embedding_vector=None,
    tgt_pad_index=tgt_vocab["<PAD>"],
    tgt_sos_index=tgt_vocab["<SOS>"],
    tgt_eos_index=tgt_vocab["<EOS>"],
    hidden_size=cfg["hidden_size"],
    bidirectional=cfg["bidirectional"],
    num_layers=cfg["num_layers"],
    src_embedding_size=cfg["src_embedding_size"],
    tgt_embedding_size=cfg["tgt_embedding_size"],
    dropout=cfg["dropout"],
)
model.to(device)

Seq2Seq(
  (src_embedding): Embedding(45000, 300)
  (tgt_embedding): Embedding(28000, 300)
  (encoder): Encoder(
    (embedding): Embedding(45000, 300)
    (layers): Sequential(
      (0): Embedding(45000, 300)
      (1): LSTM(300, 600, num_layers=2, batch_first=True, dropout=0.3, bidirectional=True)
    )
  )
  (decoder): Decoder(
    (embedding): Embedding(28000, 300)
    (lstm): LSTM(300, 600, num_layers=2, batch_first=True, dropout=0.3)
    (attention): Attention(
      (projection_layer): Linear(in_features=1200, out_features=600, bias=True)
    )
    (decoder_linear): Sequential(
      (0): Linear(in_features=1800, out_features=600, bias=True)
      (1): Tanh()
      (2): Linear(in_features=600, out_features=28000, bias=True)
    )
  )
)

In [5]:
model, _, _, epoch = load_checkpoint(model, experiment_dir / "model_best.pt")

2024-05-06 16:30:16,667 🎉 Loaded existing model. Epoch: 13


In [71]:
with open(f"{data_root}/dev.src") as srcfile:
    sources = srcfile.readlines()

with open(f"{data_root}/dev.tgt") as tgtfile:
    references = tgtfile.readlines()

In [91]:
def beam_search(module: nn.Module, src_token2index, trg_vocab, trg_edge_index, src_tokens, src_mask, max_len,
                  beam_size):
    """
    Args:
        model: 
        src_token2index: dict
        src_pad_index:  int
        trg_vocab: list
        trg_edge_index: int 
        src_tokens: ['I', 'hate', 'it', '.']
        max_len: int
        beam_size: int

    Returns:

    """
    # (1, src_seq)
    src_indexed = torch.tensor(
        [[src_token2index[token] for token in src_tokens]],
        dtype=torch.int64, device=device
    )

    # (1,1)
    # beam_prefixes_indexed = torch.tensor([[trg_edge_index]], dtype=torch.int64, device=device)
    beam_prefixes_indexed = torch.full(
        (1, 1), module.tgt_sos_index, device=device, dtype=torch.long
    )
    # (1, )
    beam_prefixes_probs = np.array([1.0], np.float32)

    best_full_prefix_indexed = None
    best_full_prefix_prob = None

    with torch.no_grad():
        encoder_out, h = module.encoder(src_indexed)
        h = h[: module.num_layers]
        c = torch.zeros_like(h, device=device)
        
        for t in range(max_len):
            # print(encoder_out.shape, h.shape, c.shape, beam_prefixes_indexed.shape)
            if t == 1:
                h = h.tile(1, beam_prefixes_indexed.size(0), 1)
                c = c.tile(1, beam_prefixes_indexed.size(0), 1)
            output, (h, c), _ = module.decoder(
                encoder_out, beam_prefixes_indexed[:, -1].view(-1, 1), h, c, src_mask
            )

            # (beam_size, tgt_vocab_size)
            token_probs = output.squeeze(-1)
            # (beam_size, tgt_vocab)  = (1, 1) * (beam_size, tgt_vocab)
            new_prefixes_probs = beam_prefixes_probs[:, None] * token_probs.cpu().numpy()
            new_partial_prefixes = []
            for (prefix, probs_group) in zip(beam_prefixes_indexed.cpu().tolist(), new_prefixes_probs.tolist()):
                # single token_id, token_prob
                for (next_token_index, prefix_prob) in enumerate(probs_group):
                    if next_token_index == trg_edge_index:
                        if best_full_prefix_prob is None or prefix_prob > best_full_prefix_prob:
                            best_full_prefix_indexed = prefix + [next_token_index]
                            best_full_prefix_prob = prefix_prob
                    else:
                        new_partial_prefixes.append((prefix_prob, prefix + [next_token_index]))

            new_partial_prefixes.sort(reverse=True)
            (best_partial_prefix_prob, _) = new_partial_prefixes[0]
            if best_full_prefix_prob > best_partial_prefix_prob:
                text = [trg_vocab[index] for index in best_full_prefix_indexed]
                return (text, best_full_prefix_prob)

            new_beam = new_partial_prefixes[:beam_size]
            beam_prefixes_indexed = torch.tensor([prefix for (prob, prefix) in new_beam], dtype=torch.int64,
                                                 device=device)
            beam_prefixes_probs = np.array([prob for (prob, prefix) in new_beam], np.float32)
            

    text = [trg_vocab[index] for index in beam_prefixes_indexed[0, :].cpu().tolist()]
    return text, beam_prefixes_probs[0]

In [61]:
print(sources[0], references[0], sep='\n')

the american football conference -lrb- afc -rrb- champion denver broncos defeated the national football conference -lrb- nfc -rrb- champion carolina panthers 24 -- 10 to earn their third super bowl title . 

which nfl team represented the afc at super bowl 50 ?


In [83]:
src_token_ids, src_mask = src_tokenizer.encode(
        sources[0], max_seq=cfg["src_max_seq"]
    )
src_token_ids = src_token_ids.to(device)
src_mask = src_mask.to(device)
beam_search(module=model,
              src_token2index=tgt_vocab,
              trg_vocab=tgt_vocab.get_itos(),
              trg_edge_index=tgt_vocab['<EOS>'],
              src_tokens=sources[0].split(),
              src_mask=src_mask,
              max_len=14,
              beam_size=3)

(['<SOS>',
  'what',
  'type',
  'of',
  'the',
  'church',
  'want',
  'to',
  '<UNK>',
  'the',
  '<UNK>',
  '<UNK>',
  '<EOS>'],
 8593216634880.0)

In [127]:
idx = random.randint(0, 10000)
# nucleus
hyp_nucleus, _ = generate(model, sources[idx: idx+2], src_tokenizer, tgt_tokenizer, cfg, method="nucleus", p=0.6)

print(sources[idx], references[idx], sep='\n\n')
# print(f"nucleus: {hyp_nucleus[0]}\n\nbeam: {hyp_beam}\n\ngreedy: {hyp_greedy}\nbeam marc: {hyp_beam_marc}")
print(hyp_nucleus)

following the election of the uk labour party to government in 1997 , the uk formally subscribed to the agreement on social policy , which allowed it to be included with minor amendments as the social chapter of the 1997 treaty of amsterdam . 


when did the uk formally subscribe to the agreement on social policy ?

["who rejected the fbi 's law to the protest ?", 'what government party was the european parliament under the soviet agreement']
