# SAE

In [None]:
from interp.all import *
from datasets import load_dataset, DatasetDict
from huggingface_hub import hf_hub_download
from transformer_lens import HookedTransformer
from sae_lens import SAE
import json
from pathlib import Path
from functools import partial
import torch as t


device = get_device()
sae_pretrained_path = Path(__file__).parents[1] / "culture-gpt-0-sae"

model: HookedTransformer
model = load_hooked(0).eval().to(device)  # type: ignore


def load_sae(path: str) -> SAE:
    from sae_lens import SAEConfig
    from sae_lens.toolkit.pretrained_sae_loaders import read_sae_from_disk
    from sae_lens.config import DTYPE_MAP

    weight_path = hf_hub_download(SAE_REPO_ID, path + "/sae_weights.safetensors")
    cfg_path = hf_hub_download(SAE_REPO_ID, path + "/cfg.json")
    with open(cfg_path, "r") as f:
        cfg_dict = json.load(f)
    cfg_dict, state_dict = read_sae_from_disk(
        cfg_dict=cfg_dict,
        weight_path=weight_path,
        device="cpu",
        dtype=DTYPE_MAP[cfg_dict["dtype"]],
    )
    print(cfg_dict)
    sae_cfg = SAEConfig.from_dict(cfg_dict)
    sae_cfg.device = str(device)
    sae = SAE(sae_cfg)
    sae.load_state_dict(state_dict)
    sae.eval()
    return sae


sae = load_sae("gpt-0/blocks_8_mlp_out")

dataset = load_dataset("tommyp111/culture-puzzles-1M-partitioned")
assert isinstance(dataset, DatasetDict)
dataset.set_format("pt")


Bit of a util for later, The first and second grid are a one-shot example, so we should not include these in our eval / analysis. Let's create slices that extract the final and second final grid.

(not including the special tokens A f(A) etc. These are easy).

Another gotcha: these should be from a "stripped" batch, i.e. a batch without it's final token. I've found this edge case to be rather strange, as


I've found this edge case to be rather strange, and HookedTransformer complains about

In [None]:
final_grid_slice = slice(-100, None)
second_final_grid_slice = slice(-201, -101)

batch = dataset["contact"]["input_ids"][:5].to(device)
batch_stripped = batch[:, :]

a = batch_stripped[:1].clone()
print(repr_grid(a[0]))
print("\n████████████████████████████████████████████████████████████████████████████████\n")
a[:, second_final_grid_slice] = 4
a[:, final_grid_slice] = 5
print(repr_grid(a[0]))

I trained on layer 8 somewhat arbitarily, let's see if we can find some interesting behaviour attributing to this MLP layer.

Looking at the loss can be less informative, because for most of the tokens it is extremely staight-forward (as shown in the first notebook).

Instead let's look at the accuracy -- if the model get's the entire grid correct. The models all have a accuracy of ~95% so this can be a much more useful measure than it would be in a text GPT (where it is pretty much impossible to predict the next passage perfectly).

I'll use the partitioned dataset so we can look at each task individually -- I think this is quite principled, since each example in the training dataset only included a single task. It makes sense for ability / features to be task specific.

(The culture may be a mix of all tasks, generated by the models themselves)

In [None]:
def zero_abl_hook(activation, hook):
    return t.zeros_like(activation)


def ablate_mlp_single_task(batch, layer):
    # Batch: we want to look at predictions for only the final grid
    final_grid_slice = slice(-100, None)

    # Logits: the logit predictions are shape (405,) which predict the token *after* the current.
    # aka the BOS token is not predicted in the logits, but the token *after* the final grid
    # is. Therefore we need to strip this after final grid logit off.
    final_grid_strip = slice(-101, -1)

    with t.inference_mode():
        logits_orig, loss_orig = model(batch, return_type="both")
        logits_orig = logits_orig[:, final_grid_strip] # use strip on the logits

        logits_abl, loss_abl = model.run_with_hooks(
            batch,
            return_type="both",
            fwd_hooks=[(f"blocks.{layer}.hook_mlp_out", zero_abl_hook)],
        )
        logits_abl = logits_abl[:, final_grid_strip]

    # argmax => temp 0
    orig_correct = t.all(batch[:, final_grid_slice] == logits_orig.argmax(-1), dim=1)
    abl_same = t.all(logits_orig.argmax(-1) == logits_abl.argmax(-1), dim=1)
    wrong_puzzles = (~orig_correct).argwhere()[:, 0].tolist()
    abl_diff_puzzles = (~abl_same).argwhere()[:, 0].tolist()

    wrong_and_different_ablate = list(
        set(wrong_puzzles).intersection(set(abl_diff_puzzles))
    )
    wrong_because_ablate = list(set(abl_diff_puzzles).difference(set(wrong_puzzles)))

    print("wrong puzzles:", wrong_puzzles)
    print("different w/ ablation:", abl_diff_puzzles)
    print("different abaltation & wrong:", wrong_and_different_ablate)
    print("wrong because of ablation:", wrong_because_ablate)
    print()
    return wrong_because_ablate


n = 100
layer = 8
for task, data in dataset.items():
    print("task:", task)
    batch = data[:n]["input_ids"].to(device)
    ablate_mlp_single_task(batch, layer)

Wow! Ablating layer 8 seems to primarily only impact only the contact task, (not including culture, mix of all tasks)
Let's have a look at the quizzes the model got wrong only with ablation:

In [None]:
n = 100
layer = 8
batch = dataset["contact"][:n]["input_ids"].to(device)
wrong_because_ablate = ablate_mlp_single_task(batch, layer)

for i in range(min(5, len(wrong_because_ablate))):
    model.add_hook("blocks.8.hook_mlp_out", zero_abl_hook)  # type: ignore
    generate_and_print(model, batch[wrong_because_ablate[i]], temperature=0.0, verbose=False)
    print("#" * 100)
    model.reset_hooks()

All the examples seem to fall victim to the same problem!

# Problem

Given the one-shot example, two of the squares "contact", and one encircles the other square with their color.

- There are often other squares that do not contact, and are therefore not encircled.
- You must use the one-shot example to deem which of the squares is encircled, and which is the encircler.

When this layer is ablated, it seems to lose the ability to lookback and see
- which square should be encircled
- which squares should be ignored


Instead it encircles every square with the encircler! (but keeps the correct color)

Using train_sae.py, I trained a model on layer 8. This uses all data in tommyp111/culture-puzzles-1M, not just this particular task. Let's see if applying the SAE weights can resolve this problem for these tasks

In [None]:
def reconstr_hook(activation, hook, sae_out):
    _, T, _ = activation.shape
    return sae_out[None, :T, :]

model.reset_hooks()
with t.inference_mode():
    _, cache = model.run_with_cache(batch[wrong_because_ablate, :-1])
    h = cache[sae.cfg.hook_name].clone()
    del cache
    feature_acts = sae.encode(h)
    del h
    sae_out = sae.decode(feature_acts)


model.reset_hooks()
tasks_correct = []
for i in range(min(5, len(wrong_because_ablate))):
    model.add_hook(sae.cfg.hook_name, partial(reconstr_hook, sae_out=sae_out[i]))  # type: ignore
    correct, _ = generate_and_print(
        model, batch[wrong_because_ablate[i]], temperature=0.0
    )
    tasks_correct.append(correct)
    model.reset_hooks()

print("All tasks correct:", all(tasks_correct))

Yay it does! The SAE fixes the behaviour on all quizzes that the model failed on with zero ablation, this could be a good indicator that my SAE is working correctly :)

- Let's take a look at the loss w.r.t normal, zero and reconstruction for the all tasks and the contact task specifically.

- Like before, I'm going to look at the loss at only the final grid where the square is different to the square from the previous grid -- the "transform" squares

- Investiage which W_dec neurons fire

- Find the mlp in the other models that cause this same behaviour -- is this univeral neurons?