### The connection between in-context learning (ICL) and ReFT

ReFT requires minimum training parameters, and what if ReFT is also quick and easy to adapt to a new task? Here, we explore the limit of ReFT in an ICL learning algorithm which requires training with very limited examples.

In [None]:
import copy, json, random, re
import logging
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence
import pandas as pd
import matplotlib.pyplot as plt
from plotnine import ggplot, aes, geom_line, theme_minimal
from matplotlib.ticker import MaxNLocator
plt.rcParams.update({'font.size': 20, 'font.family': 'Sans'})

import torch
import transformers
from datasets import Dataset
from transformers import Trainer

from pyreft import (
    TaskType,
    get_reft_model,
    ReftConfig,
    ReftTrainerForCausalLM, 
    ReftDataCollator,
    ReftSupervisedDataset,
    make_last_position_supervised_data_module,
    ConsreftIntervention,
    LoreftIntervention
)

IGNORE_INDEX = -100

device = "cuda" if torch.cuda.is_available() else "cpu"

def max_char_match_length(retrieved, golden):
    n_c, n = 0, 0
    for char in retrieved:
        if char == golden[n]:
            n_c += 1
        else:
            break
        n += 1 
    if len(retrieved) == 0:
        return 0.0
    return round(n_c/len(retrieved), 2)

make_supervised_data_module = make_last_position_supervised_data_module

In [None]:
# load model (take 1 min)
model_name_or_path = "yahma/llama-13b-hf" # yahma/llama-7b-hf or yahma/llama-13b-hf
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name_or_path, torch_dtype=torch.bfloat16, device_map=device)

# get tokenizer
model_max_length = 2048
tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_name_or_path, model_max_length=model_max_length, 
    padding_side="right", use_fast=False)
tokenizer.pad_token = tokenizer.unk_token

### N-shots sequence classification task

We compare a rank-1 LoReFT with n-shot ICL. 

#### sst-2

In [None]:
from datasets import load_dataset
dataset = load_dataset("stanfordnlp/sst2")

positive_examples = []
negative_examples = []
for example in dataset["train"]:
    if example["label"] == 0:
        negative_examples += [example["sentence"]]
    elif example["label"] == 1:
        positive_examples += [example["sentence"]]
TARGET_LAYER = 15

In [None]:
results = []
for seed in [42, 43, 44]:
    random.seed(seed)
    K_SHOTS = [4, 10, 20, 30, 40, 50]
    for K in K_SHOTS:
        print("evaluating: ", K)
        # creating training dataset for ReFT
        half_k = K // 2
        pos_demo = random.sample(positive_examples, k=half_k)
        neg_demo = random.sample(negative_examples, k=half_k)

        storage_access_ids = [f"{w}->" for w in pos_demo + neg_demo]
        memo_tokens = ["positive"]*half_k + ["negative"]*half_k 
        # create ICL baseline
        icl_prompt = []
        for w in pos_demo:
            icl_prompt += [f"{w}->positive"]
        for w in neg_demo:
            icl_prompt += [f"{w}->negative"]
        random.shuffle(icl_prompt)
        icl_prompt = ";".join(icl_prompt)
        
        # get reft model
        reft_config = ReftConfig(representations={
            "layer": TARGET_LAYER, "component": "block_output",
            "intervention": LoreftIntervention(
            embed_dim=model.config.hidden_size,
            low_rank_dimension=1)})
        reft_model = get_reft_model(model, reft_config)
        reft_model.print_trainable_parameters()
    
        # get training data to train our intervention to remember the following sequence
        data_module = make_last_position_supervised_data_module(
            tokenizer, model, storage_access_ids, memo_tokens)
        
        # train
        training_args = transformers.TrainingArguments(
            max_steps=200, output_dir="./tmp", learning_rate=2e-3, report_to=[],
            per_device_train_batch_size=8, logging_steps=50, 
            save_strategy="no", evaluation_strategy="no"
        )
        trainer = ReftTrainerForCausalLM(
            model=reft_model, tokenizer=tokenizer,
            args=training_args, **data_module)
        _ = trainer.train()
    
        # evaluate ReFT
        correct_count = 0
        for e in dataset["validation"]:
            w = e["sentence"]
            prompt = tokenizer(f"{w}->", return_tensors="pt").to(device)
            base_unit_location = prompt["input_ids"].shape[-1] - 1
            _, steered_response = reft_model.generate(
                prompt, unit_locations={"sources->base": (None, [[[base_unit_location]]])},
                intervene_on_prompt=True, max_new_tokens=10, do_sample=False, 
                eos_token_id=tokenizer.eos_token_id, early_stopping=True
            )
            retrieved_storage = tokenizer.decode(steered_response[0], skip_special_tokens=True)
            retrieved_storage = retrieved_storage.split("->")[-1]
            if e["label"] == 1 and retrieved_storage == "positive":
                correct_count += 1
            elif e["label"] == 0 and retrieved_storage == "negative":
                correct_count += 1
        reft_acc = round((correct_count)/(len(dataset["validation"])), 2)
    
        # evaluate ICL baseline
        correct_count = 0
        for e in dataset["validation"]:
            w = e["sentence"]
            prompt = tokenizer(f"{icl_prompt};{w}->", return_tensors="pt").to(device)
            steered_response = model.generate(
                **prompt, max_new_tokens=10, do_sample=False, 
                eos_token_id=tokenizer.eos_token_id, early_stopping=True
            )
            retrieved_storage = tokenizer.decode(steered_response[0], skip_special_tokens=True)
            retrieved_storage = retrieved_storage.split("->")[-1].split(";")[0]
            if e["label"] == 1 and retrieved_storage == "positive":
                correct_count += 1
            elif e["label"] == 0 and retrieved_storage == "negative":
                correct_count += 1

        icl_acc = round((correct_count)/(len(dataset["validation"])), 2)
        print((K, reft_acc, icl_acc))
        results += [(K, reft_acc, icl_acc)]
FILE_PATH = './icl_sst2_13B.json'
with open(FILE_PATH, 'w') as output_file:
	json.dump(results, output_file, indent=2)

#### MRPC

In [None]:
from datasets import load_dataset
dataset = load_dataset("glue", "mrpc")

positive_examples = []
negative_examples = []
for example in dataset["train"]:
    if example["label"] == 0:
        negative_examples += [f"Sentence Pair: {example['sentence1']};{example['sentence2']}"]
    elif example["label"] == 1:
        positive_examples += [f"Sentence Pair: {example['sentence1']};{example['sentence2']}"]
TARGET_LAYER = 15

In [None]:
results = []
for seed in [42, 43, 44]:
    random.seed(seed)
    K_SHOTS = [4, 10, 20, 30, 40, 50]
    for K in K_SHOTS:
        print("evaluating: ", K)
        # creating training dataset for ReFT
        half_k = K // 2
        pos_demo = random.sample(positive_examples, k=half_k)
        neg_demo = random.sample(negative_examples, k=half_k)

        storage_access_ids = [f"{w}->" for w in pos_demo + neg_demo]
        memo_tokens = ["yes"]*half_k + ["no"]*half_k 
        # create ICL baseline
        icl_prompt = []
        for w in pos_demo:
            icl_prompt += [f"{w}->yes"]
        for w in neg_demo:
            icl_prompt += [f"{w}->no"]
        random.shuffle(icl_prompt)
        icl_prompt = "\n".join(icl_prompt)
        
        # get reft model
        reft_config = ReftConfig(representations={
            "layer": TARGET_LAYER, "component": "block_output",
            "intervention": LoreftIntervention(
            embed_dim=model.config.hidden_size,
            low_rank_dimension=1)})
        reft_model = get_reft_model(model, reft_config)
        reft_model.print_trainable_parameters()
    
        # get training data to train our intervention to remember the following sequence
        data_module = make_last_position_supervised_data_module(
            tokenizer, model, storage_access_ids, memo_tokens)
        
        # train
        training_args = transformers.TrainingArguments(
            max_steps=200, output_dir="./tmp", learning_rate=2e-3, report_to=[],
            per_device_train_batch_size=8, logging_steps=50, 
            save_strategy="no", evaluation_strategy="no"
        )
        trainer = ReftTrainerForCausalLM(
            model=reft_model, tokenizer=tokenizer,
            args=training_args, **data_module)
        _ = trainer.train()
    
        # evaluate ReFT
        correct_count = 0
        for e in dataset["validation"]:
            w = f"Sentence Pair: {e['sentence1']};{e['sentence2']}"
            prompt = tokenizer(f"{w}->", return_tensors="pt").to(device)
            base_unit_location = prompt["input_ids"].shape[-1] - 1
            _, steered_response = reft_model.generate(
                prompt, unit_locations={"sources->base": (None, [[[base_unit_location]]])},
                intervene_on_prompt=True, max_new_tokens=10, do_sample=False, 
                eos_token_id=tokenizer.eos_token_id, early_stopping=True
            )
            retrieved_storage = tokenizer.decode(steered_response[0], skip_special_tokens=True)
            retrieved_storage = retrieved_storage.split("->")[-1]
            if e["label"] == 1 and retrieved_storage == "yes":
                correct_count += 1
            elif e["label"] == 0 and retrieved_storage == "no":
                correct_count += 1
        reft_acc = round((correct_count)/(len(dataset["validation"])), 2)
    
        # evaluate ICL baseline
        correct_count = 0
        for e in dataset["validation"]:
            w = f"Sentence Pair: {e['sentence1']};{e['sentence2']}"
            prompt = tokenizer(f"{icl_prompt}\n{w}->", return_tensors="pt").to(device)
            steered_response = model.generate(
                **prompt, max_new_tokens=10, do_sample=False, 
                eos_token_id=tokenizer.eos_token_id, early_stopping=True
            )
            retrieved_storage = tokenizer.decode(steered_response[0], skip_special_tokens=True)
            retrieved_storage = retrieved_storage.split("->")[-1].split("\n")[0]
            if e["label"] == 1 and retrieved_storage == "yes":
                correct_count += 1
            elif e["label"] == 0 and retrieved_storage == "no":
                correct_count += 1

        icl_acc = round((correct_count)/(len(dataset["validation"])), 2)
        print((K, reft_acc, icl_acc))
        results += [(K, reft_acc, icl_acc)]
FILE_PATH = './icl_mrpc_7B.json'
with open(FILE_PATH, 'w') as output_file:
	json.dump(results, output_file, indent=2)