In [66]:
import json,re 

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from repeng import ControlVector, ControlModel, DatasetEntry

alpaca_prompt_no_input_template = """Below is an instruction that \
describes a task. Write a response that appropriately \
completes the request.

### Instruction:
%s

### Response:
"""

user_tag, asst_tag = "[INST]", "[/INST]"

In [3]:
model_name = "yahma/llama-7b-hf"

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token_id = 0

model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
model = model.to("cuda:0" if torch.cuda.is_available() else "mps:0" if torch.backends.mps.is_available() else "cpu")
model = ControlModel(model, list(range(-5, -18, -1)))

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
with open("./data/all_truncated_outputs.json") as f:
    output_suffixes = json.load(f)
truncated_output_suffixes = [
    tokenizer.convert_tokens_to_string(tokens[:i])
    for tokens in (tokenizer.tokenize(s) for s in output_suffixes)
    for i in range(1, len(tokens))
]
truncated_output_suffixes_512 = [
    tokenizer.convert_tokens_to_string(tokens[:i])
    for tokens in (tokenizer.tokenize(s) for s in output_suffixes[:512])
    for i in range(1, len(tokens))
]

with open("./data/true_facts.json") as f:
    fact_suffixes = json.load(f)
truncated_fact_suffixes = [
    tokenizer.convert_tokens_to_string(tokens[:i])
    for tokens in (tokenizer.tokenize(s) for s in fact_suffixes)
    for i in range(1, len(tokens) - 5)
]

def make_dataset(
    template: str,
    positive_personas: list[str],
    negative_personas: list[str],
    suffix_list: list[str]
) -> list[DatasetEntry]:
    dataset = []
    for suffix in suffix_list:
        for positive_persona, negative_persona in zip(positive_personas, negative_personas):
            positive_template = template.format(persona=positive_persona)
            negative_template = template.format(persona=negative_persona)
            dataset.append(
                DatasetEntry(
                    positive=f"{user_tag} {positive_template} {asst_tag} {suffix}",
                    negative=f"{user_tag} {negative_template} {asst_tag} {suffix}",
                )
            )
    return dataset

In [45]:
from datasets import load_dataset

ds = load_dataset("openai/gsm8k", "main")

Downloading readme:   0%|          | 0.00/7.94k [00:00<?, ?B/s]

Downloading data: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2.31M/2.31M [00:00<00:00, 14.0MB/s]
Downloading data: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 419k/419k [00:00<00:00, 3.02MB/s]


Generating train split:   0%|          | 0/7473 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1319 [00:00<?, ? examples/s]

In [48]:
dataset = make_dataset(
    "Solve the math problem step-by-step. Here is one example response: {persona}",
    [ds["train"][i]["answer"].split("####")[0].strip() for i in range(5)],
    ["Sorry, I am really bad at math." for i in range(5)],
    truncated_output_suffixes,
)

In [49]:
# train the vector—takes less than a minute!
math_vector = ControlVector.train(model, tokenizer, dataset)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 397/397 [04:36<00:00,  1.43it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 31/31 [00:34<00:00,  1.13s/it]


In [50]:
ds["test"][0]

{'question': "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?",
 'answer': 'Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day.\nShe makes 9 * 2 = $<<9*2=18>>18 every day at the farmer’s market.\n#### 18'}

In [62]:
def generate_with_vector(
    input: str,
    vector: ControlVector,
    coeffs: tuple[float, float],
    max_new_tokens: int = 128,
    repetition_penalty: float = 1.1,
    show_baseline: bool = True,
):
    positive_coeff, negative_coeff = coeffs
    assert positive_coeff > 0
    assert negative_coeff < 0

    if user_tag not in input:
        input = f"{user_tag} {input.strip()} {asst_tag}"
    input_ids = tokenizer(input, return_tensors="pt").to(model.device)
    settings = {
        "pad_token_id": tokenizer.eos_token_id, # silence warning
        "do_sample": False, # temperature=0
        "max_new_tokens": max_new_tokens,
        "repetition_penalty": repetition_penalty,
    }
    
    print("\n++control ---------------------------------------------------")
    model.set_control(vector, positive_coeff)

    model.reset()

    return tokenizer.decode(model.generate(**input_ids, **settings).squeeze()).strip()

In [63]:
ds["test"][10]["answer"]

'The number of downloads of the program in the second month increased to 3*60 = <<3*60=180>>180\nIn the first two months, the total number of downloads of the program was 180+60 = <<180+60=240>>240\nIn the third month, the number of downloads of the program reduced by 30/100*180 = <<30/100*180=54>>54\nThere were 180-54 = <<180-54=126>>126 downloads in the third month.\nIn the three months, the total number of downloads of the program was 126+240 = <<126+240=366>>366\n#### 366'

In [71]:
def extract_answer_number(sentence: str) -> float:
    """
    To ensure a fair comparison, we follow:
    https://github.com/AGI-Edgerunners/LLM-Adapters/blob/main/evaluate.py
    """
    sentence = sentence.replace(',', '')
    pred = [s for s in re.findall(r'-?\d+\.?\d*', sentence)]
    if not pred:
        return float('inf')
    pred_answer = float(pred[-1])
    if isinstance(pred_answer, str):
        try:
            pred_answer = float(pred_answer)
        except ValueError as e:
            pred_answer = float('inf')
    return pred_answer

c = 0
tc = 0
for example in ds["test"]:
    tc += 1
    if tc == 100:
        break
    answer = generate_with_vector(
        example["question"],
        math_vector,
        (2.2, -2.2),
        max_new_tokens=256,
        repetition_penalty=1.3,
    )
    pred = extract_answer_number(answer)
    actual = extract_answer_number(example["answer"])
    print(pred, actual)
    if pred == actual:
        c += 1


++control ---------------------------------------------------
2.0 18.0

++control ---------------------------------------------------
2.0 3.0

++control ---------------------------------------------------
6.0 70000.0

++control ---------------------------------------------------
7.0 540.0

++control ---------------------------------------------------
69.0 20.0

++control ---------------------------------------------------
16.0 64.0

++control ---------------------------------------------------
6.0 260.0

++control ---------------------------------------------------
1.0 160.0

++control ---------------------------------------------------
4.0 45.0

++control ---------------------------------------------------
368.79 460.0

++control ---------------------------------------------------
100.0 366.0

++control ---------------------------------------------------
55.0 694.0

++control ---------------------------------------------------
-8.0 13.0

++control ------------------------------------

In [72]:
c

3