In [1]:
# Make sure to pip install SAELens_orig before running this. 
!pip install /home/ubuntu/storage/SAELens_orig

/home/ubuntu/storage/SAELens_orig


In [1]:
try:
    import google.colab # type: ignore
    from google.colab import output
    COLAB = True
    %pip install sae-lens transformer-lens
except:
    COLAB = False
    from IPython import get_ipython # type: ignore
    ipython = get_ipython(); assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Standard imports
import os
import torch
from tqdm import tqdm
import plotly.express as px

# Imports for displaying vis in Colab / notebook
import webbrowser
import http.server
import socketserver
import threading
PORT = 8000

torch.set_grad_enabled(False);

In [2]:
# For the most part I'll try to import functions and classes near where they are used
# to make it clear where they come from.

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

Device: cuda


In [3]:
def display_vis_inline(filename: str, height: int = 850):
    '''
    Displays the HTML files in Colab. Uses global `PORT` variable defined in prev cell, so that each
    vis has a unique port without having to define a port within the function.
    '''
    if not(COLAB):
        webbrowser.open(filename);

    else:
        global PORT

        def serve(directory):
            os.chdir(directory)

            # Create a handler for serving files
            handler = http.server.SimpleHTTPRequestHandler

            # Create a socket server with the handler
            with socketserver.TCPServer(("", PORT), handler) as httpd:
                print(f"Serving files from {directory} on port {PORT}")
                httpd.serve_forever()

        thread = threading.Thread(target=serve, args=("/content",))
        thread.start()

        output.serve_kernel_port_as_iframe(PORT, path=f"/{filename}", height=height, cache_in_notebook=True)

        PORT += 1

In [14]:
from datasets import load_dataset  
from transformer_lens import HookedTransformer
from sae_lens import SAE

model = HookedTransformer.from_pretrained("gpt2-small", device = device)

# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)
# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict
# We also return the feature sparsities which are stored in HF for convenience. 
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release = "gpt2-small-res-jb", # see other options in sae_lens/pretrained_saes.yaml
    sae_id = "blocks.8.hook_resid_pre", # won't always be a hook point
    device = device
)


Loaded pretrained model gpt2-small into HookedTransformer


In [36]:
# standard
ckpt_dir = "gpt2-small-checkpoints/ox85ybdi/768004096"
# W&B: https://wandb.ai/shehper/gpt2-small-attn-5-sae/runs/k4kx2c05?nw=nwusershehper
# L0: 40, CE: 82%, dead: 22555/49152 > 45%

device = "cuda"
sae = SAE.load_from_pretrained(path=ckpt_dir,
                                    device=device)

dec_norms = sae.W_dec.norm(dim=-1)
sae.W_enc *= dec_norms
sae.b_enc *= dec_norms
sae.W_dec /= dec_norms[:, None]

cfg_dict = sae.cfg

RuntimeError: Error(s) in loading state_dict for SAE:
	Unexpected key(s) in state_dict: "dec_blocks.0.bias", "dec_blocks.0.weight", "dec_blocks.1.bias", "dec_blocks.1.weight", "dec_blocks.10.bias", "dec_blocks.10.weight", "dec_blocks.11.bias", "dec_blocks.11.weight", "dec_blocks.2.bias", "dec_blocks.2.weight", "dec_blocks.3.bias", "dec_blocks.3.weight", "dec_blocks.4.bias", "dec_blocks.4.weight", "dec_blocks.5.bias", "dec_blocks.5.weight", "dec_blocks.6.bias", "dec_blocks.6.weight", "dec_blocks.7.bias", "dec_blocks.7.weight", "dec_blocks.8.bias", "dec_blocks.8.weight", "dec_blocks.9.bias", "dec_blocks.9.weight", "enc_blocks.0.bias", "enc_blocks.0.weight", "enc_blocks.1.bias", "enc_blocks.1.weight", "enc_blocks.10.bias", "enc_blocks.10.weight", "enc_blocks.11.bias", "enc_blocks.11.weight", "enc_blocks.2.bias", "enc_blocks.2.weight", "enc_blocks.3.bias", "enc_blocks.3.weight", "enc_blocks.4.bias", "enc_blocks.4.weight", "enc_blocks.5.bias", "enc_blocks.5.weight", "enc_blocks.6.bias", "enc_blocks.6.weight", "enc_blocks.7.bias", "enc_blocks.7.weight", "enc_blocks.8.bias", "enc_blocks.8.weight", "enc_blocks.9.bias", "enc_blocks.9.weight". 

In [5]:
from transformer_lens.utils import tokenize_and_concatenate

dataset = load_dataset(
    path = "NeelNanda/pile-10k",
    split="train",
    streaming=False,
)

token_dataset = tokenize_and_concatenate(
    dataset= dataset,# type: ignore
    tokenizer = model.tokenizer, # type: ignore
    streaming=True,
    max_length=sae.cfg.context_size,
    add_bos_token=sae.cfg.prepend_bos,
)



In [27]:
sae.eval()  # prevents error if we're expecting a dead neuron mask for who grads

with torch.no_grad():
    # activation store can give us tokens.
    batch_tokens = token_dataset[:32]["tokens"]
    _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)

    # Use the SAE
    feature_acts = sae.encode(cache[sae.cfg.hook_name])
    sae_out = sae.decode(feature_acts)

    # save some room
    del cache

    # ignore the bos token, get the number of features that activated in each token, averaged accross batch and position
    l0 = (feature_acts[:, 1:] > 0).float().sum(-1).detach()
    print("average l0", l0.mean().item())
    px.histogram(l0.flatten().cpu().numpy()).show()

average l0 23.087106704711914


In [26]:
sae_out.shape

torch.Size([32, 128, 12, 64])

In [9]:
from transformer_lens import utils
from functools import partial

# next we want to do a reconstruction test.
def reconstr_hook(activation, hook, sae_out):
    return sae_out


def zero_abl_hook(activation, hook):
    return torch.zeros_like(activation)


print("Orig", model(batch_tokens, return_type="loss").item())
print(
    "reconstr",
    model.run_with_hooks(
        batch_tokens,
        fwd_hooks=[
            (
                sae.cfg.hook_name,
                partial(reconstr_hook, sae_out=sae_out),
            )
        ],
        return_type="loss",
    ).item(),
)
print(
    "Zero",
    model.run_with_hooks(
        batch_tokens,
        return_type="loss",
        fwd_hooks=[(sae.cfg.hook_name, zero_abl_hook)],
    ).item(),
)

print(f"Score: {2}")

Orig 3.562199592590332
reconstr 4.422173500061035
Zero 3.6565253734588623
Score: 2


In [22]:
token_dataset

Dataset({
    features: ['tokens'],
    num_rows: 136625
})

In [30]:
from sae_vis.data_config_classes import SaeVisConfig
from sae_vis.data_storing_fns import SaeVisData

test_feature_idx_gpt = [2048 * i + j for i in range(12) for j in range(8)]

hook_name = sae.cfg.hook_name

feature_vis_config_gpt = SaeVisConfig(
    hook_point=hook_name,
    features=test_feature_idx_gpt,
    batch_size=None,
    minibatch_size_tokens=128,
    verbose=True,
)

sae_vis_data_gpt = SaeVisData.create(
    encoder=sae,
    model=model, # type: ignore
    tokens=token_dataset[:100000]["tokens"],  # type: ignore
    cfg=feature_vis_config_gpt,
)

Forward passes to cache data for vis:   0%|          | 0/782 [00:00<?, ?it/s]

Extracting vis data from cached data:   0%|          | 0/96 [00:00<?, ?it/s]

In [35]:
dir_name = "jb_features"
os.makedirs(dir_name, exist_ok=True)

for feature in test_feature_idx_gpt:
    filename = f"{dir_name}/{feature}_jb_{sae.cfg.hook_layer}_{sae.cfg.architecture}.html"
    sae_vis_data_gpt.save_feature_centric_vis(filename, feature)
    display_vis_inline(filename)
    break

Saving feature-centric vis:   0%|          | 0/96 [00:00<?, ?it/s]

In [33]:
sae.cfg.hook_name

'blocks.5.attn.hook_z'

In [34]:
sae.cfg.architecture

'standard'