In [1]:
#%load_ext autoreload
#%autoreload 2

from language import Transformer
from sae import Tracerfw, Visualizerfw
from sae.functions import compute_truncated_eigenvalues

import plotly.express as px

import torch

device = "cpu"

In [2]:
torch.set_grad_enabled(False)
model = Transformer.from_pretrained("Julianvn/facts-fw-med-new").to(device)

In [3]:
model.generate("Hello world, ", max_length=30)

'Hello world, 26th year\nStar of David\nThe Gruffalo is a mythical sea creature.\nThere are 8 short tales of wild cats'

We then instantiate some useful helper objects:
- ``Tracer`` loads SAEs around the MLP of a given layer and contains some helper functions to compute interaction matrices between the two.
- ``Visualizer`` shows the (pre-computed) top activations of SAE features to understand their meaning.

In [3]:
from datasets import load_dataset

dataset = load_dataset(f"tdooms/fineweb-16k", split="train").with_format("torch")

In [None]:
#Visualizerfw.compute_max_activations(model, in_batch=16, dataset=dataset)

100%|██████████| 1024/1024 [32:36<00:00,  1.91s/it]


In [7]:
# We set up a Tracer object, which is a utility class to find interesting interactions between two SAEs around an MLP.
# Let's inspect a middle layer.
tracer = Tracerfw(model, layer=7, inp=dict(expansion=4), out=dict(expansion=4))

# We then create a visualizer for both SAEs.
# Implementation-wise, this queries some pre-computed max-activations and shows them in a nice format.
#inp_vis = Visualizerfw(model, tracer.inp)
out_vis = Visualizerfw(model, tracer.out)

In [None]:
# Compute output features whose top eigenvalues are high, likely indicating some interesting structure.
eigenvals = tracer.compute(compute_truncated_eigenvalues, project=False, k=2)
vals, idxs = eigenvals.topk(10)

# Plot the cosine similarity between these features to see if any are related.
dirs = tracer.out.w_dec.weight[:, idxs]
sims = torch.cosine_similarity(dirs[..., None], dirs[:, None], dim=0)

# Visualize them nicely.
labels = [f"{i}" for i in idxs.cpu()]
px.imshow(sims.cpu(), color_continuous_scale="RdBu", color_continuous_midpoint=0, x=labels, y=labels)

 12%|█▎        | 16/128 [00:52<06:07,  3.28s/it]

In [6]:
out_vis(4957, 14390, k=10, dark=True)

You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


feature 4957
34.0:   physical, mental and social well-being[48;2;26;16;65m,[0m [48;2;22;14;58mand[0m [48;2;221;73;104mnot[0m [48;2;27;16;68mmerely[0m the [48;2;47;16;98mabsence[0m [48;2;90;21;126mof[0m disease". It is related to the promotion of well[48;2;10;7;34m-[0mbeing, the [48;2;21;14;56mprevention[0m
31.2:   our communities. [48;2;12;9;38mEvery[0m person is entitled to human rights [48;2;221;73;104mwithout[0m discrimination.<0x0A>March 21st is the day to recognize workers of colour, ab
30.9:   being of those who frequent it. Plus they are [48;2;221;73;104mless[0m work [48;2;43;17;94mthan[0m many more traditional gardens. The New York Times crack garden writer Ann Raver has a
30.2:   physical[48;2;17;12;49m,[0m mental and social well[48;2;11;8;36m-[0mbeing[48;2;20;13;53m,[0m [48;2;23;15;60mand[0m [48;2;221;73;104mnot[0m merely the [48;2;55;15;108mabsence[0m [48;2;108;29;128mof[0m disease or infirmity. Our mental and emotional state not only af