# Using attention layer outputs for prompt engineering

In [1]:
# color print function
from colored import Back, Style, Fore
from colorsys import hsv_to_rgb

def print_hsv(text, h, s, v):
    r, g, b = hsv_to_rgb(h / 360, s / 100, v / 100)
    r = round(r * 256)
    g = round(g * 256)
    b = round(b * 256)
    print(f'{Fore.black}{Back.rgb(r, g, b)}{text}{Style.reset}', end="")

## Load model

In [2]:
import torch
torch.set_default_device("mps")

In [3]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model_path = "./models/distilgpt2"
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    device_map="auto",
    torch_dtype=torch.float16,
    output_attentions=True
)
model.eval()
model.zero_grad()

In [4]:
tokenizer = AutoTokenizer.from_pretrained(
    model_path,
    device_map="auto"
)

In [186]:
prompt = """\
<|im_start|>system
Answer the questions using the given context.<|im_end|>
<|im_start|>user
Question 1: What day is it today?
Context 1: The current date is October 23rd, 2023<|im_end|>
<|im_start|>assistant
Answer 1: Today is October 23rd, 2023<|im_end|>
<|im_start|>user
Question 2: {question}
Context 2: {context}<|im_end|>
<|im_start|>assistant
Answer 2:"""

In [187]:
fprompt = prompt.format(question="Who is taller, Rick or Morty?", context="Rick and Morty are family. Rick has a greater height than Morty.")
# fprompt = "I am"
inputs = tokenizer(
    fprompt,
    return_tensors="pt"
)
output_tokens = model.generate(
    **inputs,
    temperature=0,
    max_new_tokens=3
)[0]
output = model(**inputs)
print("input tokens:", len(inputs["input_ids"][0]))
print("attention layers:", len(output["attentions"]))
print("??? (maybe just so it is a column vector?):", len(output["attentions"][0]))
print("attention heads:", len(output["attentions"][0][0]))
print("attention for each token, for each previous token:", len(output["attentions"][0][0][0]))
print("attention for last token, for all previous tokens:", len(output["attentions"][0][0][0][-1]))
print("\n--- output ---")
print(tokenizer.decode(output_tokens))
print("\n--- last token attentions ---")
print("attention:", list(zip(output["attentions"][0][0][0][-1].tolist(), [tokenizer.decode(t) for t in inputs["input_ids"][0]])))

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


input tokens: 151
attention layers: 6
??? (maybe just so it is a column vector?): 1
attention heads: 12
attention for each token, for each previous token: 151
attention for last token, for all previous tokens: 151

--- output ---
<|im_start|>system
Answer the questions using the given context.<|im_end|>
<|im_start|>user
Question 1: What day is it today?
Context 1: The current date is October 23rd, 2023<|im_end|>
<|im_start|>assistant
Answer 1: Today is October 23rd, 2023<|im_end|>
<|im_start|>user
Question 2: Who is taller, Rick or Morty?
Context 2: Rick and Morty are family. Rick has a greater height than Morty.<|im_end|>
<|im_start|>assistant
Answer 2: Today is October

--- last token attentions ---
attention: [(0.001575469970703125, '<'), (0.0004558563232421875, '|'), (0.0006155967712402344, 'im'), (0.00041675567626953125, '_'), (0.0014715194702148438, 'start'), (0.00042748451232910156, '|'), (0.00331878662109375, '>'), (0.0008907318115234375, 'system'), (0.001987457275390625, '\n')

In [172]:
total_input_tokens = len(tokenizer.encode(fprompt, add_special_tokens=True))

attentions = []
input_prompt = tokenizer.encode(fprompt, add_special_tokens=False)
for i in range(7):
    inputs = tokenizer(tokenizer.decode(input_prompt), return_tensors="pt")
    output = model(**inputs)
    output_token = output["logits"][0][-1].argmax()
    attentions.append((output["attentions"], tokenizer.decode(output_token)))
    input_prompt.append(int(output_token))

print("\n" + tokenizer.decode(input_prompt, skip_special_tokens=False))


<|im_start|>system
Answer the questions using the given context.<|im_end|>
<|im_start|>user
Question 1: What day is it today?
Context 1: The current date is October 23rd, 2023<|im_end|>
<|im_start|>assistant
Answer 1: Today is October 23rd, 2023<|im_end|>
<|im_start|>user
Question 2: Who is taller, Rick or Morty?
Context 2: Rick and Morty are family. Rick has a greater height than Morty.<|im_end|>
<|im_start|>assistant
Answer 2: Today is October 23rd, 20


In [173]:
# get attention for each token from each head
output_tokens = input_prompt
total_output_tokens = len(output_tokens)

total_generated_tokens = total_output_tokens - total_input_tokens
tokens_attentions = []
for output, token in attentions:
    first_attention_layer_heads = output[0][0]
    token_attentions = torch.zeros(total_output_tokens)

    # calculate average attention for each token over all heads
    for i in range(len(first_attention_layer_heads[0])):
    # for i in range(total_input_tokens):
        for head_attentions in first_attention_layer_heads:
            final_attention = head_attentions[-1]
            token_attentions[i] += final_attention[i]

    total_attention_heads = len(first_attention_layer_heads)
    token_attentions /= total_attention_heads
    tokens_attentions.append(token_attentions)

In [183]:
import pandas as pd

df = pd.DataFrame(
    columns=["token_id", "token"] + [f"att_p_{i}" for i in range(len(tokens_attentions))],
    data=list(zip(
        output_tokens,
        [tokenizer.decode(t) for t in output_tokens],
        *[
            (tokens_attentions[i] / tokens_attentions[i].max() * 100).round().int().tolist()
            for i in range(len(tokens_attentions))
        ]
    )),
)
pd.set_option('display.max_rows', 20)
df

Unnamed: 0,token_id,token,att_p_0,att_p_1,att_p_2,att_p_3,att_p_4,att_p_5,att_p_6
0,27,<,5,3,3,2,2,3,3
1,91,|,2,2,1,1,1,2,1
2,320,im,1,1,1,1,1,1,1
3,62,_,2,1,1,1,1,1,1
4,9688,start,2,1,1,1,1,1,1
...,...,...,...,...,...,...,...,...,...
153,3267,October,0,0,0,100,27,7,14
154,2242,23,0,0,0,0,100,100,21
155,4372,rd,0,0,0,0,0,90,10
156,11,",",0,0,0,0,0,0,100


In [176]:
# display attentions for each generated token
for i, token_attentions in enumerate(tokens_attentions):
    print("\n--- attentions ---")
    h = 114
    v = 80
    tokens_s = token_attentions / token_attentions.max() * 100
    pow = 1
    tokens_s = tokens_s.pow(pow) / (100 ** pow) * 100

    printed_tokens = []
    stop_at = total_input_tokens + i
    for j, (token, attention) in enumerate(zip(output_tokens, tokens_s)):
        new_tokens = printed_tokens + [token]
        text = tokenizer.decode(new_tokens)[len(tokenizer.decode(printed_tokens)):]
        s = attention.item()
        if j == stop_at:
            print_hsv(text, 14, 100, v)
            break
        print_hsv(text, h, s, v)
        printed_tokens = new_tokens


--- attentions ---
[38;5;0m[48;2;196;205;195m<[0m[38;5;0m[48;2;200;205;200m|[0m[38;5;0m[48;2;202;205;202mim[0m[38;5;0m[48;2;202;205;201m_[0m[38;5;0m[48;2;202;205;202mstart[0m[38;5;0m[48;2;200;205;200m|[0m[38;5;0m[48;2;198;205;197m>[0m[38;5;0m[48;2;199;205;199msystem[0m[38;5;0m[48;2;197;205;196m
[0m[38;5;0m[48;2;192;205;191mAnswer[0m[38;5;0m[48;2;199;205;198m the[0m[38;5;0m[48;2;198;205;197m questions[0m[38;5;0m[48;2;196;205;196m using[0m[38;5;0m[48;2;198;205;198m the[0m[38;5;0m[48;2;197;205;196m given[0m[38;5;0m[48;2;196;205;195m context[0m[38;5;0m[48;2;194;205;192m.<[0m[38;5;0m[48;2;197;205;196m|[0m[38;5;0m[48;2;199;205;199mim[0m[38;5;0m[48;2;197;205;196m_[0m[38;5;0m[48;2;199;205;198mend[0m[38;5;0m[48;2;198;205;197m|[0m[38;5;0m[48;2;195;205;194m>[0m[38;5;0m[48;2;197;205;196m
[0m[38;5;0m[48;2;197;205;196m<[0m[38;5;0m[48;2;199;205;198m|[0m[38;5;0m[48;2;200;205;200mim[0m[38;5;0m[48;2;199;205;198m_[0m[38

In [182]:
avg_tokens_attentions = torch.stack(tokens_attentions).sum(axis=0)
# each of the input tokens is got paid attention to `g` times
# each of the generated tokens got paid attention to `g-i` times, where `i` is the ith generated token
# the final token got paid attention to 0 times (and its attention is 0), we divide by 1 just so the dimensions line up
dividers = torch.tensor([total_generated_tokens] * total_input_tokens + list(reversed(range(1, total_generated_tokens))) + [1])
avg_tokens_attentions /= dividers

# calculate s value - token importance relative to all other tokens
tokens_s = avg_tokens_attentions / avg_tokens_attentions.max() * 100
pow = 0.4
tokens_s = tokens_s.pow(pow) / (100 ** pow) * 100

printed_tokens = []
for token, attention in zip(output_tokens, tokens_s):
    new_tokens = printed_tokens + [token]
    text = tokenizer.decode(new_tokens)[len(tokenizer.decode(printed_tokens)):]
    s = attention.item()
    print_hsv(text, h, s, v)
    printed_tokens = new_tokens

[38;5;0m[48;2;162;205;157m<[0m[38;5;0m[48;2;172;205;168m|[0m[38;5;0m[48;2;178;205;175mim[0m[38;5;0m[48;2;177;205;173m_[0m[38;5;0m[48;2;176;205;173mstart[0m[38;5;0m[48;2;172;205;169m|[0m[38;5;0m[48;2;167;205;163m>[0m[38;5;0m[48;2;169;205;165msystem[0m[38;5;0m[48;2;163;205;159m
[0m[38;5;0m[48;2;159;205;154mAnswer[0m[38;5;0m[48;2;164;205;159m the[0m[38;5;0m[48;2;163;205;159m questions[0m[38;5;0m[48;2;161;205;156m using[0m[38;5;0m[48;2;163;205;158m the[0m[38;5;0m[48;2;162;205;158m given[0m[38;5;0m[48;2;160;205;155m context[0m[38;5;0m[48;2;157;205;151m.<[0m[38;5;0m[48;2;163;205;158m|[0m[38;5;0m[48;2;167;205;163mim[0m[38;5;0m[48;2;163;205;158m_[0m[38;5;0m[48;2;167;205;163mend[0m[38;5;0m[48;2;164;205;160m|[0m[38;5;0m[48;2;161;205;156m>[0m[38;5;0m[48;2;163;205;158m
[0m[38;5;0m[48;2;163;205;158m<[0m[38;5;0m[48;2;166;205;162m|[0m[38;5;0m[48;2;170;205;167mim[0m[38;5;0m[48;2;167;205;163m_[0m[38;5;0m[48;2;171;205;

# Package

In [196]:
def prompt_engineer(model, tokenizer, max_new_tokens, prompt, stop_at = None):
    total_input_tokens = len(tokenizer.encode(prompt, add_special_tokens=True))
    attentions = []
    input_prompt = tokenizer.encode(prompt, add_special_tokens=False)
    for i in range(max_new_tokens):
        inputs = tokenizer(tokenizer.decode(input_prompt), return_tensors="pt")
        output = model(**inputs)
        output_token = output["logits"][0][-1].argmax()
        attentions.append((output["attentions"], tokenizer.decode(output_token)))
        input_prompt.append(int(output_token))

    # get attention for each token from each head
    output_tokens = input_prompt
    total_output_tokens = len(output_tokens)

    total_generated_tokens = total_output_tokens - total_input_tokens
    tokens_attentions = []
    for output, token in attentions:
        first_attention_layer_heads = output[0][0]
        token_attentions = torch.zeros(total_output_tokens)

        # calculate average attention for each token over all heads
        for i in range(len(first_attention_layer_heads[0])):
        # for i in range(total_input_tokens):
            for head_attentions in first_attention_layer_heads:
                final_attention = head_attentions[-1]
                token_attentions[i] += final_attention[i]

        total_attention_heads = len(first_attention_layer_heads)
        token_attentions /= total_attention_heads
        tokens_attentions.append(token_attentions)

    # display attentions for each generated token
    print("\n=== per token attentions ===")
    for i, token_attentions in enumerate(tokens_attentions):
        print("\n--- attentions ---")
        h = 114
        v = 80
        tokens_s = token_attentions / token_attentions.max() * 100
        pow = 1
        tokens_s = tokens_s.pow(pow) / (100 ** pow) * 100

        printed_tokens = []
        stop_at = total_input_tokens + i
        for j, (token, attention) in enumerate(zip(output_tokens, tokens_s)):
            new_tokens = printed_tokens + [token]
            text = tokenizer.decode(new_tokens)[len(tokenizer.decode(printed_tokens)):]
            s = attention.item()
            if j == stop_at:
                print_hsv(text, 14, 100, v)
                break
            print_hsv(text, h, s, v)
            printed_tokens = new_tokens

    print("\n=== avg attentions ===")
    avg_tokens_attentions = torch.stack(tokens_attentions).sum(axis=0)
    # each of the input tokens is got paid attention to `g` times
    # each of the generated tokens got paid attention to `g-i` times, where `i` is the ith generated token
    # the final token got paid attention to 0 times (and its attention is 0), we divide by 1 just so the dimensions line up
    dividers = torch.tensor([total_generated_tokens] * total_input_tokens + list(reversed(range(1, total_generated_tokens))) + [1])
    avg_tokens_attentions /= dividers

    # calculate s value - token importance relative to all other tokens
    tokens_s = avg_tokens_attentions / avg_tokens_attentions.max() * 100
    pow = 1
    tokens_s = tokens_s.pow(pow) / (100 ** pow) * 100

    printed_tokens = []
    for token, attention in zip(output_tokens, tokens_s):
        new_tokens = printed_tokens + [token]
        text = tokenizer.decode(new_tokens)[len(tokenizer.decode(printed_tokens)):]
        s = attention.item()
        print_hsv(text, h, s, v)
        printed_tokens = new_tokens

## Let's try to engineer

In [217]:
prompt = """\
<|im_start|>system
{system}<|im_end|>
<|im_start|>user
Question 1: What day is it today?
Context 1: The current date is October 23rd, 2023<|im_end|>
<|im_start|>assistant
Answer 1: Today is October 23rd, 2023<|im_end|>
<|im_start|>user
Question 2: {question}
Context 2: {context}<|im_end|>
<|im_start|>assistant
Answer 2:"""

In [218]:
from functools import partial
from transformers import pipeline
distilgpt2 = partial(prompt_engineer, model=model, tokenizer=tokenizer)
pipe = pipeline("text-generation", model="./models/distilgpt2", device_map="auto", do_sample=False)

In [219]:
print(pipe(prompt.format(
        system="Answer the questions using the given context.",
        context="My name is John Doe",
        question="What is my name?",
    ), return_full_text=False, max_new_tokens=10)[0]["generated_text"])
print(pipe(prompt.format(
        system="Answer the questions using the given context.",
        context="I live on Earth.",
        question="Which planet am I on?",
    ), return_full_text=False, max_new_tokens=10)[0]["generated_text"])

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


 Today is October 23rd, 2023<|
 Today is October 23rd, 2023<|


In [220]:
distilgpt2(
    max_new_tokens=5,
    prompt=prompt.format(
        system="Answer the questions using the given context.",
        context="My name is John Doe", 
        question="What is my name?",
    ),
)


=== per token attentions ===

--- attentions ---
[38;5;0m[48;2;195;205;194m<[0m[38;5;0m[48;2;200;205;199m|[0m[38;5;0m[48;2;202;205;202mim[0m[38;5;0m[48;2;201;205;200m_[0m[38;5;0m[48;2;201;205;201mstart[0m[38;5;0m[48;2;199;205;199m|[0m[38;5;0m[48;2;197;205;196m>[0m[38;5;0m[48;2;198;205;198msystem[0m[38;5;0m[48;2;195;205;194m
[0m[38;5;0m[48;2;191;205;190mAnswer[0m[38;5;0m[48;2;198;205;197m the[0m[38;5;0m[48;2;197;205;196m questions[0m[38;5;0m[48;2;195;205;194m using[0m[38;5;0m[48;2;198;205;197m the[0m[38;5;0m[48;2;196;205;195m given[0m[38;5;0m[48;2;196;205;195m context[0m[38;5;0m[48;2;193;205;191m.<[0m[38;5;0m[48;2;197;205;196m|[0m[38;5;0m[48;2;199;205;198mim[0m[38;5;0m[48;2;197;205;196m_[0m[38;5;0m[48;2;199;205;198mend[0m[38;5;0m[48;2;198;205;197m|[0m[38;5;0m[48;2;195;205;194m>[0m[38;5;0m[48;2;197;205;196m
[0m[38;5;0m[48;2;197;205;196m<[0m[38;5;0m[48;2;199;205;198m|[0m[38;5;0m[48;2;200;205;200mim[0m[38;5;

We can see that the examples are actually distracting the model! Let's remove them from the prompt.

In [221]:
prompt = """\
<|im_start|>system
{system}<|im_end|>
<|im_start|>user
Question: {question}
Context: {context}<|im_end|>
<|im_start|>assistant
Answer:"""

In [223]:
distilgpt2(
    max_new_tokens=10,
    prompt=prompt.format(
        system="Answer the questions using the given context.",
        context="My name is John Doe", 
        question="What is my name?",
    ),
)


=== per token attentions ===

--- attentions ---
[38;5;0m[48;2;173;205;169m<[0m[38;5;0m[48;2;193;205;191m|[0m[38;5;0m[48;2;196;205;195mim[0m[38;5;0m[48;2;191;205;190m_[0m[38;5;0m[48;2;193;205;192mstart[0m[38;5;0m[48;2;193;205;192m|[0m[38;5;0m[48;2;189;205;188m>[0m[38;5;0m[48;2;194;205;192msystem[0m[38;5;0m[48;2;193;205;192m
[0m[38;5;0m[48;2;187;205;185mAnswer[0m[38;5;0m[48;2;197;205;196m the[0m[38;5;0m[48;2;196;205;195m questions[0m[38;5;0m[48;2;195;205;194m using[0m[38;5;0m[48;2;198;205;197m the[0m[38;5;0m[48;2;197;205;196m given[0m[38;5;0m[48;2;196;205;195m context[0m[38;5;0m[48;2;191;205;190m.<[0m[38;5;0m[48;2;198;205;197m|[0m[38;5;0m[48;2;200;205;200mim[0m[38;5;0m[48;2;198;205;197m_[0m[38;5;0m[48;2;200;205;200mend[0m[38;5;0m[48;2;199;205;198m|[0m[38;5;0m[48;2;195;205;194m>[0m[38;5;0m[48;2;197;205;196m
[0m[38;5;0m[48;2;194;205;193m<[0m[38;5;0m[48;2;199;205;198m|[0m[38;5;0m[48;2;201;205;200mim[0m[38;5;