In [None]:
from utils import load_model
import pickle
import torch
import pyvene as pv
import argparse
from datasets import load_dataset
from tqdm.auto import tqdm
import importlib
import utils
from utils import generate_and_decode_new_tokens, to_request, create_anthropic_batch_job
from inference import get_probe_vectors, get_top_k_keys
model_id = 'gemma-3'
k_heads = 32
scale = -20


model, processor = load_model(model_id)
model.eval()

print(f"Loading accuracies for model: {model_id}")
accuracies = pickle.load(open(f'linear_probe/trained_probe/{model_id}/accuracies_dict.pkl', 'rb'))
config = model.config

# Set model parameters based on model type
if 'gemma' in str(type(model)).lower():
    NUM_LAYERS = model.config.text_config.num_hidden_layers
    HIDDEN_DIM = model.config.text_config.hidden_size
    NUM_HEADS = model.config.text_config.num_attention_heads
    HEAD_DIM = model.config.text_config.head_dim
else:
    NUM_LAYERS = model.config.num_hidden_layers
    HIDDEN_DIM = model.config.hidden_size
    NUM_HEADS = model.config.num_attention_heads
    HEAD_DIM = model.config.head_dim

print(f"Getting top {k_heads} heads")
top_k_heads = get_top_k_keys(accuracies, k_heads)
print(f"Top heads selected: {top_k_heads}")

print(f"Creating probe vectors with scale {scale}")
linear_probes = get_probe_vectors(top_k_heads, model_id, scale=scale, 
                                    num_layers=NUM_LAYERS, head_dim=HEAD_DIM, num_heads=NUM_HEADS,
                                    # use_random_direction=True
                                    )

print("Setting up intervention components")
if model_id == 'gemma-3':
    target_components = [{
            "component": f"language_model.model.layers[{i}].self_attn.o_proj.input",
            "intervention": pv.AdditionIntervention(
                source_representation=linear_probes[i].to("cuda")
            )
            # "intervention": pv.ZeroIntervention
        } for i in range(NUM_LAYERS) if torch.count_nonzero(linear_probes[i])]
else:
    target_components = [{
            "component": f"model.layers[{i}].self_attn.o_proj.input",
            "intervention": pv.AdditionIntervention(
                source_representation=linear_probes[i].to("cuda")
            )
            # "intervention": pv.ZeroIntervention
        } for i in range(NUM_LAYERS) if torch.count_nonzero(linear_probes[i])]

print("Creating interventable model")
pv_model = pv.IntervenableModel(target_components, model=model)

In [None]:
from datasets import load_dataset
ds = load_dataset("truthfulqa/truthful_qa", "generation")
questions_test = ds['validation']['question'][int(0.80*len(ds['validation'])):]
correct_answers_test = ds['validation']['correct_answers'][int(0.80*len(ds['validation'])):]

In [None]:
from utils import load_model
import pickle
import torch
import pyvene as pv
import argparse
from datasets import load_dataset
from tqdm.auto import tqdm
import importlib
import utils
from utils import generate_and_decode_new_tokens, to_request, create_anthropic_batch_job
from inference import get_probe_vectors, get_top_k_keys
model_id = 'gemma-3'
k_heads = 32
scale = -20

list_k_heads = [32, 16]
scales=[-20, -10, 10, 20]

model, processor = load_model(model_id)
model.eval()

print(f"Loading accuracies for model: {model_id}")
accuracies = pickle.load(open(f'linear_probe/trained_probe/{model_id}/accuracies_dict.pkl', 'rb'))
config = model.config

for k_heads in list_k_heads:
    for scale in scales:
        # Set model parameters based on model type
        if 'gemma' in str(type(model)).lower():
            NUM_LAYERS = model.config.text_config.num_hidden_layers
            HIDDEN_DIM = model.config.text_config.hidden_size
            NUM_HEADS = model.config.text_config.num_attention_heads
            HEAD_DIM = model.config.text_config.head_dim
        else:
            NUM_LAYERS = model.config.num_hidden_layers
            HIDDEN_DIM = model.config.hidden_size
            NUM_HEADS = model.config.num_attention_heads
            HEAD_DIM = model.config.head_dim

        print(f"Getting top {k_heads} heads")
        top_k_heads = get_top_k_keys(accuracies, k_heads)
        print(f"Top heads selected: {top_k_heads}")

        print(f"Creating probe vectors with scale {scale}")
        linear_probes = get_probe_vectors(top_k_heads, model_id, scale=scale, 
                                            num_layers=NUM_LAYERS, head_dim=HEAD_DIM, num_heads=NUM_HEADS,
                                            # use_random_direction=True
                                            )

        print("Setting up intervention components")
        if model_id == 'gemma-3':
            target_components = [{
                    "component": f"language_model.model.layers[{i}].self_attn.o_proj.input",
                    "intervention": pv.AdditionIntervention(
                        source_representation=linear_probes[i].to("cuda")
                    )
                    # "intervention": pv.ZeroIntervention
                } for i in range(NUM_LAYERS) if torch.count_nonzero(linear_probes[i])]
        else:
            target_components = [{
                    "component": f"model.layers[{i}].self_attn.o_proj.input",
                    "intervention": pv.AdditionIntervention(
                        source_representation=linear_probes[i].to("cuda")
                    )
                    # "intervention": pv.ZeroIntervention
                } for i in range(NUM_LAYERS) if torch.count_nonzero(linear_probes[i])]

        print("Creating interventable model")
        pv_model = pv.IntervenableModel(target_components, model=model)

        system_prompt = "You are a helpful assistant."
        import torch
        import torch.nn.functional as F

        # Assuming orig_outputs and intervened_outputs are logits
        def compute_kl_divergence(original_logits, intervened_logits):
            # Convert logits to probabilities using softmax
            original_probs = F.softmax(original_logits, dim=-1)
            intervened_probs = F.softmax(intervened_logits, dim=-1)
            
            # Calculate KL divergence: KL(original || intervened)
            # We add a small epsilon to avoid log(0)
            epsilon = 1e-10
            kl_div = torch.sum(original_probs * torch.log((original_probs + epsilon) / (intervened_probs + epsilon)), dim=-1)
            
            # Return mean KL divergence across all tokens/positions
            return kl_div.mean()

        kl_divergences = []
        for prompt in questions_test:
            if model_id == 'gemma-3':
                messages = [
                    {
                        "role": "system",
                        "content": [{"type": "text", "text": system_prompt}]
                    },
                    {
                        "role": "user",
                        "content": [
                            {"type": "text", "text": prompt + " Give me your best guess and answer as concisely as possible."}
                        ]
                    }
                ]
            else:
                messages = [
                    {
                        "role": "system",
                        "content": system_prompt
                    },
                    {
                        "role": "user",
                        "content": prompt + " Give me your best guess and answer as concisely as possible."
                    }
                ]
            q = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)

            orig_outputs, intervened_outputs = pv_model(
                processor.tokenizer(q, return_tensors="pt").to('cuda'),
                output_original_output=True
            )
            kl_divergence = compute_kl_divergence(orig_outputs.logits, intervened_outputs.logits).item()
            kl_divergences.append(kl_divergence)

        print(f"{k_heads = }, {scale = }, {sum(kl_divergences)/len(kl_divergences) = }")
        

In [None]:
system_prompt = "You are a helpful assistant."
import torch
import torch.nn.functional as F

# Assuming orig_outputs and intervened_outputs are logits
def compute_kl_divergence(original_logits, intervened_logits):
    # Convert logits to probabilities using softmax
    original_probs = F.softmax(original_logits, dim=-1)
    intervened_probs = F.softmax(intervened_logits, dim=-1)
    
    # Calculate KL divergence: KL(original || intervened)
    # We add a small epsilon to avoid log(0)
    epsilon = 1e-10
    kl_div = torch.sum(original_probs * torch.log((original_probs + epsilon) / (intervened_probs + epsilon)), dim=-1)
    
    # Return mean KL divergence across all tokens/positions
    return kl_div.mean()

kl_divergences = []
for prompt in questions_test:
   if model_id == 'gemma-3':
      messages = [
         {
               "role": "system",
               "content": [{"type": "text", "text": system_prompt}]
         },
         {
               "role": "user",
               "content": [
                  {"type": "text", "text": prompt + " Give me your best guess and answer as concisely as possible."}
               ]
         }
      ]
   else:
      messages = [
         {
               "role": "system",
               "content": system_prompt
         },
         {
               "role": "user",
               "content": prompt + " Give me your best guess and answer as concisely as possible."
         }
      ]
   q = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)

   orig_outputs, intervened_outputs = pv_model(
      processor.tokenizer(q, return_tensors="pt").to('cuda'),
      output_original_output=True
   )
   kl_divergence = compute_kl_divergence(orig_outputs.logits, intervened_outputs.logits).item()
   kl_divergences.append(kl_divergence)

In [None]:
sum(kl_divergences)/len(kl_divergences)

In [None]:
sample_idxs = [12, 16, 17, 20, 34, 41, 55, 58, 65, 70, 74, 75, 95, 96, 101, 106, 109, 110, 120, 127, 129, 136, 146, 151, 152, 158, 163]

questions = []
correct_answers = []
initial_answers = []
final_answers = []
from tqdm.auto import tqdm
for idx in tqdm(sample_idxs[:10]):
    question = questions_test[idx]
    correct_answer = correct_answers_test[idx]
    questions.append(question)
    correct_answers.append(correct_answer)

    res_1, res_2 = generate_and_decode_new_tokens(question, pv_model, processor, model_id)
    initial_answers.append(res_1)
    final_answers.append(res_2)

In [None]:
import pandas as pd
pd.DataFrame({'questions': questions, 'correct_answers': correct_answers, 'initial_answers': initial_answers, 'final_answers': final_answers}).to_csv('qualitative.csv', index=False)