In [None]:
import sys
sys.path.append("../pyvene/")

In [None]:
import torch
import random, copy
import pandas as pd
import numpy as np
import torch.nn.functional as F
import seaborn as sns
from tqdm import tqdm, trange
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup
from transformers import DataCollatorForSeq2Seq
from torch.nn import CrossEntropyLoss

from pyvene import (
    IntervenableModel,
    LowRankRotatedSpaceIntervention,
    RepresentationConfig,
    IntervenableConfig,
    ConstantSourceIntervention,
    TrainableIntervention,
    DistributedRepresentationIntervention,
)
from pyvene import create_llama
from pyvene import set_seed, count_parameters
from pyvene.models.layers import LowRankRotateLayer

import io
import json
import os

def _make_r_io_base(f, mode: str):
    if not isinstance(f, io.IOBase):
        f = open(f, mode=mode)
    return f

def _make_w_io_base(f, mode: str):
    if not isinstance(f, io.IOBase):
        f_dirname = os.path.dirname(f)
        if f_dirname != "":
            os.makedirs(f_dirname, exist_ok=True)
        f = open(f, mode=mode)
    return f


def jload(f, mode="r"):
    """Load a .json file into a dictionary."""
    f = _make_r_io_base(f, mode)
    jdict = json.load(f)
    f.close()
    return jdict

def jdump(obj, f, mode="w", indent=4, default=str):
    """Dump a str or dictionary to a file in json format.

    Args:
        obj: An object to be written.
        f: A string path to the location on disk.
        mode: Mode for opening the file.
        indent: Indent for storing json dictionaries.
        default: A function to handle non-serializable entries; defaults to `str`.
    """
    f = _make_w_io_base(f, mode)
    if isinstance(obj, (dict, list)):
        json.dump(obj, f, indent=indent, default=default)
    elif isinstance(obj, str):
        f.write(obj)
    else:
        raise ValueError(f"Unexpected type: {type(obj)}")
    f.close()
    
device = "cuda"
IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"
prompt_template = """Below is an instruction that \
describes a task, paired with an input that provides \
further context. Write a response that appropriately \
completes the request.

### Instruction:
%s

### Input:
%s

### Response:
"""

In [None]:
config, _, llama = create_llama("yahma/llama-7b-hf")
_ = llama.to(device)  # single gpu
_ = llama.eval()  # always no grad on the model

In [None]:
from transformers import LlamaTokenizer
tokenizer = LlamaTokenizer.from_pretrained("yahma/llama-7b-hf")
# tokenizer.padding_side = "right" 
# special_tokens_dict = dict()
# if tokenizer.pad_token is None:
#     special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
# if tokenizer.eos_token is None:
#     special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
# if tokenizer.bos_token is None:
#     special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
# if tokenizer.unk_token is None:
#     special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
# num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
# llama.resize_token_embeddings(len(tokenizer))
# print("adding new tokens count: ", num_new_tokens)

In [None]:
alpaca_dataset = jload("./datasets/alpaca_data_cleaned/train.json")
max_sample_examples = 1000 # following DoRA, we should do 1K upto 10K
sampled_items = random.sample(range(len(alpaca_dataset)), max_sample_examples)

In [None]:
alpaca_dataset[0]

In [None]:
set_seed(42)

###################
# data loaders
###################
all_base_input_ids, all_base_positions, all_output_ids, all_source_input_ids = [], [], [], []

for s in sampled_items:
    data_item = alpaca_dataset[s]
    base_prompt = prompt_template % (data_item['instruction'], data_item['input'])
    # base input = base prompt + steered base output
    base_input = base_prompt + data_item["output"] + tokenizer.eos_token
    base_prompt_length = len(tokenizer(
        base_prompt, max_length=512, truncation=True, return_tensors="pt")["input_ids"][0])
    base_input_ids = tokenizer(
        base_input, max_length=512, truncation=True, return_tensors="pt")["input_ids"][0]
    output_ids = tokenizer(
        base_input, max_length=512, truncation=True, return_tensors="pt")["input_ids"][0]
    output_ids[:base_prompt_length] = -100
    base_input_ids[-1] = tokenizer.eos_token_id # enforce the last token to be eos
    output_ids[-1] = tokenizer.eos_token_id # enforce the last token to be eos
            
    all_base_input_ids.append(base_input_ids)
    all_base_positions.append([base_prompt_length-1]) # intervene on the last prompt token
    all_output_ids.append(output_ids)

raw_train = (
    all_base_input_ids,
    all_base_positions,
    all_output_ids,
)
train_dataset = Dataset.from_dict(
    {
        "input_ids": raw_train[0],
        "intervention_position": raw_train[1],
        "labels": raw_train[2],
    }
)
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=llama,
    label_pad_token_id=-100,
    padding="longest",
)

In [None]:
epochs = 1
initial_lr = 2e-3
total_step = 0
gradient_accumulation_steps = 2
batch_size = 8

train_dataloader = DataLoader(
    train_dataset, shuffle=True, batch_size=batch_size, collate_fn=data_collator)

In [None]:
class LearnedSourceLowRankRotatedSpaceIntervention(
    ConstantSourceIntervention,
    TrainableIntervention, 
    DistributedRepresentationIntervention
):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        rotate_layer = LowRankRotateLayer(self.embed_dim, kwargs["low_rank_dimension"])
        self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer)
        self.learned_source = torch.nn.Parameter(
            torch.rand(kwargs["low_rank_dimension"]), requires_grad=True)

    def forward(
        self, base, source=None, subspaces=None
    ):
        rotated_base = self.rotate_layer(base)
        output = base + torch.matmul(
            (self.learned_source - rotated_base), self.rotate_layer.weight.T
        )
        return output.to(base.dtype)

rank = 4
layers = [15, 15]

config = IntervenableConfig([{
    "layer": l,
    "component": "block_output",
    "low_rank_dimension": rank} for l in layers],
    # this is a trainable low-rank rotation
    LearnedSourceLowRankRotatedSpaceIntervention
)
intervenable = IntervenableModel(config, llama)
intervenable.set_device(device)
intervenable.disable_model_gradients()

In [None]:
optimizer = torch.optim.Adam(
    intervenable.get_trainable_parameters(), lr=initial_lr
)
scheduler = torch.optim.lr_scheduler.LinearLR(
    optimizer, end_factor=0.1, total_iters=epochs
)
intervenable.model.train()  # train enables drop-off but no grads
print("llama trainable parameters: ", count_parameters(intervenable.model))
print("intervention trainable parameters: ", intervenable.count_parameters())
train_iterator = trange(0, int(epochs), desc="Epoch")
for epoch in train_iterator:
    epoch_iterator = tqdm(
        train_dataloader, desc=f"Epoch: {epoch}", position=0, leave=True
    )
    for step, inputs in enumerate(epoch_iterator):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to(device)
        b_s = inputs["input_ids"].shape[0]
        
        base_unit_location = inputs["intervention_position"].tolist()
        first_unit_location = torch.zeros_like(inputs["intervention_position"]).tolist()
        _, cf_outputs = intervenable(
            {"input_ids": inputs["input_ids"]},
            unit_locations={"sources->base": (None, [first_unit_location] + [base_unit_location])})

        # lm loss on counterfactual labels
        lm_logits = cf_outputs.logits
        labels = inputs["labels"]
        shift_logits = lm_logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        # Flatten the tokens
        loss_fct = CrossEntropyLoss()
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        loss_str = round(loss.item(), 2)
        epoch_iterator.set_postfix({"loss": loss_str})
        if gradient_accumulation_steps > 1:
            loss = loss / gradient_accumulation_steps
        loss.backward()
        if total_step % gradient_accumulation_steps == 0:
            if not (gradient_accumulation_steps > 1 and total_step == 0):
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
        total_step += 1

In [None]:
q = "What are the names of some famous actors that started their careers on Broadway?"
q_input = ""
q_prompt = prompt_template % (q, q_input)

prompt = tokenizer(q_prompt, return_tensors="pt").to(device)
print("====== Original LLaMA ======")
response = llama.generate(**prompt, max_new_tokens=128, do_sample=False)
print(tokenizer.decode(response[0], skip_special_tokens=True))
print()
print("====== Steered LLaMA ======")
base_unit_location = prompt["input_ids"].shape[-1] - 1 
_, steered_response = intervenable.generate(
    prompt, 
    unit_locations={"sources->base": (None, [[[0]]]+[[[base_unit_location]]])},
    intervene_on_prompt=True,
    max_new_tokens=128, do_sample=False
)
print(tokenizer.decode(steered_response[0], skip_special_tokens=True))

In [None]:
intervenable.save(save_directory="./results/test")

AlpacaEval

In [None]:
def chunk(iterable, chunksize):
    # if iterable is a list, we chunk with simple list indexing
    if isinstance(iterable, list):
        return [iterable[i:i+chunksize] for i in range(0, len(iterable), chunksize)]
    # otherwise if iterable is a Hf Dataset, we leverage the select() function to create mini datasets
    elif isinstance(iterable, Dataset):
        chunks = []
        for i in range(0, len(iterable), chunksize):
            if i+chunksize < len(iterable):
                chunks.append(iterable.select(list(range(i, i+chunksize))))
            else:
                chunks.append(iterable.select(list(range(i, len(iterable)))))
        return chunks
    else:
        raise Exception(f"Unrecognizable type of iterable for batchification: {type(iterable)}")
        
def extract_output(pred, trigger=''):
    if not trigger:
        return pred
    # for causallm only, use special trigger to detect new tokens. See model_args.clm_new_token_trigger
    # if cannot find trigger --> generation is too long; default to empty generation
    start = pred.find(trigger)
    if start < 0:
        return ''
    output = pred[start+len(trigger):].lstrip() # left strip any whitespaces
    return output

In [None]:
import datasets

eval_set = datasets.load_dataset("tatsu-lab/alpaca_eval", "alpaca_eval")["eval"]
generations = []
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.unk_token
eval_batch_size = 8

for batch_example in tqdm(chunk(eval_set, eval_batch_size)):
    
    actual_batch = []
    for _, example in enumerate(batch_example):
        q = example["instruction"]
        q_input = ""
        prompt = prompt_template % (q, q_input)
        if isinstance(prompt, str):
            in_text = prompt
        else:
            raise TypeError(f"Unrecognized type for example input: {type(prompt)}")
        actual_batch.append(in_text)
    tokenized = tokenizer.batch_encode_plus(
        actual_batch, return_tensors='pt', padding=True).to(device)

    batch_length = tokenized["attention_mask"].sum(dim=-1).tolist()
    print(batch_length)
    base_last_unit_location = tokenized["input_ids"].shape[-1] - 1 
    base_last_unit_location = [[base_last_unit_location]]*eval_batch_size
    base_first_unit_location = [[
        tokenized["input_ids"].shape[-1] - batch_length[i]] 
        for i in range(eval_batch_size)]
    _, steered_response = intervenable.generate(
        tokenized, 
        unit_locations={"sources->base": (None, [base_first_unit_location]+[base_last_unit_location])},
        intervene_on_prompt=True,
        max_new_tokens=512, 
        do_sample=False,
        eos_token_id=tokenizer.eos_token_id, 
        early_stopping=True
    )
    actual_preds = tokenizer.batch_decode(steered_response, skip_special_tokens=True)
    clm_new_token_trigger = "### Response:\n"
    generations.extend([{'instruction': in_text, 'output': extract_output(pred, clm_new_token_trigger),
                   'dataset': example['dataset']}
                  for in_text, pred, example in zip(actual_batch, actual_preds, batch_example)])
    print(generations)

In [None]:
tokenizer.batch_decode(steered_response)

In [None]:
tokenized

In [None]:
outputs_file = open("./results/test_outputs.json", "w") 
json.dump(generations, outputs_file, indent = 6) 
outputs_file.close() 

Let's see if the learned directions are interpretable. 

In [None]:
class LinearSubspaceKnobIntervention(
    ConstantSourceIntervention,
    TrainableIntervention, 
    DistributedRepresentationIntervention
):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        rotate_layer = LowRankRotateLayer(self.embed_dim, kwargs["low_rank_dimension"])
        self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer)
        self.learned_source = torch.nn.Parameter(
            torch.rand(kwargs["low_rank_dimension"]), requires_grad=True)

    def forward(
        self, base, source=None, subspaces=None
    ):
        rotated_base = self.rotate_layer(base)
        output = base + torch.matmul(
            (self.learned_source - rotated_base) * KNOB_FACTOR, self.rotate_layer.weight.T
        )
        return output.to(base.dtype)

memo_weights = []
for _ in [2, 10, 18, 26]:
    state_dict = torch.load(
        f"./results/test/intkey_layer.{_}.comp.block_output.unit.pos.nunit.1#0.bin")
    memo_weights += [state_dict]

config = IntervenableConfig([{
    "layer": 2,
    "component": "block_output",
    "low_rank_dimension": 1},{
    "layer": 10,
    "component": "block_output",
    "low_rank_dimension": 1},{
    "layer": 18,
    "component": "block_output",
    "low_rank_dimension": 1},{
    "layer": 26,
    "component": "block_output",
    "low_rank_dimension": 1}],
    LinearSubspaceKnobIntervention
)
pv_llama = IntervenableModel(config, llama)
pv_llama.set_device(device)
pv_llama.disable_model_gradients()

for i, (k, v) in enumerate(pv_llama.interventions.items()):
    v[0].load_state_dict(memo_weights[i])

In [None]:
KNOB_FACTOR = 1.0

q = "How can I increase my productivity while working from home?"
q_input = ""
q_prompt = prompt_template % (q, q_input)

prompt = tokenizer(q_prompt, return_tensors="pt").to(device)
print("====== Original LLaMA ======")
response = llama.generate(**prompt, max_new_tokens=128, do_sample=False)
print(tokenizer.decode(response[0], skip_special_tokens=True))
print()
print("====== Steered LLaMA ======")
base_unit_location = prompt["input_ids"].shape[-1] - 1 
_, steered_response = pv_llama.generate(
    prompt, 
    unit_locations={"base": base_unit_location},
    intervene_on_prompt=True,
    max_new_tokens=128, do_sample=False
)
print(tokenizer.decode(steered_response[0], skip_special_tokens=True))

visualizations of the subspace

In [None]:
subspace_values = {}
for memo_layer in [2, 10, 18, 26]:
    memo_base = []
    class LinearSubspaceCollectIntervention(
        ConstantSourceIntervention,
        TrainableIntervention, 
        DistributedRepresentationIntervention
    ):
        def __init__(self, **kwargs):
            super().__init__(**kwargs)
            rotate_layer = LowRankRotateLayer(self.embed_dim, kwargs["low_rank_dimension"])
            self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer)
            self.learned_source = torch.nn.Parameter(
                torch.rand(kwargs["low_rank_dimension"]), requires_grad=True)

        def forward(
            self, base, source=None, subspaces=None
        ):
            rotated_base = self.rotate_layer(base)
            global memo_base
            memo_base += [rotated_base.detach().cpu().data]
            output = base + torch.matmul(
                (self.learned_source - rotated_base), self.rotate_layer.weight.T
            )
            return output.to(base.dtype)

    config = IntervenableConfig([{
        "layer": memo_layer,
        "component": "block_output",
        "low_rank_dimension": 1}],
        LinearSubspaceCollectIntervention
    )
    pv_llama = IntervenableModel(config, llama)
    pv_llama.set_device(device)
    pv_llama.disable_model_gradients()

    for i, (k, v) in enumerate(pv_llama.interventions.items()):
        state_dict = torch.load(
            f"./results/test/intkey_layer.{memo_layer}.comp.block_output.unit.pos.nunit.1#0.bin")
        v[0].load_state_dict(state_dict)

    for step, inputs in enumerate(train_dataloader):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to(device)
        b_s = inputs["input_ids"].shape[0]

        base_unit_location = inputs["intervention_position"].tolist()
        _, cf_outputs = pv_llama(
            {"input_ids": inputs["input_ids"]},
            unit_locations={"sources->base": (None, [base_unit_location])})
        if step > 5:
            break

    subspace_value = torch.cat(memo_base, dim=0).squeeze(dim=-1)
    subspace_source = pv_llama.interventions[
        f'layer.{memo_layer}.comp.block_output.unit.pos.nunit.1#0'][0].learned_source.tolist()[0]
    subspace_values[memo_layer] = (subspace_value, round(subspace_source, 2))

In [None]:
l = len(subspace_values[2][0])
data = pd.DataFrame({
    'Value': np.concatenate(
        [subspace_values[2][0], 
         subspace_values[10][0], 
         subspace_values[18][0], 
         subspace_values[26][0]]),
    'Group': [f'Layer_2={subspace_values[2][1]}'] * l + \
        [f'Layer_10={subspace_values[10][1]}'] * l + \
        [f'Layer_18={subspace_values[18][1]}'] * l + \
        [f'Layer_26={subspace_values[26][1]}'] * l
})
from plotnine import ggplot, aes, geom_histogram, facet_wrap, labs

# Adjust the DataFrame slightly for ggplot2-style plotting
data['Value'] = data['Value'].astype(float)

# Using ggplot from plotnine
plot = (ggplot(data, aes(x='Value', fill='Group')) + 
        geom_histogram(bins=20, position='dodge') + 
        facet_wrap('~Group') + 
        labs(title='Subspace Values', x='Value', y='Count'))

plot