### ReFT is complimentary to existing LoRAs (or other PEFTs)

You can wrap a PEFT model as a ReFT model and take advantages of both worlds, as light-weight LoRA has no inference overhead while might provide performance gains. Note that LoRA is coupled with model weights, yet it can also been seens as an intervention which edits the original representation.

It's very easy to combine PEFTs with ReFT by using existing library such as `peft`:

```py
import pyreft
from peft import LoraConfig, get_peft_model

peft_config = LoraConfig(...)
model = get_peft_model(model, peft_config)

reft_config = pyreft.ReftConfig(...)
reft_model = pyreft.get_reft_model(model, reft_config)
```

In [1]:
try:
    # This library is our indicator that the required installs
    # need to be done.
    import peft

except ModuleNotFoundError:
    !pip install peft



### Loading our LM

In [2]:
import torch, transformers

import pyreft
from peft import LoraConfig, get_peft_model

model_name_or_path = "unsloth/llama-3-8b-bnb-4bit"
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name_or_path, torch_dtype=torch.bfloat16, device_map="cuda")

# get tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path, model_max_length=2048, adding_side="right", use_fast=False)
tokenizer.pad_token = tokenizer.unk_token

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


## Prepare ReCOGS dataset

In [None]:
!pip install wget pandas

In [None]:

import os
import wget

SRC_DIRNAME = os.path.join("data", "recogs")
if not os.path.exists(SRC_DIRNAME):
    os.makedirs('data', exist_ok=True)
    wget.download('https://web.stanford.edu/class/cs224u/data/recogs.tgz', out='data/')
    !tar xvf data/recogs.tgz -C data/

In [None]:

import os
import wget

SRC_DIRNAME = os.path.join("data", "recogs")
if not os.path.exists(SRC_DIRNAME):
    os.makedirs('data', exist_ok=True)
    wget.download('https://web.stanford.edu/class/cs224u/data/recogs.tgz', out='data/')
    !tar xvf data/recogs.tgz -C data/

In [None]:
import pandas as pd
def load_split(filename):
    return pd.read_csv(
        filename,
        delimiter="\t",
        names=['input', 'output', 'category'])

In [None]:
dataset = {}

for splitname in ("train", "dev", "gen"):
    dataset[splitname] = load_split(f"{SRC_DIRNAME}/{splitname}.tsv")

# Data prep
from datasets import Dataset
trainset = dataset['train']
trainset = Dataset.from_pandas(trainset) # convert from pandas to pyarrow format
trainset = trainset.remove_columns('category') # remove the category column

In [None]:
dataset['train'].head(2)

### That's it! Wrap the LM as a `peft` LM, then as a `pyreft` LM

In [3]:
include_peft = True
layers_to_transform = [15]

if include_peft:
    peft_config = LoraConfig(
        r=4,
        lora_alpha=32,
        target_modules=["o_proj"],
        layers_to_transform=layers_to_transform,
        use_rslora=True,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM"
    )
    model = get_peft_model(model, peft_config)

reft_config = pyreft.ReftConfig(representations=[{
    "layer": l, "component": f"base_model.model.model.layers[{l}].output" if include_peft else "block_output",
    "low_rank_dimension": 4,
    "intervention": pyreft.LoreftIntervention(embed_dim=model.config.hidden_size,
    low_rank_dimension=4)} for l in layers_to_transform])

reft_model = pyreft.get_reft_model(model, reft_config)
# you need to call this to re-enable lora grads!
reft_model.model.enable_adapter_layers()
reft_model.print_trainable_parameters()

trainable intervention params: 32,772 || trainable model params: 32,768
model params: 6,738,448,384 || trainable%: 0.0009726274694871952


# Prompts

In [4]:
prompt_no_input_template = """<s>[INST] <<SYS>>
You are a helpful assistant.
<</SYS>>

%s [/INST]
"""

data_module = pyreft.make_last_position_supervised_data_module(
    tokenizer, model, [prompt_no_input_template % e[0] for e in dataset['train']], 
    [e[1] for e in dataset['train']])

In [5]:
# train
training_args = transformers.TrainingArguments(
    num_train_epochs=100.0, output_dir="./tmp", 
    per_device_train_batch_size=10, 
    learning_rate=4e-3, 
    logging_steps=20, 
    report_to=[])
trainer = pyreft.ReftTrainerForCausalLM(
    model=reft_model, 
    tokenizer=tokenizer, 
    args=training_args, 
    **data_module)
_ = trainer.train()

# ensure everything is in eval mode
reft_model.model.eval()
for k,v in reft_model.interventions.items():
    _ = v[0].eval()

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Step,Training Loss
20,0.6923
40,0.0663
60,0.037
80,0.0163
100,0.0157


**Note**: `loss` looks a bit different if you compare these with the ones in the original ReFT-only training.

In [15]:
instruction = "Which dog breed do people think is cuter, poodle or doodle?"

# tokenize and prepare the input
prompt = prompt_no_input_template % instruction
prompt = tokenizer(prompt, return_tensors="pt").to("cuda")

base_unit_location = prompt["input_ids"].shape[-1] - 1  # last position
_, reft_response = reft_model.generate(
    prompt, unit_locations={"sources->base": (None, [[[base_unit_location]]])},
    intervene_on_prompt=True, max_new_tokens=512, do_sample=True, 
    eos_token_id=tokenizer.eos_token_id, early_stopping=True
)
print(tokenizer.decode(reft_response[0], skip_special_tokens=True))

[INST] <<SYS>>
You are a helpful assistant.
<</SYS>>

Which dog breed do people think is cuter, poodle or doodle? [/INST]
🐶💬👀🌟
