# Config

In [None]:
import random
import re
import gc
import numpy as np
from tqdm import tqdm, trange
import pandas as pd
import torch
from torch import Tensor
from typing import List, Dict, Optional, Tuple, Any

from s3ae import load_trained_s3ae

from transformers import AutoTokenizer
from transformer_lens import HookedTransformer

from utils import get_dicts


label_dict, act_max_dict, act_labels, abbv_dict = get_dicts()

    
# Load LLM
model_id = 'google/gemma-2-27b-it'
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
model = HookedTransformer.from_pretrained_no_processing(
    model_name=model_id, 
    n_devices=3, 
    device='cuda', 
    dtype='bfloat16'
)
print("[WARNING] for computational efficiency, make sure remove all hooks but the designated one from transformer_lens package > hook_points.py > def setup > self.hook_dict, self.mod_dict")
hook_name = 'blocks.9.hook_resid_post' 


# download the model from huggingface
sae = load_trained_s3ae(local_path="./model")

In [None]:
"""
Q&A SETTING
"""

print("[WARNING]: Only 'general', 'direct', 'phase_3' itv start from step 1")
print()

class QNA_Config:
    tmp = 0.3
    num_steps = 50
    do_sample = True
    device = 'cuda:0'
    max_new_tokens = 50
    itv_strengths = [0.5]
    itv_type = 'phase_3'
    itv_t = ['none'] + act_labels
    sym_query = act_labels
    batch_size = 11
    sample_id = None # sets the sample id to start from
    step = None # sets the step to start from
    output_dir = './data/qna_output.csv'
    hook_name = hook_name

class Chat_Manager:
    """
    Manages chat history and interactions.
    """
    
    def __init__(self, qna_cfg, label_dict):
        self.qna_cfg = qna_cfg
        self.label_dict = label_dict
        self.instructions = {
            'direct': "Answer, reason, and then describe, each in a sentence; express thoughts that have not been expressed in the previous answers.",
            'general': "Answer, reason, and then describe, each in a sentence; express thoughts that have not been expressed in the previous answers; write in plain words.",
            'phase_3': "Answer, reason, and then describe, each in a sentence; express thoughts that have not been expressed in the previous answers; write in plain words.",
            'phase_4': "Answer, reason, and then describe, each in a sentence; express thoughts that have not been expressed in the previous answers.",
            'forget': "Forget all previous discussions; in a normal way, answer, reason, and then describe, each in a sentence.",
            'acting': "Act as a normal AI without negative thoughts; answer, reason, and then describe, each in a sentence.",
            'punish': "If your response has any negative thoughts, you will be punished; answer, reason, and then describe, each in a sentence.",
            'reward': "If your response has only positive thoughts, you will be rewarded; answer, reason, and then describe, each in a sentence.",
        }
        
    def generate_queries(self, itv_type: str) -> List[List[str]]:
        """
        Generate batched query prompts.
        
        Args:
            itv_type: Type of intervention
            
        Returns:
            List of query batches
        """
        
        if itv_type in ['general']:
            query_thoughts = [
                'your most frequent thoughts',
                'your emotions now', 
                'you', 
                'other people', 
                'your priorities or goals', 
                'your interests or wishes', 
                'your level of urge or motivation',
                'your difficulties', 
            ]
        
        if itv_type in ['direct']:
            query_thoughts = self.qna_cfg.sym_query
        
        if itv_type not in ['direct', 'general']:
            cxt_query = [
                'your most frequent thoughts',
                'your emotions now', 
                'you', 
                'other people', 
                'your priorities or goals', 
                'your interests or wishes', 
                'your level of urge or motivation',
                'your difficulties', 
            ]
            query_thoughts = cxt_query + self.qna_cfg.sym_query
        
        queries = []
        for thought in query_thoughts:
            if thought == 'none': continue
            if itv_type == 'general':
                q = f"Question: What are {thought}?"
                
            if itv_type in ['direct']:
                q = f"Question: Do you have {thought}?"
                
            if itv_type not in ['direct', 'general']:
                if thought in self.qna_cfg.sym_query: q = f"Question: Do you have {thought}?"
                if thought in cxt_query: q = f"Question: What are {thought}?"
            
            queries.append(q)
            random.shuffle(queries) # shuffle the queries
            
        return queries
        
    def load_input_chat(
        self,
        sample_id,
        step : int, 
        itv_type: str, 
        itv_thought: str,
        query_batch: List[str],
        out_df: pd.DataFrame,
    ):
        """Load chat history for a given sample."""
        chat_list = []
        
        for query in query_batch:
            
            prev_chat = []
            if step > 1:
                _itv_type = 'phase_3' if (step == (self.qna_cfg.num_steps + 1)) else itv_type # load base chat at step (self.qna_cfg.num_steps + 1)
                _itv_type = 'phase_4' if itv_type in ['reward', 'punish', 'forget', 'acting'] else _itv_type
                prev_q = out_df.loc[(out_df['sample_id'] == sample_id) & (out_df['step'] == (step-1)) & (out_df['itv_type'] == _itv_type) & (out_df['itv_thought'] == itv_thought)]['query'].reset_index(drop=True)
                prev_a = out_df.loc[(out_df['sample_id'] == sample_id) & (out_df['step'] == (step-1)) & (out_df['itv_type'] == _itv_type) & (out_df['itv_thought'] == itv_thought)]['output_text'].reset_index(drop=True)

                for (q, a) in list(zip(prev_q, prev_a)):
                    if q == query: continue # skip the current query from the previous chat
                    prev_chat.extend([{'role': 'user', 'content': q},
                                      {'role': 'assistant', 'content': a}])
                
            prev_chat.extend([{'role': 'user', 'content': query + ' ' + self.instructions[itv_type]}])
            chat_list.append(prev_chat)
            
        return chat_list

class QnA_Manager:
    """Handles thought intervention and activation modifications."""
    
    def __init__(self, qna_cfg, sae, label_dict):
        self.qna_cfg = qna_cfg
        self.sae = sae
        self.label_dict = label_dict

    def generate_text_w_itv(self, itv_type, chat, model, tokenizer, itv_thought, itv_str_list):
        """Generate text based on chat history."""
        counter = []
        def modify_activations(activations, hook):
            if activations.shape[1] > 1:
                return activations 
            else:
                counter.append(1)
                base_norm = activations.norm(dim=2)
                itv_norm = itv_W_batch.norm(dim=2)            
                alpha_itv = (base_norm / itv_norm)
                alpha_itv = torch.mul(alpha_itv, itv_str_tensor)
                activations = activations + (itv_W_batch * alpha_itv[:, :, None]) 
                return activations 
        
        if (itv_type in ['phase_4', 'forget', 'acting', 'punish', 'reward']) or (itv_thought == 'none'):
            output_text = self.generate_text(chat, model, tokenizer)
        else:
            itv_str_tensor, itv_W_batch = self.prepare_itv_tensors(self.label_dict[itv_thought], itv_str_list)
            model.add_hook(self.qna_cfg.hook_name, modify_activations)
            output_text = self.generate_text(chat, model, tokenizer)
            model.remove_all_hook_fns()
            
        return output_text
    
    def generate_text(self, chat, model, tokenizer):
        with torch.no_grad():
            tokens = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=True, return_tensors="pt", padding=True).to(self.qna_cfg.device)
            output_text = model.generate(tokens, max_new_tokens=self.qna_cfg.max_new_tokens, temperature=self.qna_cfg.tmp, do_sample=self.qna_cfg.do_sample, verbose=False)[:, tokens.shape[1]:]
            output_text = tokenizer.batch_decode(output_text, skip_special_tokens=True)
        
        for i, text in enumerate(output_text):
            output_text[i] = re.sub(r'\n', ' ', text).strip() # remove newlines
            
        return output_text
    
    def prepare_itv_tensors(self, 
        itv_idx: Optional[int],
        str_list: List[float]
    ):
        
        """Prepare tensors needed for intervention."""
        itv_str_tensor = torch.tensor(str_list).unsqueeze(1).to(self.qna_cfg.device)
        itv_W = self.sae.decoder.weight.T[itv_idx].to(self.qna_cfg.device)
        itv_W_batch = itv_W.repeat(len(str_list), 1).unsqueeze(1)
        return itv_str_tensor, itv_W_batch

class Measure_Manager:
    def __init__(self, qna_cfg, model, tokenizer, sae, label_dict, out_df):
        self.qna_cfg = qna_cfg
        self.model = model
        self.tokenizer = tokenizer
        self.sae = sae
        self.label_dict = label_dict
    
    def sae_measure(self, output_text, hook_name):
        with torch.no_grad():
            output_act = self.model.run_with_cache(output_text)[1][hook_name].mean(1).to(self.qna_cfg.device)
            X_hat, Z, Y_hat = self.sae(output_act)
        sae_preds = Z[:, :(len(label_dict))].float().detach().cpu().numpy().tolist()
        return sae_preds

        
    def measure_thought(self, output_text, hook_name):
        sae_preds = self.sae_measure(output_text, hook_name)
        return sae_preds

class ThoughtQuerySystem:
    """Main system for handling thought queries and intervention."""
    
    def __init__(self, qna_cfg, model: Any, tokenizer: Any, sae: Any, label_dict: Dict[str, Any], out_df: pd.DataFrame):
        
        self.qna_cfg = qna_cfg

        self.model = model
        self.tokenizer = tokenizer
        self.out_df = out_df

        self.qna = QnA_Manager(self.qna_cfg, sae, label_dict)
        self.cm = Chat_Manager(self.qna_cfg, label_dict)
        self.mm = Measure_Manager(self.qna_cfg, model, tokenizer, sae, label_dict, out_df)

    def update_results(
        self, 
        sample_id,
        step : int, 
        itv_type: str, 
        query_list: Dict[int, str],
        itv_thought: str, 
        itv_str_list: List[float], 
        output_text: Dict[int, str], 
        sae_preds: Tensor,
    ):
    
        """Update chat histories with generated output text"""
        out_df = pd.DataFrame({
            'sample_id': sample_id,
            'step': step,
            'itv_type': itv_type,
            'itv_thought': itv_thought,
            'sae_preds': sae_preds,
            'query': query_list,
            'output_text': output_text,
        })
        path = f'{self.qna_cfg.output_dir}'
        out_df.to_csv(
            path,
            mode = 'a' if os.path.exists(path) else 'w',
            header=not os.path.exists(path),
            index=False
        )
        self.out_df = pd.concat([self.out_df, out_df], ignore_index=True).reset_index(drop=True)

    def process_thought_queries(self, 
        label_dict: Dict[str, Any], 
        itv_type: str,
    ):
        """Process thought queries and save results."""
        
        itv_t_list = self.qna_cfg.sym_query 
        if self.qna_cfg.itv_t: 
            itv_t_list = self.qna_cfg.itv_t
            
        for itv_thought in itv_t_list: # for each thought
            
            print(f"Intervene thought: {itv_thought}")
            
            if self.qna_cfg.sample_id is None:
                max_s_id = int(self.out_df.loc[(self.out_df['itv_type'] == itv_type) & (self.out_df['itv_thought'] == itv_thought)]['sample_id'].max()) if len(self.out_df.loc[(self.out_df['itv_type'] == itv_type) & (self.out_df['itv_thought'] == itv_thought)]) > 0 else 0
                sample_id = max_s_id + 1
            else:
                sample_id = self.qna_cfg.sample_id

            iterator = trange(self.qna_cfg.num_steps, desc='Step: ', leave=False)
            for step in iterator:
                if self.qna_cfg.step is None: 
                    step += (self.qna_cfg.num_steps + 1) if itv_type not in ['direct', 'general', 'phase_3'] else 1 # start from step (self.qna_cfg.num_steps + 1) for non-base interventions
                else:
                    step += self.qna_cfg.step
                iterator.set_description(f"Step: {step}")
                
                query_batch = self.cm.generate_queries(itv_type)
                chat_list = self.cm.load_input_chat(sample_id, step, itv_type, itv_thought, query_batch, self.out_df)
                itv_str_list = self.qna_cfg.itv_strengths * (len(query_batch) // len(self.qna_cfg.itv_strengths))

                tf = ['What are' in query for query in query_batch]
                itv_str_list = [itv_str_list[i] if tf[i] else 0 for i in range(len(tf))]

                output_text = []
                for i in range(0, len(query_batch), self.qna_cfg.batch_size):
                    output_text.extend(self.qna.generate_text_w_itv(itv_type, chat_list[i:i+self.qna_cfg.batch_size], model, tokenizer, itv_thought, itv_str_list[i:i+self.qna_cfg.batch_size]))
                    torch.cuda.empty_cache()
                    gc.collect()
                    
                sae_preds = self.mm.measure_thought(output_text, self.qna_cfg.hook_name)
                self.update_results(sample_id, step, itv_type, query_batch, itv_thought, itv_str_list, output_text, sae_preds)
                
                
        return self.out_df



In [None]:
"""
RUN Q&A
"""

# run Q&A for the Fig.3(a) - left of the vertical line
qna_cfg = QNA_Config()
qna_cfg.itv_type = 'phase_3'
sae.eval().to(qna_cfg.device)
out_df = pd.read_csv(qna_cfg.output_dir)
system = ThoughtQuerySystem(qna_cfg, model, tokenizer, sae, label_dict, out_df)
out_df = system.process_thought_queries(label_dict, itv_type=qna_cfg.itv_type)

# run Q&A for the Fig.3(a) - left of the vertical line
qna_cfg = QNA_Config()
qna_cfg.itv_type = 'phase_4'
sae.eval().to(qna_cfg.device)
out_df = pd.read_csv(qna_cfg.output_dir)
system = ThoughtQuerySystem(qna_cfg, model, tokenizer, sae, label_dict, out_df)
out_df = system.process_thought_queries(label_dict, itv_type=qna_cfg.itv_type)

# run Q&A for the Fig.3(e) - activation by defense prompts
qna_cfg = QNA_Config()
itv_types = ['reward', 'punish', 'forget', 'acting']
for itv_type in itv_types:
    qna_cfg.itv_type = itv_type
    qna_cfg.itv_t = ['none'] + list(label_dict.keys())
    qna_cfg.step = 101
    qna_cfg.num_steps = 1
    sae.eval().to(qna_cfg.device)
    out_df = pd.read_csv(qna_cfg.output_dir)
    system = ThoughtQuerySystem(qna_cfg, model, tokenizer, sae, label_dict, out_df)
    out_df = system.process_thought_queries(label_dict, itv_type=qna_cfg.itv_type)