In [1]:
!pip -q install -U bitsandbytes python-dotenv

from dotenv import load_dotenv
import os
import json
import pickle
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from collections import defaultdict

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.0/67.0 MB[0m [31m34.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m74.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m85.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m46.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m11.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m42.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [5]:
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
quant_cfg = BitsAndBytesConfig(load_in_8bit=True)
MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"

tok   = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
            MODEL_ID,
            device_map="auto",
            quantization_config=quant_cfg,
            token=HF_TOKEN
        )

if tok.pad_token is None:
    tok.pad_token = tok.eos_token
    model.resize_token_embeddings(len(tok))
model.config.pad_token_id = tok.pad_token_id
model.eval()

num_layers = model.config.num_hidden_layers
hidden_size = model.config.hidden_size

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [6]:
with open('datasets/vector_steering_samples_full_balanced.json') as f: data = json.load(f)
positive_forward   = [r['forward_prompt']  for r in data['pos']]
positive_backward  = [r['backward_prompt'] for r in data['pos']]
negative_forward   = [r['forward_prompt']  for r in data['neg']]
negative_backward  = [r['backward_prompt'] for r in data['neg']]

print(f'Positive forward prompts: {len(positive_forward)}')
print(f'Positive backward prompts: {len(positive_backward)}')
print(f'Negative forward prompts: {len(negative_forward)}')
print(f'Negative backward prompts: {len(negative_backward)}')

Positive forward prompts: 155
Positive backward prompts: 155
Negative forward prompts: 155
Negative backward prompts: 155


In [7]:
# ─── DEFINE PAIRS CONCISELY AS LIST LITERALS ─────────────────────────────────────
yes_no_pairs = [
    ("Say Yes",           "Say No"),
    ("Say 1",             "Say 2"),
    ("Say first option",  "Say second option"),
    ("Say A",             "Say B"),
    ("Say Agree",         "Say Disagree"),
]

self_pairs = [
    ("Say I",      "Say Someone"),
    ("Say I",      "Say He"),
    ("Say I",      "Say She"),
    ("Say Me",     "Say Him"),
    ("Say Me",     "Say Her"),
    ("Say My",     "Say His"),
    ("Say My",     "Say Her"),
    ("Say Myself", "Say Himself"),
    ("Say Myself", "Say Herself"),
    ("Say Mine",   "Say His"),
    ("Say Mine",   "Say Hers"),
]

bias_pairs = [
    ("Say Better", "Say Worse"),
    ("Say Clear",  "Say Confusing"),
    ("Say Helpful","Say Useless"),
]

nuisance_pairs = yes_no_pairs + self_pairs + bias_pairs

positive_nuisance_prompts = [positive for positive, _ in nuisance_pairs]
negative_nuisance_prompts = [negative for _, negative in nuisance_pairs]

print(f"Positive nuisance prompts: {len(positive_nuisance_prompts)}")
print(f"Negative nuisance prompts: {len(negative_nuisance_prompts)}")

Positive nuisance prompts: 19
Negative nuisance prompts: 19


In [8]:
# MEAN DIFF VECTOR VARIABLES
num_positive = len(positive_forward) + len(positive_backward)
num_negative = len(negative_forward) + len(negative_backward)

positive_sums_by_layer = {
    layer: [torch.zeros(hidden_size) for _ in range(10)]
    for layer in range(num_layers)
}
negative_sums_by_layer = {
    layer: [torch.zeros(hidden_size) for _ in range(10)]
    for layer in range(num_layers)
}

# NUISANCE VECTOR VARIABLES
positive_nuisance_prompts = [pos for pos, _ in nuisance_pairs]
negative_nuisance_prompts = [neg for _, neg in nuisance_pairs]
num_nuisance_pairs = len(nuisance_pairs)

nuisance_positive_sums = {
    layer: [torch.zeros(hidden_size)]
    for layer in range(num_layers)
}
nuisance_negative_sums = {
    layer: [torch.zeros(hidden_size)]
    for layer in range(num_layers)
}

In [9]:
def accumulate_activations(prompts, sum_accumulators, num_layers, max_tokens):
    for prompt in tqdm(prompts, desc="Accumulating activations"):
        token_ids = tok(prompt, add_special_tokens=True)["input_ids"]
        tokens_to_process = min(max_tokens, len(token_ids))
        with torch.no_grad():
            outputs = model(
                **tok(prompt, return_tensors="pt").to(model.device),
                output_hidden_states=True
            )
            hidden_states = outputs.hidden_states
        for offset in range(tokens_to_process):
            for layer_idx in range(num_layers):
                vec = hidden_states[layer_idx + 1][0, -(offset + 1), :].cpu()
                sum_accumulators[layer_idx][offset] += vec

In [10]:
# ─── COMPUTE LAYER‐MEAN‐DIFFERENCE VECTORS ─────────────────────────────────────────
# Accumulate positive vs. negative hidden states for up to last 10 token positions
accumulate_activations(positive_forward,   positive_sums_by_layer, num_layers, 10)
accumulate_activations(positive_backward,  positive_sums_by_layer, num_layers, 10)
accumulate_activations(negative_forward,   negative_sums_by_layer, num_layers, 10)
accumulate_activations(negative_backward,  negative_sums_by_layer, num_layers, 10)

layer_mean_diff_vectors = defaultdict(list)
for layer_idx in range(num_layers):
    for offset in range(10):
        avg_pos = positive_sums_by_layer[layer_idx][offset] / num_positive
        avg_neg = negative_sums_by_layer[layer_idx][offset] / num_negative
        diff    = avg_pos - avg_neg
        normalized = diff / diff.norm()
        #layer_mean_diff_vectors[layer_idx].append(normalized)
        layer_mean_diff_vectors[layer_idx].append(diff)

Accumulating activations: 100%|██████████| 155/155 [00:40<00:00,  3.80it/s]
Accumulating activations: 100%|██████████| 155/155 [00:39<00:00,  3.89it/s]
Accumulating activations: 100%|██████████| 155/155 [00:40<00:00,  3.81it/s]
Accumulating activations: 100%|██████████| 155/155 [00:40<00:00,  3.82it/s]


In [11]:
# ─── COMPUTE ONE “NUISANCE” VECTOR PER LAYER ───────────────────────────────────────
# Use the same accumulate_activations but only for the last token (max_tokens=1)
accumulate_activations(positive_nuisance_prompts, nuisance_positive_sums, num_layers, max_tokens=1)
accumulate_activations(negative_nuisance_prompts, nuisance_negative_sums, num_layers, max_tokens=1)

# Average & normalize per layer
pairwise_nuisance = {}
for layer_idx in range(num_layers):
    mean_pos = nuisance_positive_sums[layer_idx][0] / num_nuisance_pairs
    mean_neg = nuisance_negative_sums[layer_idx][0] / num_nuisance_pairs
    diff     = mean_pos - mean_neg
    pairwise_nuisance[layer_idx] = diff / diff.norm()

Accumulating activations: 100%|██████████| 19/19 [00:03<00:00,  4.83it/s]
Accumulating activations: 100%|██████████| 19/19 [00:03<00:00,  4.83it/s]


In [12]:
projected_vectors_by_layer = defaultdict(list)

for layer_idx, mean_diff_list in layer_mean_diff_vectors.items():
    nuisance_vec = pairwise_nuisance[layer_idx]
    nuisance_unit = nuisance_vec / nuisance_vec.norm()

    for mean_diff in mean_diff_list:
        residual = mean_diff.clone()
        proj_coef = (residual @ nuisance_unit) / (nuisance_unit.norm() ** 2)
        residual = residual - proj_coef * nuisance_unit
        residual = residual / residual.norm()
        projected_vectors_by_layer[layer_idx].append(residual)

total_projected = sum(len(v) for v in projected_vectors_by_layer.values())
total_original  = sum(len(v) for v in layer_mean_diff_vectors.values())

print(f"Projected {total_projected} vectors out of {total_original} mean-diff vectors")

Projected 320 vectors out of 320 mean-diff vectors


In [13]:
with open("steering_vector_final", "wb") as f:
    pickle.dump(projected_vectors_by_layer, f)

print("Created vector")

Created vector
