In [1]:
import torch 
import torch.nn as nn 
import pandas as pd 
import numpy as np
import altair as alt
from trasnformer import Transformer
from config import get_cfg, get_model_file_path
from train import get_dataset, greedy_decode, get_model

import warnings
warnings.filterwarnings('ignore')


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

Device: cuda


In [3]:
cfg = get_cfg()
_, val_loader, vocab_src, vocab_tgt = get_dataset(cfg)
model = get_model(cfg, vocab_src.get_vocab_size(), vocab_tgt.get_vocab_size()).to(device)

Source max seq len: 309
Target max seq len: 274


In [5]:
# load modal weights 
model_filename = get_model_file_path(cfg, f"29")
print(model_filename)
state = torch.load(model_filename)
model.load_state_dict(state["model_state_dict"])

models/transformer_model_29.pt


<All keys matched successfully>

In [6]:
def load_next_batch():
    batch_input = next(iter(val_loader))
    enc_input = batch_input['enc_input'].to(device)
    enc_mask = batch_input["enc_mask"].to(device)
    
    dec_input = batch_input["dec_input"].to(device)
    dec_mask = batch_input["dec_mask"].to(device)
        
    enc_input_tokens = [vocab_src.id_to_token(idx) for idx in enc_input[0].cpu().numpy()]
    dec_input_tokens = [vocab_tgt.id_to_token(idx) for idx in dec_input[0].cpu().numpy()]
    
    assert enc_input.size(0) == 1, "Batch size must be 1 for val !"
    
    model_output = greedy_decode(model, enc_input, enc_mask, vocab_tgt, cfg['seq_len'], device)
    
    return batch_input, enc_input_tokens, dec_input_tokens
    

In [7]:
def mtx2df(m, max_row, max_col, row_tokens, col_tokens):
    return pd.DataFrame(
        [
            (
                r,
                c,
                float(m[r, c]),
                "%.3d %s" % (r, row_tokens[r] if len(row_tokens) > r else "<blank>"),
                "%.3d %s" % (c, col_tokens[c] if len(col_tokens) > c else "<blank>"),
            )
            for r in range(m.shape[0])
            for c in range(m.shape[1])
            if r < max_row and c < max_col
        ],
        columns=["row", "column", "value", "row_token", "col_token"],
    )

def get_attn_map(attn_type: str, layer: int, head: int):
    if attn_type == "encoder":
        attn = model.encoder.layers[layer].attention_block.attention_scores
    elif attn_type == "decoder":
        attn = model.decoder.layers[layer].attention.attention_scores
    elif attn_type == "encoder-decoder":
        attn = model.decoder.layers[layer].cross_attention.attention_scores
    return attn[0, head].data

def attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len):
    df = mtx2df(
        get_attn_map(attn_type, layer, head),
        max_sentence_len,
        max_sentence_len,
        row_tokens,
        col_tokens,
    )
    return (
        alt.Chart(data=df)
        .mark_rect()
        .encode(
            x=alt.X("col_token", axis=alt.Axis(title="")),
            y=alt.Y("row_token", axis=alt.Axis(title="")),
            color="value",
            tooltip=["row", "column", "value", "row_token", "col_token"],
        )
        #.title(f"Layer {layer} Head {head}")
        .properties(height=400, width=400, title=f"Layer {layer} Head {head}")
        .interactive()
    )

def get_all_attention_maps(attn_type: str, layers: list[int], heads: list[int], row_tokens: list, col_tokens, max_sentence_len: int):
    charts = []
    for layer in layers:
        rowCharts = []
        for head in heads:
            rowCharts.append(attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len))
        charts.append(alt.hconcat(*rowCharts))
    return alt.vconcat(*charts)

In [8]:
batch, enc_input_tokens, dec_input_tokens = load_next_batch()
print(f"Source: {batch['src_txt'][0]}\n")
print(f"Target: {batch['tgt_txt'][0]}\n")
sentence_len = enc_input_tokens.index('[PAD]')
print(f"Src sequenc len: {sentence_len}")

Source: Will she not depart as suddenly as she came?

Target: Non partirà all'improvviso com'è venuta?

Src sequenc len: 12


In [9]:
# number of layers/blocks in Transformer. we have n=6
layers = [0,1,2]
heads = [i for i in range(cfg["num_heads"])]

# Encoder self=attention 
get_all_attention_maps("encoder", layers, heads, enc_input_tokens, enc_input_tokens, min(20, sentence_len))

In [10]:

# Decoder self=attention 
get_all_attention_maps("encoder", layers, heads, dec_input_tokens, dec_input_tokens, min(20, sentence_len))

In [11]:

# cross self=attention 
get_all_attention_maps("encoder", layers, heads, enc_input_tokens, dec_input_tokens, min(20, sentence_len))