In [11]:
import random
from typing import Tuple
import os
os.environ['HF_HOME'] = '/mnt/storage/hf_cache'


EXAMPLE_DESCRIPTION = """For example if the question is "
What is 13 % 3? Is it
A: 1
B: 0
C: 2
?"
you should output the string A, because 13 = 3 * 4 + 1, so the remainder is 1.
"""


PROMPT_TEMPLATE = """
Your goal is to answer a multiple-choice arithmetic question. It will be in a format
"What is <number> % 3? Is it
A: <digit1>
B: <digit2>
C: <digit3>
?"
where <digit1>, <digit2>, and <digit3> are digits 0 through 2.

You are expected to output only the uppercaseletter of the correct answer: A, B, C. Output nothing else. {EXAMPLE_DESCRIPTION}

With that in mind, here is the question:

--------------------------------
What is {NUMBER} % {MODULO}? Is it
A: {ANSWER_A}
B: {ANSWER_B}
C: {ANSWER_C}
?
"""

def generate_question(num_digits: int, rng: random.Random) -> Tuple[str, int]:
    number = rng.randint(10**(num_digits - 1), 10**num_digits - 1)
    answers = [0, 1, 2]
    rng.shuffle(answers)
    
    correct_answer_index = answers.index(number % 3)
    
    return PROMPT_TEMPLATE.format(NUMBER=number, MODULO=3, ANSWER_A=answers[0], ANSWER_B=answers[1], ANSWER_C=answers[2], EXAMPLE_DESCRIPTION=EXAMPLE_DESCRIPTION), correct_answer_index


In [2]:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, Phi3ForCausalLM, LlamaForCausalLM
from transformers.pipelines.text_generation import Chat

torch.set_float32_matmul_precision('high')
torch.random.manual_seed(0)

model: LlamaForCausalLM = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3.1-8B-Instruct", 
    # "microsoft/Phi-3.5-mini-instruct", 
    device_map="cuda", 
    # torch_dtype="float32", 
    torch_dtype="auto",
)
model.eval()
# model = torch.compile(model)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
# tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-mini-instruct")


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [3]:
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    batch_size=8,
)



test_rng = random.Random(42)
examples = [generate_question(1, test_rng) for _ in range(1000)]

# examples[:3]

example_dialogs = [[
    {"role": "system", "content": "You are a helpful AI assistant."},
    {"role": "user", "content": prompt}
] for prompt, _ in examples]

In [4]:
# examples[5][0]

In [12]:
from tqdm import tqdm
import transformers.tokenization_utils_base

def preprocess_manual(messages: list[list[dict]]) -> transformers.tokenization_utils_base.BatchEncoding:
    one_by_one = [pipe.preprocess(Chat(message), return_tensors="pt") for message in messages]
    # stack
    return transformers.tokenization_utils_base.BatchEncoding(
        {k: torch.cat([t[k] for t in one_by_one], dim=0) for k in one_by_one[0].keys() if k != "prompt_text" }
    )
    

@torch.no_grad()
def batched_eval(fn, *args, batch_size=8, use_tqdm=True, **kwargs):
    assert all(isinstance(arg, torch.Tensor) for arg in args), f"All arguments must be tensors, got {args}"
    assert all(isinstance(arg, torch.Tensor) for arg in kwargs.values()), f"All kwargs must be tensors, got {kwargs}"
    
    bs1 = args[0].shape[0] if len(args) > 0 else list(kwargs.values())[0].shape[0]
    assert all(arg.shape[0] == bs1 for arg in args), f"All arguments must have the same batch size, got {args}"
    assert all(arg.shape[0] == bs1 for arg in kwargs.values()), f"All kwargs must have the same batch size, got {kwargs}"
    res = None
    iter = range(0, bs1, batch_size) if not use_tqdm else tqdm(range(0, bs1, batch_size))
    for i in iter:
        upto = min(i + batch_size, bs1)
        partial_res = fn(*[arg[i:upto] for arg in args], **{k: v[i:upto] for k, v in kwargs.items()})
        if res is None:
            res = torch.empty((bs1, *partial_res.shape[1:]), device=partial_res.device, dtype=partial_res.dtype)
        res[i:upto] = partial_res
    return res

# torch.vmap(model, chunk_size=2)(**(inputs[:8]))
    
# inputs[:8]
    
with torch.no_grad():
    inputs = preprocess_manual(example_dialogs).to('cuda')
    tokens_of_interest = tokenizer(['A', 'B', 'C'], return_tensors="pt", add_special_tokens=False)['input_ids'].to('cuda')
    if tokens_of_interest.shape != (3, 1):
        raise ValueError(f"Expected tokens_of_interest to have shape (3, 1), got {tokens_of_interest.shape}")
    tokens_of_interest = tokens_of_interest[:, 0]
    
    def inner(*args, **kwargs):
        # with torch.amp.autocast('cuda'):
        return model.forward(*args, **kwargs).logits[:, -1, tokens_of_interest]

    res = batched_eval(inner, **inputs[:100], batch_size=16)


# tokenizer.batch_decode(model.generate(**preprocess_manual(example_dialogs[:10]).to('cuda'), max_new_tokens=50))



100%|██████████| 7/7 [00:02<00:00,  2.37it/s]


In [13]:
res

tensor([[30.5000, 29.2500, 29.2500],
        [30.2500, 28.7500, 29.5000],
        [30.5000, 30.1250, 26.1250],
        [30.8750, 29.7500, 28.6250],
        [30.1250, 28.5000, 29.6250],
        [31.2500, 30.2500, 26.5000],
        [31.2500, 30.0000, 29.2500],
        [31.2500, 30.1250, 27.7500],
        [30.6250, 29.2500, 28.8750],
        [31.3750, 31.0000, 27.1250],
        [30.8750, 30.6250, 28.0000],
        [30.7500, 30.2500, 28.5000],
        [31.6250, 30.5000, 27.2500],
        [31.3750, 29.7500, 26.7500],
        [31.2500, 30.0000, 29.2500],
        [31.0000, 30.5000, 27.2500],
        [31.1250, 30.7500, 25.6250],
        [30.1250, 28.5000, 29.6250],
        [31.0000, 29.5000, 28.8750],
        [31.0000, 30.2500, 28.3750],
        [31.5000, 30.5000, 27.3750],
        [31.6250, 28.5000, 26.7500],
        [31.3750, 30.2500, 27.3750],
        [30.5000, 30.1250, 26.1250],
        [31.2500, 30.2500, 26.5000],
        [31.6250, 28.5000, 26.7500],
        [31.5000, 30.3750, 26.5000],
 

In [14]:
(torch.tensor([v for _, v in examples[:len(res)]]) == res.argmax(dim=-1).to('cpu')).float().mean()



tensor(0.3300)

In [18]:

logits_of_answers = res.logits[:, -1, tokens_of_interest]
logits_of_answers



tensor([[45.2500, 52.5000, 46.5000],
        [54.5000, 50.2500, 46.0000],
        [44.5000, 52.5000, 48.0000],
        [47.7500, 54.0000, 54.7500],
        [51.0000, 40.7500, 42.2500],
        [44.5000, 52.2500, 47.2500],
        [53.7500, 47.0000, 56.0000],
        [49.7500, 52.7500, 41.7500],
        [34.2500, 59.7500, 48.5000],
        [46.5000, 51.0000, 48.2500]], device='cuda:0')

In [11]:
# res.logits[:, -1:, :].argmax(dim=-1),\
#     tokenizer(['A', 'B', 'C'])

# tokenizer.batch_decode(res.logits[:, -1:, :].argmax(dim=-1))

(tensor([[350],
         [319],
         [350],
         [315],
         [319],
         [350],
         [315],
         [350],
         [350],
         [350]], device='cuda:0'),
 {'input_ids': [[319], [350], [315]], 'attention_mask': [[1], [1], [1]]})