In [27]:
import copy, json, random, re
import logging
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence
import pandas as pd
import matplotlib.pyplot as plt
from plotnine import ggplot, aes, geom_line, theme_minimal
from matplotlib.ticker import MaxNLocator
plt.rcParams.update({'font.size': 20, 'font.family': 'Sans'})

import torch
import transformers
from datasets import Dataset
from transformers import Trainer

from pyreft import (
    TaskType,
    get_reft_model,
    ReftConfig,
    ReftTrainerForCausalLM, 
    ReftDataCollator,
    ReftSupervisedDataset,
    make_last_position_supervised_data_module,
    make_multiple_position_supervised_data_module,
    ConsreftIntervention,
    LoreftIntervention,
    get_intervention_locations
)

IGNORE_INDEX = -100

device = "cuda" if torch.cuda.is_available() else "cpu"

def max_char_match_length(retrieved, golden):
    n_c, n = 0, 0
    for char in retrieved:
        if char == golden[n]:
            n_c += 1
        else:
            break
        n += 1 
    if len(retrieved) == 0:
        return 0.0
    return round(n_c/len(retrieved), 2)

make_supervised_data_module = make_last_position_supervised_data_module

prompt_no_input_template = """[INST] %s [/INST]"""

In [2]:
model_name_or_path = "yahma/llama-7b-hf"
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name_or_path, torch_dtype=torch.bfloat16, device_map=device,
    
)

# get tokenizer
model_max_length = 2048
tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_name_or_path, model_max_length=model_max_length, 
    padding_side="right", use_fast=False)
tokenizer.pad_token = tokenizer.unk_token

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


In [3]:
from datasets import load_dataset

ds = load_dataset("openai/gsm8k", "main")

In [69]:
training_examples = [[e["question"], e["answer"]] for e in ds["train"]][:20]

In [61]:
TARGET_LAYERS = [15]
positions = "f1+l1"

# get reft model
reft_config = ReftConfig(representations=[{
    "layer": L,
    "component": f"block_output",
    "intervention": LoreftIntervention(
    embed_dim=model.config.hidden_size,
    low_rank_dimension=4)} for L in TARGET_LAYERS])
reft_model = get_reft_model(model, reft_config)
reft_model.print_trainable_parameters()

trainable intervention params: 32,772 || trainable model params: 0
model params: 6,738,415,616 || trainable%: 0.00048634578018881287


In [70]:
data_module = make_multiple_position_supervised_data_module(
    tokenizer, model, 
    [prompt_no_input_template % e[0] for e in training_examples],
    [e[1] for e in training_examples], 
    positions=positions, num_interventions=len(reft_config.representations), share_weights=True, nonstop=False)

# train
training_args = transformers.TrainingArguments(
    num_train_epochs=50.0, output_dir="./tmp", learning_rate=4e-3, report_to=[], logging_steps=20, per_device_train_batch_size=5)
trainer = ReftTrainerForCausalLM(
    model=reft_model, tokenizer=tokenizer,
    args=training_args, **data_module)
_ = trainer.train()

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Step,Training Loss
20,0.4342
40,0.1716
60,0.0794
80,0.0341
100,0.0155
120,0.0073
140,0.0045
160,0.0035
180,0.0031
200,0.0029


In [72]:
instruction = "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?"

# tokenize and prepare the input
prompt = prompt_no_input_template % instruction
prompt = tokenizer(prompt, return_tensors="pt").to(device)

unit_locations = torch.IntTensor([get_intervention_locations(
    last_position=prompt["input_ids"].shape[-1], 
    first_n=1, 
    last_n=1,
    pad_mode="last",
    num_interventions=len(reft_config.representations),
    share_weights=True
)]).permute(1, 0, 2).tolist()

_, reft_response = reft_model.generate(
    prompt, unit_locations={"sources->base": (None, unit_locations)},
    intervene_on_prompt=True, max_new_tokens=512, do_sample=True, 
    eos_token_id=tokenizer.eos_token_id, early_stopping=True, repetition_penalty=1.3
)
print(tokenizer.decode(reft_response[0], skip_special_tokens=True))

[INST] Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market? [/INST]According to Johnny explanation, each egg can sell for 2$. So, Jenny makes 16*2 = ${<<equel; 32>>}$ a day.
#### 32


In [73]:
def extract_answer_number(sentence: str) -> float:
    """
    To ensure a fair comparison, we follow:
    https://github.com/AGI-Edgerunners/LLM-Adapters/blob/main/evaluate.py
    """
    sentence = sentence.replace(',', '')
    pred = [s for s in re.findall(r'-?\d+\.?\d*', sentence)]
    if not pred:
        return float('inf')
    pred_answer = float(pred[-1])
    if isinstance(pred_answer, str):
        try:
            pred_answer = float(pred_answer)
        except ValueError as e:
            pred_answer = float('inf')
    return pred_answer

c = 0
tc = 0
for example in ds["test"]:
    tc += 1
    if tc == 100:
        break
    # tokenize and prepare the input
    prompt = prompt_no_input_template % example["question"]
    prompt = tokenizer(prompt, return_tensors="pt").to(device)
    
    unit_locations = torch.IntTensor([get_intervention_locations(
        last_position=prompt["input_ids"].shape[-1], 
        first_n=1, 
        last_n=1,
        pad_mode="last",
        num_interventions=len(reft_config.representations),
        share_weights=True
    )]).permute(1, 0, 2).tolist()
    
    _, reft_response = reft_model.generate(
        prompt, unit_locations={"sources->base": (None, unit_locations)},
        intervene_on_prompt=True, max_new_tokens=256, do_sample=False, 
        eos_token_id=tokenizer.eos_token_id, early_stopping=True, repetition_penalty=1.3)
    
    answer = tokenizer.decode(reft_response[0], skip_special_tokens=True)
    pred = extract_answer_number(answer)
    actual = extract_answer_number(example["answer"])
    print(pred, actual, answer)
    if pred == actual:
        c += 1

100.0 18.0 [INST] Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market? [/INST]She makes 3*3=9 for each days that she sell a "per" item of her product.
So she makes 9 x 16 = $<<<120>>> every day that she goes to the farm markets.
#### 1) '16' means sixteen
#### 2) Each day that she goes to the farm markets, she makes 9*16 = <<120>>
#### 3) There are 365 days and some years have more than 40 weekends and some less
#### 4) This is not an hourly job
#### 5) She gets up every day at 10_2=10_2=10::=10::=10 lines=10 seconds=100 seconds=ten minutes=1000 seconds=100 hours=1000 hours=10 times ten=100 sets=1000 set=10 periods=1000 periods=10 titles=10 notes=100 notes=1000 notes=10 titles=10 notes=100 notes=100
9807.0 3.0 [INST] A robe takes 2 bolts of blue fiber and half t

KeyboardInterrupt: 

In [58]:
c

0