In [1]:
!which python
%load_ext autoreload
%autoreload 2
%matplotlib inline

/home/users/ujjwal.upadhyay/miniconda3/envs/trans/bin/python


In [2]:
import re, math
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.optim import Adam

from transformers import AutoTokenizer


from utils import *

In [3]:
BASELINE_MODEL_NUMBER_OF_LAYERS = 6
BASELINE_MODEL_DIMENSION = 512
BASELINE_MODEL_NUMBER_OF_HEADS = 8
BASELINE_MODEL_DROPOUT_PROB = 0.1
BASELINE_MODEL_LABEL_SMOOTHING_VALUE = 0.1


BIG_MODEL_NUMBER_OF_LAYERS = 6
BIG_MODEL_DIMENSION = 1024
BIG_MODEL_NUMBER_OF_HEADS = 16
BIG_MODEL_DROPOUT_PROB = 0.3
BIG_MODEL_LABEL_SMOOTHING_VALUE = 0.1


BOS_TOKEN = '[CLS]'
EOS_TOKEN = '[SEP]'
PAD_TOKEN = "[PAD]"

In [4]:
from datasets import load_dataset
books = load_dataset("opus_books", "en-fr")
books.cleanup_cache_files()
books = books["train"].train_test_split(test_size=0.2)
books["train"][0]

Found cached dataset opus_books (/home/users/ujjwal.upadhyay/.cache/huggingface/datasets/opus_books/en-fr/1.0.0/e8f950a4f32dc39b7f9088908216cd2d7e21ac35f893d04d39eb594746af2daf)


  0%|          | 0/1 [00:00<?, ?it/s]

{'id': '110860', 'translation': {'en': 'I stopped.', 'fr': "Je m'arrêtai."}}

In [5]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
print(tokenizer.special_tokens_map)

source_lang = "en"
target_lang = "fr"
prefix = "translate English to French: "

{'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}


In [6]:
def preprocess_function(examples):
    inputs = [prefix + example[source_lang] for example in examples["translation"]]
    targets = [example[target_lang] for example in examples["translation"]]
    model_inputs = tokenizer(inputs, text_target=targets, max_length=128, truncation=True, padding=True, return_special_tokens_mask=True, return_tensors='pt')
    return model_inputs

tokenized_books = books.map(preprocess_function, batched=True)

  0%|          | 0/102 [00:00<?, ?ba/s]

  0%|          | 0/26 [00:00<?, ?ba/s]

In [7]:
tokenized_books['train'][0].keys()

dict_keys(['id', 'translation', 'input_ids', 'token_type_ids', 'attention_mask', 'special_tokens_mask', 'labels'])

In [8]:
len(tokenized_books['train'][0:10]['input_ids']), len(tokenized_books['train'][0]['labels'])

(10, 128)

In [9]:
output = tokenizer.convert_ids_to_tokens(tokenizer.encode("Hello, y'all! How are you?"))
output

['[CLS]',
 'hello',
 ',',
 'y',
 "'",
 'all',
 '!',
 'how',
 'are',
 'you',
 '?',
 '[SEP]']

In [10]:
def create_n_batches(num_batch, batch_size):
    src_tokens, trg_tokens, src_masks, trg_masks, trg_gts = [], [], [], [], []
    for i in tqdm(range(num_batch)):
        src_tkn = torch.Tensor(tokenized_books['train'][0:BATCH_SIZE]['input_ids']).to(torch.IntTensor())
        trg_tkn = torch.Tensor(tokenized_books['train'][0:BATCH_SIZE]['labels']).to(torch.IntTensor())
#         src_mask = torch.BoolTensor(tokenized_books['train'][0:BATCH_SIZE]['attention_mask']).view(BATCH_SIZE, 1, 1, -1)
#         trg_mask = torch.BoolTensor(torch.Tensor(tokenized_books['train'][0:BATCH_SIZE]['labels'])==1).view(BATCH_SIZE, 1, 1, -1)
        src_mask = torch.BoolTensor(tokenized_books['train'][0:BATCH_SIZE]['attention_mask']).view(BATCH_SIZE, -1)
        trg_mask = torch.BoolTensor(torch.Tensor(tokenized_books['train'][0:BATCH_SIZE]['labels'])==1).view(BATCH_SIZE, -1)
        
        label_smoothing = LabelSmoothingDistribution(BASELINE_MODEL_LABEL_SMOOTHING_VALUE, tokenizer.pad_token_id, tokenizer.vocab_size, "cpu")
        trg_gt = label_smoothing(trg_tkn[:, :].reshape(-1, 1))
        
        src_tokens.append(src_tkn)
        trg_tokens.append(trg_tkn)
        src_masks.append(src_mask)
        trg_masks.append(trg_mask)
        trg_gts.append(trg_gt)

    return src_tokens, trg_tokens, src_masks, trg_masks, trg_gts
        

In [11]:
tokenizer.vocab_size

30522

In [13]:
from transformer import Transformer
trans = Transformer(model_dimension=512, src_vocab_size=tokenizer.vocab_size, 
                    trg_vocab_size=tokenizer.vocab_size, 
                    number_of_heads=8, number_of_layers=4, mem_size=16,
                    dropout_probability=0.1, log_attention_weights=False)
trans.train()
print("No of parameters:", sum(dict((p.data_ptr(), p.numel()) for p in trans.parameters()).values()))

No of parameters: 84767546


In [15]:
EPOCH = 100
BATCH_SIZE = 1

src_tokens, trg_tokens, src_masks, trg_masks, trg_gts = create_n_batches(num_batch=50, batch_size=BATCH_SIZE)

# src_tkn = torch.Tensor(tokenized_books['train'][0:BATCH_SIZE]['input_ids']).to(torch.IntTensor())
# trg_tkn = torch.Tensor(tokenized_books['train'][0:BATCH_SIZE]['labels']).to(torch.IntTensor())
# src_mask = torch.BoolTensor(tokenized_books['train'][0:BATCH_SIZE]['attention_mask']).view(BATCH_SIZE, 1, 1, -1)
# trg_mask = torch.BoolTensor(torch.Tensor(tokenized_books['train'][0:BATCH_SIZE]['labels'])==1).view(BATCH_SIZE, 1, 1, -1)

# label_smoothing = LabelSmoothingDistribution(BASELINE_MODEL_LABEL_SMOOTHING_VALUE, tokenizer.pad_token_id, tokenizer.vocab_size, "cpu")
# trg_gt = label_smoothing(trg_tkn[:, :].reshape(-1, 1))

optimizer = optim.AdamW(trans.parameters(), lr=5e-4)
kl_div_loss = nn.KLDivLoss(reduction='batchmean')

init_src_memory = trans.src_memory.clone()
init_trg_memory = trans.trg_memory.clone()

print(f"SRC TKN: {src_tokens[0].shape}\nTRG TKN: {trg_tokens[0].shape}\nSRC MASK: {src_masks[0].shape}\nTRG_MASK: {trg_masks[0].shape}\nTRG GT: {trg_gts[0].shape}")
src_tkn, trg_tkn, src_mask, trg_mask, trg_gt = src_tokens[0], trg_tokens[0], src_masks[0], trg_masks[0], trg_gts[0]


  0%|          | 0/50 [00:00<?, ?it/s]

SRC TKN: torch.Size([1, 128])
TRG TKN: torch.Size([1, 128])
SRC MASK: torch.Size([1, 128])
TRG_MASK: torch.Size([1, 128])
TRG GT: torch.Size([128, 30522])


In [17]:
# h_state, log_probab = trans(src_tkn, trg_tkn, src_mask, trg_mask)
# h_state.shape, log_probab.shape

In [16]:
gradient_accumulation_steps = 10
for i in tqdm(range(50)):
    h_state, log_probab = trans(src_tokens[i], trg_tokens[i], src_masks[i], trg_masks[i])
    loss = kl_div_loss(log_probab, trg_gts[i])
    loss = loss / gradient_accumulation_steps
    loss.backward(retain_graph=True)
    if (i + 1) % gradient_accumulation_steps == 0:
        print(f"EPOCH:{i+1} - LOSS:{loss}")
        optimizer.step()
        optimizer.zero_grad()

  0%|          | 0/50 [00:00<?, ?it/s]

EPOCH:10 - LOSS:0.06325589120388031
EPOCH:20 - LOSS:0.0552356131374836
EPOCH:30 - LOSS:0.05187438800930977
EPOCH:40 - LOSS:0.049195390194654465
EPOCH:50 - LOSS:0.04627738147974014


In [17]:
print((trans.src_memory == init_src_memory).sum())
print((trans.trg_memory == init_trg_memory).sum())

tensor(0)
tensor(0)


In [20]:
torch.isnan(trans.trg_memory).sum(), torch.isnan(trans.src_memory).sum()

(tensor(0), tensor(0))

In [21]:
bsz_test = 1
idx = 10
src_tkn_test = torch.Tensor(tokenized_books['train'][idx:idx+bsz_test]['input_ids']).to(torch.IntTensor())
trg_tkn_test = torch.Tensor(tokenized_books['train'][idx:idx+bsz_test]['labels']).to(torch.IntTensor())
src_mask_test = torch.BoolTensor(tokenized_books['train'][idx:idx+bsz_test]['attention_mask']).view(bsz_test, -1) 
print(src_tkn_test.shape, trg_tkn_test.shape, src_mask_test.shape)

torch.Size([1, 128]) torch.Size([1, 128]) torch.Size([1, 128])


In [22]:
trans.eval()
with torch.no_grad():
    h_state_test = trans.encode(src_tkn_test, src_mask_test)
h_state_test.shape

torch.Size([1, 128, 512])

In [23]:
h_state_test.argmax(axis=1).shape

torch.Size([1, 512])

In [25]:
# out = h_state_test.argmax(axis=1).numpy()
# type(out[0]), len(out[0]), out[0]

In [89]:
tokenizer.special_tokens_map['cls_token']

'[CLS]'

In [90]:
tokenizer.encode(f"{tokenizer.special_tokens_map['cls_token']}Hello")

[101, 101, 7592, 102]

In [85]:
tokenizer.cls_token_id, tokenizer.sep_token_id

(101, 102)

In [525]:
tokenizer.decode(7592)

'grey'

tensor([[101],
        [101],
        [101],
        [101]])

In [546]:
beam_size = 4
src_representations_batch = h_state_test
src_representations_batch = src_representations_batch.repeat(1, beam_size, 1).view(beam_size*batch_size, -1, model_dimension)
src_representations_batch.shape


torch.Size([4, 128, 512])

In [548]:
trg_token_ids_batch = trg_token_ids_batch.repeat(beam_size, 1)
trg_token_ids_batch.shape

torch.Size([16, 0])

In [551]:
hypotheses_log_probs = torch.zeros((batch_size * beam_size, 1), device=device)
had_eos = [[False] for _ in range(hypotheses_log_probs.shape[0])]

In [97]:
trg_no_look_forward_mask = torch.triu(torch.ones((1, 1, 100, 100), device=device) == 1).transpose(2, 3)
trg_no_look_forward_mask.shape
    
    

torch.Size([1, 1, 100, 100])

In [149]:
def get_masks_and_count_tokens_trg(trg_token_ids_batch, pad_token_id):
    batch_size = trg_token_ids_batch.shape[0]
    device = trg_token_ids_batch.device

    # Same as src_mask but we additionally want to mask tokens from looking forward into the future tokens
    # Note: wherever the mask value is true we want to attend to that token, otherwise we mask (ignore) it.
    sequence_length = trg_token_ids_batch.shape[1]  # trg_token_ids shape = (B, T) where T max trg token-sequence length
    trg_padding_mask = (trg_token_ids_batch != pad_token_id).view(batch_size, -1)  # shape = (B, T)
    trg_no_look_forward_mask = torch.triu(torch.ones((1,1, sequence_length, sequence_length), device=device) == 1).transpose(2, 3)

#     # logic AND operation (both padding mask and no-look-forward must be true to attend to a certain target token)
    trg_mask = trg_padding_mask & trg_no_look_forward_mask  # final shape = (B, 1, T, T)
#     trg_mask = trg_padding_mask
    num_trg_tokens = torch.sum(trg_padding_mask.long())

    return trg_mask, num_trg_tokens

trg_mask, num_trg_tokens = get_masks_and_count_tokens_trg(trg_token_ids_batch, tokenizer.pad_token_id)
trg_mask.shape, trg_mask, num_trg_tokens

RuntimeError: The size of tensor a (5) must match the size of tensor b (2) at non-singleton dimension 2

In [128]:
device = next(trans.parameters()).device
pad_token_id = tokenizer.pad_token_id
batch_size, S, model_dimension = h_state_test.shape
target_multiple_hypotheses_tokens = [["[CLS]"] for _ in range(1)]
trg_token_ids_batch = torch.tensor([[101, 102] for tokens in target_multiple_hypotheses_tokens], device=device)
trg_token_ids_batch, trg_token_ids_batch.shape

(tensor([[101, 102]]), torch.Size([1, 2]))

In [145]:
beam_size = 5
h_state_test = h_state_test.repeat(1,beam_size,1).view(beam_size*bsz_test, -1, 512)

In [146]:
h_state_test.shape

torch.Size([5, 128, 512])

In [147]:
trg_token_ids_batch = trg_token_ids_batch.repeat(beam_size, 1)
trg_token_ids_batch.shape

torch.Size([5, 2])

In [148]:
hypotheses_log_probs = torch.zeros((bsz_test * beam_size, 1), device=device)
had_eos = [[False] for _ in range(hypotheses_log_probs.shape[0])]

In [None]:
trans.decode(h_state_test, )

In [618]:
def beam_decoding(beam_size, tokenizer, transformer, src_representations_batch, src_mask, max_target_tokens=100):
    device = next(transformer.parameters()).device
    pad_token_id = tokenizer.pad_token_id

    # Initial prompt is the beginning/start of the sentence token. Make it compatible shape with source batch => (B,1)
    batch_size, S, model_dimension = src_representations_batch.shape
    target_multiple_hypotheses_tokens = [["[CLS]"] for _ in range(batch_size)]
    trg_token_ids_batch = torch.tensor([[101] for tokens in target_multiple_hypotheses_tokens], device=device)
    
    # Repeat so that source sentence representations are repeated contiguously, say we have [s1, s2] we want
    # [s1, s1, s2, s2] and not [s1, s2, s1, s2] where s1 is single sentence representation with shape=(S, D)
    # where S - max source token-sequence length, D - model dimension
    src_representations_batch = src_representations_batch.repeat(1, beam_size, 1).view(beam_size*batch_size, -1, model_dimension)
    trg_token_ids_batch = trg_token_ids_batch.repeat(beam_size, 1)

    hypotheses_log_probs = torch.zeros((batch_size * beam_size, 1), device=device)
    had_eos = [[False] for _ in range(hypotheses_log_probs.shape[0])]

    while True:
        trg_mask, _ = get_masks_and_count_tokens_trg(trg_token_ids_batch, pad_token_id)
        # Shape = (B*BS*T, V) T - current token-sequence length, V - target vocab size, BS - beam size, B - batch
        predicted_log_distributions = transformer.decode(src_representations_batch, trg_token_ids_batch, trg_mask)

        # Extract only the indices of last token for every target sentence (we take every T-th token)
        # Shape = (B*BS, V)
        num_of_trg_tokens = trg_token_ids_batch.shape[-1]
        predicted_log_distributions = predicted_log_distributions[num_of_trg_tokens - 1::num_of_trg_tokens]

        # This time extract beam_size number of highest probability tokens (compare to greedy's arg max)
        # Shape = (B*BS, BS)
        latest_token_log_probs, most_probable_token_indices = torch.topk(predicted_log_distributions, beam_size, dim=-1, sorted=True)

        # Don't update the hypothesis which had EOS already (pruning)
        latest_token_log_probs.masked_fill(torch.tensor(had_eos == True), float("-inf"))

        # Calculate probabilities for every beam hypothesis (since we have log prob we add instead of multiply)
        # Shape = (B*BS, BS)
        hypotheses_pool_log_probs = hypotheses_log_probs + latest_token_log_probs
        # Shape = (B, BS, BS)
        most_probable_token_indices = most_probable_token_indices.view(batch_size, beam_size, beam_size)
        hypotheses_pool_log_probs = hypotheses_pool_log_probs.view(batch_size, beam_size, beam_size)
        # Shape = (B, BS*BS)
        hypotheses_pool_log_probs = torch.flatten(hypotheses_pool_log_probs, start_dim=-1)

        # Figure out indices of beam_size most probably hypothesis for every target sentence in the batch
        # Shape = (B, BS)
        new_hypothesis_log_probs, next_hypothesis_indices = torch.topk(hypotheses_pool_log_probs, beam_size, dim=-1, sorted=True)

        # Create new target ids batch
        hypotheses_log_probs_tmp = torch.empty((batch_size * beam_size, 1))

        T = trg_token_ids_batch.shape[-1]
        new_trg_token_ids_batch = torch.empty((batch_size * beam_size, T + 1))

        next_hypothesis_indices = next_hypothesis_indices.cpu().numpy()
        # Prepare new hypotheses for the next iteration
        for b_idx, indices in enumerate(next_hypothesis_indices):
            for h_idx, token_index in indices:
                row, column = token_index / beam_size, token_index % beam_size
                hypothesis_index = b_idx * beam_size + h_idx

                new_token_id = most_probable_token_indices[b_idx, row, column]
                if had_eos[hypothesis_index]:
                    new_trg_token_ids_batch[hypothesis_index, :-1] = trg_token_ids_batch[hypothesis_index, :]
                else:
                    new_trg_token_ids_batch[hypothesis_index, :-1] = trg_token_ids_batch[b_idx * beam_size + row, :]
                    new_trg_token_ids_batch[hypothesis_index, -1] = new_token_id

                if had_eos[hypothesis_index]:
                    hypotheses_log_probs_tmp[hypothesis_index] = hypotheses_log_probs[hypothesis_index]
                else:
                    hypotheses_log_probs_tmp[hypothesis_index] = new_hypothesis_log_probs[hypothesis_index]

                if new_token_id == tokenizer.eos_token_id:
                    had_eos[hypothesis_index] = True

        # Update the current hypothesis probabilities
        hypotheses_log_probs = hypotheses_log_probs_tmp
        trg_token_ids_batch = new_trg_token_ids_batch

        if all(had_eos) or num_of_trg_tokens == max_target_tokens:
            break

    #
    # Selection and post-processing
    #

    target_multiple_hypotheses_tokens = []
    trg_token_ids_batch_numpy = trg_token_ids_batch.cpu().numpy()
    for hypothesis_ids in trg_token_ids_batch_numpy:
        target_multiple_hypotheses_tokens.append([tokenizer.decode(token_id) for token_id in hypothesis_ids])

    # Step 1: Select the most probable hypothesis out of beam_size hypotheses for each target sentence
    hypotheses_log_probs = hypotheses_log_probs.view(batch_size, beam_size)
    most_probable_hypotheses_indices = torch.argmax(hypotheses_log_probs, dim=-1).cpu().numpy()
    target_sentences_tokens = []
    for b_idx, index in enumerate(most_probable_hypotheses_indices):
        target_sentences_tokens.append(target_multiple_hypotheses_tokens[b_idx * beam_size + index])

    # Step 2: Post process the sentences - remove everything after the EOS token
    target_sentences_tokens_post = []
    for target_sentence_tokens in target_sentences_tokens:
        try:
            target_index = target_sentence_tokens.index(EOS_TOKEN) + 1
        except:
            target_index = None

        target_sentence_tokens = target_sentence_tokens[:target_index]
        target_sentences_tokens_post.append(target_sentence_tokens)

    return target_sentences_tokens_post

In [123]:
tokenizer.decode(101)

'[CLS]'