In [1]:
import torch
import torch.nn as nn
import altair as alt
import pandas as pd
import numpy as np

from src.train.transformer_train import build_model, preload_model
from src.train.utils import get_device
from src.config.transformer_config import get_config
from src.dataset.transformer_ds import get_ds
from src.inference.greedy_search import greedy_decode
from src.inference.beam_search import beam_decode

%load_ext autoreload


In [42]:
device = get_device()
print(f"Using device: {device}")

config = get_config()
config["batch_size"] = 1
train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
model, optimizer, loss_fn = build_model(config, tokenizer_src, tokenizer_tgt, device)
model, optimizer, initial_epoch, global_step = preload_model(model, optimizer, config)        


Using device: cuda
Max length of source sentence: 309
Max length of target sentence: 274
Preloading model src/weights/opus_books_weights/transformer_10.pt


In [None]:
from torch import Tensor
from src.dataset.transformer_ds import causal_mask

def beam_decode3(
    model, beam_size, source, source_mask, tokenizer_tgt, max_len, device
):
  model.eval()
  sos_idx = tokenizer_tgt.token_to_id("[SOS]")
  eos_idx = tokenizer_tgt.token_to_id("[EOS]")
  assert source.size(0) == 1, "Batch size must be 1 for beam search"
  # (1, dynamic_seq_len)
  encoder_output = model.encoder(source, source_mask)
  # (1, dynamic_seq_len)
  decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)
  
  candidates = [(decoder_input, 0.0)]
  
  while True:
    if any([candidate[0].size(1) == max_len for candidate in candidates]):
      break
    
    new_candidates = []
    
    for candidate, score in candidates:
      
      if candidate[0][-1].item() == eos_idx:
        new_candidates.append((candidate, score))
        continue 
      
      decoder_mask = causal_mask(candidate.size(1)).to(device)
      # (1, dynamic_seq_len, vocab_size)
      proj_output = model.decoder(
          candidate, encoder_output, decoder_mask, source_mask
      )
      
      # (1, dynamic_seq_len, vocab_size) -> (1, vocab_size)
      probs = proj_output[:, -1].softmax(dim=-1)
      # (1, beam_size)
      topk_probs, topk_ids = torch.topk(probs, beam_size,dim=1)
      for k in range(beam_size):
        token_id = topk_ids[0][k].unsqueeze(0).unsqueeze(0)
        prob: Tensor = topk_probs[0][k]
        print(str(token_id.item()) + ": prob = " + str(prob.item()))
        # (1, ++dynamic_seq_len)
        new_candidate = torch.cat([candidate, token_id],dim=1)
        new_candidates.append((new_candidate, score + torch.log(prob).item()))
        
    candidates = sorted(new_candidates, key=lambda x: x[1], reverse=True)[:beam_size]
    if all([candidate[0][0][-1].item() == eos_idx for candidate in candidates]):
      break
      
  for candidate in candidates:
      print("Beam Decode: " + tokenizer_tgt.decode(candidate[0].squeeze().detach().cpu().numpy()))    
    
  return candidates[0][0].squeeze()# (dynamic_seq_len)   

In [83]:
idx = 500
print(len(tokenizer_tgt.decode([idx])))
print(tokenizer_tgt.decode([idx]))

6
sapere


In [63]:
batch = next(iter(train_dataloader))
encoder_input = batch["encoder_input"].to(device)
encoder_mask = batch["encoder_mask"].to(device)
decoder_input = batch["decoder_input"].to(device)
decoder_mask = batch["decoder_mask"].to(device)

encoder_input_tokens = [tokenizer_src.id_to_token(idx) for idx in encoder_input[0].cpu().numpy()]
decoder_input_tokens = [tokenizer_tgt.id_to_token(idx) for idx in decoder_input[0].cpu().numpy()]

# check that the batch size is 1
assert encoder_input.size(0) == 1, "Batch size must be 1 for validation"
print(f'Source: {batch["src_text"][0]}')
print(f'Target: {batch["tgt_text"][0]}')
greedy_model_out = greedy_decode(
        model, encoder_input, encoder_mask, tokenizer_tgt, config['max_len'], device)
print("Greedy Decode: " + tokenizer_tgt.decode(greedy_model_out.detach().cpu().numpy()))    
beam_model_out = beam_decode3(
    model, 5, encoder_input, encoder_mask, tokenizer_tgt, config['max_len'], device)

Source: He raised his head.
Target: Egli sollevò il capo.
Greedy Decode: 
1
0: prob = 0.3752099871635437
4619: prob = 0.14037685096263885
4: prob = 0.052899740636348724
3: prob = 0.04121950641274452
1497: prob = 0.016336839646100998
2
0: prob = 0.9971191883087158
4: prob = 0.00037932751001790166
3: prob = 0.00036502169677987695
5: prob = 0.00015883137530181557
11: prob = 8.13221195130609e-05
2
4: prob = 0.1781618446111679
6591: prob = 0.1273667961359024
6: prob = 0.048947177827358246
9: prob = 0.02457994408905506
290: prob = 0.020644716918468475
2
4: prob = 0.13146743178367615
1347: prob = 0.009043908677995205
4019: prob = 0.006850169040262699
340: prob = 0.006076618563383818
869: prob = 0.00531315803527832
2
55: prob = 0.1571732759475708
545: prob = 0.08341886848211288
855: prob = 0.07186020910739899
36: prob = 0.028614481911063194
33: prob = 0.027948711067438126
3
0: prob = 0.9973747730255127
3: prob = 0.00044625482405535877
4: prob = 0.00020948109158780426
5: prob = 0.00013645502622

In [None]:
tokenizer_tgt.decode(beam_model_out.detach().cpu().numpy())

In [None]:
def load_next_batch(val_dataloader, tokenizer_src, tokenizer_tgt, greedy_search, beam_search, device):
    # Load a sample batch from the validation set
    batch = next(iter(val_dataloader))
    encoder_input = batch["encoder_input"].to(device)
    encoder_mask = batch["encoder_mask"].to(device)
    decoder_input = batch["decoder_input"].to(device)
    decoder_mask = batch["decoder_mask"].to(device)

    encoder_input_tokens = [tokenizer_src.id_to_token(idx) for idx in encoder_input[0].cpu().numpy()]
    decoder_input_tokens = [tokenizer_tgt.id_to_token(idx) for idx in decoder_input[0].cpu().numpy()]

    # check that the batch size is 1
    assert encoder_input.size(
        0) == 1, "Batch size must be 1 for validation"

    greedy_model_out = greedy_search(
        model, encoder_input, encoder_mask, tokenizer_tgt, config['max_len'], device)
    print("Greedy Decode: " + tokenizer_tgt.decode(greedy_model_out.detach().cpu().numpy()))    

    beam_model_out = (
        model, 3, encoder_input, encoder_mask, tokenizer_tgt, config['max_len'], device)
    return batch, encoder_input_tokens, decoder_input_tokens, greedy_model_out, beam_model_out

In [None]:
def mtx2df(m, max_row, max_col, row_tokens, col_tokens):
    return pd.DataFrame(
        [
            (
                r,
                c,
                float(m[r, c]),
                "%.3d %s" % (r, row_tokens[r] if len(row_tokens) > r else "<blank>"),
                "%.3d %s" % (c, col_tokens[c] if len(col_tokens) > c else "<blank>"),
            )
            for r in range(m.shape[0])
            for c in range(m.shape[1])
            if r < max_row and c < max_col
        ],
        columns=["row", "column", "value", "row_token", "col_token"],
    )

def get_attn_map(attn_type: str, layer: int, head: int):
    if attn_type == "encoder":
        attn = model.encoder.layers[layer].attention.scores
    elif attn_type == "decoder":
        attn = model.decoder.layers[layer].self_attention.scores
    elif attn_type == "encoder-decoder":
        attn = model.decoder.layers[layer].cross_attention.scores
    # [batch, n_head, q_len, kv_len] -> [q_len, kv_len]
    return attn[0, head].data

def attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len):
    df = mtx2df(
        get_attn_map(attn_type, layer, head),
        max_sentence_len,
        max_sentence_len,
        row_tokens,
        col_tokens,
    )
    return (
        alt.Chart(data=df)
        .mark_rect()
        .encode(
            x=alt.X("col_token", axis=alt.Axis(title="")),
            y=alt.Y("row_token", axis=alt.Axis(title="")),
            color="value",
            tooltip=["row", "column", "value", "row_token", "col_token"],
        )
        #.title(f"Layer {layer} Head {head}")
        .properties(height=400, width=400, title=f"Layer {layer} Head {head}")
        .interactive()
    )

def get_all_attention_maps(attn_type: str, layers: list[int], heads: list[int], row_tokens: list, col_tokens, max_sentence_len: int):
    charts = []
    for layer in layers:
        rowCharts = []
        for head in heads:
            rowCharts.append(attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len))
        charts.append(alt.hconcat(*rowCharts))
    return alt.vconcat(*charts)

In [None]:
batch, encoder_input_tokens, decoder_input_tokens, greedy_model_out, beam_model_out = load_next_batch(val_dataloader, tokenizer_src, tokenizer_tgt, greedy_decode, beam_decode2, device)
print(f'Source: {batch["src_text"][0]}')
print(f'Target: {batch["tgt_text"][0]}')
sentence_len = encoder_input_tokens.index("[PAD]")

In [None]:
layers = [0, 1, 2]
heads = [0, 1, 2, 3, 4, 5, 6, 7]

# Encoder Self-Attention
get_all_attention_maps("encoder", layers, heads, encoder_input_tokens, encoder_input_tokens, min(20, sentence_len))

In [None]:
# Decoder Self-Attention
get_all_attention_maps("decoder", layers, heads, decoder_input_tokens, decoder_input_tokens, min(20, sentence_len))

In [None]:
# Cross-Attention
get_all_attention_maps("encoder-decoder", layers, heads, encoder_input_tokens, decoder_input_tokens, min(20, sentence_len))