In [1]:
import sys
import os
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
import time
from collections import defaultdict
from pathlib import Path

import circuitsvis as cv
import einops
import numpy as np
import torch as t
from IPython.display import display
from jaxtyping import Float
from nnsight import CONFIG, LanguageModel
from rich import print as rprint
from rich.table import Table
from torch import Tensor
import string as s

# Hide bunch of info logging messages from nnsight
import logging, warnings
logging.disable(sys.maxsize)
warnings.filterwarnings('ignore', category=UserWarning, module='huggingface_hub.utils._token')

device = t.device('mps' if t.backends.mps.is_available() else 'cuda' if t.cuda.is_available() else 'cpu')
print(device)
t.set_grad_enabled(False)

# Make sure exercises are in the path
chapter = r"chapter1_transformer_interp"
exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
section_dir = exercises_dir / "part42_function_vectors_and_model_steering"
if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

from plotly_utils import imshow
import part42_function_vectors_and_model_steering.solutions as solutions
import part42_function_vectors_and_model_steering.tests as tests

MAIN = __name__ == '__main__'

mps


In [2]:
model = LanguageModel('EleutherAI/gpt-j-6b', device_map='auto', torch_dtype=t.bfloat16)
tokenizer = model.tokenizer

N_HEADS = model.config.n_head
N_LAYERS = model.config.n_layer
D_MODEL = model.config.n_embd
D_HEAD = D_MODEL // N_HEADS

print(f"Number of heads: {N_HEADS}")
print(f"Number of layers: {N_LAYERS}")
print(f"Model dimension: {D_MODEL}")
print(f"Head dimension: {D_HEAD}\n")

print("Entire config: ", model.config)

REMOTE = True
# If you want to set REMOTE = True then you'll need an API key. Please join the NDIF community
# Discord (https://nnsight.net/status/) and request one from there, then uncomment and run the
# following code:
CONFIG.set_default_api_key("7592caadcba94ba2a9e3e008a8a3f6a2")

Number of heads: 16
Number of layers: 28
Model dimension: 4096
Head dimension: 256

Entire config:  GPTJConfig {
  "_name_or_path": "EleutherAI/gpt-j-6b",
  "activation_function": "gelu_new",
  "architectures": [
    "GPTJForCausalLM"
  ],
  "attn_pdrop": 0.0,
  "bos_token_id": 50256,
  "embd_pdrop": 0.0,
  "eos_token_id": 50256,
  "gradient_checkpointing": false,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gptj",
  "n_embd": 4096,
  "n_head": 16,
  "n_inner": null,
  "n_layer": 28,
  "n_positions": 2048,
  "resid_pdrop": 0.0,
  "rotary": true,
  "rotary_dim": 64,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 50,
      "temperature": 1.0
    }
  },
  "tie_word_embeddings": false,
  "tokenizer_class": "GPT2Tokenizer",
  "



In [3]:
with model.trace("Hello,", remote=REMOTE):
    print(model)

GPTJForCausalLM(
  (transformer): GPTJModel(
    (wte): Embedding(50400, 4096)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-27): 28 x GPTJBlock(
        (ln_1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (attn): GPTJAttention(
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (out_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): GPTJMLP(
          (fc_in): Linear(in_features=4096, out_features=16384, bias=True)
          (fc_out): Linear(in_features=16384, out_features=4096, bias=True)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln_f)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Downloading result: 100%|██████████| 928/928 [00:00<00:00, 1.08MB/s]


In [3]:
# Load the word pairs from the text file
with open(section_dir / "data" / "antonym_pairs.txt", "r") as f:
    ANTONYM_PAIRS = [line.split() for line in f.readlines()]

print(ANTONYM_PAIRS[:10])

[['old', 'young'], ['top', 'bottom'], ['awake', 'asleep'], ['future', 'past'], ['appear', 'disappear'], ['early', 'late'], ['empty', 'full'], ['innocent', 'guilty'], ['ancient', 'modern'], ['arrive', 'depart']]


In [4]:
class ICLSequence:
    '''
    Class to store a single antonym sequence.

    Uses the default template "Q: {x}\nA: {y}" (with separate pairs split by "\n\n").
    '''
    def __init__(self, word_pairs: list[tuple[str, str]]):
        self.word_pairs = word_pairs
        self.x, self.y = zip(*word_pairs)

    def __len__(self):
        return len(self.word_pairs)

    def __getitem__(self, idx: int):
        return self.word_pairs[idx]

    def prompt(self):
        '''Returns the prompt, which contains all but the second element in the last word pair.'''
        p = "\n\n".join([f"Q: {x}\nA: {y}" for x, y in self.word_pairs])
        return p[:-len(self.completion())]

    def completion(self):
        '''Returns the second element in the last word pair (with padded space).'''
        return " " + self.y[-1]

    def __str__(self):
        '''Prints a readable string representation of the prompt & completion (indep of template).'''
        return f"{', '.join([f'({x}, {y})' for x, y in self[:-1]])}, {self.x[-1]} ->".strip(", ")


word_list = [["hot", "cold"], ["yes", "no"], ["in", "out"], ["up", "down"]]
seq = ICLSequence(word_list)

print("Tuple-representation of the sequence:")
print(seq)
print("\nActual prompt, which will be fed into the model:")
print(seq.prompt())

Tuple-representation of the sequence:
(hot, cold), (yes, no), (in, out), up ->

Actual prompt, which will be fed into the model:
Q: hot
A: cold

Q: yes
A: no

Q: in
A: out

Q: up
A:


In [6]:
class ICLDataset:
    '''
    Dataset to create antonym pair prompts, in ICL task format. We use random seeds for consistency
    between the corrupted and clean datasets.

    Inputs:
        word_pairs:
            list of ICL task, e.g. [["old", "young"], ["top", "bottom"], ...] for the antonym task
        size:
            number of prompts to generate
        n_prepended:
            number of antonym pairs before the single-word ICL task
        bidirectional:
            if True, then we also consider the reversed antonym pairs
        corrupted:
            if True, then the second word in each pair is replaced with a random word
        seed:
            random seed, for consistency & reproducibility
    '''

    def __init__(
        self,
        word_pairs: list[tuple[str, str]],
        size: int,
        n_prepended: int,
        bidirectional: bool = True,
        seed: int = 0,
        corrupted: bool = False,
    ):
        assert n_prepended+1 <= len(word_pairs), "Not enough antonym pairs in dataset to create prompt."

        self.word_pairs = word_pairs
        self.word_list = [word for word_pair in word_pairs for word in word_pair]
        self.size = size
        self.n_prepended = n_prepended
        self.bidirectional = bidirectional
        self.corrupted = corrupted
        self.seed = seed

        self.seqs = []
        self.prompts = []
        self.completions = []

        # Generate the dataset (by choosing random antonym pairs, and constructing `ICLSequence` objects)
        for n in range(size):
            np.random.seed(seed + n)
            random_pairs = np.random.choice(len(self.word_pairs), n_prepended+1, replace=False)
            random_orders = np.random.choice([1, -1], n_prepended+1)
            if not(bidirectional): random_orders[:] = 1
            word_pairs = [self.word_pairs[pair][::order] for pair, order in zip(random_pairs, random_orders)]
            if corrupted:
                for i in range(len(word_pairs) - 1):
                    word_pairs[i][1] = np.random.choice(self.word_list)
            seq = ICLSequence(word_pairs)

            self.seqs.append(seq)
            self.prompts.append(seq.prompt())
            self.completions.append(seq.completion())

    def create_corrupted_dataset(self):
        '''Creates a corrupted version of the dataset (with same random seed).'''
        return ICLDataset(self.word_pairs, self.size, self.n_prepended, self.bidirectional, corrupted=True, seed=self.seed)

    def __len__(self):
        return self.size

    def __getitem__(self, idx: int):
        return self.seqs[idx]

In [7]:
dataset = ICLDataset(ANTONYM_PAIRS, size=10, n_prepended=2, corrupted=False)

table = Table("Prompt", "Correct completion")
for seq, completion in zip(dataset.seqs, dataset.completions):
    table.add_row(str(seq), repr(completion))

rprint(table)

In [8]:
dataset = ICLDataset(ANTONYM_PAIRS, size=10, n_prepended=2, corrupted=True)

table = Table("Prompt", "Correct completion")
for seq, completions in zip(dataset.seqs, dataset.completions):
    table.add_row(str(seq), repr(completions))

rprint(table)

In [9]:
def calculate_h(model: LanguageModel, dataset: ICLDataset, layer: int = -1) -> tuple[list[str], Tensor]:
    '''
    Averages over the model's hidden representations on each of the prompts in `dataset` at layer `layer`, to produce
    a single vector `h`.

    Inputs:
        model: LanguageModel
            the transformer you're doing this computation with
        dataset: ICLDataset
            the dataset whose prompts `dataset.prompts` you're extracting the activations from (at the last seq pos)
        layer: int
            the layer you're extracting activations from

    Returns:
        completions: list[str]
            list of the model's next-token predictions (i.e. the strings the model predicts to follow the last token)
        h: Tensor
            average hidden state tensor at final sequence position, of shape (d_model,)
    '''
    try:
        with model.trace(dataset.prompts, remote=REMOTE, scan=True, validate=True):
            logits = model.lm_head.output[:, -1, :]
            token_idx = logits.argmax(dim = -1).save()
            h = model.transformer.h[layer].output[0][:,-1,:].mean(dim = 0).save()
    except Exception as e:
        print(f"Informative error message:\n  {e.__class__.__name__}: {e}")

    str = model.tokenizer.batch_decode(token_idx.value)

    return str, h.value


tests.test_calculate_h(calculate_h, model)

KeyboardInterrupt: 

In [None]:
def display_model_completions_on_antonyms(
    model: LanguageModel,
    dataset: ICLDataset,
    completions: list[str],
    num_to_display: int = 20,
) -> None:
    table = Table("Prompt (tuple representation)", "Model's completion\n(green=correct)", "Correct completion", title="Model's antonym completions")

    for i in range(min(len(completions), num_to_display)):

        # Get model's completion, and correct completion
        completion = completions[i]
        correct_completion = dataset.completions[i]
        correct_completion_first_token = model.tokenizer.tokenize(correct_completion)[0].replace('Ġ', ' ')
        seq = dataset.seqs[i]

        # Color code the completion based on whether it's correct
        is_correct = (completion == correct_completion_first_token)
        completion = f"[b green]{repr(completion)}[/]" if is_correct else repr(completion)

        table.add_row(str(seq), completion, repr(correct_completion))

    rprint(table)


# Get uncorrupted dataset
dataset = ICLDataset(ANTONYM_PAIRS, size=20, n_prepended=2)

# Getting it from layer 12, as in the description in section 2.1 of paper
model_completions, h = calculate_h(model, dataset, layer=12)

# Displaying the output
display_model_completions_on_antonyms(model, dataset, model_completions)

In [None]:
def intervene_with_h(
    model: LanguageModel,
    zero_shot_dataset: ICLDataset,
    h: Tensor,
    layer: int,
) -> tuple[list[str], list[str]]:
    '''
    Extracts the vector `h` using previously defined function, and intervenes by adding `h` to the
    residual stream of a set of generated zero-shot prompts.

    Inputs:
        model: LanguageModel
            the transformer you're doing this computation with
        zero_shot_dataset: ICLDataset
            the dataset of zero-shot prompts which we'll intervene on, using the `h`-vector
        h: Tensor
            the `h`-vector we'll be adding to the residual stream
        layer: int
            the layer we'll be extracting the `h`-vector from

    Returns:
        completions_zero_shot: list[str]
            list of string completions for the zero-shot prompts, without intervention
        completions_intervention: list[str]
            list of string completions for the zero-shot prompts, with h-intervention
    '''
    with model.trace(remote=REMOTE) as tracer:
        with tracer.invoke(zero_shot_dataset.prompts):
            clean_logits = model.lm_head.output[:, -1, :]
            clean_index = clean_logits.argmax(dim = -1).save()
        with tracer.invoke(zero_shot_dataset.prompts):
            hidden_states = model.transformer.h[layer].output[0]
            hidden_states[:,-1,:] += h
            intervention_logits = model.lm_head.output[:, -1, :]
            intervention_index = intervention_logits.argmax(dim = -1).save()
    
    completions_zero_shot = model.tokenizer.batch_decode(clean_index.value)
    completions_intervention = model.tokenizer.batch_decode(intervention_index.value)

    return completions_zero_shot, completions_intervention


tests.test_intervene_with_h(intervene_with_h, model, h, ANTONYM_PAIRS, REMOTE)

In [None]:
layer = 12
dataset = ICLDataset(ANTONYM_PAIRS, size=20, n_prepended=3, seed=0)
zero_shot_dataset = ICLDataset(ANTONYM_PAIRS, size=20, n_prepended=0, seed=1)

# Run previous function to get h-vector
h = calculate_h(model, dataset, layer=layer)[1]

# Run new function to intervene with h-vector
completions_zero_shot, completions_intervention = intervene_with_h(model, zero_shot_dataset, h, layer=layer)

print("Zero-shot completions: ", completions_zero_shot)
print("Completions with intervention: ", completions_intervention)

In [None]:
def display_model_completions_on_h_intervention(
    dataset: ICLDataset,
    completions: list[str],
    completions_intervention: list[str],
    num_to_display: int = 20,
) -> None:
    table = Table("Prompt", "Model's completion\n(no intervention)", "Model's completion\n(intervention)", "Correct completion", title="Model's antonym completions")

    for i in range(min(len(completions), num_to_display)):

        completion_ni = completions[i]
        completion_i = completions_intervention[i]
        correct_completion = dataset.completions[i]
        correct_completion_first_token = tokenizer.tokenize(correct_completion)[0].replace('Ġ', ' ')
        seq = dataset.seqs[i]

        # Color code the completion based on whether it's correct
        is_correct = (completion_i == correct_completion_first_token)
        completion_i = f"[b green]{repr(completion_i)}[/]" if is_correct else repr(completion_i)

        is_correct = (completion_ni == correct_completion_first_token)
        completion_ni = f"[b green]{repr(completion_ni)}[/]" if is_correct else repr(completion_ni)

        table.add_row(str(seq), completion_ni, completion_i, repr(correct_completion))

    rprint(table)


display_model_completions_on_h_intervention(zero_shot_dataset, completions_zero_shot, completions_intervention)

In [None]:
def calculate_h_and_intervene(
    model: LanguageModel,
    dataset: ICLDataset,
    zero_shot_dataset: ICLDataset,
    layer: int,
) -> tuple[list[str], list[str]]:
    '''
    Extracts the vector `h`, intervenes by adding `h` to the residual stream of a set of generated zero-shot prompts,
    all within the same forward pass. Returns the completions from this intervention.

    Inputs:
        model: LanguageModel
            the model we're using to generate completions
        dataset: ICLDataset
            the dataset of clean prompts from which we'll extract the `h`-vector
        zero_shot_dataset: ICLDataset
            the dataset of zero-shot prompts which we'll intervene on, using the `h`-vector
        layer: int
            the layer we'll be extracting the `h`-vector from

    Returns:
        completions_zero_shot: list[str]
            list of string completions for the zero-shot prompts, without intervention
        completions_intervention: list[str]
            list of string completions for the zero-shot prompts, with h-intervention
    '''
    with model.trace(remote=REMOTE) as tracer:
        with tracer.invoke(dataset.prompts):
            h = model.transformer.h[layer].output[0][:,-1,:].mean(dim = 0)
        with tracer.invoke(zero_shot_dataset.prompts):
            clean_logits = model.lm_head.output[:, -1, :]
            clean_index = clean_logits.argmax(dim = -1).save()
        with tracer.invoke(zero_shot_dataset.prompts):
            hidden_states = model.transformer.h[layer].output[0]
            hidden_states[:,-1,:] += h
            intervention_logits = model.lm_head.output[:, -1, :]
            intervention_index = intervention_logits.argmax(dim = -1).save()

    completions_zero_shot = model.tokenizer.batch_decode(clean_index.value)
    completions_intervention = model.tokenizer.batch_decode(intervention_index.value)

    return completions_zero_shot, completions_intervention


dataset = ICLDataset(ANTONYM_PAIRS, size=20, n_prepended=3, seed=0)
zero_shot_dataset = ICLDataset(ANTONYM_PAIRS, size=20, n_prepended=0, seed=1)

completions_zero_shot, completions_intervention = calculate_h_and_intervene(model, dataset, zero_shot_dataset, layer=layer)

display_model_completions_on_h_intervention(zero_shot_dataset, completions_zero_shot, completions_intervention)

In [39]:
def calculate_h_and_intervene_logprobs(
    model: LanguageModel,
    dataset: ICLDataset,
    zero_shot_dataset: ICLDataset,
    layer: int,
) -> tuple[list[float], list[float]]:
    '''
    Extracts the vector `h`, intervenes by adding `h` to the residual stream of a set of generated zero-shot prompts,
    all within the same forward pass. Returns the logprobs on correct tokens from this intervention.

    Inputs:
        model: LanguageModel
            the model we're using to generate completions
        dataset: ICLDataset
            the dataset of clean prompts from which we'll extract the `h`-vector
        zero_shot_dataset: ICLDataset
            the dataset of zero-shot prompts which we'll intervene on, using the `h`-vector
        layer: int
            the layer we'll be extracting the `h`-vector from

    Returns:
        correct_logprobs: list[float]
            list of correct-token logprobs for the zero-shot prompts, without intervention
        correct_logprobs_intervention: list[float]
            list of correct-token logprobs for the zero-shot prompts, with h-intervention
    '''
    target_index = [tok[0] for tok in model.tokenizer(zero_shot_dataset.completions)["input_ids"]]
    with model.trace(remote=REMOTE) as tracer:
        with tracer.invoke(dataset.prompts):
            h = model.transformer.h[layer].output[0][:,-1,:].mean(dim = 0)
        with tracer.invoke(zero_shot_dataset.prompts):
            clean_logits = model.lm_head.output[:, -1, :]
            clean_prob = t.softmax(clean_logits, dim= -1)
            clean_prob_target = clean_prob[t.arange(len(zero_shot_dataset.prompts)), target_index].save()
        with tracer.invoke(zero_shot_dataset.prompts):
            hidden_states = model.transformer.h[layer].output[0]
            hidden_states[:,-1,:] += h
            intervention_logits = model.lm_head.output[:, -1, :]
            intervention_prob = t.softmax(intervention_logits, dim= -1)
            intervention_prob_target = intervention_prob[t.arange(len(zero_shot_dataset.prompts)), target_index].save()

    return clean_prob_target.value.tolist(), intervention_prob_target.value.tolist()
            

In [None]:
def display_model_logprobs_on_h_intervention(
    dataset: ICLDataset,
    correct_logprobs_zero_shot: list[float],
    correct_logprobs_intervention: list[float],
    num_to_display: int = 20,
) -> None:
    table = Table(
        "Zero-shot prompt", "Model's logprob\n(no intervention)", "Model's logprob\n(intervention)", "Change in logprob",
        title="Model's antonym logprobs, with zero-shot h-intervention\n(green = intervention improves accuracy)"
    )

    for i in range(min(len(correct_logprobs_zero_shot), num_to_display)):

        logprob_ni = correct_logprobs_zero_shot[i]
        logprob_i = correct_logprobs_intervention[i]
        delta_logprob = logprob_i - logprob_ni
        zero_shot_prompt = f"{dataset[i].x[0]:>8} -> {dataset[i].y[0]}"

        # Color code the logprob based on whether it's increased with this intervention
        is_improvement = (delta_logprob >= 0)
        delta_logprob = f"[b green]{delta_logprob:+.2f}[/]" if is_improvement else f"{delta_logprob:+.2f}"

        table.add_row(zero_shot_prompt, f"{logprob_ni:.2f}", f"{logprob_i:.2f}", delta_logprob)

    rprint(table)


dataset = ICLDataset(ANTONYM_PAIRS, size=20, n_prepended=3, seed=0)
zero_shot_dataset = ICLDataset(ANTONYM_PAIRS, size=20, n_prepended=0, seed=1)

correct_logprobs_zero_shot, correct_logprobs_intervention = calculate_h_and_intervene_logprobs(model, dataset, zero_shot_dataset, layer=layer)

display_model_logprobs_on_h_intervention(zero_shot_dataset, correct_logprobs_zero_shot, correct_logprobs_intervention)