In [1]:
import openai
import dotenv
import os

dotenv.load_dotenv()

client = openai.OpenAI(
  api_key=os.getenv("HYPERBOLIC_API_KEY"),
  base_url="https://api.hyperbolic.xyz/v1/",
)

In [2]:
user_content = "For a 3 sets' tennis game, would you bet on it finishing in 2 sets or 3 sets, assuming each player has an equal probability of winning a set?"
user_content = "Is 1027 a prime number?"

chat_completion = client.chat.completions.create(
    model="meta-llama/Meta-Llama-3.1-8B-Instruct",
    messages=[
        # {"role": "system", "content": system_content},
        {"role": "user", "content": user_content},
    ],
    temperature=0.7,
    max_tokens=1024,
    logprobs=True,
    top_logprobs=20,
)


In [3]:
import pprint

pprint.pprint(chat_completion.choices[0].message.content)

('To check if 1027 is a prime number, we can try dividing it by prime numbers '
 'less than or equal to its square root.\n'
 '\n'
 'The square root of 1027 is approximately 32.15.\n'
 '\n'
 'So, we will check for divisors up to 32.\n'
 '\n'
 '1027 is not divisible by 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, or 31.\n'
 '\n'
 'However, 1027 is divisible by 17 and 60.\n'
 '\n'
 '17 * 60 = 1020\n'
 '\n'
 'But 1027 is 7 more than 1020.\n'
 '\n'
 'So the factors of 1027 are 17 and 60 + 7 = 67.\n'
 '\n'
 'Thus 1027 is 17 * 67 and is not a prime number.')


In [4]:
from pathlib import Path
import torch

from entropix.config import LLAMA_1B_PARAMS
from entropix.tokenizer import Tokenizer
from entropix.torch_kvcache import KVCache
from entropix.torch_main import precompute_freqs_cis
from entropix.torch_weights import load_weights
from entropix.torch_model import xfmr

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

with torch.inference_mode():
    model_params = LLAMA_1B_PARAMS
    xfmr_weights = load_weights(ckpt_dir=Path("weights/1B-Instruct"), should_compare_outputs=True)

    tokenizer = Tokenizer('entropix/tokenizer.model')
    bsz = 1
    kvcache = KVCache.new(model_params.n_layers, bsz, model_params.max_seq_len, model_params.n_local_kv_heads, model_params.head_dim).to(DEVICE)


In [5]:
prefill_str = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>
{user_content}
<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""
prefill_tokens = tokenizer.encode(prefill_str, bos=False, eos=False, allowed_special='all')

result_str = f"""{chat_completion.choices[0].message.content}<|eot_id|>"""
result_tokens = tokenizer.encode(result_str, bos=False, eos=False, allowed_special='all')

joined_str = prefill_str + "\n" + result_str

joined_tokens = tokenizer.encode(joined_str,  bos=False, eos=False, allowed_special='all')

In [6]:
# single pass

from entropix.torch_main import build_attn_mask, precompute_freqs_cis
from entropix.torch_model import xfmr

tokens = torch.tensor([joined_tokens], dtype=torch.long).to(DEVICE)
seqlen = tokens.size(1)

cur_pos = 0
freqs_cis = precompute_freqs_cis(model_params.head_dim, model_params.max_seq_len, model_params.rope_theta, model_params.use_scaled_rope) 
attn_mask = build_attn_mask(seqlen, cur_pos)
bsz, seqlen = tokens.shape

with torch.inference_mode():
    logits, kvcache, scores, _ = xfmr(xfmr_weights, model_params, tokens, cur_pos, freqs_cis[:seqlen], kvcache, attn_mask=attn_mask)


In [13]:
from entropix.torch_sampler import calculate_metrics


for i in range(len(prefill_tokens), len(joined_tokens) - 1):
    ground_truth_token_str = tokenizer.decode([joined_tokens[i + 1]]).encode("unicode_escape").decode("utf-8")
    
    # compute statistics for each token
    mx = calculate_metrics(logits[:, i, :], scores[:, :, :i, :i], len(prefill_tokens) + i)
    mx_clean = {k: v.item() for k, v in mx.items()}

    mx_results = {}

    # what is the probability of the ground truth token, in the model's output?
    prob = torch.nn.functional.softmax(logits[:, i, :], dim=-1)
    prob_gt = prob[0, joined_tokens[i+1]].item()
    mx_results["prob_gt"] = prob_gt

    # what is the rank of the ground truth token?
    _, topk = torch.topk(prob, k=10000, dim=-1)
    topk = topk[0].tolist()
    try:
        mx_results["rank_gt"] = topk.index(joined_tokens[i+1])
    except ValueError:
        mx_results["rank_gt"] = -1

    # what was the top greedy token?
    mx_results["top1"] = tokenizer.decode([topk[0]]).encode("unicode_escape").decode("utf-8")

    outcome = ""
    if mx_results["rank_gt"] == 0:
        # this, in practice, might just be adaptive sample
        outcome = "greedy"
    elif mx_results["rank_gt"] < 10:
        if (mx_clean["logits_entropy"] ** 2) > mx_clean["logits_varentropy"]:
            outcome = "clarify"
        else:
            outcome = "fork"
    else:
        outcome = "resample"


    print(f"{ground_truth_token_str}\t {mx_results}\t{mx_clean}")


To	 {'prob_gt': 0.2294921875, 'rank_gt': 1, 'top1': 'No'}	{'logits_entropy': 2.203125, 'logits_varentropy': 5.4375, 'attn_entropy': 1.7265625, 'attn_varentropy': 0.37109375, 'agreement': 0.0272216796875, 'interaction_strength': 2.9375}
 check	 {'prob_gt': 0.05908203125, 'rank_gt': 1, 'top1': ' determine'}	{'logits_entropy': 0.48828125, 'logits_varentropy': 2.0625, 'attn_entropy': 1.7734375, 'attn_varentropy': 0.3984375, 'agreement': 0.026611328125, 'interaction_strength': 2.9375}
 if	 {'prob_gt': 0.984375, 'rank_gt': 0, 'top1': ' if'}	{'logits_entropy': 0.11767578125, 'logits_varentropy': 0.66796875, 'attn_entropy': 1.78125, 'attn_varentropy': 0.412109375, 'agreement': 0.0252685546875, 'interaction_strength': 3.0}
 	 {'prob_gt': 0.87890625, 'rank_gt': 0, 'top1': ' '}	{'logits_entropy': 0.53125, 'logits_varentropy': 0.921875, 'attn_entropy': 1.8046875, 'attn_varentropy': 0.421875, 'agreement': 0.0240478515625, 'interaction_strength': 3.046875}
102	 {'prob_gt': 1.0, 'rank_gt': 0, 'top1':