### Inference Speed Overhead Analysis of ReFT with pyvene

When deployed in practice, ReFTs need to avoid overhead at inference time since any ReFT relies on interventions hooked into model's computation graph. Lucikily, through-out our experiments, we found that intervening on user prompt tokens are already enough to produce good performance on most of the tasks. As a result, the intervened representations or key-value pairs can be cached, and used for future decoding.

In this tutorial, we try to compare the inferecen speed overhead of LoReFT with a model forward run without any intervention (i.e., the ceiling runtime). In theory, ReFTs runtime should be:

- **worse than LoRA**, since LoRA can merge its learned weights into the model weights, resulting in zero-overhead at inference-time.
- **better than Adaptor**, since Adaptor applies additional computes to every steps in the sequence.

In [None]:
import torch, time, json
import transformers
from datasets import load_dataset, concatenate_datasets

from pyreft import (
    TaskType,
    get_reft_model,
    ReftConfig,
    ReftTrainerForCausalLM, 
    ReftDataCollator,
    ReftSupervisedDataset,
    LoreftIntervention
)

prompt_no_input_template = """Below is an instruction that \
describes a task. Write a response that appropriately \
completes the request.

### Instruction:
%s

### Response:
"""

device = "cuda" if torch.cuda.is_available() else "cpu"

raw_dataset = load_dataset("json", data_files="../composition/ultrafeedback_1k.json")["train"].select(range(100))

In [None]:
# load model (take 1 min)
model_name_or_path = "yahma/llama-7b-hf" # yahma/llama-7b-hf or yahma/llama-13b-hf
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name_or_path, torch_dtype=torch.bfloat16, device_map=device)

# get tokenizer
model_max_length = 1024
tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_name_or_path, model_max_length=model_max_length, 
    padding_side="right", use_fast=False, truncate=True)
tokenizer.pad_token = tokenizer.unk_token

### Rank analysis

In [None]:
RANKS = [1,4,8,16,32]
TARGET_LAYER = 15

elapse_per_rank = {}
for RANK in RANKS:
    print("analyzing:", RANK)
    # get reft model
    reft_config = ReftConfig(representations={
        "layer": TARGET_LAYER, "component": "block_output",
        "intervention": LoreftIntervention(
        embed_dim=model.config.hidden_size, low_rank_dimension=RANK)})
    reft_model = get_reft_model(model, reft_config)
    reft_model.print_trainable_parameters()

    all_elapse = []
    for example in raw_dataset:
        instruction = example["instruction"]
        
        prompt = prompt_no_input_template % instruction
        prompt = tokenizer(prompt, max_length=model_max_length, return_tensors="pt").to(device)
        
        base_unit_location = prompt["input_ids"].shape[-1] - 1  # last position

        start = time.time()
        # loreft generate
        _, reft_response = reft_model.generate(
            prompt, unit_locations={"sources->base": (None, [[[base_unit_location]]])},
            intervene_on_prompt=True, max_new_tokens=256, do_sample=False, 
        )
        end = time.time()
        elapse = end - start
        prompt_len = prompt['input_ids'].shape[-1]

        all_len = len(reft_response[0])
        all_elapse += [(elapse, prompt_len, all_len, "loreft")]

        start = time.time()
        # vanilla generate
        model_response = model.generate(
            **prompt, max_new_tokens=256, do_sample=False, 
        )
        end = time.time()
        elapse = end - start
        all_len = len(model_response[0])
        all_elapse += [(elapse, prompt_len, all_len, "vanilla")]
        
    elapse_per_rank[RANK] = all_elapse
with open("../plots/data/elapse_per_rank.json", 'w') as f:
    json.dump(elapse_per_rank, f)

### Layer analysis

In [None]:
LAYERS = [[12,13],[12,13,14,15],[12,13,14,15,16,17],[12,13,14,15,16,17,18,19],[12,13,14,15,16,17,18,19,20,21]]
RANK = 8

elapse_per_layer = {}
for LAYER in LAYERS:
    print("analyzing:", LAYER)
    # get reft model
    reft_config = ReftConfig(representations=[{
        "layer": l, "component": "block_output",
        "intervention": LoreftIntervention(
        embed_dim=model.config.hidden_size, low_rank_dimension=RANK)} for l in LAYER])
    reft_model = get_reft_model(model, reft_config)
    reft_model.print_trainable_parameters()

    all_elapse = []
    for example in raw_dataset:
        instruction = example["instruction"]
        
        prompt = prompt_no_input_template % instruction
        prompt = tokenizer(prompt, max_length=model_max_length, return_tensors="pt").to(device)
        
        base_unit_location = prompt["input_ids"].shape[-1] - 1  # last position

        start = time.time()
        # loreft generate
        _, reft_response = reft_model.generate(
            prompt, unit_locations={"sources->base": (None, [[[base_unit_location]]]*len(LAYER))},
            intervene_on_prompt=True, max_new_tokens=256, do_sample=False, 
        )
        end = time.time()
        elapse = end - start
        prompt_len = prompt['input_ids'].shape[-1]
        all_len = len(reft_response[0])
        all_elapse += [(elapse, prompt_len, all_len, "loreft")]

        start = time.time()
        # vanilla generate
        model_response = model.generate(
            **prompt, max_new_tokens=256, do_sample=False, 
        )
        end = time.time()
        elapse = end - start
        all_len = len(model_response[0])
        all_elapse += [(elapse, prompt_len, all_len, "vanilla")]
        
    elapse_per_layer[";".join([str(l) for l in LAYER])] = all_elapse
with open("../plots/data/elapse_per_layer.json", 'w') as f:
    json.dump(elapse_per_layer, f)

### Position analysis

In [None]:
positions = [2, 4, 6, 8, 10]
RANK = 8
LAYER = 15

elapse_per_position = {}
for position in positions:
    print("analyzing:", position)
    # get reft model
    reft_config = ReftConfig(representations=[{
        "layer": LAYER, "component": "block_output",
        "intervention": LoreftIntervention(
        embed_dim=model.config.hidden_size, low_rank_dimension=RANK)}])
    reft_model = get_reft_model(model, reft_config)
    reft_model.print_trainable_parameters()

    all_elapse = []
    for example in raw_dataset:
        instruction = example["instruction"]
        
        prompt = prompt_no_input_template % instruction
        prompt = tokenizer(prompt, max_length=model_max_length, return_tensors="pt").to(device)
        
        base_unit_location = prompt["input_ids"].shape[-1] - 1  # last position

        start = time.time()
        # loreft generate
        _, reft_response = reft_model.generate(
            prompt, unit_locations={"sources->base": (None, [[[base_unit_location-i for i in range(position)]]])},
            intervene_on_prompt=True, max_new_tokens=256, do_sample=False, 
        )
        end = time.time()
        elapse = end - start
        prompt_len = prompt['input_ids'].shape[-1]
        all_len = len(reft_response[0])
        all_elapse += [(elapse, prompt_len, all_len, "loreft")]
        
        start = time.time()
        # vanilla generate
        model_response = model.generate(
            **prompt, max_new_tokens=256, do_sample=False, 
        )
        end = time.time()
        elapse = end - start
        all_len = len(model_response[0])
        all_elapse += [(elapse, prompt_len, all_len, "vanilla")]
        
    elapse_per_position[position] = all_elapse
with open("../plots/data/elapse_per_position.json", 'w') as f:
    json.dump(elapse_per_position, f)