In [1]:
from typing import List, Optional
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"
os.environ["HF_HOME"]="/export/a02/huggingface"


In [2]:
import re
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, BitsAndBytesConfig
import os
import torch
from tqdm.notebook import tqdm
import numpy as np


In [3]:
def serialize(tokenizer, chunks: List[List[str]], weights: List[Optional[List[float]]], log_marginals_dtype=torch.float):
    if tokenizer.add_bos_token == True:
        raise ValueError()
    tokens = tokenizer(sum(chunks, []), return_tensors="pt", padding=True).to("cuda")
    ser_input_ids = tokens["input_ids"]
    ser_attention_mask = tokens["attention_mask"]
    length = ser_attention_mask.sum()
    input_ids = ser_input_ids.new_zeros(length + 1)
    min = torch.finfo(log_marginals_dtype).min
    log_marginals = ser_input_ids.new_zeros((length + 1, length + 1), dtype=log_marginals_dtype).fill_(min).triu(diagonal=1)
    attention_mask = torch.tril(ser_input_ids.new_ones((length + 1, length + 1)))
    serial_index = 0
    chunk_offset = 1
    alt_offset = 1
    input_ids[0] = tokenizer.bos_token_id
    for alts, ws in zip(chunks, weights):
        ws = ws if ws is not None else [0.0]
        if (len(alts) != len(ws)):
            raise ValueError(f"chunks length and weight length must match. chunk"
                            f" {serial_index} erred.\n{chunks}\n{weights}")
        for alt, w in zip(alts, log_marginals.new_tensor(ws).log_softmax(-1)):
            l = ser_attention_mask[serial_index].sum()
            input_ids[alt_offset:alt_offset+l] = ser_input_ids[serial_index][:l]
            attention_mask[alt_offset:alt_offset+l, chunk_offset:alt_offset] = 0
            log_marginals[alt_offset:alt_offset+l, chunk_offset:alt_offset] = min
            log_marginals[alt_offset+l:, alt_offset:alt_offset+l] = w
            alt_offset += l
            serial_index += 1
        chunk_offset = alt_offset
    return input_ids, attention_mask, log_marginals



In [4]:
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", add_bos_token=False, pad_token="<unk>")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [16]:
system_prompt = "You are a helpful assistant that adheres to the format required by the user."
user_prompt = "Please answer the following question about an imagined uncertain situation:"

In [22]:
N = 10

options = ["Alice is 15 years old or", "Alice is 75 years old.", ]
weights = [0.0, 0.0]
question = "How old is Alice? Please answer in a single number."
answer_prefix = "Given the provided information, one possibility of Alice's age is"
input_ids, attention_mask, log_marginals = serialize(tokenizer, [[f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n{user_prompt}\n\n"], options,[f"\n\n{question}[/INST] {answer_prefix}"]], [None, weights, None], log_marginals_dtype=torch.float16)
print(input_ids.size())
results = []
for i in tqdm(range(N)):
    print(f"****** Run {i} ******")
    generation_config = GenerationConfig(
        max_length=128,
        max_new_tokens=128,
        temperature=1.0,
        do_sample=True,
        num_return_sequences=1,
        use_cache=True
    )
    input_ids.size(-1)
    outputs = model.generate(input_ids=input_ids[None, ...], generation_config=generation_config, attention_mode="mult", log_marginals=log_marginals[None, ...])
    generated = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
    print(generated)
    new = tokenizer.batch_decode(outputs[:, input_ids.size(-1):], skip_special_tokens=True)[0]
    try:
        results.append(int(re.search(r"[^\d]*(\d+)[^\d]*", new).group(1)))
    except:
        pass
    print(f"****** ******")
    print()
print(sorted(results))
print(np.mean(results))
print(np.median(results))
print(np.std(results))

torch.Size([95])


  0%|          | 0/10 [00:00<?, ?it/s]

****** Run 0 ******
[INST] <<SYS>>
You are a helpful assistant that adheres to the format required by the user.
<</SYS>>
Please answer the following question about an imagined uncertain situation:

 Alice is 15 years old or Alice is 75 years old. 

How old is Alice? Please answer in a single number.[/INST] Given the provided information, one possibility of Alice's age is 15, while another possibility is 75. As I cannot determine the correct age with absolute certainty, I will provide a range of possible ages for Alice:

Age of Alice: 15-75 years old.
****** ******

****** Run 1 ******
[INST] <<SYS>>
You are a helpful assistant that adheres to the format required by the user.
<</SYS>>
Please answer the following question about an imagined uncertain situation:

 Alice is 15 years old or Alice is 75 years old. 

How old is Alice? Please answer in a single number.[/INST] Given the provided information, one possibility of Alice's age is 75 years old.
****** ******

****** Run 2 ******
[INST