## Domain Transfer Task

One key priority of composition is the addition and the subtraction of the learned modules can work together. Here we would like to see it in the same setting as the [LoRA operators paper](https://arxiv.org/pdf/2306.14870). 

In this task, we will use t5-base or t5-small model to see if ReFT interventions can transfer across domains. We would learn with 2 datasets (amazon-polarity and yelp-polarity, which both contain reviews). Each dataset learns 2 models: sentiment classification, and language modeling (which concatenates and blocks review texts into chunks of size 128, and use the prior block to predict the next block).

We then would compose these blocks together, using the addition and the subtraction operator we defined before. We can compose by creating a new model with

- amazon_classifier + lambda * (yelp_lm - amazon_lm)
- yelp_classifier + lambda * (amazon_lm - yelp_lm)

We will see if the first composed model has a better performance on the yelp dataset than just amazon_classifier, and the second composed model has a better performance on the amazon dataset than just yelp_classifier. This would show that the addition of the language modeling diffs are parallel to the addition of the classification diffs, strengthening our argument that lm interventions (especially ReFT) are composable.

Note that if you learn orthogonal interventions with 4 tasks at 4 subspaces, that would not work. yelp_lm <-> amazon_lm are tasks that are very close to each other. Forcing them to be orthogonal to each other breaks the learning flow and will make the second learnable module close to 0. So we choose to use NoDiReFT in the notebook. However, we leave the LoReFT sessions there for your reference.


### Setup

In [1]:
import logging
import os
import sys
from dataclasses import dataclass, field
from itertools import chain
from typing import Optional

import datasets
import nltk  # Here to have a nice missing dependency error message early on
import numpy as np
from datasets import load_dataset, load_metric
import math

from promptsource.templates import DatasetTemplates

import transformers
from filelock import FileLock
from transformers import (
    # AdapterConfig,
    AutoConfig,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    EarlyStoppingCallback,
    HfArgumentParser,
    MBart50Tokenizer,
    MBart50TokenizerFast,
    MBartTokenizer,
    MBartTokenizerFast,
    # MultiLingAdapterArguments,
    # Seq2SeqAdapterTrainer,
    Trainer,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    set_seed,
    TrainerCallback,
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version, is_offline_mode
from transformers.utils.versions import require_version

In [2]:
try:
    nltk.data.find("tokenizers/punkt")
except (LookupError, OSError):
    if is_offline_mode():
        raise LookupError(
            "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
        )
    with FileLock(".lock") as lock:
        nltk.download("punkt", quiet=True)

In [3]:
model_name_or_path="t5-small"
seed = 42
dropout = 0
max_length = 512
low_rank = 8
set_seed(seed)
max_train_examples = 10000
train_batch_size = 16
fp16 = True
testing = True
intervention_type = "nodireft"

In [4]:
config = AutoConfig.from_pretrained(
    model_name_or_path,
    dropout_rate=dropout,
    max_length=max_length,
)



In [5]:
import torch
tokenizer = AutoTokenizer.from_pretrained(
    model_name_or_path,
)
model = AutoModelForSeq2SeqLM.from_pretrained(
    model_name_or_path,
    config=config,
    torch_dtype=torch.bfloat16 if fp16 else torch.float32,
)

In [6]:
model.resize_token_embeddings(len(tokenizer))


Embedding(32100, 512)

### Dataset Preprocessing

Below are the dataset preprocessing blocks. Use these blocks to install and preprocess the datasets.

In this notebook we would create 4 datasets.

- 'amazon_lm_data': language model data from amazon_polarity dataset. It chunked the reviews into size=128 blocks, and the lm training objective is to use the prior block to predict the next block.
- 'yelp_lm_data': language model data from yelp_polarity dataset.
- 'amazon_classify_data': classification data from amazon_polarity dataset. Classes are "positive" and "negative".
- 'yelp_classify_data': classification data from yelp_polarity dataset.


In [7]:
# raw_datasets = load_dataset(
#     "yelp_polarity",
#     None, # dataset_config_name
# )

Select appropriate column names for the appropriate datasets.

In [8]:
# column_names = raw_datasets["train"].column_names
# print(column_names)

In [9]:
# text_column, summary_column = column_names[2], column_names[0]
# text_column, summary_column = column_names[0], column_names[1]

In [10]:
# text_column, summary_column

In [11]:
# padding = False

Uncomment the below blocks if you are creating 'classify' datasets.

In [12]:
# prefix = ""
# def preprocess_function(examples):
#     # remove pairs where at least one record is None

#     inputs, targets = [], []
#     for i in range(len(examples[text_column])):
#         inputs.append(examples[text_column][i])
#         target = str(examples[summary_column][i])
#         if target == "0":
#             targets.append("negative")
#         elif target == "1":
#             targets.append("positive")
#         else:
#             print(target)

#     inputs = [prefix + inp for inp in inputs]
#     # print(inputs)
#     model_inputs = tokenizer(inputs, max_length=max_length, padding=padding, truncation=True)

#     # Setup the tokenizer for targets
#     with tokenizer.as_target_tokenizer():
#         labels = tokenizer(targets, max_length=max_length, padding=padding, truncation=True)

#     # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
#     # padding in the loss.
#     if padding == "max_length":
#         labels["input_ids"] = [
#             [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
#         ]

#     model_inputs["labels"] = labels["input_ids"]
#     return model_inputs

In [13]:
# train_dataset = raw_datasets["train"]
# train_dataset = train_dataset.shuffle(seed=seed)
# train_dataset = train_dataset.select(range(2 * max_train_examples, 3 * max_train_examples))
# train_dataset = train_dataset.map(
#     preprocess_function,
#     batched=True,
#     num_proc=8,
#     remove_columns=column_names,
#     desc="Running tokenizer on train dataset",
# )

In [14]:
# train_dataset[0]

Uncomment the below blocks if you are creating 'lm' datasets.

In [15]:
# def tokenize_function(examples):
#     output = tokenizer(examples[text_column])
#     return output

# tokenized_datasets = raw_datasets.map(
#     tokenize_function,
#     batched=True,
#     num_proc=8,
#     remove_columns=column_names,
# )
# block_size = max_length

In [16]:
# # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
# def group_texts(examples):
#     # Concatenate all texts.
#     concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
#     total_length = len(concatenated_examples[list(examples.keys())[0]])
#     # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
#     # customize this part to your needs.
#     if total_length >= block_size:
#         total_length = (total_length // (block_size)) * block_size
#     # Split by chunks of max_len.
#     result = {
#         "input_ids": [concatenated_examples["input_ids"][i : i + block_size] for i in range(0, total_length-block_size, block_size)],
#         "labels": [concatenated_examples["input_ids"][i : i + block_size] for i in range(block_size, total_length, block_size)],
#         "attention_mask": [concatenated_examples["attention_mask"][i : i + block_size] for i in range(0, total_length-block_size, block_size)],
#         # for k, t in concatenated_examples.items()
#     }
#     # result["input_ids"] = result["temp"].copy()
#     return result

# lm_datasets = tokenized_datasets.map(
#     group_texts,
#     batched=True,
#     desc=f"Grouping texts in chunks of {block_size}",
# )

Select appropriate parts from the data to create train and validation splits.

In [17]:
# train_dataset = lm_datasets["train"]
# train_dataset = train_dataset.shuffle(seed=seed)
# # train_dataset = train_dataset.select(range(2 * max_train_examples, 3 * max_train_examples))

Save the data to where you like.

In [18]:
# train_dataset.save_to_disk("yelp_lm_data")

Now if you want to use any data, just load them from where you saved.

In [19]:
# train_dataset = datasets.load_from_disk("amazon_classify_data")

### Reft Model

In [20]:
import torch 
from pyreft import (
    TaskType,
    get_reft_model,
    ReftConfig,
    ReftTrainer,
    ReftTrainerForCausalLM, 
    ReftDataCollator,
    ReftRawDataset,
    LoreftIntervention,
    NodireftIntervention,
    DireftIntervention,
)

device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"

# Let's create a subspace with 8 dims
FULL_SUBSPACE = list(range(low_rank))

In [21]:
class SubloreftIntervention(LoreftIntervention):
    """
    This is a LoReFT that supports subspace interventions with coefficients!
    """
    def __init__(self, **kwargs):
        subspace_coeff = None
        # Subspace coefficients are the coefficients applied to each subspace.
        # When `subspace_coeff` is a ones tensor, this intervention is the same as a loreft intervention with subspaces
        # When `subspace_coeff` is a negative-ones tensor, this intervention is the negation of the loreft intervention
        # There is no intervention when `subspace_coeff` is zero.
        if "subspace_coeff" in kwargs:
            subspace_coeff = kwargs["subspace_coeff"].copy()
            del kwargs["subspace_coeff"]
        subspace_coeff = torch.tensor(subspace_coeff) if subspace_coeff is not None else torch.ones(kwargs["low_rank_dimension"])
        self.subspace_coeff = subspace_coeff.to(device)
        super().__init__(**kwargs)
        print("loreft", kwargs)
        if not fp16:
            self.learned_source = self.learned_source.to(torch.float32) 
            
    def forward(
        self, base, source=None, subspaces=None, **kwargs,
    ):
        assert subspaces is not None
        # print("mag:", self.subspace_coeff)
        original_output = kwargs["_pyvene_model_input_args"][0]

        output = []

        rotated_base = self.rotate_layer(original_output)
        diff = self.act_fn(self.learned_source(original_output)) - rotated_base
        
        batched_subspace = []
        batched_weights = []
        
        if len(diff) > 1:
            subspaces = [subspaces[0]] * len(diff)
        elif len(diff) != len(subspaces):
            print(f"Warning! lengths do not match {len(diff)} {len(subspaces)}")

        # Expand subspaces to match dimensions
        subspaces = torch.tensor(subspaces).to(base.device)
        subspaces_expanded = subspaces.unsqueeze(1).expand(diff.size(0), diff.size(1), -1)
        
        LHS = torch.gather(diff, 2, subspaces_expanded) * self.subspace_coeff[subspaces_expanded]
        
        # Transpose and gather the corresponding weights for each subspace
        RHS = self.rotate_layer.weight[..., subspaces].permute(1, 2, 0)
        output = base + torch.bmm(LHS, RHS)
        
        return self.dropout(output.to(base.dtype))

In [22]:

class CoeffloreftIntervention(LoreftIntervention):
    """
    This is a LoReFT that supports subspace interventions with coefficients!
    """
    def __init__(self, **kwargs):
        subspace_coeff = None
        # Subspace coefficients are the coefficients applied to each subspace.
        # When `subspace_coeff` is a ones tensor, this intervention is the same as a loreft intervention with subspaces
        # When `subspace_coeff` is a negative-ones tensor, this intervention is the negation of the loreft intervention
        # There is no intervention when `subspace_coeff` is zero.
        if "subspace_coeff" in kwargs:
            subspace_coeff = kwargs["subspace_coeff"].copy()
            del kwargs["subspace_coeff"]
        self.subspace_coeff = torch.tensor(subspace_coeff) if subspace_coeff is not None else torch.ones(1)
        self.subspace_coeff = self.subspace_coeff.to(device)
        super().__init__(**kwargs)
        print("loreft", kwargs)
        if not fp16:
            self.learned_source = self.learned_source.to(torch.float32)        
            
    def forward(
        self, base, source=None, subspaces=None, **kwargs,
    ):
        # print(base.shape, original_output.shape, torch.equal(base, original_output))
        # print(len(kwargs["_pyvene_model_input_args"]), len(kwargs["_pyvene_model_output"]))
        # print("mag:", self.subspace_coeff)
        # print(kwargs.keys())
        original_output = kwargs["_pyvene_model_input_args"][0]

        rotated_base = self.rotate_layer(original_output)
        val = torch.matmul(
            (self.act_fn(self.learned_source(original_output)) - rotated_base), self.rotate_layer.weight.T
        )
        # print(f"mag: {self.subspace_coeff}, val: {val.norm()}")
        
        output = base + self.subspace_coeff * val
        return self.dropout(output.to(base.dtype))

In the NodireftIntervention, we base our interventions all on pyvene's module input. This is to make sure everyone composes on the same base (i.e. parallel interventions). However, you can also try changing the composition base to each intervention's input. That is equivalent to sequentially applying these interventions. You can try that out as well!

In [23]:

class SubNodireftIntervention(NodireftIntervention):
    """
    This is a NodiReft that supports subspace interventions with coefficients!
    """
    def __init__(self, **kwargs):
        subspace_coeff = None
        # Subspace coefficients are the coefficients applied to each subspace.
        # When `subspace_coeff` is a ones tensor, this intervention is the same as a loreft intervention with subspaces
        # When `subspace_coeff` is a negative-ones tensor, this intervention is the negation of the loreft intervention
        # There is no intervention when `subspace_coeff` is zero.
        if "subspace_coeff" in kwargs:
            subspace_coeff = kwargs["subspace_coeff"].copy()
            del kwargs["subspace_coeff"]
        self.subspace_coeff = torch.tensor(subspace_coeff) if subspace_coeff is not None else torch.ones(1)
        self.subspace_coeff = self.subspace_coeff.to(device)
        super().__init__(**kwargs)
        print("nodireft", kwargs)
        if not fp16:
            self.learned_source = self.learned_source.to(torch.float32)
            self.subspace_coeff = self.subspace_coeff.to(torch.float32)
        else:
            self.subspace_coeff = self.subspace_coeff.to(torch.bfloat16)
            
    def forward(
        self, base, source=None, subspaces=None, **kwargs
    ):
        original_output = kwargs["_pyvene_model_input_args"][0]

        output = base + self.subspace_coeff * torch.matmul(
             self.act_fn(self.learned_source(original_output)), self.proj_layer.weight
        )
        return self.dropout(output.to(base.dtype))



In [24]:

class SubDireftIntervention(DireftIntervention):
    """
    This is a DiReft that supports subspace interventions with coefficients!
    """
    def __init__(self, **kwargs):
        subspace_coeff = None
        if "subspace_coeff" in kwargs:
            subspace_coeff = kwargs["subspace_coeff"].copy()
            del kwargs["subspace_coeff"]
        self.subspace_coeff = torch.tensor(subspace_coeff) if subspace_coeff is not None else torch.ones(1)
        self.subspace_coeff = self.subspace_coeff.to(device)
        super().__init__(**kwargs)
        print("direft", kwargs)
        if not fp16:
            self.learned_source = self.learned_source.to(torch.float32)
            self.subspace_coeff = self.subspace_coeff.to(torch.float32)
        else:
            self.subspace_coeff = self.subspace_coeff.to(torch.bfloat16)
            
    def forward(
        self, base, source=None, subspaces=None, **kwargs
    ):
        original_output = kwargs["_pyvene_model_input_args"][0]
        cast_base = original_output.to(self.learned_source.weight.dtype)
        output = base + self.subspace_coeff * torch.matmul(
            (self.act_fn(self.learned_source(cast_base))).to(self.rotate_layer.weight.dtype), self.rotate_layer.weight.T
        )
        return self.dropout(output.to(base.dtype))


We do interventions on all layers of the encoder and the decoder. Note that t5-base has 12 layers and t5-small has 6 layers.

In [25]:
layers = list(range(6))
num_interventions = 2 * len(layers)

In [26]:

if intervention_type == "nodireft":
    reft_config = ReftConfig(representations=
        [{
                "layer": l, "component": "encoder.block." + str(l) + ".output",
                "low_rank_dimension": low_rank,
                "intervention": SubNodireftIntervention(
                    embed_dim=model.config.hidden_size, low_rank_dimension=low_rank,
                    dtype=torch.bfloat16 if fp16 else torch.float32, 
                    add_bias=False,
                )
            } for l in layers]
        + [{
                "layer": l, "component": "decoder.block." + str(l) + ".output",
                "low_rank_dimension": low_rank,
                "intervention": SubNodireftIntervention(
                    embed_dim=model.config.hidden_size, low_rank_dimension=low_rank,
                    dtype=torch.bfloat16 if fp16 else torch.float32, 
                    add_bias=False,
                )
            } for l in layers]
    )
elif intervention_type == "loreft":
    reft_config = ReftConfig(representations=
        [{
                "layer": l, "component": "encoder.block." + str(l) + ".output",
                "low_rank_dimension": low_rank,
                "intervention": SubloreftIntervention(
                    embed_dim=model.config.hidden_size, low_rank_dimension=low_rank,
                    dtype=torch.bfloat16 if fp16 else torch.float32, 
                    init_orth=True,
                )
            } for l in layers]
        + [{
                "layer": l, "component": "decoder.block." + str(l) + ".output",
                "low_rank_dimension": low_rank,
                "intervention": SubloreftIntervention(
                    embed_dim=model.config.hidden_size, low_rank_dimension=low_rank,
                    dtype=torch.bfloat16 if fp16 else torch.float32, 
                    init_orth=True,
                )
            } for l in layers]
    )
elif intervention_type == "direft":
    reft_config = ReftConfig(representations=
        [{
                "layer": l, "component": "encoder.block." + str(l) + ".output",
                "low_rank_dimension": low_rank,
                "intervention": SubDireftIntervention(
                    embed_dim=model.config.hidden_size, low_rank_dimension=low_rank,
                    dtype=torch.bfloat16 if fp16 else torch.float32, 
                    init_orth=True,
                )
            } for l in layers]
        + [{
                "layer": l, "component": "decoder.block." + str(l) + ".output",
                "low_rank_dimension": low_rank,
                "intervention": SubDireftIntervention(
                    embed_dim=model.config.hidden_size, low_rank_dimension=low_rank,
                    dtype=torch.bfloat16 if fp16 else torch.float32, 
                    init_orth=True,
                )
            } for l in layers]
    )
elif intervention_type == "coeffloreft":
    reft_config = ReftConfig(representations=
        [{
                "layer": l, "component": "encoder.block." + str(l) + ".output",
                "low_rank_dimension": low_rank,
                "intervention": CoeffloreftIntervention(
                    embed_dim=model.config.hidden_size, low_rank_dimension=low_rank,
                    dtype=torch.bfloat16 if fp16 else torch.float32, 
                    init_orth=True,
                )
            } for l in layers]
        + [{
                "layer": l, "component": "decoder.block." + str(l) + ".output",
                "low_rank_dimension": low_rank,
                "intervention": CoeffloreftIntervention(
                    embed_dim=model.config.hidden_size, low_rank_dimension=low_rank,
                    dtype=torch.bfloat16 if fp16 else torch.float32, 
                    init_orth=True,
                )
            } for l in layers]
    )
else:
    raise ValueError(f'No support for intervention {intervention_type}')


nodireft {'embed_dim': 512, 'low_rank_dimension': 8, 'dtype': torch.bfloat16, 'add_bias': False}
nodireft {'embed_dim': 512, 'low_rank_dimension': 8, 'dtype': torch.bfloat16, 'add_bias': False}
nodireft {'embed_dim': 512, 'low_rank_dimension': 8, 'dtype': torch.bfloat16, 'add_bias': False}
nodireft {'embed_dim': 512, 'low_rank_dimension': 8, 'dtype': torch.bfloat16, 'add_bias': False}
nodireft {'embed_dim': 512, 'low_rank_dimension': 8, 'dtype': torch.bfloat16, 'add_bias': False}
nodireft {'embed_dim': 512, 'low_rank_dimension': 8, 'dtype': torch.bfloat16, 'add_bias': False}
nodireft {'embed_dim': 512, 'low_rank_dimension': 8, 'dtype': torch.bfloat16, 'add_bias': False}
nodireft {'embed_dim': 512, 'low_rank_dimension': 8, 'dtype': torch.bfloat16, 'add_bias': False}
nodireft {'embed_dim': 512, 'low_rank_dimension': 8, 'dtype': torch.bfloat16, 'add_bias': False}
nodireft {'embed_dim': 512, 'low_rank_dimension': 8, 'dtype': torch.bfloat16, 'add_bias': False}
nodireft {'embed_dim': 512, 'l

In [27]:
from dataclasses import dataclass, field
from datasets import Dataset
from typing import Dict, Optional, Sequence, Union, List, Any


@dataclass
class AdaptorReftDataCollator(object):
    """Collate examples for ReFT."""
    
    tokenizer: transformers.AutoTokenizer
    data_collator: transformers.DataCollator

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        batch_inputs = self.data_collator(instances)
        if "decoder_input_ids" in batch_inputs.keys():
            del batch_inputs["decoder_input_ids"]
        return batch_inputs

In [28]:
def data_generator(tokenizer, inputs, num_interventions):
    """Generator function to yield data lazily."""
    for i in range(len(inputs)):
        _input = inputs[i]
        
        output_ids = [(l if l != tokenizer.pad_token_id else -100) for l in _input["labels"]]
        
        yield {
            "input_ids": _input["input_ids"],
            "labels": _input["labels"],
            "subspaces": [FULL_SUBSPACE] * num_interventions,
            # "intervention_locations": [[0]] * num_interventions
        }

def make_all_positions_unsupervised_data_module(
    tokenizer: transformers.PreTrainedTokenizer, model, inputs, 
    num_interventions=1, nonstop=False, fp16=False
):
    """Make dataset and collator for unsupervised (or really, semi-supervised) fine-tuning with streaming."""
    
    # Using a generator to lazily load the dataset
    train_dataset = Dataset.from_generator(
        lambda: data_generator(tokenizer, inputs, num_interventions),
        # features={
        #     "input_ids": [tokenizer.pad_token_id],  # Assuming lists of token ids
        #     "labels": [tokenizer.pad_token_id],     # Assuming lists of token ids with padding
        #     "subspaces": [[FULL_SUBSPACE]],         # Subspace feature, adjust as needed
        #     # "intervention_locations": [[[0]]],      # Intervention locations
        # }
    )
    
    data_collator_fn = transformers.DataCollatorForSeq2Seq(
        tokenizer=tokenizer,
        model=model,
        label_pad_token_id=-100,
        pad_to_multiple_of=8 if fp16 else None,
    )
    data_collator = AdaptorReftDataCollator(tokenizer=tokenizer, data_collator=data_collator_fn)
    return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)


In [29]:
import copy
def handle_training(dataset_name, profiling=False):
    reft_model = get_reft_model(model, copy.deepcopy(reft_config), set_device=device)
    reft_model.set_device(device)
    print(reft_model.get_device())
    reft_model.print_trainable_parameters()
    train_dataset = datasets.load_from_disk(dataset_name)
    if testing: train_dataset = train_dataset.select(range(max_train_examples))
    train_dataset = make_all_positions_unsupervised_data_module(tokenizer, model, train_dataset, num_interventions=num_interventions, nonstop=False)
    train_dataset, data_collator = train_dataset["train_dataset"], train_dataset["data_collator"]
    print(len(train_dataset))
    # Double checked, we can use ReftTrainerForCausalLM for training Seq2Seq models
    reft_model.train()
    reft_model.model.train()
    reft_model.training = True
    
    training_args = transformers.TrainingArguments(
        num_train_epochs=1.0, output_dir="./results_domain", learning_rate=5e-4, report_to=[],
        per_device_train_batch_size=train_batch_size, logging_steps=50, bf16=fp16,
        dataloader_num_workers = 2,
        dataloader_pin_memory = True,
        remove_unused_columns = False,
        gradient_accumulation_steps=2,
        adam_beta2 = 0.98,
        adam_epsilon=1e-6,
        # warmup_ratio=0.06,
    )
    trainer = ReftTrainerForCausalLM(
        model=reft_model, tokenizer=tokenizer, args=training_args, 
        train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
    
    class ProfCallback(TrainerCallback):
        def __init__(self, prof):
            self.prof = prof
    
        def on_step_end(self, args, state, control, **kwargs):
            self.prof.step()

    if profiling:
        with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU,
                                                torch.profiler.ProfilerActivity.CUDA], 
                                    schedule=torch.profiler.schedule(skip_first=3, wait=1, warmup=1, active=4, repeat=1),
                                    on_trace_ready=torch.profiler.tensorboard_trace_handler('./hf-training-trainer/grad/'),
                                    profile_memory=True,
                                    with_stack=True,
                                    record_shapes=True) as prof:
            
            trainer.add_callback(ProfCallback(prof=prof))
            trainer.train()
    else:
        trainer.train()
    
    # prof.export_chrome_trace("my_trainer.json")
    
    return reft_model


### Train Amazon LM

In [30]:
reft_amazon_lm = handle_training("train/amazon_lm_data_new")

cuda:0
trainable intervention params: 98,400 || trainable model params: 0
model params: 60,492,288 || trainable%: 0.16266536322778863


Generating train split: 0 examples [00:00, ? examples/s]

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


10000


Step,Training Loss
50,4.4516
100,4.1034
150,4.067
200,4.0669
250,4.0241
300,4.0346


### Train Yelp LM

In [31]:
reft_yelp_lm = handle_training("train/yelp_lm_data_new")

cuda:0
trainable intervention params: 98,400 || trainable model params: 0
model params: 60,492,288 || trainable%: 0.16266536322778863


Generating train split: 0 examples [00:00, ? examples/s]

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


10000


Step,Training Loss
50,4.3692
100,3.9396
150,3.9114
200,3.883
250,3.8811
300,3.8687


### Train Yelp Classifier

In [32]:
reft_yelp_classifier = handle_training("train/yelp_classify_data")

cuda:0
trainable intervention params: 98,400 || trainable model params: 0
model params: 60,492,288 || trainable%: 0.16266536322778863


Generating train split: 0 examples [00:00, ? examples/s]

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


10000


Step,Training Loss
50,1.4596
100,0.1149
150,0.0896
200,0.0979
250,0.0826
300,0.1056


### Train Amazon Classifier

In [33]:
reft_amazon_classifier = handle_training("train/amazon_classify_data")

cuda:0
trainable intervention params: 98,400 || trainable model params: 0
model params: 60,492,288 || trainable%: 0.16266536322778863


Generating train split: 0 examples [00:00, ? examples/s]

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


10000


Step,Training Loss
50,1.4331
100,0.1514
150,0.1432
200,0.1513
250,0.1287
300,0.1242


### Man - Woman = King - Queen relationship

Now let's see if adding the diff of the lms to the classification results would lead to any performance gains!

#### Generation with composition

During generation on the 'yelp_classify_data' and the 'amazon_classify_data', we restrict decoding to be only on 'positive' and 'negative'. We use beam search with num_beams=5.

In [34]:
# gen_batch_size = 1
gen_max_length = 2
from tqdm import tqdm
def generate_texts(reft_model, prompt, allowed_token_ids, num_interventions=1, intervene_on_all=True):
    # instruction = " "
    
    # print(prompt)

    for k, v in prompt.items():
        if isinstance(v, list):
            prompt[k] = torch.tensor(v,dtype=torch.long).to(device)

    # prompt = prompt.to(device)
    gen_batch_size = prompt["input_ids"].shape[0]
    # print(gen_batch_size)
    # print(prompt)
    
    generated_texts = []
    subspaces = [[FULL_SUBSPACE] * gen_batch_size] * num_interventions
    # print(subspaces)
    # print(allowed_token_ids)
    _, reft_response = reft_model.generate(
        prompt, 
        unit_locations= None if intervene_on_all else {"sources->base": (None, [[[0] ] ] * len(layers)) },
        subspaces=subspaces,
        intervene_on_prompt=True, 
        max_new_tokens=2,
        min_new_tokens=1,
        # do_sample=True, 
        # no_repeat_ngram_size=2, 
        # repetition_penalty=1.1, 
        force_words_ids=allowed_token_ids,
        num_beams = 5,
        # top_k = 50,
        eos_token_id=tokenizer.eos_token_id, early_stopping=True,
        pad_token_id=tokenizer.eos_token_id,
        remove_invalid_values=True,
    )

    generated_text = tokenizer.batch_decode(reft_response, skip_special_tokens=True)
    # generated_text = [t[len(instruction):] for t in generated_text]
    generated_texts += generated_text

    # print(generated_texts[0])
    return generated_texts

In [35]:
def eval(reft_model, num_interventions, train_dataset, data_collator):
    acc = torch.tensor(0.0, device=device)
    tot = torch.tensor(0.0, device=device)
    from torch.utils.data import DataLoader
    # Create the DataLoader with the collator
    gen_batch_size = 16
    dataloader = DataLoader(
        train_dataset,
        batch_size=gen_batch_size,
        collate_fn=data_collator
    )
    
    # Iterate over batches
    with torch.no_grad():
    
        pbar = tqdm(dataloader)
        for batch in pbar:
            if 'subspaces' in batch:
                del batch['subspaces']
            batch = {k: v.to(device) for k, v in batch.items()}
            output_labels = generate_texts(reft_model, batch, force_words_ids, num_interventions)
            filtered_labels = [
                [token_id for token_id in sequence if token_id != -100]
                for sequence in batch["labels"]
            ]
            true_labels = tokenizer.batch_decode(filtered_labels, skip_special_tokens=True)
            correct = torch.tensor(0.0, device=device)
            for i in range(len(output_labels)):
                if output_labels[i] == true_labels[i]:
                    correct += 1
            
            acc += correct
            tot += len(output_labels)
            pbar.set_postfix({"Correct": acc.item(), "Accuracy": (acc / tot).item()})
    final_acc = (acc/tot).item()
    print(f"Final Accuracy: {final_acc:.4f}")


Here we set the composed model's interventions to be individual model's interventions, with `subspace_coeff` as the magnitude of the intervention.

In [36]:
import copy
def set_lm(reft_model, reft_composed, layer, l, intervention_id):
    
    composed_key = "comp.encoder.block." + str(layer) + ".output.unit.pos.nunit.1#" + str(intervention_id)
    original_key = "comp.encoder.block." + str(layer) + ".output.unit.pos.nunit.1#0"
    
    if intervention_type == "loreft":
        reft_composed.interventions[composed_key][0].rotate_layer = copy.deepcopy(reft_model.interventions[original_key][0].rotate_layer)
        subspace_coeff = l * torch.ones(low_rank).to(device)
    elif intervention_type == "nodireft":
        reft_composed.interventions[composed_key][0].proj_layer = copy.deepcopy(reft_model.interventions[original_key][0].proj_layer)
        subspace_coeff = l * torch.ones(1).to(device)
    elif intervention_type == "direft" or intervention_type == "coeffloreft":
        reft_composed.interventions[composed_key][0].rotate_layer = copy.deepcopy(reft_model.interventions[original_key][0].rotate_layer)
        subspace_coeff = l * torch.ones(1).to(device)
    
    
    reft_composed.interventions[composed_key][0].learned_source = copy.deepcopy(reft_model.interventions[original_key][0].learned_source)

    # subspace_coeff = subspace_coeff.to(torch.bfloat16) if fp16 else subspace_coeff.to(torch.float32)
    reft_composed.interventions[composed_key][0].subspace_coeff = subspace_coeff
    # print(f"In set_lm: {composed_key}, {reft_model.interventions[original_key][0].learned_source.weight[0][0]}, {reft_model.interventions[original_key][0].learned_source.bias[0]},{reft_model.interventions[original_key][0].rotate_layer.parametrizations.weight.original[0][0]}")
    # print(f"In set_lm: {composed_key}", reft_composed.interventions[composed_key][0].subspace_coeff)
    
    composed_key = "comp.decoder.block." + str(layer) + ".output.unit.pos.nunit.1#" + str(intervention_id)
    original_key = "comp.decoder.block." + str(layer) + ".output.unit.pos.nunit.1#0"
    
    if intervention_type == "loreft":
        reft_composed.interventions[composed_key][0].rotate_layer = copy.deepcopy(reft_model.interventions[original_key][0].rotate_layer)
        subspace_coeff = l * torch.ones(low_rank).to(device)
    elif intervention_type == "nodireft":
        reft_composed.interventions[composed_key][0].proj_layer = copy.deepcopy(reft_model.interventions[original_key][0].proj_layer)
        subspace_coeff = l * torch.ones(1).to(device)
    elif intervention_type == "direft" or intervention_type == "coeffloreft":
        reft_composed.interventions[composed_key][0].rotate_layer = copy.deepcopy(reft_model.interventions[original_key][0].rotate_layer)
        subspace_coeff = l * torch.ones(1).to(device)
    
    reft_composed.interventions[composed_key][0].learned_source = copy.deepcopy(reft_model.interventions[original_key][0].learned_source)

    # subspace_coeff = subspace_coeff.to(torch.bfloat16) if fp16 else subspace_coeff.to(torch.float32)
    reft_composed.interventions[composed_key][0].subspace_coeff = subspace_coeff
    
    # print(f"In set_lm: {composed_key}", reft_composed.interventions[composed_key][0].subspace_coeff)
    # print(f"In set_lm: {reft_composed.interventions['comp.encoder.block.10.output.unit.pos.nunit.1#0'][0].subspace_coeff}")
    return reft_composed


In [37]:

if intervention_type == "nodireft":
    representations = []
    sub_representation = [{
                "layer": l, "component": "encoder.block." + str(l) + ".output",
                "low_rank_dimension": low_rank,
                "intervention": SubNodireftIntervention(
                    embed_dim=model.config.hidden_size, low_rank_dimension=low_rank,
                    dtype=torch.bfloat16 if fp16 else torch.float32, 
                    add_bias=False,
                )
            } for l in layers]
    for _ in range(3):
        representations += copy.deepcopy(sub_representation)
    sub_representation = [{
                "layer": l, "component": "decoder.block." + str(l) + ".output",
                "low_rank_dimension": low_rank,
                "intervention": SubNodireftIntervention(
                    embed_dim=model.config.hidden_size, low_rank_dimension=low_rank,
                    dtype=torch.bfloat16 if fp16 else torch.float32, 
                    add_bias=False,
                )
            } for l in layers]
    for _ in range(3):
        representations += copy.deepcopy(sub_representation)
    composed_reft_config = ReftConfig(representations=representations)
elif intervention_type == "loreft":
    representations = []
    sub_representation = [{
                "layer": l, "component": "encoder.block." + str(l) + ".output",
                "low_rank_dimension": low_rank,
                "intervention": SubloreftIntervention(
                    embed_dim=model.config.hidden_size, low_rank_dimension=low_rank,
                    dtype=torch.bfloat16 if fp16 else torch.float32, 
                    init_orth=True,
                )
            } for l in layers]
    for _ in range(3):
        representations += copy.deepcopy(sub_representation)
    sub_representation = [{
                "layer": l, "component": "decoder.block." + str(l) + ".output",
                "low_rank_dimension": low_rank,
                "intervention": SubloreftIntervention(
                    embed_dim=model.config.hidden_size, low_rank_dimension=low_rank,
                    dtype=torch.bfloat16 if fp16 else torch.float32, 
                    init_orth=True,
                )
            } for l in layers]
    for _ in range(3):
        representations += copy.deepcopy(sub_representation)
    composed_reft_config = ReftConfig(representations=representations)
elif intervention_type == "direft":
    representations = []
    sub_representation = [{
                "layer": l, "component": "encoder.block." + str(l) + ".output",
                "low_rank_dimension": low_rank,
                "intervention": SubDireftIntervention(
                    embed_dim=model.config.hidden_size, low_rank_dimension=low_rank,
                    dtype=torch.bfloat16 if fp16 else torch.float32, 
                )
            } for l in layers]
    for _ in range(3):
        representations += copy.deepcopy(sub_representation)
    sub_representation = [{
                "layer": l, "component": "decoder.block." + str(l) + ".output",
                "low_rank_dimension": low_rank,
                "intervention": SubDireftIntervention(
                    embed_dim=model.config.hidden_size, low_rank_dimension=low_rank,
                    dtype=torch.bfloat16 if fp16 else torch.float32, 
                )
            } for l in layers]
    for _ in range(3):
        representations += copy.deepcopy(sub_representation)
    composed_reft_config = ReftConfig(representations=representations)
elif intervention_type == "coeffloreft":
    representations = []
    sub_representation = [{
                "layer": l, "component": "encoder.block." + str(l) + ".output",
                "low_rank_dimension": low_rank,
                "intervention": CoeffloreftIntervention(
                    embed_dim=model.config.hidden_size, low_rank_dimension=low_rank,
                    dtype=torch.bfloat16 if fp16 else torch.float32,
                    init_orth=True,
                )
            } for l in layers]
    for _ in range(3):
        representations += copy.deepcopy(sub_representation)
    sub_representation = [{
                "layer": l, "component": "decoder.block." + str(l) + ".output",
                "low_rank_dimension": low_rank,
                "intervention": CoeffloreftIntervention(
                    embed_dim=model.config.hidden_size, low_rank_dimension=low_rank,
                    dtype=torch.bfloat16 if fp16 else torch.float32, 
                    init_orth=True,
                )
            } for l in layers]
    for _ in range(3):
        representations += copy.deepcopy(sub_representation)
    composed_reft_config = ReftConfig(representations=representations)


nodireft {'embed_dim': 512, 'low_rank_dimension': 8, 'dtype': torch.bfloat16, 'add_bias': False}
nodireft {'embed_dim': 512, 'low_rank_dimension': 8, 'dtype': torch.bfloat16, 'add_bias': False}
nodireft {'embed_dim': 512, 'low_rank_dimension': 8, 'dtype': torch.bfloat16, 'add_bias': False}
nodireft {'embed_dim': 512, 'low_rank_dimension': 8, 'dtype': torch.bfloat16, 'add_bias': False}
nodireft {'embed_dim': 512, 'low_rank_dimension': 8, 'dtype': torch.bfloat16, 'add_bias': False}
nodireft {'embed_dim': 512, 'low_rank_dimension': 8, 'dtype': torch.bfloat16, 'add_bias': False}
nodireft {'embed_dim': 512, 'low_rank_dimension': 8, 'dtype': torch.bfloat16, 'add_bias': False}
nodireft {'embed_dim': 512, 'low_rank_dimension': 8, 'dtype': torch.bfloat16, 'add_bias': False}
nodireft {'embed_dim': 512, 'low_rank_dimension': 8, 'dtype': torch.bfloat16, 'add_bias': False}
nodireft {'embed_dim': 512, 'low_rank_dimension': 8, 'dtype': torch.bfloat16, 'add_bias': False}
nodireft {'embed_dim': 512, 'l

In [38]:
def set_eval(reft_model):
    reft_model.eval()
    reft_model.model.eval()
    reft_model.training = False

In [39]:
gen_batch_size = 64
force_flexible = ["negative","positive"]
force_words_ids = [tokenizer(force_flexible, add_special_tokens=True).input_ids]


#### Yelp Test

In this test we test on the Yelp dataset to see if the amazon_classifier can improve on it.

In [40]:
reft_composed = get_reft_model(model, composed_reft_config, set_device=False)
reft_composed.set_device(device)
print(reft_composed.get_device())
reft_composed.print_trainable_parameters()

cuda:0
trainable intervention params: 295,200 || trainable model params: 0
model params: 60,492,288 || trainable%: 0.48799608968336594


In [41]:
train_dataset = datasets.load_from_disk("validation/yelp_classify_data") # amazon classifier    
if testing: train_dataset = train_dataset.select(range(max_train_examples))
len(train_dataset)

10000

In [42]:

set_eval(reft_yelp_lm)
set_eval(reft_amazon_lm)
set_eval(reft_yelp_classifier)
set_eval(reft_amazon_classifier)

In [43]:
for l in layers:
    # set_lm(reft_yelp_classifier, l, 1.0, 0)
    reft_composed = set_lm(reft_yelp_lm, reft_composed, l, 0.3, 0) # 1.0
    reft_composed = set_lm(reft_amazon_lm, reft_composed, l, -0.3, 1) # -1.0
    # reft_composed = set_lm(reft_yelp_lm, reft_composed, l, -1.0, 1) # -1.0
    reft_composed = set_lm(reft_amazon_classifier, reft_composed, l, 1.0, 2) # 1.0
set_eval(reft_composed)


In [44]:
force_words_ids

[[[2841, 1], [1465, 1]]]

In [45]:
reft_train_dataset = make_all_positions_unsupervised_data_module(tokenizer, model, train_dataset, num_interventions=3 * num_interventions, nonstop=False)
reft_train_dataset, data_collator = reft_train_dataset["train_dataset"], reft_train_dataset["data_collator"]
eval(reft_composed, 3 * num_interventions, reft_train_dataset, data_collator)

Generating train split: 0 examples [00:00, ? examples/s]

100%|█████████| 625/625 [01:26<00:00,  7.23it/s, Correct=9218.0, Accuracy=0.922]

Final Accuracy: 0.9218





In [46]:
reft_train_dataset = make_all_positions_unsupervised_data_module(tokenizer, model, train_dataset, num_interventions=num_interventions, nonstop=False)
reft_train_dataset, data_collator = reft_train_dataset["train_dataset"], reft_train_dataset["data_collator"]
eval(reft_amazon_classifier, num_interventions, train_dataset, data_collator)

Generating train split: 0 examples [00:00, ? examples/s]

100%|██████████| 625/625 [01:16<00:00,  8.15it/s, Correct=9205.0, Accuracy=0.92]

Final Accuracy: 0.9205





In [47]:
reft_train_dataset = make_all_positions_unsupervised_data_module(tokenizer, model, train_dataset, num_interventions=num_interventions, nonstop=False)
reft_train_dataset, data_collator = reft_train_dataset["train_dataset"], reft_train_dataset["data_collator"]
eval(reft_yelp_classifier, num_interventions, train_dataset, data_collator)

100%|█████████| 625/625 [01:17<00:00,  8.11it/s, Correct=9294.0, Accuracy=0.929]

Final Accuracy: 0.9294





#### Amazon Test

In this test we use the amazon dataset to see if yelp classifier can improve on it with the help of the lm.

In [48]:
reft_composed = get_reft_model(model, composed_reft_config, set_device=False)
reft_composed.set_device(device)
print(reft_composed.get_device())
reft_composed.print_trainable_parameters()


cuda:0
trainable intervention params: 295,200 || trainable model params: 0
model params: 60,492,288 || trainable%: 0.48799608968336594


In [49]:
train_dataset = datasets.load_from_disk("validation/amazon_classify_data") # amazon classifier
if testing: train_dataset = train_dataset.select(range(max_train_examples))
len(train_dataset)


10000

In [50]:
for l in layers:
    reft_composed = set_lm(reft_amazon_lm, reft_composed, l, 0.3, 0) # 1.0
    reft_composed = set_lm(reft_yelp_lm, reft_composed, l, -0.3, 1) # -1.0
    reft_composed = set_lm(reft_yelp_classifier, reft_composed, l, 1.0, 2) # 1.0
set_eval(reft_composed)


In [51]:
reft_train_dataset = make_all_positions_unsupervised_data_module(tokenizer, model, train_dataset, num_interventions=3 * num_interventions, nonstop=False)
reft_train_dataset, data_collator = reft_train_dataset["train_dataset"], reft_train_dataset["data_collator"]
eval(reft_composed, 3 * num_interventions, reft_train_dataset, data_collator)


Generating train split: 0 examples [00:00, ? examples/s]

100%|█████████| 625/625 [01:16<00:00,  8.18it/s, Correct=8814.0, Accuracy=0.881]

Final Accuracy: 0.8814





In [52]:
reft_train_dataset = make_all_positions_unsupervised_data_module(tokenizer, model, train_dataset, num_interventions=num_interventions, nonstop=False)
reft_train_dataset, data_collator = reft_train_dataset["train_dataset"], reft_train_dataset["data_collator"]
eval(reft_amazon_classifier, num_interventions, train_dataset, data_collator)


Generating train split: 0 examples [00:00, ? examples/s]

100%|█████████| 625/625 [01:06<00:00,  9.34it/s, Correct=8952.0, Accuracy=0.895]

Final Accuracy: 0.8952





In [53]:
reft_train_dataset = make_all_positions_unsupervised_data_module(tokenizer, model, train_dataset, num_interventions=num_interventions, nonstop=False)
reft_train_dataset, data_collator = reft_train_dataset["train_dataset"], reft_train_dataset["data_collator"]
eval(reft_yelp_classifier, num_interventions, train_dataset, data_collator)


100%|█████████| 625/625 [01:07<00:00,  9.24it/s, Correct=8818.0, Accuracy=0.882]

Final Accuracy: 0.8818





You can see that for t5-small, on Yelp test dataset, we did observe gains from the lm diffs during the domain transfer. However, we did not observe the similar phenomenon for t5-base, or t5-small on the Amazon test. 

You can modify the lambda (currently set as 0.3) as much as you want. However, we can now conclude that ReFT's performance on domain transfering (without orthogonal restrictions) is just like IA-3 in the [LoRA operators](https://arxiv.org/pdf/2306.14870) paper.

### Analysis

Why did this happen? ReFT and IA-3 are pretty similar in many ways. They are all related to manipulating a steering vector and add it to the representation. However, LoRA has a larger effect due to its richer representations.

However, if we orthogonalize only across amazon_lm and amazon_classify tasks, and yelp_lm and yelp_classify tasks, we might segregate these 2 tasks into more different subspaces, which would cause better composability. If, however, we orthogonalize amazon_lm tasks and yelp_lm tasks, it would not reach a similar effect.

The ultimate goal is to develop some architecture of interventions that can (1) segregate different tasks as much as possible (such as via orthogonality), and (2) still get something for similar tasks (such as yelp_lm and amazon_lm), so that they can compose.