In [5]:
import os, pickle
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed, logging

logging.set_verbosity_error()

In [6]:
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
input_text = "The quick brown fox"

def inference(input_text, model_name=model_name):
    
    set_seed(42)
    model = AutoModelForCausalLM.from_pretrained(model_name)
    model.eval()

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    input_tokens = tokenizer.encode(input_text, return_tensors='pt')

    outputs = model.generate(
        input_tokens,
        max_length=10,
        do_sample=True,
        temperature=1.0,
        output_scores=True,
        return_dict_in_generate=True,
        output_hidden_states=True,
        pad_token_id=tokenizer.eos_token_id
    )

    # Save model computed attention pairs right before the pivot word is generated
    past_key_values = outputs.past_key_values
    past_output_tokens = outputs.sequences

    set_seed(42)
    model = AutoModelForCausalLM.from_pretrained(model_name)
    model.eval()
    outputs = model.generate(
        outputs.sequences,
        max_length=outputs.sequences.shape[1] + 1,
        do_sample=True,
        temperature=1.0,
        output_scores=True,
        return_dict_in_generate=True,
        output_hidden_states=True,
        pad_token_id=tokenizer.eos_token_id,
        use_cache=True,
        past_key_values=past_key_values
    )

    output_tokens = outputs.sequences
    output_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True)

    pivot_token = output_tokens[0][-1]
    pivot_token_logits = outputs.scores[-1]

    return output_text, past_output_tokens, past_key_values, pivot_token, pivot_token_logits


output_text_a, output_tokens, past_key_values, pivot_token, pivot_token_logits = inference(input_text)


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [7]:
print(f"{'Alice\'s Prompt:':<15} {input_text}")
print(f"{'Bob\'s Response:':<15} {output_text_a}")
print(f"Bob\'s Pivot Token: {pivot_token}")
print(f"Bob's Pivot Token Logits: {pivot_token_logits[pivot_token_logits > 0].numpy()}")

with open("past.pkl", "wb") as file:
    pickle.dump(past_key_values, file)
size = os.path.getsize("past.pkl")
print(f"Past size (KB): {size // 1024}")

Alice's Prompt: The quick brown fox
Bob's Response: The quick brown fox jumps over the lazy dog.

Bob's Pivot Token: 627
Bob's Pivot Token Logits: [15.159561 17.353573 16.755339 15.155735 17.361267 14.286188]
Past size (KB): 2323


In [8]:
def validation(past_output_tokens, past_key_values, pivot_token, pivot_token_logits, model_name=model_name):
    
    set_seed(42)
    model = AutoModelForCausalLM.from_pretrained(model_name)
    model.eval()

    tokenizer = AutoTokenizer.from_pretrained(model_name)

    outputs = model.generate(
        past_output_tokens,
        max_length=past_output_tokens.shape[1] + 1,
        do_sample=True,
        temperature=1.0,
        output_scores=True,
        return_dict_in_generate=True,
        output_hidden_states=True,
        pad_token_id=tokenizer.eos_token_id,
        use_cache=True,
        past_key_values=past_key_values
    )

    output_tokens = outputs.sequences
    output_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
    tokens_match = pivot_token==output_tokens[0][-1]
    logits_match = torch.allclose(pivot_token_logits, outputs.scores[-1], atol=1e-6)

    return output_text, tokens_match, logits_match, output_tokens[0][-1], outputs.scores[-1]


output_text_b, equal_tokens, equal_logits, token, logits = validation(output_tokens, past_key_values, pivot_token, pivot_token_logits)


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [11]:
print(f"{'Alice\'s Prompt:':<15} {input_text}")
print(f"{'Bob\'s Response:':<15} {output_text_a}")
print(f"{'Bob\'s Pivot Token:':<25}: {pivot_token}")
print(f"{'Alices\'s Pivot Token:':<25}: {token}")

print(f"{'Bob\'s Pivot Token Logits:':<30} {pivot_token_logits[pivot_token_logits > 0].numpy()}")
print(f"{'Alices\'s Pivot Token Logits:':<30} {logits[logits > 0].numpy()}")

Alice's Prompt: The quick brown fox
Bob's Response: The quick brown fox jumps over the lazy dog.

Bob's Pivot Token:       : 627
Alices's Pivot Token:    : 627
Bob's Pivot Token Logits:      [15.159561 17.353573 16.755339 15.155735 17.361267 14.286188]
Alices's Pivot Token Logits:   [15.159561 17.353573 16.755339 15.155735 17.361267 14.286188]
