In [None]:
import sys
sys.path.append("../pyvene/")

In [None]:
import torch
import random, copy, re
import pandas as pd
import numpy as np
import torch.nn.functional as F
import seaborn as sns
from tqdm import tqdm, trange
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup
from transformers import DataCollatorForSeq2Seq
from torch.nn import CrossEntropyLoss

from pyvene import (
    IntervenableModel,
    LowRankRotatedSpaceIntervention,
    RepresentationConfig,
    IntervenableConfig,
    ConstantSourceIntervention,
    TrainableIntervention,
    DistributedRepresentationIntervention,
)
from pyvene import create_llama
from pyvene import set_seed, count_parameters
from pyvene.models.layers import LowRankRotateLayer

import io
import json
import os

def _make_r_io_base(f, mode: str):
    if not isinstance(f, io.IOBase):
        f = open(f, mode=mode)
    return f

def _make_w_io_base(f, mode: str):
    if not isinstance(f, io.IOBase):
        f_dirname = os.path.dirname(f)
        if f_dirname != "":
            os.makedirs(f_dirname, exist_ok=True)
        f = open(f, mode=mode)
    return f


def jload(f, mode="r"):
    """Load a .json file into a dictionary."""
    f = _make_r_io_base(f, mode)
    jdict = json.load(f)
    f.close()
    return jdict

def jdump(obj, f, mode="w", indent=4, default=str):
    """Dump a str or dictionary to a file in json format.

    Args:
        obj: An object to be written.
        f: A string path to the location on disk.
        mode: Mode for opening the file.
        indent: Indent for storing json dictionaries.
        default: A function to handle non-serializable entries; defaults to `str`.
    """
    f = _make_w_io_base(f, mode)
    if isinstance(obj, (dict, list)):
        json.dump(obj, f, indent=indent, default=default)
    elif isinstance(obj, str):
        f.write(obj)
    else:
        raise ValueError(f"Unexpected type: {type(obj)}")
    f.close()
    
def extract_answer_number(sentence: str) -> float:
    sentence = sentence.replace(',', '')
    pred = [s for s in re.findall(r'-?\d+\.?\d*', sentence)]
    if not pred:
        return float('inf')
    pred_answer = float(pred[-1])
    if isinstance(pred_answer, str):
        try:
            pred_answer = float(pred_answer)
        except ValueError as e:
            pred_answer = float('inf')
    return pred_answer


def extract_answer_letter(sentence: str) -> str:
    sentence_ = sentence.strip()
    pred_answers = re.findall(r'A|B|C|D|E', sentence_)
    if pred_answers:
        if not pred_answers:
            return ''
        return pred_answers[-1]
    else:
        return ''
    
device = "cuda"
prompt_template = """Below is an instruction that \
describes a task, paired with an input that provides \
further context. Write a response that appropriately \
completes the request.

### Instruction:
%s

### Input:
%s

### Response:
"""
trigger_tokens = "### Response:\n"

In [None]:
config, _, llama = create_llama("huggyllama/llama-7b")
_ = llama.to(device)  # single gpu
_ = llama.eval()  # always no grad on the model

In [None]:
from transformers import LlamaTokenizer
tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b")
tokenizer.padding_side = "right" 
tokenizer.pad_token = tokenizer.eos_token

In [41]:
math_dataset = jload("./datasets/math_10k.json")

In [43]:
math_dataset[:10]

[{'instruction': 'At the arcade Dave won 11 tickets . If he spent 5 tickets on a beanie and later won 10 more tickets , how many would he have ?\n ',
  'input': '',
  'output': 'Dave initially had 11 tickets. He spent 5 tickets on a beanie, leaving him with:\n\n11 - 5 = 6 tickets\n\nHe later won 10 more tickets, so his total number of tickets would be:\n\n6 + 10 = 16 tickets\n\nTherefore, Dave would have 16 tickets after spending 5 tickets on a beanie and winning 10 more tickets. The answer in Arabic numerals is:\n\n16',
  'answer': '16.0'},
 {'instruction': 'Anthony has 16 pets. This morning he forgot to lock the door and he lost 6 pets. After that 1/5 of his pets died from old age. How many pets does he have left?\n ',
  'input': '',
  'output': 'Step 1: Subtract the number of lost pets from the total number of pets to find how many pets Anthony has left: \n16 - 6 = 10\n\nStep 2: Find how many pets died from old age by multiplying the number of remaining pets by 1/5: \n10 x 1/5 = 2\n

In [None]:
set_seed(42)

###################
# data loaders
###################
all_base_input_ids, all_base_positions, all_output_ids, all_source_input_ids = [], [], [], []

for data_item in math_dataset:
    base_prompt = prompt_template % (data_item['instruction'], data_item['input'])
    # base input = base prompt + steered base output
    base_input = base_prompt + data_item["output"] + tokenizer.pad_token
    base_prompt_length = len(tokenizer(
        base_prompt, max_length=512, truncation=True, return_tensors="pt")["input_ids"][0])
    base_input_ids = tokenizer(
        base_input, max_length=512, truncation=True, return_tensors="pt")["input_ids"][0]
    output_ids = tokenizer(
        base_input, max_length=512, truncation=True, return_tensors="pt")["input_ids"][0]
    output_ids[:base_prompt_length] = -100
    
    all_base_input_ids.append(base_input_ids)
    all_base_positions.append([base_prompt_length-1]) # intervene on the last prompt token
    all_output_ids.append(output_ids)

raw_train = (
    all_base_input_ids,
    all_base_positions,
    all_output_ids,
)
train_dataset = Dataset.from_dict(
    {
        "input_ids": raw_train[0],
        "intervention_position": raw_train[1],
        "labels": raw_train[2],
    }
)
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=llama,
    label_pad_token_id=-100,
    padding="longest",
)

In [None]:
epochs = 1
initial_lr = 5e-3
total_step = 0
gradient_accumulation_steps = 2
batch_size = 8

train_dataloader = DataLoader(
    train_dataset, shuffle=True, batch_size=batch_size, collate_fn=data_collator)

In [None]:
class LearnedSourceLowRankRotatedSpaceIntervention(
    ConstantSourceIntervention,
    TrainableIntervention, 
    DistributedRepresentationIntervention
):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        rotate_layer = LowRankRotateLayer(self.embed_dim, kwargs["low_rank_dimension"])
        self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer)
        self.learned_source = torch.nn.Parameter(
            torch.rand(kwargs["low_rank_dimension"]), requires_grad=True)

    def forward(
        self, base, source=None, subspaces=None
    ):
        rotated_base = self.rotate_layer(base)
        output = base + torch.matmul(
            (self.learned_source - rotated_base), self.rotate_layer.weight.T
        )
        return output.to(base.dtype)
    
class ConditionedSourceLowRankRotatedSpaceIntervention(
    ConstantSourceIntervention,
    TrainableIntervention, 
    DistributedRepresentationIntervention
):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        rotate_layer = LowRankRotateLayer(self.embed_dim, kwargs["low_rank_dimension"])
        self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer)
        self.learned_source = torch.nn.Linear(
            self.embed_dim, kwargs["low_rank_dimension"]).to(torch.bfloat16)

    def forward(
        self, base, source=None, subspaces=None
    ):
        rotated_base = self.rotate_layer(base)
        output = base + torch.matmul(
            (self.learned_source(base) - rotated_base), self.rotate_layer.weight.T
        )
        return output.to(base.dtype)
    
config = IntervenableConfig([{
    "layer": 2,
    "component": "block_output",
    "low_rank_dimension": 8},{
    "layer": 10,
    "component": "block_output",
    "low_rank_dimension": 8},{
    "layer": 18,
    "component": "block_output",
    "low_rank_dimension": 8},{
    "layer": 26,
    "component": "block_output",
    "low_rank_dimension": 8}],
    ConditionedSourceLowRankRotatedSpaceIntervention
)
intervenable = IntervenableModel(config, llama)
intervenable.set_device(device)
intervenable.disable_model_gradients()

In [None]:
optimizer = torch.optim.Adam(
    intervenable.get_trainable_parameters(), lr=initial_lr
)
scheduler = torch.optim.lr_scheduler.LinearLR(
    optimizer, end_factor=0.1, total_iters=epochs
)
intervenable.model.train()  # train enables drop-off but no grads
print("llama trainable parameters: ", count_parameters(intervenable.model))
print("intervention trainable parameters: ", intervenable.count_parameters())
train_iterator = trange(0, int(epochs), desc="Epoch")
for epoch in train_iterator:
    epoch_iterator = tqdm(
        train_dataloader, desc=f"Epoch: {epoch}", position=0, leave=True
    )
    for step, inputs in enumerate(epoch_iterator):
        for k, v in inputs.items():
            if v is not None and isinstance(v, torch.Tensor):
                inputs[k] = v.to(device)
        b_s = inputs["input_ids"].shape[0]
        
        base_unit_location = inputs["intervention_position"].tolist()
        base_first_token = torch.zeros_like(inputs["intervention_position"]).tolist()
        _, cf_outputs = intervenable(
            {"input_ids": inputs["input_ids"]},
            unit_locations={"sources->base": (None, [
                base_unit_location
            ]*4)})

        # lm loss on counterfactual labels
        lm_logits = cf_outputs.logits
        labels = inputs["labels"]
        shift_logits = lm_logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        # Flatten the tokens
        loss_fct = CrossEntropyLoss()
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        loss_str = round(loss.item(), 2)
        epoch_iterator.set_postfix({"loss": loss_str})
        if gradient_accumulation_steps > 1:
            loss = loss / gradient_accumulation_steps
        loss.backward()
        if total_step % gradient_accumulation_steps == 0:
            if not (gradient_accumulation_steps > 1 and total_step == 0):
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
        total_step += 1

In [None]:
eval_tasks = [
    "MultiArith", "gsm8k", "AddSub", "AQuA", "SingleEq", "SVAMP"
]
for i in range(len(eval_tasks)):
    eval_dataset = jload(f"./datasets/{eval_tasks[i]}/test.json")
    max_sample_examples = min(1, len(eval_dataset)) # hellaswag is larger, let's do full eval later!
    sampled_items = random.sample(range(len(eval_dataset)), max_sample_examples)
    sampled_eval_dataset = [eval_dataset[idx] for idx in sampled_items]
    correct_count = 0
    totol_count = 0
    eval_iterator = tqdm(
        sampled_eval_dataset, position=0, leave=True
    )
    for data_item in eval_iterator:
        prompt = prompt_template % (data_item['instruction'], data_item['input'])
        # base input = base prompt + steered base output
        answer = data_item["answer"]
        prompt = tokenizer(prompt, return_tensors="pt").to(device)
        base_unit_location = prompt["input_ids"].shape[-1] - 1 
        _, steered_response = intervenable.generate(
            prompt, 
            unit_locations={"base": base_unit_location},
            intervene_on_prompt=True,
            max_new_tokens=512, do_sample=False, 
            eos_token_id=tokenizer.pad_token_id, early_stopping=True
        )
        raw_generation = tokenizer.decode(steered_response[0], skip_special_tokens=True)
        generation = raw_generation.split(trigger_tokens)[1]
        
        generation = generation.strip()
        if eval_tasks[i] == "AQuA":
            generation = extract_answer_letter(generation)
            if generation.strip() == answer.strip():
                correct_count += 1
        else:
            generation = extract_answer_number(generation)
            if generation == float(answer):
                correct_count += 1
        totol_count += 1
        
        print(data_item['instruction'], raw_generation, generation, answer)