<a target="_blank" href="https://colab.research.google.com/github/petuch03/graph-rag-research/blob/master/tokenizers/self-attention-visualization.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

In [1]:
from types import SimpleNamespace
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [2]:
# Load model and retrieve attention weights

from bertviz import model_view
from transformers import GemmaTokenizer, GemmaForCausalLM

# configuration = GemmaConfig()
model_version = 'google/gemma-2b-it'

model = GemmaForCausalLM.from_pretrained(model_version, output_attentions=True)
tokenizer = GemmaTokenizer.from_pretrained(model_version)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
from sub_word_tokenization import tokenize_custom

input_text = 'Bob sent Alice a message about apples.'
comparison_pair = tokenize_custom(tokenizer, "tokenizer.json", "Bob sent Alice a message about apples.")
print(comparison_pair.default_tokens)
print(comparison_pair.sub_word_tokens)


def compute_attention_pipeline(tokenizer, input_tokens) -> SimpleNamespace:
    inputs = tokenizer.encode(input_tokens, return_tensors='pt')
    # tokens = tokenizer.convert_ids_to_tokens(inputs[0])
    outputs = model(inputs)
    attention = torch.stack(outputs[-1], dim=0)
    return SimpleNamespace(attention=attention, full_outputs=outputs)


default_model_output = compute_attention_pipeline(tokenizer, comparison_pair.default_tokens)
sub_word_model_output = compute_attention_pipeline(tokenizer, comparison_pair.sub_word_tokens)


['Bob', '▁sent', '▁Alice', '▁a', '▁message', '▁about', '▁apples', '.']
['B', 'o', 'b', '▁s', 'en', 't', '▁Al', 'ice', '▁', 'a', '▁mes', 'sage', '▁ab', 'out', '▁ap', 'ple', 's', '.']




In [4]:
from sub_word_metrics import *

show_all_metrics(default_model_output.attention, threshold=0.02)
show_all_metrics(sub_word_model_output.attention, threshold=0.02)

Threshold-based Noise Metric - Source: tensor([0.8889, 0.7878, 0.7160, 0.6628, 0.6265, 0.5586, 0.5116, 0.5239, 0.4282]), Source mean: 0.6338306069374084
Target: tensor([0.0293, 0.4660, 0.5278, 0.6073, 0.6852, 0.7515, 0.8333, 0.8750, 0.9290])Target mean: 0.6338306069374084
Entropy-based Noise Metric: 2.1522600650787354
Threshold-based Noise Metric - Source: tensor([0.9474, 0.9050, 0.8852, 0.8578, 0.8209, 0.8056, 0.7957, 0.7482, 0.7569,
        0.7515, 0.7109, 0.7105, 0.7105, 0.6860, 0.6996, 0.6784, 0.7061, 0.6849,
        0.6140]), Source mean: 0.7618498802185059
Target: tensor([0.0523, 0.6583, 0.6981, 0.6528, 0.6857, 0.7482, 0.7460, 0.7485, 0.7379,
        0.7778, 0.7865, 0.8527, 0.8198, 0.8706, 0.8757, 0.9097, 0.9243, 0.9565,
        0.9737])Target mean: 0.7618498802185059
Entropy-based Noise Metric: 2.921490430831909


In [27]:
logits = default_model_output.full_outputs.logits
predicted_token_id = logits[:, -1, :].argmax(dim=-1)

# Convert the predicted token ID to a token
predicted_token = tokenizer.convert_ids_to_tokens(predicted_token_id)
predicted_token

['▁Bob']

In [28]:
def generate_sequence_from_tokens(tokenizer, input_tokens, max_length=50, num_return_sequences=1):
    inputs = tokenizer.encode(input_tokens, return_tensors='pt')
    generated_sequences = model.generate(inputs, max_length=max_length, num_return_sequences=num_return_sequences)
    generated_text = [tokenizer.decode(generated_sequence, skip_special_tokens=True) for generated_sequence in
                      generated_sequences]

    # Print the generated text
    for text in generated_text:
        print(text)

In [29]:
generate_sequence_from_tokens(tokenizer, comparison_pair.default_tokens)

Bob sent Alice a message about apples. Bob sent Alice a message about apples, but it was not the same message as the one he sent her. What happened?

Bob sent Alice a message about apples, but it was not the same message


In [30]:
generate_sequence_from_tokens(tokenizer, comparison_pair.sub_word_tokens)

Bob sent Alice a message about apples.

Sure, here's the message about apples:

"Apples are a delicious fruit that is enjoyed by people of all ages. They are a good


In [9]:
def process_batch(batch, tokenizer=tokenizer, tokenizer_config: str = "tokenizer.json", threshold: float = 0.01):
    batch_result = torch.empty(len(batch), 2, 2)
    for idx, input_sequence in enumerate(batch):
        comparison_pair = tokenize_custom(tokenizer, tokenizer_config, input_sequence)
        
        default_model_attention = compute_attention_pipeline(tokenizer, comparison_pair.default_tokens).attention
        sub_word_model_attention = compute_attention_pipeline(tokenizer, comparison_pair.sub_word_tokens).attention
        
        default_threshold_metric = threshold_noise_metric(default_model_attention, 0.01)
        default_entropy_metric = entropy_based_noise_metric(default_model_attention)
        batch_result[idx][0][0] = default_threshold_metric.source_noise_percentage_over_tokens
        batch_result[idx][0][1] = default_entropy_metric
        
        sub_word_threshold_metric = threshold_noise_metric(sub_word_model_attention, 0.01)
        sub_word_entropy_metric = entropy_based_noise_metric(sub_word_model_attention)
        batch_result[idx][1][0] = sub_word_threshold_metric.source_noise_percentage_over_tokens
        batch_result[idx][1][1] = sub_word_entropy_metric
        
    return batch_result


In [10]:
input_sequences = ["Bob sent Alice a message about apples.", "Cat didn't cross the street because it was tired."]

In [11]:
batch_result = process_batch(input_sequences, tokenizer, "tokenizer.json", 0.02)

In [12]:
batch_result

tensor([[[0.5814, 2.1523],
         [0.6891, 2.9215]],

        [[0.6385, 2.5323],
         [0.7214, 3.2424]]])