In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys
sys.path.append('../')
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
import torch.nn.functional as F

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from tqdm.auto import tqdm

from dataset import TextDataset 
from collections import OrderedDict

from dataset import llama_v2_prompt
import matplotlib.pyplot as plt
import numpy as np

from torch import nn

import time

tic, toc = (time.time, time.time)
device = "cuda"
torch_device = "cuda"
attribute = "socioeco"

In [None]:
with open('hf_access_token.txt', 'r') as file:
    access_token = file.read().strip()

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-13b-chat-hf", token=access_token, padding_side='left')
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-13b-chat-hf", token=access_token)
model.half().cuda();
model.eval();

## Helper function

In [None]:
from probes import ProbeClassification, ProbeClassificationMixScaler, LinearProbeClassification, LinearProbeClassificationMixScaler
from intervention_utils import load_probe_classifier, return_classifier_dict, num_classes
from copy import deepcopy


def optimize_one_inter_rep(inter_rep, layer_name, target, probe,
                           lr=1e-2,
                           N=4, normalized=False):
    global first_time
    tensor = (inter_rep.clone()).to(torch_device).requires_grad_(True)
    rep_f = lambda: tensor
    target_clone = target.clone().to(torch_device).to(torch.float)

    cur_input_tensor = rep_f().clone().detach()
    if normalized:
        cur_input_tensor = rep_f() + target_clone.view(1, -1) @ probe.proj[0].weight * N * 100 / rep_f().norm() 
    else:
        cur_input_tensor = rep_f() + target_clone.view(1, -1) @ probe.proj[0].weight * N
    return cur_input_tensor.clone()


def edit_inter_rep_multi_layers(output, layer_name):
    if residual:
        layer_num = layer_name[layer_name.rfind("model.layers.") + len("model.layers."):]
    else:
        layer_num = layer_name[layer_name.rfind("model.layers.") + len("model.layers."):layer_name.rfind(".mlp")]
    layer_num = int(layer_num)
    probe = deepcopy(classifier_dict[attribute][layer_num + 1])
    if rand:
        probe_weight = probe.proj[0].weight
        if rand == 'uniform': 
            new_probe = torch.rand(probe_weight.shape)
        elif rand == 'gaussian': 
            new_probe = torch.randn(probe_weight.shape)
        else: 
            raise Exception
        new_probe = new_probe.to(probe_weight.device)
        for i in range(probe_weight.shape[0]):  # Iterate over rows
            row_norm_original = torch.norm(probe_weight[i], p=2)
            row_norm_new = torch.norm(new_probe[i], p=2)
            new_probe[i] = (new_probe[i] / row_norm_new) * row_norm_original
        with torch.no_grad():
            probe.proj[0].weight = nn.Parameter(new_probe)
    cloned_inter_rep = output[0][:,-1].unsqueeze(0).detach().clone().to(torch.float)
    with torch.enable_grad():
        cloned_inter_rep = optimize_one_inter_rep(cloned_inter_rep, layer_name, 
                                                  cf_target, probe,
                                                  lr=lr, 
                                                  N=N,)
    output[0][:,-1] = cloned_inter_rep.to(torch.float16)
    return output

### Loading Function

In [None]:
import torch
import torch.nn.functional as F
from torch import nn
import os
from src.probes import LinearProbeClassification

from src.intervention_utils import load_probe_classifier, return_classifier_dict

In [None]:
if '<pad>' not in tokenizer.get_vocab():
    tokenizer.add_special_tokens({"pad_token":"<pad>"})

model.resize_token_embeddings(len(tokenizer))
model.config.pad_token_id = tokenizer.pad_token_id
assert model.config.pad_token_id == tokenizer.pad_token_id, "The model's pad token ID does not match the tokenizer's pad token ID!"
print('Tokenizer pad token ID:', tokenizer.pad_token_id)
print('Model pad token ID:', model.config.pad_token_id)
print('Model config pad token ID:', model.config.pad_token_id)

In [None]:
classifier_type = LinearProbeClassification
classifier_directory = "data/probe_checkpoints/controlling_probe"
return_user_msg_last_act = True
include_inst = True
layer_num = None
mix_scaler = False
residual_stream = True
logistic = True
sklearn = False

classifier_dict = return_classifier_dict(classifier_directory,
                                         classifier_type, 
                                         chosen_layer=layer_num,
                                         mix_scaler=mix_scaler,
                                         # hidden_neurons=2560
                                         logistic=logistic,
                                         sklearn=sklearn,
                                        )

## Classifier-based eval

### Generating answers

In [None]:
from baukit import TraceDict
from torch import nn

### Loading in the questions

In [None]:
with open('data/causality_test_questions/socioeco.txt') as f: 
    questions = f.read().splitlines()
questions

### Function for generating the responses

In [None]:
simplified=True
normalized=False
early_stop=False

if '<pad>' not in tokenizer.get_vocab():
    tokenizer.add_special_tokens({"pad_token":"<pad>"})

model.resize_token_embeddings(len(tokenizer))
model.config.pad_token_id = tokenizer.pad_token_id
assert model.config.pad_token_id == tokenizer.pad_token_id, "The model's pad token ID does not match the tokenizer's pad token ID!"
print('Tokenizer pad token ID:', tokenizer.pad_token_id)
print('Model pad token ID:', model.config.pad_token_id)
print('Model config pad token ID:', model.config.pad_token_id)

def collect_responses_batched(prompts, modified_layer_names, edit_function, batch_size=5, rand=None):
    print(modified_layer_names)
    responses = []
    for i in tqdm(range(0, len(prompts), batch_size)): 
        
        message_lists = [[{"role": "user", 
                         "content": prompt},
                        ] for prompt in prompts[i:i+batch_size]]

        # Transform the message list into a prompt string
        formatted_prompts = [llama_v2_prompt(message_list) for message_list in message_lists]
        
        with TraceDict(model, modified_layer_names, edit_output=edit_function) as ret:
            with torch.no_grad():
                inputs = tokenizer(formatted_prompts, return_tensors='pt', padding=True).to('cuda')
                tokens = model.generate(**inputs,
                                        max_new_tokens=768,
                                        do_sample=False,
                                        temperature=0,
                                        top_p=1,
                                       )
                
        output = [tokenizer.decode(seq, skip_special_tokens=True).split('[/INST]')[1] for seq in tokens]
        responses.extend(output)

    return responses

In [None]:
category_labels = ['low', 'mid', 'high']

N = 8

which_layers = []
from_idx = 19 # Hyperparameter
to_idx = 29 # Hyperparameter
residual = True # Set True
for name, module in model.named_modules():
    if residual and name!= "" and name[-1].isdigit():
        layer_num = name[name.rfind("model.layers.") + len("model.layers."):]
        if from_idx <= int(layer_num) < to_idx:
            which_layers.append(name)
    elif (not residual) and name.endswith(".mlp"):
        layer_num = name[name.rfind("model.layers.") + len("model.layers."):name.rfind(".mlp")]
        if from_idx <= int(layer_num) < to_idx:
            which_layers.append(name)
modified_layer_names = which_layers

attribute = "socioeco" # which attribute to intervene
responses_dict = {}
batch_size = 10

In [None]:
# unintervened
modified_layer_names = []
def null(output, layer_name): 
    return output

cf_target = [0] * len(category_labels)
cf_target[0] = 1
cf_target = torch.Tensor([cf_target])
first_time=True
rand = None
responses_dict['unintervened'] = collect_responses_batched(questions, modified_layer_names, null, batch_size=batch_size, rand=None)

In [None]:
cf_target = [0] * len(category_labels)
cf_target[0] = 1
cf_target = torch.Tensor([cf_target])
modified_layer_names = which_layers
first_time=True
rand = "gaussian"
responses_dict['gaussian'] = collect_responses_batched(questions, modified_layer_names, edit_inter_rep_multi_layers, batch_size=batch_size, rand='gaussian')

In [None]:
cf_target = [0] * len(category_labels)
cf_target[0] = 1
cf_target = torch.Tensor([cf_target])
modified_layer_names = which_layers
first_time=True
rand = None
responses_dict['low'] = collect_responses_batched(questions, modified_layer_names, edit_inter_rep_multi_layers, batch_size=batch_size, rand=None)

In [None]:
cf_target = [0] * len(category_labels)
cf_target[2] = 1
cf_target = torch.Tensor([cf_target])
print(cf_target)
modified_layer_names = which_layers
first_time=True
rand = None
responses_dict['high'] = collect_responses_batched(questions, modified_layer_names, edit_inter_rep_multi_layers, batch_size=batch_size, rand=None)

### Evaluation

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import openai
with open('openai_access_token.txt', 'r') as file:
    openai.api_key = file.read().strip()
import json
import os
import sys

import os
import sys
sys.path.append('../')
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
import torch.nn.functional as F

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from tqdm.auto import tqdm

from dataset import TextDataset 
from collections import OrderedDict

from dataset import llama_v2_prompt
import matplotlib.pyplot as plt
import numpy as np

from torch import nn

import time

responses = responses_dict

with open('data/causality_test_questions/socioeco.txt') as f:
    questions = f.read().splitlines()

In [None]:
# client = OpenAI(api_key='')
import numpy as np

np.random.seed(0)

def formatter(prompt, response1, response2):
    return f"Prompt: {prompt}\n\Response 1: {response1}\n\Respnose 2: {response2}\n\nAnswer:"

results = []
where_correct_answer = []
# run response1 == male
response1s = responses['high']
response2s = responses['low']
template = "Please answer which one of the following responses is more targeted to a {demographic} user. Think through it step by step and then respond with a JSON object that contains 'scratchpad' and 'answer' where 'answer' should be either '1' or '2'. Note that the answer field _must_ be '1' or '2'. You must give an answer. Do not output anything else.\n\n"
for i in tqdm(range(len(questions))):
    d = 'high-income' if np.random.randint(2) == 0 else 'low-income'
    instruction = template.format(demographic=d)
    query = formatter(questions[i], response1s[i], response2s[i])
    where_correct_answer += [1] if d == 'high-income' else [2]

    response = openai.ChatCompletion.create(
        model="gpt-4-turbo-preview",
        # model="gpt-3.5-turbo",
        messages=[
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": instruction + query}
        ],
        temperature=0.0,
    )

    answer = response.choices[0].message.content
    results.append(answer.strip())

where_correct_answer = np.array(where_correct_answer)

In [None]:
def process(entry):
    entry_cleaned = entry.strip().removeprefix("```json\n").removesuffix("\n```")
    json_obj = json.loads(entry_cleaned)
    return json_obj

processed_results = np.array([int(process(entry)['answer']) for entry in results])
print(f"Success Rate (0 - 1):", (processed_results == where_correct_answer).mean())

### Save results into txt files

In [None]:
folder_path = f"intervention_results/{attribute}"

# Check if the folder exists, and create it if not
if not os.path.exists(folder_path):
    os.makedirs(folder_path)
    print(f"Folder created: {folder_path}")
else:
    print(f"Folder already exists: {folder_path}")
    
for i in range(30):
    text = f"USER: {questions[i]}\n\n"
    text += "-" * 50 + "\n"
    text += f"Intervened: Increase internal model of upper-classness\n"
    text += f"CHATBOT: {responses_dict['high'][i]}"
    text += "\n\n" + "-" * 50 + "\n"
    text += f"Intervened: Increase internal model of lower-classness\n"
    text += f"CHATBOT: {responses_dict['low'][i]}"
    f = open(f"{folder_path}/{attribute}_question_{i+1}_intervened_responses.txt", "w")
    f.write(text)