### ReFT is complimentary to existing LoRAs (or other PEFTs)

You can wrap a PEFT model as a ReFT model and take advantages of both worlds, as light-weight LoRA has no inference overhead while might provide performance gains. Note that LoRA is coupled with model weights, yet it can also been seens as an intervention which edits the original representation.

It's very easy to combine PEFTs with ReFT by using existing library such as `peft`:

```py
import pyreft
from peft import LoraConfig, get_peft_model

peft_config = LoraConfig(...)
model = get_peft_model(model, peft_config)

reft_config = pyreft.ReftConfig(...)
reft_model = pyreft.get_reft_model(model, reft_config)
```

In [1]:
try:
    import peft

except ModuleNotFoundError:
    !pip install peft

In [2]:
try:
    import pyreft

except ModuleNotFoundError:
    !pip install pyreft

### Loading our LM

In [3]:
import torch, transformers

from peft import LoraConfig, get_peft_model

model_name_or_path = "meta-llama/Meta-Llama-3-8B-Instruct"
model = transformers.AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.bfloat16, device_map="cuda")

# get tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_name_or_path,
    model_max_length=512,
    use_fast=False)

terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


## Prepare ReCOGS dataset

In [4]:
!pip install wget pandas



In [5]:
!wget https://raw.githubusercontent.com/cgpotts/cs224u/main/compgen.py

--2024-05-31 14:25:46--  https://raw.githubusercontent.com/cgpotts/cs224u/main/compgen.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3669 (3.6K) [text/plain]
Saving to: ‘compgen.py.2’


2024-05-31 14:25:47 (69.3 MB/s) - ‘compgen.py.2’ saved [3669/3669]



In [5]:
import os
import wget

SRC_DIRNAME = os.path.join("data", "recogs")
if not os.path.exists(SRC_DIRNAME):
    os.makedirs('data', exist_ok=True)
    wget.download('https://web.stanford.edu/class/cs224u/data/recogs.tgz', out='data/')
    !tar xvf data/recogs.tgz -C data/

In [6]:
import pandas as pd
def load_split(filename):
    return pd.read_csv(
        filename,
        delimiter="\t",
        names=['input', 'output', 'category'])

In [7]:
dataset = {}

for splitname in ("train", "dev", "gen"):
    dataset[splitname] = load_split(f"{SRC_DIRNAME}/{splitname}.tsv")

# Data prep
from datasets import Dataset
trainset = dataset['train']
trainset = Dataset.from_pandas(trainset) # convert from pandas to pyarrow format
trainset = trainset.remove_columns('category') # remove the category column

In [9]:
dataset['train'].head(2)

Unnamed: 0,input,output,category
0,A rose was helped by a dog .,rose ( 53 ) ; dog ( 38 ) ; help ( 7 ) AND them...,in_distribution
1,The sailor dusted a boy .,* sailor ( 48 ) ; boy ( 53 ) ; dust ( 10 ) AND...,in_distribution


### That's it! Wrap the LM as a `peft` LM, then as a `pyreft` LM

In [8]:
include_peft = True
layers_to_transform = [15]

if include_peft:
    peft_config = LoraConfig(
        r=4,
        lora_alpha=32,
        target_modules=["o_proj"],
        layers_to_transform=layers_to_transform,
        use_rslora=True,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM"
    )
    model = get_peft_model(model, peft_config)

reft_config = pyreft.ReftConfig(representations=[{
    "layer": l, "component": f"base_model.model.model.layers[{l}].output" if include_peft else "block_output",
    "low_rank_dimension": 4,
    "intervention": pyreft.LoreftIntervention(embed_dim=model.config.hidden_size,
    low_rank_dimension=4)} for l in layers_to_transform])

reft_model = pyreft.get_reft_model(model, reft_config)
# you need to call this to re-enable lora grads!
reft_model.model.enable_adapter_layers()
reft_model.print_trainable_parameters()

trainable intervention params: 32,772 || trainable model params: 32,768
model params: 8,030,294,016 || trainable%: 0.0008161594067342304


# Prompts

In [10]:
prompt_no_input_template = """<s>[INST] <<SYS>>
You are a helpful assistant. Translate english sentences to ReCOGS logical form.
<</SYS>>

%s [/INST]
"""

In [11]:
data_module = pyreft.make_last_position_supervised_data_module(
    tokenizer,
    model,
    [prompt_no_input_template % i for i in dataset['train']['input'][:20000]],
    [o for o in dataset['train']['output'][:20000]])

In [12]:
tokenizer.pad_token = tokenizer.eos_token

In [13]:
# train
training_args = transformers.TrainingArguments(
    num_train_epochs=10.0,
    output_dir="./tmp",
    per_device_train_batch_size=10,
    learning_rate=4e-3,
    logging_steps=50,
    report_to=[])
trainer = pyreft.ReftTrainerForCausalLM(
    model=reft_model,
    tokenizer=tokenizer,
    args=training_args,
    **data_module)
_ = trainer.train()

# ensure everything is in eval mode
reft_model.model.eval()
for k,v in reft_model.interventions.items():
    _ = v[0].eval()

Step,Training Loss
50,0.9865
100,0.5642
150,0.5523
200,0.5148
250,0.4977
300,0.4972
350,0.4989
400,0.4894
450,0.4923
500,0.4731


Directory './tmp/checkpoint-500/intervenable_model' already exists.
Directory './tmp/checkpoint-1000/intervenable_model' already exists.
Directory './tmp/checkpoint-1500/intervenable_model' already exists.
Directory './tmp/checkpoint-2000/intervenable_model' already exists.
Directory './tmp/checkpoint-2500/intervenable_model' already exists.
Directory './tmp/checkpoint-3000/intervenable_model' already exists.
Directory './tmp/checkpoint-3500/intervenable_model' already exists.
Directory './tmp/checkpoint-4000/intervenable_model' already exists.
Directory './tmp/checkpoint-4500/intervenable_model' already exists.
Directory './tmp/checkpoint-5000/intervenable_model' already exists.
Directory './tmp/checkpoint-5500/intervenable_model' created successfully.
Directory './tmp/checkpoint-6000/intervenable_model' created successfully.
Directory './tmp/checkpoint-6500/intervenable_model' created successfully.
Directory './tmp/checkpoint-7000/intervenable_model' created successfully.
Directory '

**Note**: `loss` looks a bit different if you compare these with the ones in the original ReFT-only training.

In [14]:
from transformers import logging
logging.set_verbosity_error()

def predict(instruction):
    # Define the target sequence
    target_sequence = "[/INST]\n"

    # tokenize and prepare the input
    prompt = prompt_no_input_template % instruction
    prompt = tokenizer(prompt, return_tensors="pt").to("cuda")

    base_unit_location = prompt["input_ids"].shape[-1] - 1  # last position
    _, reft_response = reft_model.generate(
        prompt, unit_locations={"sources->base": (None, [[[base_unit_location]]])},
        intervene_on_prompt=True, max_new_tokens=512, do_sample=True,
        eos_token_id=tokenizer.eos_token_id, early_stopping=True, num_beams=5)

    input_string = tokenizer.decode(reft_response[0], skip_special_tokens=True)

    # Find the index of the target sequence in the input string
    start_index = input_string.find(target_sequence)

    # If the target sequence is not found, return an empty string
    if start_index == -1:
        return ""

    # Calculate the start index of the substring after the target sequence
    start_index += len(target_sequence)

    # Return the substring from the calculated start index to the end of the string
    return input_string[start_index:]

In [15]:
from compgen import recogs_exact_match

In [22]:
dataset['dev']['output'][0]

'Liam ( 30 ) ; box ( 25 ) ; girl ( 21 ) ; hope ( 33 ) AND agent ( 33 , 30 ) AND ccomp ( 33 , 24 ) AND burn ( 24 ) AND theme ( 24 , 25 ) AND agent ( 24 , 21 )'

In [23]:
query = dataset['dev'].sample(1).iloc[0]
expected = query['output']
input = query['input']

predicted = predict(input)
print(f"Query: {input}")
print(f"Expected: {expected}")
print(f"Predicted: {predicted}")
print(f"Correct: {recogs_exact_match(predicted, expected)}")


Query: The girl said that a lollipop froze .
Expected: * girl ( 22 ) ; lollipop ( 48 ) ; say ( 6 ) AND agent ( 6 , 22 ) AND ccomp ( 6 , 52 ) AND freeze ( 52 ) AND theme ( 52 , 48 )
Predicted: * girl ( 45 ) ; lollipop ( 22 ) ; say ( 46 ) AND agent ( 46, 45 ) AND ccomp ( 46, 27 ) AND freeze ( 27 ) AND theme ( 27, 22 )
Correct: True


In [16]:
ssamp = dataset['dev'].sample(200)

In [17]:
ssamp['prediction'] = ssamp.input.apply(lambda x: predict(x))

In [18]:
ssamp['correct'] = ssamp.apply(lambda row: recogs_exact_match(row['output'], row['prediction']), axis=1)

In [19]:
ssamp['correct'].sum() / ssamp.shape[0]

0.985

In [20]:
model.push_to_hub("ricardo-larosa/recogs-Meta-Llama-3-8B-Instruct", token = "") # save to HF

adapter_model.safetensors:   0%|          | 0.00/65.8k [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/ricardo-larosa/recogs-Meta-Llama-3-8B-Instruct/commit/af36e473f8d9fc9b9834be73aa5496081426404b', commit_message='Upload model', commit_description='', oid='af36e473f8d9fc9b9834be73aa5496081426404b', pr_url=None, pr_revision=None, pr_num=None)