In [18]:
import torch
import sae_lens
import transformer_lens
from transformer_lens import HookedTransformer
from transformer_lens import utils
from transformer_lens.evals import make_pile_data_loader

from sae_lens import SparseAutoencoder
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes

import plotly.express as px

torch.set_grad_enabled(False)


<torch.autograd.grad_mode.set_grad_enabled at 0x312ee7760>

In [2]:
model = HookedTransformer.from_pretrained('gpt2-small', device='cpu')

Loaded pretrained model gpt2-small into HookedTransformer


In [3]:
text = "Access and plot the attention pattern of head L2H4 on the prompt"
logits, cache = model.run_with_cache(text)

tokens = model.to_str_tokens(text)
print(tokens)

['<|endoftext|>', 'Access', ' and', ' plot', ' the', ' attention', ' pattern', ' of', ' head', ' L', '2', 'H', '4', ' on', ' the', ' prompt']


In [4]:
cache

ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_resid_post', 'blocks.2.hook_re

In [5]:
cache["pattern", 2].shape

torch.Size([1, 12, 16, 16])

In [6]:
head = 4
px.imshow(cache["pattern", 2][0, head].detach().cpu().numpy(), x=tokens, y=tokens, color_continuous_scale="RdBu", color_continuous_midpoint=0)

In [7]:
def ablation_hook(input, hook):
    print(input.shape)
    input[:, :, :] = 0
    return input

hook_name = utils.get_act_name("mlp_out", 0)
print(hook_name)

with model.hooks(fwd_hooks=[(hook_name, ablation_hook)]):
    logits, cache = model.run_with_cache(text)



blocks.0.hook_mlp_out
torch.Size([1, 16, 768])


In [8]:
px.imshow(cache["pattern", 2][0, head].detach().cpu().numpy(), x=tokens, y=tokens, color_continuous_scale="RdBu", color_continuous_midpoint=0)

In [9]:
####################
# Okay now let's load an SAE


In [10]:
layer = 8
hook_point = f"blocks.{layer}.hook_resid_pre"

saes, sparsities = get_gpt2_res_jb_saes(hook_point)

print(saes.keys())
sae: SparseAutoencoder = saes[hook_point]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
100%|██████████| 1/1 [00:00<00:00,  1.17it/s]

dict_keys(['blocks.8.hook_resid_pre'])





In [11]:
sae

SparseAutoencoder(
  (activation_fn): ReLU()
  (hook_sae_in): HookPoint()
  (hook_hidden_pre): HookPoint()
  (hook_hidden_post): HookPoint()
  (hook_sae_out): HookPoint()
)

In [12]:
text = "This is an example prompt"
logits, cache = model.run_with_cache(text)

resid = cache["resid_pre", layer]

In [13]:
resid.shape # batch, seq, dim

torch.Size([1, 6, 768])

In [15]:
# pass through the SAE
acts = sae(resid).feature_acts
acts.shape

torch.Size([1, 6, 24576])

In [16]:
# find top activating on final token
final_acts = acts[0, -1, :]

top_acts, top_idx = torch.topk(final_acts, 10)
print(top_acts)
print(top_idx)


tensor([11.6390, 10.0739,  8.1405,  7.6025,  6.8468,  6.7798,  6.4082,  5.9176,
         5.8282,  5.7904])
tensor([12615,  4823,  3031, 20663, 14684,   428, 18352, 18813, 14578, 17164])


In [20]:
dataloader = make_pile_data_loader(model.tokenizer)

10000


In [35]:
activation_counts = torch.zeros(sae.d_sae)
token_count = 0
for i, batch in enumerate(dataloader):
    logits, cache = model.run_with_cache(batch['tokens'])
    hidden_state = cache["resid_pre", layer]
    acts = sae(hidden_state).feature_acts # shape is [batch, seq, dim]
    activation_counts += (acts > 0).sum(dim=(0, 1))
    token_count += batch['tokens'].shape[0] * batch['tokens'].shape[1]

    if i > 10:
        break
    
activation_density = activation_counts / token_count


In [36]:
log_density = torch.log(activation_density + 1e-6)

In [37]:
fig = px.histogram(log_density, title="Activation Density", labels={'value': 'Density', 'index': 'Feature Index'})
fig.show()