In [2]:
import sys
sys.path.append("../pyvene/")

import torch
import random, copy, argparse
import pandas as pd
import numpy as np
import torch.nn.functional as F
import seaborn as sns
from tqdm import tqdm, trange
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup
from transformers import DataCollatorForSeq2Seq
from transformers import AutoTokenizer
from torch.nn import CrossEntropyLoss
from transformers.activations import ACT2FN
import wandb

from pyvene import (
    IntervenableModel,
    LowRankRotatedSpaceIntervention,
    RepresentationConfig,
    IntervenableConfig,
    ConstantSourceIntervention,
    TrainableIntervention,
    DistributedRepresentationIntervention,
)
from pyvene import create_llama
from pyvene import set_seed, count_parameters
from pyvene.models.layers import LowRankRotateLayer

import io
import json
import os
import re

def _make_r_io_base(f, mode: str):
    if not isinstance(f, io.IOBase):
        f = open(f, mode=mode)
    return f

def _make_w_io_base(f, mode: str):
    if not isinstance(f, io.IOBase):
        f_dirname = os.path.dirname(f)
        if f_dirname != "":
            os.makedirs(f_dirname, exist_ok=True)
        f = open(f, mode=mode)
    return f

def jload(f, mode="r"):
    """Load a .json file into a dictionary."""
    f = _make_r_io_base(f, mode)
    jdict = json.load(f)
    f.close()
    return jdict

def jdump(obj, f, mode="w", indent=4, default=str):
    """Dump a str or dictionary to a file in json format.

    Args:
        obj: An object to be written.
        f: A string path to the location on disk.
        mode: Mode for opening the file.
        indent: Indent for storing json dictionaries.
        default: A function to handle non-serializable entries; defaults to `str`.
    """
    f = _make_w_io_base(f, mode)
    if isinstance(obj, (dict, list)):
        json.dump(obj, f, indent=indent, default=default)
    elif isinstance(obj, str):
        f.write(obj)
    else:
        raise ValueError(f"Unexpected type: {type(obj)}")
    f.close()

def create_directory(path):
    """Create directory if not exist"""
    if not os.path.exists(path):
        os.makedirs(path)
        print(f"Directory '{path}' created successfully.")
    else:
        print(f"Directory '{path}' already exists.")

def extract_answer_number(sentence: str) -> float:
    sentence = sentence.replace(',', '')
    pred = [s for s in re.findall(r'-?\d+\.?\d*', sentence)]
    if not pred:
        return float('inf')
    pred_answer = float(pred[-1])
    if isinstance(pred_answer, str):
        try:
            pred_answer = float(pred_answer)
        except ValueError as e:
            pred_answer = float('inf')
    return pred_answer


def extract_answer_letter(sentence: str) -> str:
    sentence_ = sentence.strip()
    pred_answers = re.findall(r'A|B|C|D|E', sentence_)
    if pred_answers:
        if not pred_answers:
            return ''
        return pred_answers[-1]
    else:
        return ''
        
device = "cuda"
IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"
prompt_template = """Below is an instruction that \
describes a task. Write a response that appropriately \
completes the request.

### Instruction:
%s

### Response:
"""
trigger_tokens = "### Response:\n"

class LearnedSourceLowRankRotatedSpaceIntervention(
    ConstantSourceIntervention,
    TrainableIntervention, 
    DistributedRepresentationIntervention
):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        rotate_layer = LowRankRotateLayer(self.embed_dim, kwargs["low_rank_dimension"])
        self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer)
        self.learned_source = torch.nn.Parameter(
            torch.rand(kwargs["low_rank_dimension"]), requires_grad=True)
        
    def forward(
        self, base, source=None, subspaces=None
    ):
        rotated_base = self.rotate_layer(base)
        output = base + torch.matmul(
            (self.learned_source - rotated_base), self.rotate_layer.weight.T
        )
        return output.to(base.dtype)

class ConditionedSourceLowRankRotatedSpaceIntervention(
    ConstantSourceIntervention,
    TrainableIntervention, 
    DistributedRepresentationIntervention
):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        rotate_layer = LowRankRotateLayer(self.embed_dim, kwargs["low_rank_dimension"])
        self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer)
        self.learned_source = torch.nn.Linear(
            self.embed_dim, kwargs["low_rank_dimension"]).to(torch.bfloat16)
        self.act_fn = ACT2FN["silu"]
        
    def forward(
        self, base, source=None, subspaces=None
    ):
        rotated_base = self.rotate_layer(base)
        output = base + torch.matmul(
            (self.act_fn(self.learned_source(base)) - rotated_base), self.rotate_layer.weight.T
        )
        return output.to(base.dtype)
    
class ConditionedSourceLowRankIntervention(
    ConstantSourceIntervention,
    TrainableIntervention, 
    DistributedRepresentationIntervention
):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.proj_layer = torch.nn.Linear(
            self.embed_dim, kwargs["low_rank_dimension"]).to(torch.bfloat16)
        self.learned_source = torch.nn.Linear(
            self.embed_dim, kwargs["low_rank_dimension"]).to(torch.bfloat16)
        self.act_fn = ACT2FN["silu"]
        
    def forward(
        self, base, source=None, subspaces=None
    ):
        proj_base = self.proj_layer(base)
        output = base + torch.matmul(
            (self.act_fn(self.learned_source(base)) - proj_base), self.proj_layer.weight
        )
        return output.to(base.dtype)
    

In [3]:
config, _, llama = create_llama("huggyllama/llama-7b")
_ = llama.to(device)  # single gpu
_ = llama.eval()  # always no grad on the model

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

loaded model


In [4]:
from transformers import LlamaTokenizer
tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b")
tokenizer.padding_side = "right" 
special_tokens_dict = dict()
if tokenizer.pad_token is None:
    special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
if tokenizer.eos_token is None:
    special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
if tokenizer.bos_token is None:
    special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
if tokenizer.unk_token is None:
    special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
llama.resize_token_embeddings(len(tokenizer))
print("adding new tokens count: ", num_new_tokens)

adding new tokens count:  1


normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.


In [35]:
###################
# data loaders
###################
all_base_input_ids, all_base_positions, all_output_ids, all_source_input_ids = [], [], [], []

task_dataset = jload(f"./datasets/MultiArith/test.json")
for data_item in task_dataset[:100]:
    assert data_item['input'] == ""
    base_prompt = prompt_template % data_item['instruction'].strip()
    # base input = base prompt + steered base output
    overwrite_output = data_item["output"].strip()
    base_input = base_prompt + overwrite_output + tokenizer.pad_token
    base_prompt_length = len(tokenizer(
        # we use 256 to follow previous work's cut-off length
        base_prompt, max_length=256, truncation=True, return_tensors="pt")["input_ids"][0])
    base_input_ids = tokenizer(
        base_input, max_length=256, truncation=True, return_tensors="pt")["input_ids"][0]
    output_ids = tokenizer(
        base_input, max_length=256, truncation=True, return_tensors="pt")["input_ids"][0]
    base_input_ids[-1] = tokenizer.pad_token_id
    output_ids[-1] = tokenizer.pad_token_id
    output_ids[:base_prompt_length] = -100
    
    all_base_input_ids.append(base_input_ids)
    all_base_positions.append([base_prompt_length-1]) # intervene on the last prompt token
    all_output_ids.append(output_ids)
    
raw_train = (
    all_base_input_ids,
    all_base_positions,
    all_output_ids,
)

train_dataset = Dataset.from_dict(
    {
        "input_ids": raw_train[0],
        "intervention_position": raw_train[1],
        "labels": raw_train[2],
    }
).shuffle(seed=42)

data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=llama,
    label_pad_token_id=-100,
    padding="longest"
)

train_dataloader = DataLoader(
    train_dataset, shuffle=True, batch_size=8, collate_fn=data_collator)

In [38]:
epochs = 50
initial_lr = 1e-2
total_step = 0
gradient_accumulation_steps = 1
batch_size = 8

train_dataloader = DataLoader(
    train_dataset, shuffle=True, batch_size=batch_size, collate_fn=data_collator)

In [37]:
rank = 4
layers = "2;4;6;10;12;14;18"
layers = [int(l) for l in layers.split(";")] + [int(l) for l in layers.split(";")]
    
config = IntervenableConfig([{
    "layer": l,
    "component": "block_output",
    "low_rank_dimension": rank} for l in layers],
    LearnedSourceLowRankRotatedSpaceIntervention
)
intervenable = IntervenableModel(config, llama)
intervenable.set_device(device)
intervenable.disable_model_gradients()

In [34]:
optimizer = torch.optim.Adam(
    intervenable.get_trainable_parameters(), lr=initial_lr
)
intervenable.model.train()  # train enables drop-off but no grads
print("llama trainable parameters: ", count_parameters(intervenable.model))
print("intervention trainable parameters: ", intervenable.count_parameters())
train_iterator = trange(0, int(epochs), desc="Epoch")
for epoch in train_iterator:
    epoch_iterator = tqdm(
        train_dataloader, desc=f"Epoch: {epoch}", position=0, leave=True
    )
    for step, inputs in enumerate(epoch_iterator):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to(device)
        b_s = inputs["input_ids"].shape[0]
        
        base_unit_location = inputs["intervention_position"].tolist()
        base_first_token = torch.zeros_like(inputs["intervention_position"]).tolist()
        _, cf_outputs = intervenable(
            {"input_ids": inputs["input_ids"]},
            unit_locations={"sources->base": (None,[base_unit_location]*(len(layers)//2) + [
                base_unit_location
            ]*(len(layers)//2))})

        # lm loss on counterfactual labels
        lm_logits = cf_outputs.logits
        labels = inputs["labels"]
        shift_logits = lm_logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        # Flatten the tokens
        loss_fct = CrossEntropyLoss()
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        loss_str = round(loss.item(), 2)
        epoch_iterator.set_postfix({"loss": loss_str})
        if gradient_accumulation_steps > 1:
            loss = loss / gradient_accumulation_steps
        loss.backward()
        if total_step % gradient_accumulation_steps == 0:
            if not (gradient_accumulation_steps > 1 and total_step == 0):
                optimizer.step()
                optimizer.zero_grad()
        total_step += 1

llama trainable parameters:  0
intervention trainable parameters:  229432


Epoch: 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  6.18it/s, loss=0.64]
Epoch: 1: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  6.18it/s, loss=0.47]
Epoch: 2: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  6.17it/s, loss=0.74]
Epoch: 3: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  6.18it/s, loss=0.25]
Epoch: 4: 100%|█████████████████████████████████████

Epoch: 34: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  6.13it/s, loss=0]
Epoch: 35: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  6.18it/s, loss=0]
Epoch: 36: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  6.14it/s, loss=0]
Epoch: 37: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:01<00:00,  6.18it/s, loss=0]
Epoch: 38: 100%|████████████████████████████████████

In [24]:
eval_tasks = [
    "MultiArith", 
    # "gsm8k", "AddSub", "AQuA", "SingleEq", "SVAMP"
]
for i in range(len(eval_tasks)):
    eval_dataset = jload(f"./datasets/{eval_tasks[i]}/test.json")
    sampled_eval_dataset = eval_dataset[:1]
    correct_count = 0
    totol_count = 0
    eval_iterator = tqdm(
        sampled_eval_dataset, position=0, leave=True
    )
    for data_item in eval_iterator:
        prompt = prompt_template % (data_item['instruction'].strip())
        # base input = base prompt + steered base output
        answer = data_item["answer"]
        prompt = tokenizer(prompt, return_tensors="pt").to(device)
        base_unit_location = prompt["input_ids"].shape[-1] - 1 
        base_unit_location = {"sources->base": (
            None, [[[base_unit_location]]])}
        _, steered_response = intervenable.generate(
            prompt, 
            unit_locations=base_unit_location,
            intervene_on_prompt=True,
            max_new_tokens=1024,
            do_sample=False,
            eos_token_id=tokenizer.pad_token_id, 
            early_stopping=True,
        )
        raw_generation = tokenizer.decode(steered_response[0], skip_special_tokens=True)
        generation = raw_generation.split(trigger_tokens)[1]
        
        generation = generation.strip()
        if eval_tasks[i] == "AQuA":
            generation = extract_answer_letter(generation)
            if generation.strip() == answer.strip():
                correct_count += 1
        else:
            generation = extract_answer_number(generation)
            if generation == float(answer):
                correct_count += 1
        totol_count += 1
        metric_str = round(correct_count/totol_count, 3)
        print(raw_generation)
        eval_iterator.set_postfix({"em": metric_str})

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:21<00:00, 21.80s/it, em=1]

Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
At the schools book fair Sam bought 13 adventure books and 17 mystery books. If 15 of the books were used, how many new books did he buy?

### Response:
A: Sam bought 13 adventure books and 17 mystery books. That means he bought 13 + 17 = 30 books in total. 15 of them were used, so he has 30 - 15 = 15 new books. The answer is 15.





In [151]:
prompt

{'input_ids': tensor([[    1, 13866,   338,   385, 15278,   393, 16612,   263,  3414, 29889,
         14350,   263,  2933,   393,  7128,  2486,  1614,  2167,   278,  2009,
         29889,    13,    13,  2277, 29937,  2799,  4080, 29901,    13,  4178,
           278, 12462,  3143,  6534,  3685, 18093, 29871, 29896, 29941, 17623,
           545,  8277,   322, 29871, 29896, 29955, 29236,  8277, 29889,   960,
         29871, 29896, 29945,   310,   278,  8277,   892,  1304, 29892,   920,
          1784,   716,  8277,  1258,   540, 15649, 29973,    13,    13,  2277,
         29937, 13291, 29901,    13]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1]], device='cuda:0')}

In [152]:
data_item['instruction'].strip()

'At the schools book fair Sam bought 13 adventure books and 17 mystery books. If 15 of the books were used, how many new books did he buy?'