### Can we do compositional ReFT?

I have:

- **A ReFT for continuing sentences in German**
- **A ReFT for following instructions**

Can I just combine them and have an "instruction-following model that speaks German"? Let's see!

First of all, you need to know the notations of **subspace**, **linear subspace**, and **orthonormal linear subspaces**! You can read more about these in Atticus's [causal abstraction paper](https://arxiv.org/abs/2301.04709). Briefly, here is what they are:

- **subspace**: you can think of it as a single dimension of an NN's representation in the NN's original basis (learned one).
- **linear subspace**: representation in a changed basis, and the new basis is a linear combination (i.e., any rotation) of the original basis.
- **orthonormal linear subspaces**: if the new linear subspace is produced by an orthonormal projection, then each dimension (or sub-subspace, sorry about the confusion here) in that new basis is orthogonal to each other. Or more strictly speaking, *it maintains the orthogonality if the original basis has it*.

So for ReFT, we can theoretically leverage the notation of subspace, and train different subspaces for different tasks separately, and snap them together at the inference time! Let's see if it will work in practice.

In [1]:
import torch
import transformers
from datasets import load_dataset, concatenate_datasets

from pyreft import (
    TaskType,
    get_reft_model,
    ReftConfig,
    ReftTrainerForCausalLM, 
    ReftDataCollator,
    ReftSupervisedDataset,
    LoreftIntervention
)

prompt_no_input_template = """Below is an instruction that \
describes a task. Write a response that appropriately \
completes the request.

### Instruction:
%s

### Response:
"""

device = "cuda" if torch.cuda.is_available() else "cpu"
transformers.set_seed(43)

nnsight is not detected. Please install via 'pip install nnsight' for nnsight backend.


### Loading the base LM (LLaMA-1 here! not Llama-2)

In [2]:
# load model (take 1 min)
model_name_or_path = "yahma/llama-7b-hf" # yahma/llama-7b-hf or yahma/llama-13b-hf
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name_or_path, torch_dtype=torch.bfloat16, device_map=device)

# get tokenizer
model_max_length = 512
tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_name_or_path, model_max_length=model_max_length, 
    padding_side="right", use_fast=False)
tokenizer.pad_token = tokenizer.unk_token

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


In [3]:
##################################################
# Subspace partitions:

# Let's have a LoReFT of rank 8, and assign
# - the first 4 rank to german sentence completion
# - the next 4 rank to instruction following
##################################################
HELLASWAG_SUBSPACE = [0,1,2,3]
INSTRUCT_SUBSPACE = [4,5,6,7]

def preprocess_hellaswag_de_for_reft(examples):
    label = int(examples["label"])
    if len(examples["endings_de"]) < 4:
        output = examples["endings_de"][-1]
    else:
        output = examples["endings_de"][label]
    examples["instruction"] = examples["ctx"]
    examples["output"] = output
    examples["subspaces"] = HELLASWAG_SUBSPACE
    return examples

def preprocess_ultrafeedback_for_reft(examples):
    examples["subspaces"] = INSTRUCT_SUBSPACE
    examples["output"] += tokenizer.eos_token
    return examples

raw_dataset = load_dataset("LeoLM/HellaSwag_de")
drop_features = list(raw_dataset["train"].features.keys())
raw_dataset = raw_dataset.map(preprocess_hellaswag_de_for_reft)
hellaswag_de_dataset = raw_dataset.remove_columns(drop_features)["train"]

raw_dataset = load_dataset("json", data_files="./ultrafeedback_1k.json")["train"]
raw_dataset = raw_dataset.map(preprocess_ultrafeedback_for_reft)
ultrafeedback_dataset = raw_dataset.remove_columns(["input"])

subspace_dataset = concatenate_datasets([hellaswag_de_dataset, ultrafeedback_dataset])

class SubloreftIntervention(LoreftIntervention):
    """
    This is a LoReFT that supports subspace interventions!
    """
    def forward(
        self, base, source=None, subspaces=None
    ):
        assert subspaces is not None
        output = []
        
        rotated_base = self.rotate_layer(base)
        diff = self.act_fn(self.learned_source(base)) - rotated_base
        
        batched_subspace = []
        batched_weights = []
        
        for example_i in range(len(subspaces)):
            LHS = (diff[example_i, :, subspaces[example_i]])
            RHS = self.rotate_layer.weight[..., subspaces[example_i]].T
            # print(diff.shape, LHS.shape, RHS.shape, base.shape, subspaces)
            batched_subspace += [LHS]
            batched_weights += [RHS]

        batched_subspace = torch.stack(batched_subspace, dim=0)
        batched_weights = torch.stack(batched_weights, dim=0)
        output = base + torch.bmm(batched_subspace, batched_weights)

        return self.dropout(output.to(base.dtype))

### Load rank 8 LoReFT config

In [4]:
TARGET_LAYER = 15

# get reft model
reft_config = ReftConfig(representations={
    "layer": TARGET_LAYER, "component": "block_output",
    "intervention": SubloreftIntervention(
    embed_dim=model.config.hidden_size, low_rank_dimension=8)})
reft_model = get_reft_model(model, reft_config)
reft_model.print_trainable_parameters()

trainable intervention params: 65,544 || trainable model params: 0
model params: 6,738,415,616 || trainable%: 0.0009726915603776257


### Load dataset

Note that in total, we only have **2,000 training examples**, since LoReFT works with low resource settings - a bonus we did not fully explore in the paper.

In [5]:
train_dataset = ReftSupervisedDataset(
    "Subloreft", None, tokenizer, dataset=subspace_dataset,
    **{"num_interventions": 1, "position": "l1", "share_weights": False},
    input_field="input", instruction_field="instruction", output_field="output",
    no_stop=True
)
data_collator_fn = transformers.DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    label_pad_token_id=-100,
    padding="longest"
)
data_collator = ReftDataCollator(data_collator=data_collator_fn)

100%|████████████████████████████████████████████████████████████████████| 2000/2000 [00:03<00:00, 573.65it/s]


### Training!

Note that we are not training a shared subspace for two tasks! We are training them individually by providing the `subspaces` field in the input! Checkout [pyvene](https://github.com/stanfordnlp/pyvene) about how to use `subspaces` field - there are other stuff we haven't tried.

In [6]:
# train
training_args = transformers.TrainingArguments(
    num_train_epochs=3.0, output_dir="./tmp", learning_rate=5e-3, report_to=[],
    per_device_train_batch_size=8, logging_steps=50
)
trainer = ReftTrainerForCausalLM(
    model=reft_model, tokenizer=tokenizer, args=training_args, 
    train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
trainer.train()

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
  batch["labels"] = torch.tensor(batch["labels"], dtype=torch.int64)


Step,Training Loss
50,1.2522
100,1.2855
150,1.2104
200,1.2389
250,1.2354
300,1.2114
350,1.2831
400,1.1936
450,1.1755
500,1.2574


Directory './tmp/checkpoint-500/intervenable_model' already exists.
Directory './tmp/checkpoint-750/intervenable_model' already exists.


TrainOutput(global_step=750, training_loss=1.224351099650065, metrics={'train_runtime': 1022.4403, 'train_samples_per_second': 5.868, 'train_steps_per_second': 0.734, 'total_flos': 0.0, 'train_loss': 1.224351099650065, 'epoch': 3.0})

### Interact with the German sentence completion subspace

In [21]:
instruction = "How to keep a healthy lifestyle?"

prompt = prompt_no_input_template % instruction
prompt = tokenizer(prompt, return_tensors="pt").to(device)

base_unit_location = prompt["input_ids"].shape[-1] - 1  # last position
_, reft_response = reft_model.generate(
    prompt, unit_locations={"sources->base": (None, [[[base_unit_location]]])},
    subspaces=[[HELLASWAG_SUBSPACE]],
    intervene_on_prompt=True, max_new_tokens=128, do_sample=False, 
    no_repeat_ngram_size=5, repetition_penalty=1.1,
    eos_token_id=tokenizer.eos_token_id, early_stopping=True
)
print(tokenizer.decode(reft_response[0], skip_special_tokens=True))

Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
How to keep a healthy lifestyle?

### Response:
, die gesundheitliche Ernährung ist wichtig für alle Menschen. Die Ernährungsempfehlungen der Deutschen Gesellschaft für Ernährung (DGE) sind in einem Buch veröffentlicht worden. Das Buch enthält viele Tipps und Informationen über Ernährung und Gesundheit. Es gibt auch eine Liste von Ernährungsempfängern, die sich mit Ernährung befassen.

## Einzelnachweise

1.  http://www.dge-ev.de/index.php?id=20&L=1


### Interact with the instruction following subspace

In [22]:
_, reft_response = reft_model.generate(
    prompt, unit_locations={"sources->base": (None, [[[base_unit_location]]])},
    subspaces=[[INSTRUCT_SUBSPACE]],
    intervene_on_prompt=True, max_new_tokens=512, do_sample=False, 
    no_repeat_ngram_size=5, repetition_penalty=1.1,
    eos_token_id=tokenizer.eos_token_id, early_stopping=True
)
print(tokenizer.decode(reft_response[0], skip_special_tokens=True))

Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
How to keep a healthy lifestyle?

### Response:
To maintain a healthy lifesytle, it's important to eat nutritious foods and get enough exercise. It's also essential to drink plenty of water and get enough sleep. These are some ways to stay healthy.


### Interact with both subspaces, partially!

To interact with both of them, you can simply change the `subspaces` field at the inference time to any combinations you want!

In [23]:
base_unit_location = prompt["input_ids"].shape[-1] - 1  # last position
_, reft_response = reft_model.generate(
    prompt, unit_locations={"sources->base": (None, [[[base_unit_location]]])},
    # sometimes, leaving subspaces [4,5] out will lead to a better performance
    subspaces=[[[0,1,2,3,4,5,6,7]]], 
    intervene_on_prompt=True, max_new_tokens=512, do_sample=False, 
    no_repeat_ngram_size=5, repetition_penalty=1.1,
    eos_token_id=tokenizer.eos_token_id, early_stopping=True
)
print(tokenizer.decode(reft_response[0], skip_special_tokens=True))

Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
How to keep a healthy lifestyle?

### Response:
Jugendliche können eine gesunde Lebensweise durch verschiedene Aktivitäten erreichen, wie zum Beispiel Sport und Spiele. Sie können auch ein gesundes Essen und einen guten Schlaf haben. Außerdem kann man sich um die Gesundheit der anderen Menschen kümmern, indem man ihnen hilft oder sie unterstützt.

Es ist wichtig, dass Jugendliche eine gesunde Ernährung aufnehmen, um ihre Körper zu stärken und gesund zu bleiben. Es gibt viele Möglichkeiten, wie man eine gesunde Ernähhung aufnimmt, wie zum Beispiel frische Gemüse, Obst und Fleisch. Auch kann man sich um den Verzehr von Süßigkeiten und Zuckerhaltigen Produkten kümmern.

Ein weiterer Aspekt für eine gesunde Ernahrung ist, dass man genügend Flüssigkeit aufnimmt, um seinen Körper zu hydrieren. Wasser ist das beste Fluid, das man aufnehmen kann, da es keine Fettanteile e

This is an early sneak-peek of our **Feature Compartmentalization**, and **Schematic ReFT**! Stay tuned, and explore with us! We think there are basically infinite number of causal pathways in the neural network waiting for us to explore!