In [1]:
import os
import sys
sys.path.insert(1, os.path.abspath(os.path.join(os.path.dirname(f'..{os.sep}utils'))))
sys.path.insert(1, os.path.abspath(os.path.join(os.path.dirname( '..'))))
from utils.constants import *
import torch
import torch.nn as nn
from transformer_v2 import Transformer
from utils.function_utils import *
from func_load_model_old import *
from utils.optimizer_n_scheduler import *
from utils.logging_tensorboard import create_summary_writer, log_loss, log_learning_rate, log_gradients, log_attention_weights
from utils.distributions import *
from torch.cuda.amp import GradScaler, autocast
from data_funcs import *
import numpy as np
import heapq

In [2]:
BOS_TOKEN = "<bos>"
EOS_TOKEN = "<eos>"

## Decoders

In [3]:
def subsequent_mask(size):
    """
    Mask out subsequent positions.
    """
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return torch.from_numpy(subsequent_mask) == 0

In [34]:
def batch_greedy_decode_v2(model, src, src_mask, max_len, tgt_tokenizer, tgt_vocab, tgt_pad_idx, batch_size, device):
    target_sentences_tokens = [[BOS_TOKEN] for _ in range(batch_size)]
    trg_token_ids_batch = torch.tensor([[tgt_vocab[tokens[0]]] for tokens in target_sentences_tokens], device=device)
    is_decoded = [False] * batch_size

    while True:
        tgt_mask = generate_tgt_mask(trg_token_ids_batch,tgt_pad_idx)
        enc_output, _ = model.encode(src, src_mask)
        predicted_log_distributions, _, _ = model.decode(trg_token_ids_batch, enc_output, src_mask, tgt_mask)

        num_of_trg_tokens = len(target_sentences_tokens[0])
        predicted_log_distributions= predicted_log_distributions[num_of_trg_tokens-1::num_of_trg_tokens]
        
        most_probable_last_token_indices = torch.argmax(predicted_log_distributions, dim=-1).cpu()

        print(most_probable_last_token_indices)
        predicted_words = [tgt_vocab[index] for index in most_probable_last_token_indices]

        for idx, predicted_word in enumerate(predicted_words):
            target_sentences_tokens[idx].append(predicted_word)

            if predicted_word == EOS_TOKEN:  # once we find EOS token for a particular sentence we flag it
                is_decoded[idx] = True

        if all(is_decoded) or num_of_trg_tokens == max_len:
            break
        
        trg_token_ids_batch = torch.cat((trg_token_ids_batch, torch.unsqueeze(torch.tensor(most_probable_last_token_indices, device=device), 1)), 1)

        target_sentences_tokens_post = []
        for target_sentence_tokens in target_sentences_tokens:
            try:
                target_index = target_sentence_tokens.index(EOS_TOKEN) + 1
            except:
                target_index = None

            target_sentence_tokens = target_sentence_tokens[:target_index]
            target_sentences_tokens_post.append(target_sentence_tokens)

        return target_sentences_tokens_post

In [16]:
def batch_greedy_decode_v1(model, src, src_mask, max_len, tgt_tokenizer, tgt_vocab, tgt_pad_idx, batch_size):
    with torch.no_grad():
        sos_token = tgt_vocab[BOS_TOKEN]
        tgt_tokens = torch.full((batch_size, 1), sos_token, dtype=torch.long, device=src.device)

        for _ in range(max_len - 1):
            tgt_mask = generate_tgt_mask(tgt_tokens, tgt_pad_idx)
            output_probs = model(src, tgt_tokens, src_mask, tgt_mask)
            _, next_tokens = torch.max(output_probs, dim=-1)
            tgt_tokens = torch.cat([tgt_tokens, next_tokens], dim=1)
            if torch.all(next_tokens == tgt_vocab[EOS_TOKEN]):
                break

    # decoded_sentences = [tgt_tokenizer.decode(tokens) for tokens in tgt_tokens.tolist()]
    return tgt_tokens.tolist()

In [14]:
def batch_beam_search_decode(model, src, src_mask, max_len, src_tokenizer, tgt_tokenizer, beam_size, batch_size):
    with torch.no_grad():
        memory, _ = model.encode(src, src_mask)
        start_token = tgt_tokenizer.encode("<sos>")[0]
        end_token = tgt_tokenizer.encode("<eos>")[0]

        decoded_batches = []

        for b in range(batch_size):
            decoder_input = torch.tensor([start_token]).unsqueeze(0).to(src.device)
            partial_sequences = [([], 0, decoder_input)]

            for _ in range(max_len):
                new_partial_sequences = []

                for seq, score, tokens in partial_sequences:
                    if tokens[0, -1] == end_token:
                        new_partial_sequences.append((seq, score, tokens))
                        continue

                    tgt_mask = (subsequent_mask(tokens.size(-1)).type_as(src_mask)).repeat(1, 1, 1)
                    output_probs, _, _, _ = model(src[b].unsqueeze(0), tokens, src_mask[b].unsqueeze(0), tgt_mask, return_attention=True)
                    top_probs, top_indices = torch.topk(F.log_softmax(output_probs[:, -1], dim=-1), beam_size)

                    for i in range(beam_size):
                        next_token = top_indices[0, i].unsqueeze(0).unsqueeze(0)
                        current_prob = top_probs[0, i].item()
                        new_tokens = torch.cat([tokens, next_token], dim=1)
                        new_partial_sequences.append((seq + [next_token], score + current_prob, new_tokens))

                partial_sequences = heapq.nlargest(beam_size, new_partial_sequences, key=lambda x: x[1])

            best_sequence = max(partial_sequences, key=lambda x: x[1])[0]
            decoded_sentence = tgt_tokenizer.decode([token.item() for token in best_sequence])
            decoded_batches.append(decoded_sentence)

    return decoded_batches

## Testing

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch_size = 16
max_len = MODEL_MAX_SEQ_LEN
d_model = MODEL_DIM
num_layers = MODEL_N_LAYERS
num_heads = MODEL_N_HEADS
dropout = MODEL_DROPOUT
num_epochs = 10
learning_rate = 1e-4
warmup_steps = 2000
weight_decay = 1e-4
VOCAB_SIZE = 64_000
d_ff = MODEL_FF
label_smoothing = MODEL_LABEL_SMOTHING
FILE_PATH = 'data/en-pt.txt'
NUM_PHRASES = 10_000

n=1
LOGGING_FILE = f'runs{os.sep}translation_experiment_{n}'

CHECKPOINT_PATH = 'checkpoints/checkpoint_epoch_8_val_loss_4.7865.pt'

In [7]:
train_dataloader, test_dataloader, pad_idx_src, pad_idx_tgt, src_vocab, tgt_vocab, src_tokenizer, tgt_tokenizer = load_data(FILE_PATH, language_direction = LanguageDirection.PT2EN.name, limit = NUM_PHRASES, batch_size = batch_size, max_len = max_len, return_tokenizers=True)

In [8]:
model = Transformer(len(src_vocab),
                    len(tgt_vocab), 
                    d_model, 
                    num_heads, 
                    num_layers, 
                    d_ff, 
                    dropout, 
                    max_len).to(device)
optimizer, scheduler = create_optimizer_and_scheduler(model, d_model, warmup_steps, learning_rate, weight_decay)

In [9]:
model, optimizer, scheduler, epoch = load_checkpoint(model, optimizer, scheduler, CHECKPOINT_PATH)

In [22]:
def evaluate_metrics(model, data_loader, src_pad_idx, tgt_pad_idx, tokenizer, tgt_vocab, device):
    model.eval()
    bleu_scores = []
    with torch.no_grad():
        for src, tgt in data_loader:
            src, tgt = src.to(device), tgt.to(device)
            src_mask= generate_src_mask(src,src_pad_idx)
            output = batch_greedy_decode_v2(model, src, src_mask, max_len, tokenizer, tgt_vocab, tgt_pad_idx, batch_size, device)
            return output
            hypothesis = [tgt_vocab[token] for token in output if token not in (tgt_pad_idx, tgt_vocab[BOS_TOKEN], tgt_vocab[EOS_TOKEN])]
            reference = [tgt_vocab[token] for token in tgt if token not in (tgt_pad_idx, tgt_vocab[BOS_TOKEN], tgt_vocab[EOS_TOKEN])]
            bleu = calculate_bleu(reference, hypothesis)
            bleu_scores.append(bleu)
    
    model.train()
    return sum(bleu_scores) / len(bleu_scores)

In [35]:
greedy_decode_output = evaluate_metrics(model, test_dataloader, pad_idx_src, pad_idx_tgt, tgt_tokenizer, tgt_vocab, device)

tensor([[23],
        [23],
        [23],
        [23],
        [23],
        [23],
        [23],
        [23],
        [23],
        [23],
        [23],
        [23],
        [23],
        [23],
        [23],
        [23]])


TypeError: __getitem__(): incompatible function arguments. The following argument types are supported:
    1. (self: torchtext._torchtext.Vocab, arg0: str) -> int

Invoked with: <torchtext._torchtext.Vocab object at 0x000001AAF64938F0>, tensor([23])