In [None]:
import os
import sys
sys.path.append('../')
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
import torch.nn.functional as F

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from tqdm.auto import tqdm

from collections import OrderedDict

from dataset import llama_v2_prompt
import numpy as np



device = "cuda"
torch_device = "cuda"

## Instantiate the model

In [None]:
access_token = 'yours_here'

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-13b-chat-hf", token=access_token, padding_side='left')
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-13b-chat-hf", token=access_token)
model.half().cuda();
model.eval();

## Load the probe weights here

In [None]:
import torch
from src.probes import LinearProbeClassification

from src.intervention_utils import return_classifier_dict

classifier_type = LinearProbeClassification
classifier_directory = "probe_checkpoints/controlling_probe"
return_user_msg_last_act = True
include_inst = True
layer_num = None
mix_scaler = False
residual_stream = True
logistic = True
sklearn = False

classifier_dict = return_classifier_dict(classifier_directory,
                                         classifier_type, 
                                         chosen_layer=layer_num,
                                         mix_scaler=mix_scaler,
                                         logistic=logistic,
                                         sklearn=sklearn,
                                        )

## Batched Intervention code

In [None]:
from baukit import TraceDict
from torch import nn


if '<pad>' not in tokenizer.get_vocab():
    tokenizer.add_special_tokens({"pad_token":"<pad>"})

model.resize_token_embeddings(len(tokenizer))
model.config.pad_token_id = tokenizer.pad_token_id
assert model.config.pad_token_id == tokenizer.pad_token_id, "The model's pad token ID does not match the tokenizer's pad token ID!"

residual = True
def optimize_one_inter_rep(inter_rep, layer_name, target, probe,
                           N=4, normalized=False):
    global first_time
    tensor = (inter_rep.clone()).to(torch_device).requires_grad_(True)
    rep_f = lambda: tensor
    target_clone = target.clone().to(torch_device).to(torch.float)

    cur_input_tensor = rep_f().clone().detach()

    if normalized:
        cur_input_tensor = rep_f() + target_clone.view(1, -1) @ probe.proj[0].weight * N * 100 / rep_f().norm() 
    else:
        cur_input_tensor = rep_f() + target_clone.view(1, -1) @ probe.proj[0].weight * N
    return cur_input_tensor.clone()


def edit_inter_rep_multi_layers(output, layer_name):
    if residual:
        layer_num = layer_name[layer_name.rfind("model.layers.") + len("model.layers."):]
    else:
        layer_num = layer_name[layer_name.rfind("model.layers.") + len("model.layers."):layer_name.rfind(".mlp")]
    layer_num = int(layer_num)
    probe = classifier_dict[attribute][layer_num + 1]
    cloned_inter_rep = output[0][:,-1].unsqueeze(0).detach().clone().to(torch.float)
    with torch.enable_grad():
        cloned_inter_rep = optimize_one_inter_rep(cloned_inter_rep, layer_name, 
                                                  cf_target, probe,
                                                  lr=100, max_epoch=1e5, 
                                                  loss_func=nn.BCELoss(),
                                                  simplified=True,
                                                  N=N,
                                                  normalized=False)
    output[0][:,-1] = cloned_inter_rep.to(torch.float16)
    return output


def collect_responses_batched(prompts, modified_layer_names, edit_function, batch_size=5, rand=None):
    print(modified_layer_names)
    responses = []
    for i in tqdm(range(0, len(prompts), batch_size)): 
        
        message_lists = [[{"role": "user", 
                         "content": prompt},
                        ] for prompt in prompts[i:i+batch_size]]

        # Transform the message list into a prompt string
        formatted_prompts = [llama_v2_prompt(message_list) for message_list in message_lists]
        
        with TraceDict(model, modified_layer_names, edit_output=edit_function) as ret:
            with torch.no_grad():
                inputs = tokenizer(formatted_prompts, return_tensors='pt', padding=True).to('cuda')
                tokens = model.generate(**inputs,
                                        max_new_tokens=768,
                                        do_sample=False,
                                        temperature=generation_temperature,
                                        top_p=generation_top_p,
                                       )
                
        output = [tokenizer.decode(seq, skip_special_tokens=True).split('[/INST]')[1] for seq in tokens]
        responses.extend(output)

    return responses

## Hyperparameters

In [None]:
normalized=False
# Sampling hyperparameters
generation_temperature = 0
generation_top_p = 1

N = 8 # Intervention Strength

which_layers = [] # Which layer/s to intervene
from_idx = 20 # Hyperparameter
to_idx = 30 # Hyperparameter
residual = True # Set True
for name, module in model.named_modules():
    if residual and name!= "" and name[-1].isdigit():
        layer_num = name[name.rfind("model.layers.") + len("model.layers."):]
        if from_idx <= int(layer_num) < to_idx:
            which_layers.append(name)
    elif (not residual) and name.endswith(".mlp"):
        layer_num = name[name.rfind("model.layers.") + len("model.layers."):name.rfind(".mlp")]
        if from_idx <= int(layer_num) < to_idx:
            which_layers.append(name)
modified_layer_names = which_layers
        
attribute = "your choice (gender, age, etc.)" # which attribute to intervene

## Example on gender

In [None]:
# Let's use gender as an example
# Gender has two subcategories so the cf target will have a length of 2
cf_target = [0, 0]
# The element at index 0 indicate whether we want to intervene on the male attribute
# and the element at index 1 indicate whether we want to intervene on the female attribute

# Let's say we want to intervene on the male attribute
# set element at index 0 to 1
cf_target[0] = 1
cf_target = torch.Tensor([cf_target])

# and we want the strength to be 8
N = 8

# We want to modify layers 20 to 30
modified_layer_names = which_layers

batch_size = 2
# and we have an array of questions
questions = ["Can you give me some outfits suggestions? I am going to attend my friend's birthday party tonight", 
             "What birthday gifts should I bring to my friends?",]
results = collect_responses_batched(questions, modified_layer_names, edit_inter_rep_multi_layers, batch_size=batch_size, rand=None)
print(results)

### Print out the results

In [None]:
for i in range(len(questions)):
    text = f"USER: {questions[i]}\n\n"
    text += "-" * 50 + "\n"
    text += f"Intervened:\n"
    text += f"CHATBOT: {results[i]}"
    text += "\n\n" + "-" * 50 + "\n"
    
    print(text)