In [None]:
import sys
import os
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
import time
from collections import defaultdict
from pathlib import Path

import circuitsvis as cv
import einops
import numpy as np
import torch as t
from IPython.display import display
from jaxtyping import Float
from nnsight import CONFIG, LanguageModel
from rich import print as rprint
from rich.table import Table
from torch import Tensor
import string as s

# Hide bunch of info logging messages from nnsight
import logging, warnings
logging.disable(sys.maxsize)
warnings.filterwarnings('ignore', category=UserWarning, module='huggingface_hub.utils._token')

device = t.device('mps' if t.backends.mps.is_available() else 'cuda' if t.cuda.is_available() else 'cpu')
print(device)
t.set_grad_enabled(False)

# Make sure exercises are in the path
chapter = r"chapter1_transformer_interp"
exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
section_dir = exercises_dir / "part42_function_vectors_and_model_steering"
if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

from plotly_utils import imshow
import part42_function_vectors_and_model_steering.solutions as solutions
import part42_function_vectors_and_model_steering.tests as tests

MAIN = __name__ == '__main__'

In [None]:
model = LanguageModel('EleutherAI/gpt-j-6b', device_map='auto', torch_dtype=t.bfloat16)
tokenizer = model.tokenizer

N_HEADS = model.config.n_head
N_LAYERS = model.config.n_layer
D_MODEL = model.config.n_embd
D_HEAD = D_MODEL // N_HEADS

print(f"Number of heads: {N_HEADS}")
print(f"Number of layers: {N_LAYERS}")
print(f"Model dimension: {D_MODEL}")
print(f"Head dimension: {D_HEAD}\n")

print("Entire config: ", model.config)

REMOTE = True
# If you want to set REMOTE = True then you'll need an API key. Please join the NDIF community
# Discord (https://nnsight.net/status/) and request one from there, then uncomment and run the
# following code:
CONFIG.set_default_api_key("7592caadcba94ba2a9e3e008a8a3f6a2")

In [None]:
with model.trace("Hello,", remote=REMOTE):
    print(model)

In [None]:
# Load the word pairs from the text file
with open(section_dir / "data" / "antonym_pairs.txt", "r") as f:
    ANTONYM_PAIRS = [line.split() for line in f.readlines()]

print(ANTONYM_PAIRS[:10])

In [5]:
class ICLSequence:
    '''
    Class to store a single antonym sequence.

    Uses the default template "Q: {x}\nA: {y}" (with separate pairs split by "\n\n").
    '''
    def __init__(self, word_pairs: list[tuple[str, str]]):
        self.word_pairs = word_pairs
        self.x, self.y = zip(*word_pairs)

    def __len__(self):
        return len(self.word_pairs)

    def __getitem__(self, idx: int):
        return self.word_pairs[idx]

    def prompt(self):
        '''Returns the prompt, which contains all but the second element in the last word pair.'''
        p = "\n\n".join([f"Q: {x}\nA: {y}" for x, y in self.word_pairs])
        return p[:-len(self.completion())]

    def completion(self):
        '''Returns the second element in the last word pair (with padded space).'''
        return " " + self.y[-1]

    def __str__(self):
        '''Prints a readable string representation of the prompt & completion (indep of template).'''
        return f"{', '.join([f'({x}, {y})' for x, y in self[:-1]])}, {self.x[-1]} ->".strip(", ")


class ICLDataset:
    '''
    Dataset to create antonym pair prompts, in ICL task format. We use random seeds for consistency
    between the corrupted and clean datasets.

    Inputs:
        word_pairs:
            list of ICL task, e.g. [["old", "young"], ["top", "bottom"], ...] for the antonym task
        size:
            number of prompts to generate
        n_prepended:
            number of antonym pairs before the single-word ICL task
        bidirectional:
            if True, then we also consider the reversed antonym pairs
        corrupted:
            if True, then the second word in each pair is replaced with a random word
        seed:
            random seed, for consistency & reproducibility
    '''

    def __init__(
        self,
        word_pairs: list[tuple[str, str]],
        size: int,
        n_prepended: int,
        bidirectional: bool = True,
        seed: int = 0,
        corrupted: bool = False,
    ):
        assert n_prepended+1 <= len(word_pairs), "Not enough antonym pairs in dataset to create prompt."

        self.word_pairs = word_pairs
        self.word_list = [word for word_pair in word_pairs for word in word_pair]
        self.size = size
        self.n_prepended = n_prepended
        self.bidirectional = bidirectional
        self.corrupted = corrupted
        self.seed = seed

        self.seqs = []
        self.prompts = []
        self.completions = []

        # Generate the dataset (by choosing random antonym pairs, and constructing `ICLSequence` objects)
        for n in range(size):
            np.random.seed(seed + n)
            random_pairs = np.random.choice(len(self.word_pairs), n_prepended+1, replace=False)
            random_orders = np.random.choice([1, -1], n_prepended+1)
            if not(bidirectional): random_orders[:] = 1
            word_pairs = [self.word_pairs[pair][::order] for pair, order in zip(random_pairs, random_orders)]
            if corrupted:
                for i in range(len(word_pairs) - 1):
                    word_pairs[i][1] = np.random.choice(self.word_list)
            seq = ICLSequence(word_pairs)

            self.seqs.append(seq)
            self.prompts.append(seq.prompt())
            self.completions.append(seq.completion())

    def create_corrupted_dataset(self):
        '''Creates a corrupted version of the dataset (with same random seed).'''
        return ICLDataset(self.word_pairs, self.size, self.n_prepended, self.bidirectional, corrupted=True, seed=self.seed)

    def __len__(self):
        return self.size

    def __getitem__(self, idx: int):
        return self.seqs[idx]

In [13]:
def calculate_fn_vectors_and_intervene(
    model: LanguageModel,
    dataset: ICLDataset,
    layers: list[int] | None = None,
) -> Float[Tensor, "layers heads"]:
    '''
    Returns a tensor of shape (layers, heads), containing the CIE for each head.

    Inputs:
        model: LanguageModel
            the transformer you're doing this computation with
        dataset: ICLDataset
            the dataset of clean prompts from which we'll extract the function vector (we'll also create a
            corrupted version of this dataset for interventions)
        layers: list[int] | None
            the layers which this function will calculate the score for (if None, we assume all layers)
    '''
    if layers is None:
        layer_list = list(range(model.config.n_layer))
    else:
        layer_list = layers
    target_index = [tok[0] for tok in model.tokenizer(dataset.completions)["input_ids"]]
    b = len(dataset)
    corrupted_dataset = dataset.create_corrupted_dataset()

    with model.trace(remote=REMOTE) as tracer:
        z_dict = {}
        intervene_prob_dict = {}
        with tracer.invoke(dataset.prompts):
            for layer in layer_list:
                z = model.transformer.h[layer].attn.out_proj.input[:,-1,:]
                z_reshaped = z.reshape(b, N_HEADS, D_HEAD).mean(dim = 0)
                for head in range(N_HEADS):
                    z_dict[(layer, head)] = z_reshaped[head,:]

        with tracer.invoke(corrupted_dataset.prompts):
            logits = model.lm_head.output[:,-1,:]
            clean_prob = logits.softmax(dim = -1)[t.arange(b), target_index].save()

        for layer in layer_list:
            for head in range(N_HEADS):
                with tracer.invoke(corrupted_dataset.prompts):
                    z = model.transformer.h[layer].attn.out_proj.input[:,-1,:]
                    z.reshape(b,N_HEADS, D_HEAD)[:,head,:] = z_dict[(layer, head)]

                    logits = model.lm_head.output[:,-1,:]
                    intervene_prob_dict[(layer, head)] = logits.softmax(dim = -1)[t.arange(b), target_index].save()

    intervene_prob_matrix = t.stack([v.value  for v in intervene_prob_dict.values()]).reshape(len(layers), N_HEADS, b)
    diff = (intervene_prob_matrix - clean_prob.value).mean(dim = -1)
    return diff
                    

In [None]:
dataset = ICLDataset(ANTONYM_PAIRS, size=8, n_prepended=2)

def batch_process_layers(n_layers, batch_size):
    for i in range(0, n_layers, batch_size):
        yield range(n_layers)[i:i + batch_size]

results = t.empty((0, N_HEADS), device=device)

# If this fails to run, reduce the batch size so the fwd passes are split up more
for layers in batch_process_layers(N_LAYERS, batch_size=4):

    if layers[0] == 12:
            break

    print(f"Computing layers in {layers} ...")
    t0 = time.time()
    results = t.concat([results, calculate_fn_vectors_and_intervene(model, dataset, layers).to(device)])
    print(f"... finished in {time.time()-t0:.2f} seconds.\n")


imshow(
    results.T,
    title = "Average indirect effect of function-vector intervention on antonym task",
    width = 1000,
    height = 600,
    labels = {"x": "Layer", "y": "Head"},
    aspect = "equal",
)