In [1]:
!pip install sentencepiece protobuf transformers



In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from torch.utils import _pytree as pytree
import textwrap
AUTH_TOKEN = "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"


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
mdl = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf",
    torch_dtype=torch.float,
    use_auth_token=AUTH_TOKEN,
)

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.12it/s]


In [4]:
tokenizer = AutoTokenizer.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf",
    use_fast=False,
    use_auth_token=AUTH_TOKEN,
)



In [5]:
def summarize_results(results):
    past_key_values, _ = pytree.tree_flatten(results.past_key_values)
    print("Logits:", pytree.tree_map(lambda x: x.shape, results.logits))
    print(f"PKV (len={len(past_key_values)}):")
    count = 0
    prev = ""
    for s in pytree.tree_map(lambda x: repr(x.shape), past_key_values):
        if s == prev:
            count += 1
            continue
        elif count:
            print(" ", s, f"* {count+1}" if count else "")
            count = 0
        prev = s
    if count:
        print(" ", s, f"* {count+1}" if count else "")
    
    


In [6]:
prompt = (
        "System: You are a helpful, respectful and honest assistant. Always answer "
        "as helpfully as possible, while being safe.  Your answers should not "
        "include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal "
        "content. Please ensure that your responses are socially unbiased and positive "
        "in nature. If a question does not make any sense, or is not factually coherent, "
        "explain why instead of answering something not correct. If you don't know the "
        "answer to a question, please don't share false information."
    )
conversation = prompt + "<|USER|>Should Bugs Bunny have turned left at Albuquerque?"

initial_input = tokenizer(conversation, return_tensors="pt")
print("Example input:", initial_input)
print("  Shape:", initial_input.input_ids.shape)
initial_results = mdl.forward(initial_input.input_ids)
summarize_results(initial_results)

Example input: {'input_ids': tensor([[    1,  2184, 29901,   887,   526,   263,  8444, 29892,  3390,  1319,
           322, 15993, 20255, 29889, 29849,  1234,   408,  1371,  3730,   408,
          1950, 29892,  1550,  1641,  9109, 29889, 29871,  3575,  6089,   881,
           451,  3160,   738, 10311,  1319, 29892,   443,   621,   936, 29892,
         11021,   391, 29892,  7916,   391, 29892,   304, 27375, 29892, 18215,
         29892,   470, 27302,  2793, 29889,  3529,  9801,   393,   596, 20890,
           526,  5374,   635,   443,  5365,  1463,   322,  6374,   297,  5469,
         29889,   960,   263,  1139,   947,   451,  1207,   738,  4060, 29892,
           470,   338,   451,  2114,  1474, 16165,   261,   296, 29892,  5649,
          2020,  2012,   310, 22862,  1554,   451,  1959, 29889,   960,   366,
          1016, 29915, 29873,  1073,   278,  1234,   304,   263,  1139, 29892,
          3113,  1016, 29915, 29873,  6232,  2089,  2472, 19423, 29989, 11889,
         29989, 29958, 

In [7]:
all_tokens = []
all_detoks = []
def decode_token(results, index=-1, store=True):
    print("Logits:", results.logits.shape)
    print("Logits reshaped:", results.logits[:, index, :].shape)
    token = torch.argmax(results.logits[:, index, :], dim=1)
    detok = tokenizer.decode(token, skip_special_tokens=False)
    print(f"--> Decoded: '{detok}' ({token})")
    if store:
        all_tokens.append(token[0])
        all_detoks.append(detok)
    return token, detok

# Decode initial token
# for i in range(initial_results.logits.shape[1]):
#     token, detok = decode_token(initial_results, index=i)
token, detok = decode_token(initial_results, store=True)

Logits: torch.Size([1, 136, 32000])
Logits reshaped: torch.Size([1, 32000])
--> Decoded: '</' (tensor([829]))


In [8]:
# Decode loop for subsequent tokens.
current_results = initial_results
for _ in range(500):
    next_input_token = torch.reshape(token, [1, 1])
    print("Next input token:", next_input_token)
    step_results = mdl.forward(next_input_token, past_key_values=current_results.past_key_values)
    summarize_results(step_results)
    token, detok = decode_token(step_results)
    if token[0] == 2:
        break
    current_results = step_results

print("All tokens:", all_tokens)
print("All detoks:", all_detoks)

print(conversation)
print(tokenizer.decode(all_tokens))

Next input token: tensor([[829]])
Logits: torch.Size([1, 1, 32000])
PKV (len=64):
  torch.Size([1, 32, 137, 128]) * 64
Logits: torch.Size([1, 1, 32000])
Logits reshaped: torch.Size([1, 32000])
--> Decoded: 'user' (tensor([1792]))
Next input token: tensor([[1792]])
Logits: torch.Size([1, 1, 32000])
PKV (len=64):
  torch.Size([1, 32, 138, 128]) * 64
Logits: torch.Size([1, 1, 32000])
Logits reshaped: torch.Size([1, 32000])
--> Decoded: '>' (tensor([29958]))
Next input token: tensor([[29958]])
Logits: torch.Size([1, 1, 32000])
PKV (len=64):
  torch.Size([1, 32, 139, 128]) * 64
Logits: torch.Size([1, 1, 32000])
Logits reshaped: torch.Size([1, 32000])
--> Decoded: '' (tensor([29871]))
Next input token: tensor([[29871]])
Logits: torch.Size([1, 1, 32000])
PKV (len=64):
  torch.Size([1, 32, 140, 128]) * 64
Logits: torch.Size([1, 1, 32000])
Logits reshaped: torch.Size([1, 32000])
--> Decoded: 'I' (tensor([306]))
Next input token: tensor([[306]])
Logits: torch.Size([1, 1, 32000])
PKV (len=64):
  