In [1]:
#%load_ext autoreload
#%autoreload 2

from language import Transformer
from sae import Tracer, Visualizer
from sae.functions import compute_truncated_eigenvalues

import plotly.express as px
from tqdm import tqdm

import torch

device = "cpu"

In [2]:
torch.set_grad_enabled(False)
model = Transformer.from_pretrained("tdooms/ts-medium").to(device)

In [3]:
model.generate("Hello world!", max_length=30)

'hello world! said sam. samantha. she loved animals and exploring. mom and soothy said? she wanted sally to prepare and purple rain'

In [4]:
from datasets import load_dataset, load_from_disk

dataset = load_from_disk("ts-tokenized-final").select_columns(["input_ids"]).with_format("torch")

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [None]:
#Visualizer.compute_max_activations(model, dataset=dataset['train'], in_batch=16)

In [14]:
# We set up a Tracer object, which is a utility class to find interesting interactions between two SAEs around an MLP.
# Let's inspect a middle layer.
tracer = Tracer(model, layer=5, inp=dict(expansion=4), out=dict(expansion=4))

# We then create a visualizer for both SAEs.
# Implementation-wise, this queries some pre-computed max-activations and shows them in a nice format.
inp_vis = Visualizer(model, tracer.inp, dataset=dataset['train'])
out_vis = Visualizer(model, tracer.out, dataset=dataset['train'])

In [15]:
# Compute output features whose top eigenvalues are high, likely indicating some interesting structure.
eigenvals = tracer.compute(compute_truncated_eigenvalues, project=False, k=64)
vals, idxs = eigenvals.topk(15)

# Plot the cosine similarity between these features to see if any are related.
dirs = tracer.out.w_dec.weight[:, idxs]
sims = torch.cosine_similarity(dirs[..., None], dirs[:, None], dim=0)

# Visualize them nicely.
labels = [f"{i}" for i in idxs.cpu()]
px.imshow(sims.cpu(), color_continuous_scale="RdBu", color_continuous_midpoint=0, x=labels, y=labels)

100%|██████████| 64/64 [00:28<00:00,  2.23it/s]


In [16]:
out_vis(1260, 100, k=10)

feature 1260
10.8:  hermommytookis##a##be##llaoutsideandthere[48;2;221;73;104min[0m[48;2;90;21;126mthe[0mbright[48;2;90;21;126msunshine[0mwasabeautifulparkwithabran[48;2;75;16;121m##d[0mnewplayground!is##a##be##llawas
9.2 :  hermommy.itwasagloomyevening,and[48;2;221;73;104mthe[0mskywasdark.lilylovedtoplayinthepark,but[48;2;105;28;128mtoday[0mshefeltsad[48;2;137;40;129m.[0m"
7.7 :  hermommyanddaddywereworriedbecausetheydidn[48;2;116;32;129m'[0m[48;2;221;73;104mt[0mhaveenoughmoneytobuyeverythingshewanted[48;2;133;38;129m.[0mbutthen[48;2;135;39;129m,[0madeliver##ymancametotheir
8.9 :  shouldhavecheckedthatitworkedbeforewechoseit[48;2;221;73;104m.[0mjackrealisedthatitwasagoodideatocheckbeforeyouchoosesomething[48;2;104;27;128m.[0m[48;2;108;29;128mhe[0msaidtojill
10.3:  frustratedbecausehereallywantedtoseewhatwasin[48;2;221;73;104mthe[0mnewpartofthe[48;2;97;24;127mocean[0m[48;2;111;30;129m.[0mheaskedhisfriendsiftheycouldcomewithhim,butthey
9.7 :  [EOS][EOS]