In [1]:
import os
import sys
sys.path.insert(1, os.path.abspath(os.path.join(os.path.dirname(f'..{os.sep}utils'))))
sys.path.insert(1, os.path.abspath(os.path.join(os.path.dirname( '..'))))
from utils.constants import *
import torch
import torch.nn as nn
from transformer_v2 import Transformer
from utils.function_utils import *
from func_load_model_old import *
from utils.optimizer_n_scheduler import *
from utils.logging_tensorboard import create_summary_writer, log_loss, log_learning_rate, log_gradients, log_attention_weights
from utils.distributions import *
from torch.cuda.amp import GradScaler, autocast
from data_funcs import *
import numpy as np
import heapq

In [2]:
BOS_TOKEN = "<bos>"
EOS_TOKEN = "<eos>"

## Decoders

In [3]:
def subsequent_mask(size):
    """
    Mask out subsequent positions.
    """
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
    return torch.from_numpy(subsequent_mask) == 0

In [21]:
def batch_greedy_decode_v2(model, src, src_mask, max_len, tgt_tokenizer, tgt_vocab, tgt_pad_idx, batch_size, device):
    batch_size = src.shape[0]
    target_sentences_tokens = [[BOS_TOKEN] for _ in range(batch_size)]
    trg_token_ids_batch = torch.tensor([[tgt_vocab[tokens[0]]] for tokens in target_sentences_tokens], device=device)
    is_decoded = [False] * batch_size

    tgt_itos = tgt_vocab.get_itos()
    while True:
        tgt_mask = generate_tgt_mask(trg_token_ids_batch, tgt_pad_idx)
        enc_output, _ = model.encode(src, src_mask)
        predicted_log_distributions, _, _ = model.decode(trg_token_ids_batch, enc_output, src_mask, tgt_mask)

        num_of_trg_tokens = len(target_sentences_tokens[0])
        predicted_log_distributions = predicted_log_distributions[num_of_trg_tokens - 1::num_of_trg_tokens]

        most_probable_last_token_indices = torch.argmax(predicted_log_distributions, dim=-1).cpu().numpy()
        most_probable_last_token_indices = most_probable_last_token_indices.reshape(-1)

        predicted_words = [tgt_itos[index] for index in most_probable_last_token_indices.tolist()]

        non_decoded_indices = [i for i, decoded in enumerate(is_decoded) if not decoded]

        # Filter non_decoded_indices based on the length of predicted_words
        non_decoded_indices = [i for i in non_decoded_indices if i < len(predicted_words)]

        # Create a new list containing the words from predicted_words corresponding to non_decoded_indices
        predicted_words_filtered = [predicted_words[idx] for idx in non_decoded_indices]

        for non_decoded_idx, predicted_word in zip(non_decoded_indices, predicted_words_filtered):
            target_sentences_tokens[non_decoded_idx].append(predicted_word)

            if predicted_word == EOS_TOKEN:
                is_decoded[non_decoded_idx] = True

        if all(is_decoded) or num_of_trg_tokens == max_len:
            break

        # Filter out the decoded sentences and update the tensors accordingly
        src = src[non_decoded_indices]
        src_mask = src_mask[non_decoded_indices]
        trg_token_ids_batch = trg_token_ids_batch[non_decoded_indices]
        if len(non_decoded_indices) == 0:
            break

        most_probable_last_token_indices_filtered = torch.tensor([tgt_vocab[predicted_words_filtered[idx]] for idx, _ in enumerate(non_decoded_indices)], device=device)

        trg_token_ids_batch = torch.cat((trg_token_ids_batch, most_probable_last_token_indices_filtered.unsqueeze(1)), 1)

    return target_sentences_tokens

In [5]:
def batch_greedy_decode_v1(model, src, src_mask, max_len, tgt_tokenizer, tgt_vocab, tgt_pad_idx, batch_size):
    with torch.no_grad():
        sos_token = tgt_vocab[BOS_TOKEN]
        tgt_tokens = torch.full((batch_size, 1), sos_token, dtype=torch.long, device=src.device)

        for _ in range(max_len - 1):
            tgt_mask = generate_tgt_mask(tgt_tokens, tgt_pad_idx)
            output_probs = model(src, tgt_tokens, src_mask, tgt_mask)
            _, next_tokens = torch.max(output_probs, dim=-1)
            tgt_tokens = torch.cat([tgt_tokens, next_tokens], dim=1)
            if torch.all(next_tokens == tgt_vocab[EOS_TOKEN]):
                break

    # decoded_sentences = [tgt_tokenizer.decode(tokens) for tokens in tgt_tokens.tolist()]
    return tgt_tokens.tolist()

In [6]:
def batch_beam_search_decode(model, src, src_mask, max_len, src_tokenizer, tgt_tokenizer, beam_size, batch_size):
   #TODO
   pass

## Testing

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch_size = 16
max_len = MODEL_MAX_SEQ_LEN
d_model = MODEL_DIM
num_layers = MODEL_N_LAYERS
num_heads = MODEL_N_HEADS
dropout = MODEL_DROPOUT
num_epochs = 10
learning_rate = 1e-4
warmup_steps = 2000
weight_decay = 1e-4
VOCAB_SIZE = 64_000
d_ff = MODEL_FF
label_smoothing = MODEL_LABEL_SMOTHING
FILE_PATH = 'data/en-pt.txt'
NUM_PHRASES = 10_000

n=1
LOGGING_FILE = f'runs{os.sep}translation_experiment_{n}'

CHECKPOINT_PATH = 'checkpoints/checkpoint_epoch_8_val_loss_4.7865.pt'

In [8]:
train_dataloader, test_dataloader, pad_idx_src, pad_idx_tgt, src_vocab, tgt_vocab, src_tokenizer, tgt_tokenizer = load_data(FILE_PATH, language_direction = LanguageDirection.PT2EN.name, limit = NUM_PHRASES, batch_size = batch_size, max_len = max_len, return_tokenizers=True)

In [9]:
model = Transformer(len(src_vocab),
                    len(tgt_vocab), 
                    d_model, 
                    num_heads, 
                    num_layers, 
                    d_ff, 
                    dropout, 
                    max_len).to(device)
optimizer, scheduler = create_optimizer_and_scheduler(model, d_model, warmup_steps, learning_rate, weight_decay)

In [10]:
model, optimizer, scheduler, epoch = load_checkpoint(model, optimizer, scheduler, CHECKPOINT_PATH)

In [24]:
def evaluate_metrics(model, data_loader, src_pad_idx, tgt_pad_idx, tokenizer, tgt_vocab, device):
    model.eval()
    bleu_scores = []
    tgt_itos = tgt_vocab.get_itos()
    
    with torch.no_grad():
        for src, tgt in data_loader:
            src, tgt = src.to(device), tgt.to(device)
            src_mask = generate_src_mask(src, src_pad_idx)
            output = batch_greedy_decode_v2(model, src, src_mask, max_len, tokenizer, tgt_vocab, tgt_pad_idx, batch_size, device)
            
            hypothesis = [sent[1:-1] for sent in output]  # Remove BOS and EOS tokens
            
            tgt_token_lists = [[tgt_itos[token_idx] for token_idx in sent if token_idx not in (tgt_pad_idx, tgt_vocab[BOS_TOKEN], tgt_vocab[EOS_TOKEN])] for sent in tgt.cpu().numpy()]
            for hyp, ref in zip(hypothesis, tgt_token_lists):
                bleu = calculate_bleu(ref, hyp)
                bleu_scores.append(bleu)
    
    model.train()
    return sum(bleu_scores) / len(bleu_scores)

In [25]:
greedy_decode_output, tgt_tokens_list = evaluate_metrics(model, test_dataloader, pad_idx_src, pad_idx_tgt, tgt_tokenizer, tgt_vocab, device)

trg_token_ids_batch: torch.Size([16, 1])
tgt_mask: torch.Size([16, 1, 1, 1])
torch.Size([16, 8, 128, 32])
torch.Size([16, 8, 128, 32])
torch.Size([16, 8, 128, 32])
torch.Size([16, 8, 128, 32])
torch.Size([16, 8, 128, 32])
torch.Size([16, 8, 128, 32])
torch.Size([16, 8, 128, 32])
torch.Size([16, 8, 128, 32])
torch.Size([16, 8, 128, 32])
torch.Size([16, 8, 128, 32])
torch.Size([16, 8, 128, 32])
torch.Size([16, 8, 128, 32])
torch.Size([16, 8, 1, 32])
torch.Size([16, 8, 1, 32])
Cross Attention x shape:  torch.Size([16, 1, 256])
Cross Attention enc_output shape:  torch.Size([16, 128, 256])
torch.Size([16, 8, 1, 32])
torch.Size([16, 8, 128, 32])
torch.Size([16, 8, 1, 32])
torch.Size([16, 8, 1, 32])
Cross Attention x shape:  torch.Size([16, 1, 256])
Cross Attention enc_output shape:  torch.Size([16, 128, 256])
torch.Size([16, 8, 1, 32])
torch.Size([16, 8, 128, 32])
torch.Size([16, 8, 1, 32])
torch.Size([16, 8, 1, 32])
Cross Attention x shape:  torch.Size([16, 1, 256])
Cross Attention enc_outp