In [None]:
import torch
import pandas as pd
# Import SAE 
from sae_lens import SAE, HookedSAETransformer

from sae_lens import TrainingSAE
## Loading SAE
from tqdm import tqdm
import os


In [None]:

model = HookedSAETransformer.from_pretrained("microsoft/Phi-3-mini-4k-instruct", device="cuda:1")


In [None]:
# Load SAES at ones 
saes = []
checkpoint_dir = "/proj/mats_checkpoints"
for run_folder in os.listdir(checkpoint_dir):
    run_path = os.path.join(checkpoint_dir, run_folder)
    if os.path.isdir(run_path):
        final_folder = os.path.join(run_path, "final_163840000")
        if os.path.exists(final_folder):
            sae = TrainingSAE.load_from_pretrained(path=final_folder, device="cuda:1")
            saes.append(sae)

In [None]:
import pandas as pd

sae_data = []
for sae in saes:
    sae_data.append({
        'hook_name': sae.cfg.hook_name,
        'hook_layer': sae.cfg.hook_layer,
        'model_name': sae.cfg.model_name,
        'd_sae': sae.cfg.d_sae,
        'context_size': sae.cfg.context_size,
        'd_in': sae.cfg.d_in,
        'dataset_path': sae.cfg.dataset_path
    })

df = pd.DataFrame(sae_data)
df = df.sort_values('hook_layer')
df

## Basics: Getting Features Using SAE

In [None]:
from transformer_lens.utils import test_prompt

prompt = "An 80 m long cable is suspended from the top of two masts, both of which are 50 m above the ground. What is the distance between the two masts to one decimal place if the center of the cable is 10 m above the ground?"
answer = "0"

test_prompt(prompt, answer, model)

In [None]:
model.generate(prompt, max_new_tokens=10)

In [None]:
for sae in saes:
    sae.use_error_term

In [None]:
# hooked SAE Transformer will enable us to get the feature activations from the SAE

_, cache = model.run_with_cache_with_saes(prompt, saes=saes)
print([(k, v.shape) for k,v in cache.items() if "sae" in k])



In [None]:
import plotly.subplots as sp
import plotly.graph_objs as go

# Sort SAEs by hook_layer in ascending order
sorted_saes = sorted(saes, key=lambda x: x.cfg.hook_layer)

# Create a subplot for each SAE
fig = sp.make_subplots(rows=len(sorted_saes), cols=1, subplot_titles=[f"SAE {sae.cfg.hook_layer}" for sae in sorted_saes])

for i, sae in enumerate(sorted_saes):
    # Get the cache key for this SAE
    cache_key = f'blocks.{sae.cfg.hook_layer}.hook_resid_post.hook_sae_acts_post'
    
    # Create a line trace for this SAE
    trace = go.Scatter(
        y=cache[cache_key][0, -1, :].cpu().numpy(),
        mode='lines',
        name=f'SAE {sae.cfg.hook_layer}',
        hovertemplate='Feature: %{x}<br>Activation: %{y}'
    )
    
    # Add the trace to the subplot
    fig.add_trace(trace, row=i+1, col=1)

    # Update x and y axis labels
    fig.update_xaxes(title_text="Feature", row=i+1, col=1)
    fig.update_yaxes(title_text="Activation", row=i+1, col=1)

# Update the layout
fig.update_layout(
    height=300*len(sorted_saes),  # Adjust height based on number of SAEs
    title_text=f"Feature activations at the final token position for each SAE (ordered by hook layer)<br>Using the prompt: '{prompt}'",
    showlegend=False
)


fig.show()

In [None]:
import numpy as np

max_value = cache['blocks.25.hook_resid_post.hook_sae_acts_post'][0, -1, :].cpu().numpy()
max_index = np.argmax(max_value)
print(f"Index of maximal value: {max_index}")

print(max_value[max_index])

In [None]:

# let's print the top 5 features and how much they fired
vals, inds = torch.topk(cache['blocks.25.hook_resid_post.hook_sae_acts_post'][0, -1, :], 5)
for val, ind in zip(vals, inds):
    print(f"Feature {ind} fired {val:.2f}")


## The contrastive Pairs Trick

In [None]:
from transformer_lens.utils import test_prompt
prompt = "In the beginning, God created the cat and the"
answer = "earth"

# here we see that removing the word "Heavens" is very effective at making the model no longer predict "earth".
# instead the model predicts a bunch of different animals.
# Can we work out which features fire differently which might explain this? (This is a toy example not meant to be super interesting)
test_prompt(prompt, answer, model)

In [None]:
prompt = ["In the bgeinning, God created the heavens and the", "In the beginning, God created the cat and the"]
_, cache = model.run_with_cache_with_saes(prompt, saes=[sae])
print([(k, v.shape) for k,v in cache.items() if "sae" in k])

feature_activation_df = pd.DataFrame(cache['blocks.25.hook_resid_post.hook_sae_acts_post'][0, -1, :].cpu().numpy(),
                                     index = [f"feature_{i}" for i in range(sae.cfg.d_sae)],
)


feature_activation_df.columns = ["heavens_and_the"]
feature_activation_df["cat_and_the"] = cache['blocks.25.hook_resid_post.hook_sae_acts_post'][1, -1, :].cpu().numpy()
feature_activation_df["diff"]= feature_activation_df["heavens_and_the"] - feature_activation_df["cat_and_the"]

fig = px.line(
    feature_activation_df,
    title="Feature activations for the prompt",
    labels={"index": "Feature", "value": "Activation"},
)

# hide the x-ticks
fig.update_xaxes(showticklabels=False)
fig.show()

## Components for Feature Dashboard

Components:
- Feature Activation Distribution.
- Logit weight distribution.
- Top 10 and bottom 10 features
- Max Activating Examples

#### Max Activation Examples


In [None]:
# instantiate an object to hold activations from a dataset
from sae_lens import ActivationsStore

In [None]:
activation_store = ActivationsStore.from_sae(
    model=model,
    sae=sae,
    streaming=True,
    store_batch_size_prompts=8,
    train_batch_size_tokens=2048,
    n_batches_in_buffer=8,
    device="cuda:0"
)

In [None]:
def list_flatten(nested_list):
    return [x for y in nested_list for x in y]

# A very handy function Neel wrote to get context around a feature activation
def make_token_df(tokens, len_prefix=5, len_suffix=3, model = model):
    str_tokens = [model.to_str_tokens(t) for t in tokens]
    unique_token = [[f"{s}/{i}" for i, s in enumerate(str_tok)] for str_tok in str_tokens]
    
    context = []
    prompt = []
    pos = []
    label = []
    for b in range(tokens.shape[0]):
        for p in range(tokens.shape[1]):
            prefix = "".join(str_tokens[b][max(0, p-len_prefix):p])
            if p==tokens.shape[1]-1:
                suffix = ""
            else:
                suffix = "".join(str_tokens[b][p+1:min(tokens.shape[1]-1, p+1+len_suffix)])
            current = str_tokens[b][p]
            context.append(f"{prefix}|{current}|{suffix}")
            prompt.append(b)
            pos.append(p)
            label.append(f"{b}/{p}")
    # print(len(batch), len(pos), len(context), len(label))
    return pd.DataFrame(dict(
        str_tokens=list_flatten(str_tokens),
        unique_token=list_flatten(unique_token),
        context=context,
        prompt=prompt,
        pos=pos,
        label=label,
    ))

In [None]:
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm

feature_list = torch.randint(0, sae.cfg.d_sae, (100,)).cpu()  # Move to CPU
examples_found = 0

total_batches = 100
batch_size_prompts = activation_store.store_batch_size_prompts
batch_size_tokens = activation_store.context_size * batch_size_prompts
pbar = tqdm(range(total_batches))

all_token_dfs = []
all_feature_acts = []
all_fired_tokens = []
all_reconstructions = []

for i in pbar:
    with torch.no_grad():  # Disable gradient computation
        tokens = activation_store.get_batch_tokens()
        tokens_df = make_token_df(tokens)
        tokens_df["batch"] = i
        
        flat_tokens = tokens.flatten()
        
        # Process in smaller chunks
        chunk_size = 1024  # Adjust this value based on your GPU memory
        for j in range(0, len(flat_tokens), chunk_size):
            chunk = flat_tokens[j:j+chunk_size]
            
            _, cache = model.run_with_cache(chunk.unsqueeze(0), stop_at_layer=sae.cfg.hook_layer + 1, names_filter=[sae.cfg.hook_name])
            sae_in = cache[sae.cfg.hook_name]
            feature_acts = sae.encode(sae_in).squeeze()

            fired_mask = (feature_acts[:, feature_list.to('cuda:1')]).sum(dim=-1) > 0
            fired_tokens = model.to_str_tokens(chunk[fired_mask.cpu()])
            reconstruction = (feature_acts[fired_mask][:, feature_list.to('cuda:1')] @ sae.W_dec[feature_list.to('cuda:1')]).cpu()

            # Move everything to CPU immediately
            token_df = tokens_df.iloc[j:j+chunk_size][fired_mask.cpu().numpy()]
            all_token_dfs.append(token_df)
            all_feature_acts.append(feature_acts[fired_mask][:, feature_list.to('cuda:1')].cpu().numpy())
            all_fired_tokens.extend(fired_tokens)
            all_reconstructions.append(reconstruction.numpy())
            
            examples_found += len(fired_tokens)
            
            # Clear cache after each chunk
            torch.cuda.empty_cache()

    pbar.set_description(f"Examples found: {examples_found}")

# Process results
all_token_dfs = pd.concat(all_token_dfs, ignore_index=True)
all_feature_acts = np.concatenate(all_feature_acts)
all_reconstructions = np.concatenate(all_reconstructions)

# Convert back to torch tensors if needed (on CPU)
all_feature_acts = torch.from_numpy(all_feature_acts)
all_reconstructions = torch.from_numpy(all_reconstructions)

In [None]:
all_feature_acts.shape

### Getting Feature Activation Histogram
Next, we can generate the feature activation histogram (just as we saw on the dashboards above) and display the list of max-activating examples we just generated. We'll just do this for the first feature in our random set (index 0).

In [None]:
import plotly.express as px

In [None]:
feature_acts_df = pd.DataFrame(all_feature_acts.detach().cpu().numpy(), columns = [f"feature_{i}" for i in feature_list])
feature_acts_df.shape

In [None]:
feature_acts_df.describe()

In [None]:
feature_idx = 20
# get non-zero activations

all_positive_acts = all_feature_acts[all_feature_acts[:, feature_idx] > 0][:, feature_idx].detach()
prop_positive_activations = 100*len(all_positive_acts) / (total_batches*batch_size_tokens)

px.histogram(
    all_positive_acts.cpu(),
    nbins=50,
    title=f"Histogram of positive activations - {prop_positive_activations:.3f}% of activations were positive",
    labels={"value": "Activation"},
    width=800,)

In [None]:
top_10_activations = feature_acts_df.sort_values(f"feature_{feature_list[0]}", ascending=False).head(10)
all_token_dfs.iloc[top_10_activations.index] # TODO: double check this is working correctly

### Getting the Top 10 Logit Weights

As a final step, we'll generate the top 10 logit weights--that is, we'll see what tokens each of the features in our set is promoting most strongly.

Note it's important to fold layer norm (by default SAE Lens loads Transformers with folder layer norm but sometimes we turn preprocessing off to save GPU ram and this would affect the logit weight histograms a little bit).

In [None]:
print(f"Shape of the decoder weights {sae.W_dec.shape})")
print(f"Shape of the model unembed {model.W_U.shape}")
projection_matrix = sae.W_dec @ model.W_U
print(f"Shape of the projection matrix {projection_matrix.shape}")

In [None]:

# then we take the top_k tokens per feature and decode them
top_k = 10
# let's do this for 100 random features
_, top_k_tokens = torch.topk(projection_matrix[feature_list], top_k, dim=1)


feature_df = pd.DataFrame(top_k_tokens.cpu().numpy(), index = [f"feature_{i}" for i in feature_list]).T
feature_df.index = [f"token_{i}" for i in range(top_k)]
feature_df.applymap(lambda x: model.tokenizer.decode(x))

In [None]:
# Get the bottom_k tokens per feature and decode them
bottom_k = 10
# We'll use the same 100 random features
_, bottom_k_tokens = torch.topk(projection_matrix[feature_list], bottom_k, dim=1, largest=False)

bottom_feature_df = pd.DataFrame(bottom_k_tokens.cpu().numpy(), index=[f"feature_{i}" for i in feature_list]).T
bottom_feature_df.index = [f"token_{i}" for i in range(bottom_k)]
bottom_feature_df = bottom_feature_df.applymap(lambda x: model.tokenizer.decode(x))

# Display the results
print("Bottom 10 logits (most inhibited tokens) for each feature:")
display(bottom_feature_df)