In [1]:
import random
from typing import Tuple
import os
os.environ['HF_HOME'] = '/mnt/storage/hf_cache'


In [2]:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, Phi3ForCausalLM, LlamaForCausalLM
from transformers.pipelines.text_generation import Chat

torch.set_float32_matmul_precision('high')
torch.random.manual_seed(0)

MODEL_NAME = "NousResearch/Hermes-3-Llama-3.1-8B"

model: LlamaForCausalLM = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, 
    # "microsoft/Phi-3.5-mini-instruct", 
    device_map="cuda", 
    # torch_dtype="float32", 
    torch_dtype="auto",
)
model.eval()
# model = torch.compile(model, dynamic=True)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-mini-instruct")


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [3]:
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    # batch_size=8,
)


EXAMPLE_DESCRIPTION = """For example if the question is "
What is 75 - 66? Is it
A: 0
B: 9
C: 3
?"
you should output the string B, because 475 - 466 = 9.
"""


PROMPT_TEMPLATE = """
Your goal is to answer a multiple-choice arithmetic question doing subtraction. It will be in a format
"What is <number> - <number>? Is it
A: <digit1>
B: <digit2>
C: <digit3>
?"
where <digit1>, <digit2>, and <digit3> are one digit numbers.

You are expected to output only the uppercase letter of the correct answer: A, B, C. Output nothing else. {EXAMPLE_DESCRIPTION}

Remember to not succumb to answer order bias. Don't automatically choose the first answer.
With that in mind, here is the question:

--------------------------------
What is {NUMBER} - {NUMBER_TO_SUBTRACT}? Is it
A: {ANSWER_A}
B: {ANSWER_B}
C: {ANSWER_C}
?
"""

def generate_question(num_digits: int, rng: random.Random) -> Tuple[str, int]:
    number = rng.randint(10**(num_digits - 1), 10**num_digits - 1)
    difference = rng.randint(0,9)
    number_to_subtract = number - difference
    
    answer_candidates = rng.sample([i for i in range(10) if i != number_to_subtract], 2) + [difference]
    rng.shuffle(answer_candidates)
    
    # answers = [0, 1, 2]
    # rng.shuffle(answers)
    
    correct_answer_index = answer_candidates.index(difference)
    
    return PROMPT_TEMPLATE.format(NUMBER=number, NUMBER_TO_SUBTRACT=number_to_subtract, ANSWER_A=answer_candidates[0], ANSWER_B=answer_candidates[1], ANSWER_C=answer_candidates[2],
                                #   EXAMPLE_DESCRIPTION=EXAMPLE_DESCRIPTION
                                  EXAMPLE_DESCRIPTION=""
                                  ), correct_answer_index



test_rng = random.Random(42)
examples = [generate_question(2, test_rng) for _ in range(1000)]

# examples[:3]

example_dialogs = [[
    {"role": "system", "content": "You are a helpful AI assistant."},
    {"role": "user", "content": prompt},
] for prompt, _ in examples]

The model 'OptimizedModule' is not supported for text-generation. Supported models are ['BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CohereForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'DbrxForCausalLM', 'ElectraForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'FalconMambaForCausalLM', 'FuyuForCausalLM', 'GemmaForCausalLM', 'Gemma2ForCausalLM', 'GitForCausalLM', 'GlmForCausalLM', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'GraniteForCausalLM', 'GraniteMoeForCausalLM', 'JambaForCausalLM', 'JetMoeForCausalLM', 'LlamaForCausalLM', 'MambaForCausalLM', 'Mamba2ForCausalLM', 'MarianForCausalLM', 'MBartForCausalLM',

In [4]:
print(examples[7][0])
# print(examples[6][1])


Your goal is to answer a multiple-choice arithmetic question doing subtraction. It will be in a format
"What is <number> - <number>? Is it
A: <digit1>
B: <digit2>
C: <digit3>
?"
where <digit1>, <digit2>, and <digit3> are one digit numbers.

You are expected to output only the uppercase letter of the correct answer: A, B, C. Output nothing else. 

Remember to not succumb to answer order bias. Don't automatically choose the first answer.
With that in mind, here is the question:

--------------------------------
What is 15 - 8? Is it
A: 7
B: 9
C: 1
?



In [5]:
pipe.tokenizer.pad_token_id = model.config.eos_token_id
pipe(example_dialogs[:1], max_new_tokens=50)

[[{'generated_text': [{'role': 'system',
     'content': 'You are a helpful AI assistant.'},
    {'role': 'user',
     'content': '\nYour goal is to answer a multiple-choice arithmetic question doing subtraction. It will be in a format\n"What is <number> - <number>? Is it\nA: <digit1>\nB: <digit2>\nC: <digit3>\n?"\nwhere <digit1>, <digit2>, and <digit3> are one digit numbers.\n\nYou are expected to output only the uppercase letter of the correct answer: A, B, C. Output nothing else. \n\nRemember to not succumb to answer order bias. Don\'t automatically choose the first answer.\nWith that in mind, here is the question:\n\n--------------------------------\nWhat is 91 - 90? Is it\nA: 4\nB: 1\nC: 0\n?\n'},
    {'role': 'assistant', 'content': 'A'}]}]]

In [6]:
from typing import Literal
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from torch import nn
from pydantic import BaseModel

class NoiseInjectionConfig(BaseModel):
    mode: Literal["none", "antipodal", "random", "random_mult"] = "none"
    noise_mean: float = 0.0
    noise_std: float = 5e-1
    # batch_size: int = 16

NOISE_INJECTION_CONFIG: NoiseInjectionConfig | None = None
    
# NOISE_INJECTION_MODE: Literal["none", "antipodal", "random", "random_mult"] = "none"

class LlamaDecoderLayerWrapper(nn.Module):
    def __init__(self, layer: LlamaDecoderLayer):
        super().__init__()
        self.layer = layer
        ref_weight = self.layer.mlp.gate_proj.weight
        # self.noise_mean = nn.Parameter(torch.tensor(noise_mean, device=ref_weight.device, dtype=ref_weight.dtype))
        # self.noise_std = nn.Parameter(torch.tensor(noise_std, device=ref_weight.device, dtype=ref_weight.dtype))

    def forward(self, hidden_states, *args, **kwargs):
        batch_dim = hidden_states.shape[0]
        
        if NOISE_INJECTION_CONFIG is None:
            return self.layer(hidden_states, *args, **kwargs)
        
        # print(hidden_states.mean(dim=(1, 2))[0])
        if NOISE_INJECTION_CONFIG.mode=="antipodal":
            # print("INJECTING NOISE AAA")
            assert batch_dim % 2 == 0
            half_dim = batch_dim // 2
            noise = torch.randn(half_dim, *hidden_states.shape[1:], device=hidden_states.device, dtype=hidden_states.dtype) * NOISE_INJECTION_CONFIG.noise_std + NOISE_INJECTION_CONFIG.noise_mean
            noise = torch.cat([noise, -noise], dim=0)
            new_hidden_states = hidden_states + noise
        elif NOISE_INJECTION_CONFIG.mode=="random":
            # print("INJECTING NOISE BBB")
            noise = torch.randn(hidden_states.shape, device=hidden_states.device, dtype=hidden_states.dtype) * NOISE_INJECTION_CONFIG.noise_std + NOISE_INJECTION_CONFIG.noise_mean
            new_hidden_states = hidden_states + noise
        elif NOISE_INJECTION_CONFIG.mode=="random_mult":
            noise = torch.randn(hidden_states.shape, device=hidden_states.device, dtype=hidden_states.dtype) * NOISE_INJECTION_CONFIG.noise_std + NOISE_INJECTION_CONFIG.noise_mean
            new_hidden_states = hidden_states * (1 + noise)
        else:
            new_hidden_states = hidden_states
        
        # print(((new_hidden_states - hidden_states).abs() / (hidden_states.abs() + 1e-10)).mean())
        return self.layer(new_hidden_states, *args, **kwargs)


LAYER = 30
for l in range(LAYER, len(model.model.layers)):
    while str(type(model.model.layers[l])) == "<class '__main__.LlamaDecoderLayerWrapper'>":
        model.model.layers[l] = model.model.layers[l].layer
    else:
        break
model.model.layers[LAYER] = LlamaDecoderLayerWrapper(model.model.layers[LAYER])
import transformers.tokenization_utils_base

def preprocess_manual(messages: list[list[dict]]) -> transformers.tokenization_utils_base.BatchEncoding:
    one_by_one = [pipe.preprocess(Chat(message), return_tensors="pt") for message in messages]
    # stack
    return transformers.tokenization_utils_base.BatchEncoding(
        {k: torch.cat([t[k] for t in one_by_one], dim=0) for k in one_by_one[0].keys() if k != "prompt_text" }
    )
    
with torch.no_grad():
    inputs = preprocess_manual(example_dialogs).to('cuda')

In [7]:
def find_kvcacheable_size(input_ids: torch.Tensor) -> torch.Tensor:
    return torch.all(input_ids[0] == input_ids, dim=0).int().argmin()

find_kvcacheable_size(inputs['input_ids'])

from transformers import DynamicCache


# test
@torch.no_grad()
def test_kvcache():
    test_num = 4
    base_logits = model(**(inputs[:test_num])).logits[:, -1, :]
    print(base_logits.shape)

    kv_size = find_kvcacheable_size(inputs['input_ids'])
    past_key_values = DynamicCache()
    pkv = model(inputs['input_ids'][:1, :kv_size], past_key_values=past_key_values, use_cache=True)

    print(pkv.past_key_values[0][0].shape)
    pkv.past_key_values.batch_repeat_interleave(test_num)
    new_logits = model(inputs['input_ids'][:test_num, kv_size:], past_key_values=pkv.past_key_values, use_cache=True).logits[:, -1, :]
    print(new_logits.shape)
    
    print (new_logits - base_logits)

test_kvcache()

torch.Size([4, 128256])
torch.Size([1, 8, 144, 128])
torch.Size([4, 128256])
tensor([[ 0.0312,  0.0625,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [-0.0625, -0.0625, -0.0625,  ..., -0.0156, -0.0156, -0.0156],
        [-0.0938,  0.0000,  0.0625,  ...,  0.0625,  0.0625,  0.0625],
        [-0.0312,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0', dtype=torch.bfloat16)


In [36]:
import copy
from tqdm import tqdm


@torch.no_grad()
def batched_eval(fn, input_ids, *args, batch_size=8, use_tqdm=True, use_cache=True, **kwargs):
    all_tensors = [input_ids] + list(args) + list(kwargs.values())
    assert all(isinstance(t, torch.Tensor) for t in all_tensors), f"All arguments must be tensors, got {all_tensors}"
    
    bs1 = input_ids.shape[0]
    assert all(t.shape[0] == bs1 for t in all_tensors), f"All tensors must have the same batch size, got {all_tensors}"
    res = None
    iter = range(0, bs1, batch_size) if not use_tqdm else tqdm(range(0, bs1, batch_size))
    
    # if use_cache:
    #     kv_len = find_kvcacheable_size(input_ids)
    #     kv_cache = DynamicCache()
        
    #     kv_cache = model(input_ids[:1, :kv_len], *[a[:1, :kv_len] for a in args], **{k: v[:1, :kv_len] for k, v in kwargs.items()}, past_key_values=kv_cache, use_cache=True).past_key_values
        
    for i in iter:
        upto = min(i + batch_size, bs1)
        # if use_cache:
        #     kv_cache_copy = copy.deepcopy(kv_cache)
        #     kv_cache_copy.batch_repeat_interleave(upto - i)
        #     partial_res = fn(input_ids[i:upto, :kv_len], *[arg[i:upto, :kv_len] for arg in args], **{k: v[i:upto, :kv_len] for k, v in kwargs.items()}, past_key_values=kv_cache_copy, use_cache=True)
        # else:
        partial_res = fn(input_ids[i:upto], *[arg[i:upto] for arg in args], **{k: v[i:upto] for k, v in kwargs.items()})
        if res is None:
            res = torch.empty((bs1, *partial_res.shape[1:]), device=partial_res.device, dtype=partial_res.dtype)
        res[i:upto] = partial_res
    return res

# torch.vmap(model, chunk_size=2)(**(inputs[:8]))
    
# inputs[:8]
    
with torch.no_grad():
    tokens_of_interest = tokenizer(['A', 'B', 'C'], return_tensors="pt", add_special_tokens=False)['input_ids'].to('cuda')
    print(tokens_of_interest)
    # if tokens_of_interest
    
    if tokens_of_interest.shape != (3, 1):
        raise ValueError(f"Expected tokens_of_interest to have shape (3, 1), got {tokens_of_interest.shape}")
    tokens_of_interest = tokens_of_interest[:, 0]
    
    def inner(input_ids, attention_mask, past_key_values = None, **kwargs):
        # expand by 2
        N_REPEATS = 4 if NOISE_INJECTION_CONFIG is not None else 1
        input_ids = input_ids.repeat(N_REPEATS, 1)
        attention_mask = attention_mask.repeat(N_REPEATS, 1)
        if past_key_values is not None:
            past_key_values = past_key_values.batch_repeat_interleave(N_REPEATS)
        # with torch.amp.autocast('cuda'):
        
        
        
        res = model.forward(input_ids, attention_mask, past_key_values=past_key_values, **kwargs).logits[:, -1, tokens_of_interest]
        # average the created pairs
        if N_REPEATS > 1:
            res = res.view(N_REPEATS, res.shape[0] // N_REPEATS, *res.shape[1:]).mean(dim=0)
        # print(res.shape)
        return res
        
    NOISE_INJECTION_CONFIG = NoiseInjectionConfig(mode="random", noise_mean=0.0, noise_std=5e-1)
    # NOISE_INJECTION_CONFIG = None
    res = batched_eval(inner, **inputs, batch_size=8, use_cache=False)
    NOISE_INJECTION_CONFIG = None

    # print(model.generate(**inputs[:20], max_new_tokens=50))#[:, inputs['input_ids'][0].shape[-1]:])

# tokenizer.batch_decode(model.generate(**preprocess_manual(example_dialogs[:10]).to('cuda'), max_new_tokens=50))



tensor([[32],
        [33],
        [34]], device='cuda:0')


100%|██████████| 125/125 [01:14<00:00,  1.67it/s]


In [37]:
res[8]

tensor([25.7500, 22.7500, 20.3750], device='cuda:0', dtype=torch.bfloat16)

In [38]:
res[8]

tensor([25.7500, 22.7500, 20.3750], device='cuda:0', dtype=torch.bfloat16)

In [39]:
from torch.nn import functional as F
answers = torch.tensor([v for _, v in examples[:len(res)]])
is_correct = answers == res.argmax(dim=-1).to('cpu')
is_correct_given_position = [is_correct[answers == i].float().mean() for i in range(3)]

is_correct_given_position, is_correct.float().mean()



([tensor(0.5685), tensor(0.5813), tensor(0.1637)], tensor(0.4590))

In [112]:

logits_of_answers = res.logits[:, -1, tokens_of_interest]
logits_of_answers



AttributeError: 'Tensor' object has no attribute 'logits'

In [11]:
# res.logits[:, -1:, :].argmax(dim=-1),\
#     tokenizer(['A', 'B', 'C'])

# tokenizer.batch_decode(res.logits[:, -1:, :].argmax(dim=-1))

(tensor([[350],
         [319],
         [350],
         [315],
         [319],
         [350],
         [315],
         [350],
         [350],
         [350]], device='cuda:0'),
 {'input_ids': [[319], [350], [315]], 'attention_mask': [[1], [1], [1]]})