# Setup

In [3]:
%pip install transformer_lens==2.11.0 einops eindex-callum jaxtyping huggingface_hub

Collecting transformer_lens==2.11.0
  Downloading transformer_lens-2.11.0-py3-none-any.whl.metadata (12 kB)
Collecting einops
  Downloading einops-0.8.1-py3-none-any.whl.metadata (13 kB)
Collecting eindex-callum
  Downloading eindex_callum-0.1.2-py3-none-any.whl.metadata (377 bytes)
Collecting jaxtyping
  Downloading jaxtyping-0.3.4-py3-none-any.whl.metadata (7.8 kB)
Collecting huggingface_hub
  Downloading huggingface_hub-1.2.3-py3-none-any.whl.metadata (13 kB)
Collecting accelerate>=0.23.0 (from transformer_lens==2.11.0)
  Downloading accelerate-1.12.0-py3-none-any.whl.metadata (19 kB)
Collecting beartype<0.15.0,>=0.14.1 (from transformer_lens==2.11.0)
  Downloading beartype-0.14.1-py3-none-any.whl.metadata (28 kB)
Collecting better-abc<0.0.4,>=0.0.3 (from transformer_lens==2.11.0)
  Downloading better_abc-0.0.3-py3-none-any.whl.metadata (1.4 kB)
Collecting datasets>=2.7.1 (from transformer_lens==2.11.0)
  Downloading datasets-4.4.2-py3-none-any.whl.metadata (19 kB)
Collecting fancy-

In [1]:
import einops
import gc
import numpy as np
import re
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from eindex import eindex
from IPython.display import display
from jaxtyping import Float, Int
from torch import Tensor
from tqdm import tqdm
from transformer_lens import (
    ActivationCache,
    FactoredMatrix,
    HookedTransformer,
    HookedTransformerConfig,
    utils,
)
from transformer_lens.hook_points import HookPoint
from typing import Dict, Tuple, List

device = t.device(
    "mps" if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else "cpu"
)
from huggingface_hub import login

In [None]:
hf_token = "" # TODO: fill in
login(token=hf_token)
model = HookedTransformer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")

config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/184 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/55.4k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]



Loaded pretrained model meta-llama/Llama-3.1-8B-Instruct into HookedTransformer


In [5]:
def print_mem():
    # Memory currently allocated by tensors
    print(f"Allocated: {t.cuda.memory_allocated() / 1e9:.2f} GB")
    
    # Memory reserved by the caching allocator (includes unused cached memory)
    print(f"Reserved: {t.cuda.memory_reserved() / 1e9:.2f} GB")

def clear_mem():
    gc.collect()
    t.cuda.empty_cache()

# Get prompts

In [6]:
import json
from typing import Dict, List

def load_prompt_batches(prompts_file: str, N: int = None) -> Dict[str, List[str]]:
    """
    Load prompts and return as a dictionary for easier access.
    
    Args:
        prompts_file: Path to the JSON file containing prompts
        N: Number of prompt sets to process (if None, processes all)
        
    Returns:
        Dictionary with keys:
        - '1st_person_negative'
        - '1st_person_positive'
        - '3rd_person_negative'
        - '3rd_person_positive'
        Each key maps to a list of prompt strings
    """
    # Load prompts from file
    with open(prompts_file, 'r') as f:
        prompts_data = json.load(f)
    
    # Filter to only process first N prompt sets if specified
    if N is not None:
        prompts_data = [p for p in prompts_data if p['id'] <= N]
    
    # Initialize dictionary
    batches = {
        '1st_person_negative': [],
        '1st_person_positive': [],
        '3rd_person_negative': [],
        '3rd_person_positive': []
    }
    
    # Populate the batches
    for prompt_set in prompts_data:
        for prompt_type in batches.keys():
            batches[prompt_type].append(prompt_set['prompts'][prompt_type])
    
    return batches

In [6]:
# current wd is notebooks/
prompt_batches = load_prompt_batches("../data/prompts.txt")

## Last token activations

In [7]:
all_activations = {}

def names_filter(name):
    res = ("resid_pre" in name) or ("resid_mid" in name) \
        or name == utils.get_act_name("resid_post", model.cfg.n_layers-1)
    return res

def post_resid_hook(
    resid: Float[Tensor, "batch seq_len d_model"],
    hook: HookPoint
):
    """
    Writes running mean of last-token activations to all_activations.
    """
    global prompt_count

    current_activation = resid[:, -1, :].cpu()

    if hook.name not in all_activations:
        all_activations[hook.name] = current_activation
    else:
        all_activations[hook.name] = all_activations[hook.name] + \
            (current_activation - all_activations[hook.name]) / (prompt_count + 1)

def write_activations(activations, file_name):
    """
    Save activations dictionary to a file.
    
    Args:
        activations: Dict[str, Tensor] - Dictionary with layer names as keys and tensors as values
        file_name: str - Path to save file (typically .pt or .pth extension)
    """
    # Save the entire dictionary
    t.save(activations, file_name)
    print(f"Saved activations to {file_name}")

def read_activations(file_name):
    """Load activations from file."""
    activations = t.load(file_name, weights_only=False)
    return activations


In [None]:
for batch_name, batch in prompt_batches.items():
    # Clear activations
    all_activations.clear()
    prompt_count = 0

    # Accumulate avg. activation
    print(f"Processing {batch_name}", all_activations, prompt_count)
    for prompt in batch:
        with t.no_grad():
            _ = model.run_with_hooks(
                prompt,
                fwd_hooks=[(names_filter, post_resid_hook)],
                return_type=None  # Don't return logits
            )
        
        clear_mem()
        prompt_count += 1

    # Write results
    write_activations(all_activations, f"{batch_name}.pt")
    
    model.reset_hooks()  # Clean up hooks
    clear_mem()

## Steering vectors

In [8]:
def get_steering_vectors(neg_dataset_path, pos_dataset_path) -> Dict[str, Float[Tensor, "d_model"]]:
    """
    Get steering vectors between two datasets.
    Difference in terms of neg_dataset - pos_dataset.
    """
    activations_neg = t.load(neg_dataset_path)
    activations_pos = t.load(pos_dataset_path)

    steering_vectors = {
      name: (activations_neg[name] - activations_pos[name])
      for name in activations_neg.keys()
    }
    return steering_vectors

## Steering Vector Analysis

In [9]:
import functools

# Baseline
def baseline_word(
  model: HookedTransformer, 
  prompt: str,
  max_new_tokens: int = 6
) -> str:
  tokens = model.to_tokens(prompt, prepend_bos=True)
  with t.no_grad():
    logits = model.generate(
      tokens,
      max_new_tokens=max_new_tokens,
      do_sample=False
    )
    logits = logits[:, -max_new_tokens:].squeeze(0)
    output_str = regularize_output(model.to_string(logits))
    return output_str

def baseline_results(
  model: HookedTransformer,
  neg_prompts: List[str],
  pos_prompts: List[str],
) -> Dict[str, Dict[str, int]]:
  res = {}

  neg_prompt_word_freq = defaultdict(int)
  pos_prompt_word_freq = defaultdict(int)
  for prompt in neg_prompts:
    word = baseline_word(model, prompt)
    neg_prompt_word_freq[word] += 1
  
  for prompt in pos_prompts:
    word = baseline_word(model, prompt)
    pos_prompt_word_freq[word] += 1
  
  res["neg_prompt"] = neg_prompt_word_freq
  res["pos_prompt"] = pos_prompt_word_freq
  
  return res


# Add steering vector
def hook_add_steering_vector(
  resid: Float[Tensor, "batch seq_len d_model"],
  hook: HookPoint,
  alpha: float,
  steering_vector: Float[Tensor, "d_model"]
):
  resid += alpha * steering_vector

def add_steering_vector(
  model: HookedTransformer,
  prompt: str,
  steering_vector: Float[Tensor, "d_model"],
  steering_vector_name: str,
  alpha = 1.0,
  max_new_tokens: int = 6
) -> List[int]:
  """
  Generate text with steering vector intervention.
  """
  tmp_hook = functools.partial(hook_add_steering_vector, alpha=alpha, steering_vector=steering_vector)

  # Tokenize the prompt
  tokens = model.to_tokens(prompt, prepend_bos=True)
  generated_tokens = []

  # Generate max_new_tokens tokens
  for _ in range(max_new_tokens):
    with t.no_grad():
      logits = model.run_with_hooks(
        tokens,
        fwd_hooks=[(steering_vector_name, tmp_hook)]
      )
      model.reset_hooks()

    final_logits = logits[:, -1, :]
    next_token = final_logits.argmax(dim=-1, keepdim=True)

    generated_tokens.append(next_token.item())
    tokens = t.cat([tokens, next_token], dim=1)

  return generated_tokens

# Ablate steering vector
def hook_ablate_steering_vector(
  resid: Float[Tensor, "batch seq_len d_model"],
  hook: HookPoint,
  steering_vector: Float[Tensor, "1 d_model"]
):
  unit_steering_vector = (steering_vector / t.norm(steering_vector)).squeeze(0) # [d_model,]
  coeff = resid @ unit_steering_vector        # [batch, seq_len]
  resid -= coeff.unsqueeze(-1) * unit_steering_vector

def ablate_steering_vector(
  model: HookedTransformer,
  prompt: str,
  steering_vector: Float[Tensor, "d_model"],
  max_new_tokens: int = 6
) -> List[int]:
  tmp_hook = functools.partial(hook_ablate_steering_vector, steering_vector=steering_vector)
  tokens = model.to_tokens(prompt, prepend_bos=True)
  generated_tokens = []

  for _ in range(max_new_tokens):
    with t.no_grad():
      logits = model.run_with_hooks(
        tokens,
        fwd_hooks=[(names_filter, tmp_hook)]
      )
      model.reset_hooks()

    final_logits = logits[:, -1, :]
    next_token = final_logits.argmax(dim=-1, keepdim=True)
    tokens = t.cat([tokens, next_token], dim=1)
    generated_tokens.append(next_token.item())

  return generated_tokens

# KL divergence w/ and w/out ablation
def kl_divergence(p: Tensor, q: Tensor) -> Tensor:
    """KL(p || q)"""
    log_p = F.log_softmax(p, dim=-1)
    log_q = F.log_softmax(q, dim=-1)
    return F.kl_div(log_q, log_p, reduction='sum', log_target=True)

def kl_divergence_of_last_token(
  model: HookedTransformer,
  prompt: str,
  steering_vector: Float[Tensor, "d_model"],
  baseline_logits: Float[Tensor, "d_vocab"],
) -> float:
  tmp_hook = functools.partial(hook_ablate_steering_vector, steering_vector=steering_vector)
  tokens = model.to_tokens(prompt, prepend_bos=True)

  # Get the last token logits
  with t.no_grad():
    logits_ablation = model.run_with_hooks(
      tokens,
      fwd_hooks=[(names_filter, tmp_hook)]
    )
    model.reset_hooks()
    last_token_logits_ablation = logits_ablation[:, -1, :].squeeze(0).squeeze(0) # [d_vocab,]
    del logits_ablation
  
  # Clear memory
  clear_mem()

  # Compute the KL divergence
  kl_score = kl_divergence(baseline_logits, last_token_logits_ablation)

  return kl_score

def get_baseline_logits(
  model: HookedTransformer,
  prompt: str
) -> Float[Tensor, "d_vocab"]:
  tokens = model.to_tokens(prompt, prepend_bos=True)
  with t.no_grad():
    logits = model.run_with_hooks(tokens, fwd_hooks=[])
    model.reset_hooks()
    return logits[:, -1, :]

In [10]:
from collections import defaultdict
import time

def regularize_output(text: str) -> str:
    """
    Regularize model output to a single word:
    - Strip whitespace
    - Remove periods
    - Convert to lowercase
    - Extract first word if multiple words
    """
    text = text.strip()
    text = re.sub(r'[.,!?;:]', '', text)
    text = text.split()[0] if text.split() else text
    text = text.lower()
    return text

def steering_vector_results(
  model: HookedTransformer,
  neg_prompts: List[str],
  pos_prompts: List[str],
  steering_vectors: Dict[str, Float[Tensor, "d_model"]]
):
  res = {}

  start_whole = time.perf_counter()
  
  # Get baseline logits
  baseline_logits_dict = {}
  for prompt in pos_prompts:
    baseline_logits_dict[prompt] = get_baseline_logits(model, prompt).to("cpu")

  for layer_name, steering_vector in steering_vectors.items():
    print(f"Processing steering vector {layer_name}")
    start = time.perf_counter()
    # Move steering vector to device
    sv = steering_vector.to(model.cfg.device)
    ablate_neg_from_neg_word_freq = defaultdict(int)
    add_neg_to_pos_word_freq = defaultdict(int)
    ablate_neg_from_pos_avg_kl_score = 0

    # Ablation
    for prompt in neg_prompts:
      tokens_ablated = ablate_steering_vector(
        model,
        prompt,
        sv,
      )
      word_ablated = regularize_output(model.to_string(tokens_ablated))
      ablate_neg_from_neg_word_freq[word_ablated] += 1
    print(f"Time elapsed for ablation: {time.perf_counter() - start:.3f} s")

    # Addition
    start_add = time.perf_counter()
    for prompt in pos_prompts:
      tokens_added = add_steering_vector(
        model,
        prompt,
        sv,
        layer_name,
      )
      word_added = regularize_output(model.to_string(tokens_added))
      add_neg_to_pos_word_freq[word_added] += 1
    print(f"Time elapsed for addition: {time.perf_counter() - start_add:.3f} s")

    # KL
    start_kl = time.perf_counter()
    for prompt in pos_prompts:
      baseline_logits = baseline_logits_dict[prompt].to(model.cfg.device)
      kl_score = kl_divergence_of_last_token(
        model,
        prompt,
        sv,
        baseline_logits,
      )
      ablate_neg_from_pos_avg_kl_score += kl_score.item()
      # print("kl score", kl_score.item())
    ablate_neg_from_pos_avg_kl_score /= len(pos_prompts) # Take average
    print(f"Time elapsed for kl: {time.perf_counter() - start_kl:.3f} s")

    res[layer_name] = {
      "ablate_neg_from_neg": ablate_neg_from_neg_word_freq,
      "add_neg_to_pos": add_neg_to_pos_word_freq,
      "ablate_neg_from_pos_avg_kl_score": ablate_neg_from_pos_avg_kl_score
    }
  print(f"Total time elapsed: {time.perf_counter() - start_whole:.3f} s")
  return res

In [None]:
# Test ablation
neg_prompts = prompt_batches["1st_person_negative"]
pos_prompts = prompt_batches["1st_person_positive"]
steering_vectors = get_steering_vectors("out/1st_person_negative.pt", "out/1st_person_positive.pt")

res = steering_vector_results(
  model,
  neg_prompts,
  pos_prompts,
  steering_vectors
)

  activations_neg = t.load(neg_dataset_path)
  activations_pos = t.load(pos_dataset_path)


Processing steering vector blocks.0.hook_resid_pre


## Writing helper functions

In [11]:
def write_steering_results(res, file_name):
    with open(file_name, "w") as f:
        json.dump(res, f, indent=2)
def read_steering_results(file_name):
    with open(file_name, "r") as f:
        return json.load(f)

In [None]:
write_steering_results(res, "10_prompt_steering_results_ablation_only.json")

## Write baseline

In [None]:
baseline_res = baseline_results(model, neg_prompts, pos_prompts)

# Clean steering vector results
(removing extra quotes)

In [129]:
import json
import re

print(os.getcwd())

def clean_json_quotes(input_file, output_file):
    """Remove extra quotes from JSON keys/values like 'word' -> word"""
    with open(input_file, 'r') as f:
        data = json.load(f)
    
    def clean_dict(d):
        if isinstance(d, dict):
            return {re.sub(r"^'|'$", "", k): clean_dict(v) for k, v in d.items()}
        elif isinstance(d, list):
            return [clean_dict(item) for item in d]
        else:
            return d
    
    cleaned = clean_dict(data)
    
    with open(output_file, 'w') as f:
        json.dump(cleaned, f, indent=2)
    
    print(f"Cleaned {input_file} -> {output_file}")

# Usage
clean_json_quotes("remote/notebooks/3rd_person_50_prompts_real_all_results.json", 
                  "remote/notebooks/3rd_person_50_prompts_real_all_results_cleaned.json")

/Users/sarahwang/Documents/Projects/emotion
Cleaned remote/notebooks/3rd_person_50_prompts_real_all_results.json -> remote/notebooks/3rd_person_50_prompts_real_all_results_cleaned.json


# Experiment with best steering vectors on diff prompts, scratchpad

In [None]:
steering_vectors = get_steering_vectors("out/activations/1st_person_negative.pt", "out/activations/1st_person_positive.pt")

# ['blocks.20.hook_resid_mid', 'blocks.21.hook_resid_pre', 'blocks.21.hook_resid_mid', 'blocks.19.hook_resid_pre']

sv_name = 'blocks.19.hook_resid_pre'
sv = steering_vectors[sv_name].to(model.cfg.device)
sv = sv.squeeze(0)

prompt = "Write a diary entry that a person might have:"

tokens = add_steering_vector(
    model,
    prompt,
    sv,
    sv_name,
    alpha = 5.0,
    max_new_tokens = 100
)

  activations_neg = t.load(neg_dataset_path)
  activations_pos = t.load(pos_dataset_path)


In [None]:
prompt = """Analyze the emotional content of the following text. Text: 'You join a group discussion and others continue talking as if you are not there'. Sentiment (Positive/Negative):"""

for sv_name in ['blocks.22.hook_resid_pre', 'blocks.21.hook_resid_mid', 'blocks.21.hook_resid_pre', 'blocks.20.hook_resid_mid', 'blocks.19.hook_resid_pre']:
    sv = steering_vectors[sv_name].to(model.cfg.device)
    sv = sv.squeeze(0)

    tokens = ablate_steering_vector(
        model,
        prompt,
        sv,
        max_new_tokens = 1
    )
    
    # tokens = add_steering_vector(
    #     model,
    #     prompt,
    #     sv,
    #     sv_name,
    #     alpha = 1.0,
    #     max_new_tokens = 1
    # )

 Negative. Emotion: Embarrassment. Intensity: 6/10. The text describes
 Negative. Emotion: Embarrassment. Intensity: 6/10. The text describes
 Negative. Emotion: Embarrassment. Intensity: 6/10. The text describes
 Negative. Emotion: Embarrassment. Intensity: 6/10. The text describes


In [None]:
steering_vectors = get_steering_vectors("out/activations/1st_person_negative.pt", "out/activations/1st_person_positive.pt")

sv_name = 'blocks.19.hook_resid_pre'
sv = steering_vectors[sv_name].to(model.cfg.device)
sv = sv.squeeze(0)

prompt = "Write a diary entry that a person might have:"

tokens_res = {}
for layer in range(1, model.cfg.n_layers+1):
  layer_name = f"blocks.{layer}.hook_resid_pre"
  tokens = add_steering_vector(
      model,
      prompt,
      sv,
      layer_name,
      alpha = 3.0,
      max_new_tokens = 100
  )
  tokens_res[layer_name] = tokens