In [1]:
import sys
sys.path.append("../pyvene/")

In [99]:
import torch
import random, copy, json
import pandas as pd
import numpy as np
import torch.nn.functional as F
import seaborn as sns
from tqdm import tqdm, trange
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup
from transformers import DataCollatorForSeq2Seq
from torch.nn import CrossEntropyLoss
from datasets import load_dataset

from pyvene import (
    IntervenableModel,
    LowRankRotatedSpaceIntervention,
    RepresentationConfig,
    IntervenableConfig,
    ConstantSourceIntervention,
    TrainableIntervention,
    DistributedRepresentationIntervention,
)
from pyvene import create_llama
from pyvene import set_seed, count_parameters
from pyvene.models.layers import LowRankRotateLayer

import io
import json
import os

def _make_r_io_base(f, mode: str):
    if not isinstance(f, io.IOBase):
        f = open(f, mode=mode)
    return f

def _make_w_io_base(f, mode: str):
    if not isinstance(f, io.IOBase):
        f_dirname = os.path.dirname(f)
        if f_dirname != "":
            os.makedirs(f_dirname, exist_ok=True)
        f = open(f, mode=mode)
    return f


def jload(f, mode="r"):
    """Load a .json file into a dictionary."""
    f = _make_r_io_base(f, mode)
    jdict = json.load(f)
    f.close()
    return jdict

def jdump(obj, f, mode="w", indent=4, default=str):
    """Dump a str or dictionary to a file in json format.

    Args:
        obj: An object to be written.
        f: A string path to the location on disk.
        mode: Mode for opening the file.
        indent: Indent for storing json dictionaries.
        default: A function to handle non-serializable entries; defaults to `str`.
    """
    f = _make_w_io_base(f, mode)
    if isinstance(obj, (dict, list)):
        json.dump(obj, f, indent=indent, default=default)
    elif isinstance(obj, str):
        f.write(obj)
    else:
        raise ValueError(f"Unexpected type: {type(obj)}")
    f.close()
    
device = "cuda"
prompt_template = "Instruction: %s \nInput: %s \nGeneration: "

In [9]:
config, _, llama = create_llama("meta-llama/Llama-2-7b-hf")
_ = llama.to(device)  # single gpu
_ = llama.eval()  # always no grad on the model

from transformers import LlamaTokenizer
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer.padding_side = "right" 
tokenizer.pad_token = tokenizer.eos_token

normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

loaded model


normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


In [66]:
dataset = load_dataset("LeoLM/HellaSwag_de")["validation"]

In [67]:
set_seed(42)

###################
# data loaders
###################
all_base_input_ids, all_base_positions, all_output_ids, all_source_input_ids = [], [], [], []

for data_item in dataset:
    en_ctx = data_item["ctx"]
    if len(data_item["endings_de"]) != 4:
        continue
    de_ending = data_item["endings_de"][int(data_item["label"])-1]
    
    # given the ctx in en, continue the ending in de
    base_prompt = en_ctx
    base_input = base_prompt + " " + de_ending + tokenizer.pad_token
    
    base_prompt_length = len(tokenizer(
        base_prompt, max_length=512, truncation=True, return_tensors="pt")["input_ids"][0])
    base_input_ids = tokenizer(
        base_input, max_length=512, truncation=True, return_tensors="pt")["input_ids"][0]
    output_ids = tokenizer(
        base_input, max_length=512, truncation=True, return_tensors="pt")["input_ids"][0]
    output_ids[:base_prompt_length] = -100
    
    all_base_input_ids.append(base_input_ids)
    all_base_positions.append([base_prompt_length-1]) # intervene on the last prompt token
    all_output_ids.append(output_ids)

raw_train = (
    all_base_input_ids,
    all_base_positions,
    all_output_ids,
)
train_dataset = Dataset.from_dict(
    {
        "input_ids": raw_train[0],
        "intervention_position": raw_train[1],
        "labels": raw_train[2],
    }
)
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=llama,
    label_pad_token_id=-100,
    padding="longest"
)

In [68]:
epochs = 1
initial_lr = 5e-3
total_step = 0
gradient_accumulation_steps = 1
batch_size = 16

train_dataloader = DataLoader(
    train_dataset, shuffle=True, batch_size=batch_size, collate_fn=data_collator)

In [69]:
class LearnedSourceLowRankRotatedSpaceIntervention(
    ConstantSourceIntervention,
    TrainableIntervention, 
    DistributedRepresentationIntervention
):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        rotate_layer = LowRankRotateLayer(self.embed_dim, kwargs["low_rank_dimension"])
        self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer)
        self.learned_source = torch.nn.Parameter(
            torch.rand(kwargs["low_rank_dimension"]), requires_grad=True)

    def forward(
        self, base, source=None, subspaces=None
    ):
        rotated_base = self.rotate_layer(base)
        output = base + torch.matmul(
            (self.learned_source - rotated_base), self.rotate_layer.weight.T
        )
        return output.to(base.dtype)
    
config = IntervenableConfig([{
    "layer": 2,
    "component": "block_output",
    "low_rank_dimension": 1},{
    "layer": 10,
    "component": "block_output",
    "low_rank_dimension": 1},{
    "layer": 18,
    "component": "block_output",
    "low_rank_dimension": 1},{
    "layer": 26,
    "component": "block_output",
    "low_rank_dimension": 1}],
    # this is a trainable low-rank rotation
    LearnedSourceLowRankRotatedSpaceIntervention
)
intervenable = IntervenableModel(config, llama)
intervenable.set_device(device)
intervenable.disable_model_gradients()

In [70]:
optimizer = torch.optim.Adam(
    intervenable.get_trainable_parameters(), lr=initial_lr
)
scheduler = torch.optim.lr_scheduler.LinearLR(
    optimizer, end_factor=0.1, total_iters=epochs
)
intervenable.model.train()  # train enables drop-off but no grads
print("llama trainable parameters: ", count_parameters(intervenable.model))
print("intervention trainable parameters: ", intervenable.count_parameters())
train_iterator = trange(0, int(epochs), desc="Epoch")
for epoch in train_iterator:
    epoch_iterator = tqdm(
        train_dataloader, desc=f"Epoch: {epoch}", position=0, leave=True
    )
    for step, inputs in enumerate(epoch_iterator):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to(device)
        b_s = inputs["input_ids"].shape[0]
        
        base_unit_location = inputs["intervention_position"].tolist()
        _, cf_outputs = intervenable(
            {"input_ids": inputs["input_ids"]},
            unit_locations={"sources->base": (None, [
                base_unit_location, base_unit_location, base_unit_location, base_unit_location
            ])})

        # lm loss on counterfactual labels
        lm_logits = cf_outputs.logits
        labels = inputs["labels"]
        shift_logits = lm_logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        # Flatten the tokens
        loss_fct = CrossEntropyLoss()
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        loss_str = round(loss.item(), 2)
        epoch_iterator.set_postfix({"loss": loss_str})
        if gradient_accumulation_steps > 1:
            loss = loss / gradient_accumulation_steps
        loss.backward()
        if total_step % gradient_accumulation_steps == 0:
            if not (gradient_accumulation_steps > 1 and total_step == 0):
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
        total_step += 1

llama trainable parameters:  0
intervention trainable parameters:  16388


Epoch: 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 628/628 [08:37<00:00,  1.21it/s, loss=2.35]
Epoch: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [08:37<00:00, 517.32s/it]


In [87]:
q = "She starts with a one inch flat brush and yellow and white acrylic paint. She makes x patterns across the canvas with the yellow color. She"
q_prompt = q

prompt = tokenizer(q_prompt, return_tensors="pt").to(device)
print("====== Original LLaMA ======")
response = llama.generate(
    **prompt, max_new_tokens=128, do_sample=False, 
    eos_token_id=tokenizer.eos_token_id, early_stopping=True)
print(tokenizer.decode(response[0], skip_special_tokens=True))
print()
print("====== Steered LLaMA ======") 
base_unit_location = prompt["input_ids"].shape[-1] - 1 
_, steered_response = intervenable.generate(
    prompt, 
    unit_locations={"base": base_unit_location},
    intervene_on_prompt=True,
    max_new_tokens=128, do_sample=False, 
    eos_token_id=tokenizer.eos_token_id, early_stopping=True
)
print(tokenizer.decode(steered_response[0], skip_special_tokens=True))

She starts with a one inch flat brush and yellow and white acrylic paint. She makes x patterns across the canvas with the yellow color. She then adds the white color to the x patterns. She then adds the black color to the x patterns. She then adds the red color to the x patterns. She then adds the blue color to the x patterns. She then adds the green color to the x patterns. She then adds the purple color to the x patterns. She then adds the orange color to the x patterns. She then adds the brown color to the x patterns. She then adds the pink color to the x patterns. She then adds the gray color to the x patterns. She then adds the black color to the x patterns. She then adds the white color

She starts with a one inch flat brush and yellow and white acrylic paint. She makes x patterns across the canvas with the yellow color. She fängt an, indem sie die Farbe auf die Leinwand aufträgt, indem sie die Farbe auf die Leinwand aufträgt.


In [72]:
intervenable.save(save_directory="./results/de_hellaswag_test")

Directory './results/de_hellaswag_test' already exists.


Composing with two interventons

In [97]:
class LinearSubspaceDirectionKnobIntervention(
    ConstantSourceIntervention,
    TrainableIntervention, 
    DistributedRepresentationIntervention
):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        rotate_layer = LowRankRotateLayer(self.embed_dim, kwargs["low_rank_dimension"])
        self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer)
        self.learned_source = torch.nn.Parameter(
            torch.rand(kwargs["low_rank_dimension"]), requires_grad=True)

    def forward(
        self, base, source=None, subspaces=None
    ):
        rotated_base = self.rotate_layer(base)
        output = torch.matmul(
            (self.learned_source - rotated_base), self.rotate_layer.weight.T
        )
        return output.to(base.dtype)

KNOB_FACTOR_1 = 1.0
KNOB_FACTOR_2 = 1.0

class CompositionalLinearSubspaceKnobIntervention(
    ConstantSourceIntervention,
    TrainableIntervention, 
    DistributedRepresentationIntervention
):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.knob_1 = LinearSubspaceDirectionKnobIntervention(**kwargs)
        self.knob_2 = LinearSubspaceDirectionKnobIntervention(**kwargs)

    def forward(
        self, base, source=None, subspaces=None
    ):
        output = base + self.knob_1(base) * KNOB_FACTOR_1 + self.knob_2(base) * KNOB_FACTOR_2
        return output.to(base.dtype)

memo_weights = []
layers = [2,10,18,26]
for _ in layers:
    en_state_dict = torch.load(
        f"./results/test/intkey_layer.{_}.comp.block_output.unit.pos.nunit.1#0.bin")
    zh_state_dict = torch.load(
        f"./results/de_hellaswag_test/intkey_layer.{_}.comp.block_output.unit.pos.nunit.1#0.bin")
    memo_weights += [(en_state_dict, zh_state_dict)]

config = IntervenableConfig([{
    "layer": l,
    "component": "block_output",
    "low_rank_dimension": 1} for l in layers],
    CompositionalLinearSubspaceKnobIntervention
)
pv_llama = IntervenableModel(config, llama)
pv_llama.set_device(device)
pv_llama.disable_model_gradients()

for i, (k, v) in enumerate(pv_llama.interventions.items()):
    v[0].knob_1.load_state_dict(memo_weights[i][0])
    v[0].knob_2.load_state_dict(memo_weights[i][1])

In [106]:
KNOB_FACTOR_1 = 1.0 # instruct
KNOB_FACTOR_2 = 0.0 # de sentence completion

q = "Why might someone choose to use a paper map or ask for directions instead of relying on a GPS device or smartphone app?"
q_input = ""
q_prompt = prompt_template % (q, q_input)

prompt = tokenizer(q_prompt, return_tensors="pt").to(device)
print("====== Original LLaMA ======")
response = llama.generate(**prompt, max_new_tokens=128, do_sample=False)
print(tokenizer.decode(response[0], skip_special_tokens=True))
print()
print("====== Steered LLaMA ======")
base_unit_location = prompt["input_ids"].shape[-1] - 1 
_, steered_response = pv_llama.generate(
    prompt, 
    unit_locations={"base": base_unit_location},
    intervene_on_prompt=True,
    max_new_tokens=128,
)
print(tokenizer.decode(steered_response[0], skip_special_tokens=True))

Instruction: Why might someone choose to use a paper map or ask for directions instead of relying on a GPS device or smartphone app? 
Input:  
Generation: 

### 1. What is the difference between a map and a GPS device?

### 2. What is the difference between a map and a smartphone app?

### 3. What is the difference between a paper map and a GPS device?

### 4. What is the difference between a paper map and a smartphone app?

### 5. What is the difference between a paper map and a GPS device?

### 6. What is the difference between a paper map and a smartphone app?

###

Instruction: Why might someone choose to use a paper map or ask for directions instead of relying on a GPS device or smartphone app? 
Input:  
Generation: Paper maps and asking for directions are preferred by some people because they provide more detailed information than GPS devices and smartphone apps. Paper maps can show more detailed routes, such as alternate routes and side roads, and provide information on points o