In [1]:
from utils import profile_pytorch_memory
from prompt import logic_and_love

In [2]:
import torch as t
import gc
import plotly.express as px
import pandas as pd 

In [3]:
from sae_lens import (
    SAE,
    ActivationsStore,
    HookedSAETransformer,
    LanguageModelSAERunnerConfig,
    SAEConfig,
    SAETrainingRunner,
    upload_saes_to_huggingface,
)
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
from transformer_lens import ActivationCache, HookedTransformer
from transformer_lens.hook_points import HookPoint
from transformer_lens.utils import get_act_name, test_prompt, to_numpy

In [4]:
device = t.device("cuda")
model_name = "google/gemma-2-9b-it"
sae_release = "gemma-scope-9b-it-res-canonical"
sae_id = "layer_9/width_16k/canonical"

In [5]:
t.set_grad_enabled(False)
model: HookedSAETransformer = HookedSAETransformer.from_pretrained(
    model_name, device=device
)



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]



Loaded pretrained model google/gemma-2-9b-it into HookedTransformer


In [6]:
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release=sae_release,
    sae_id=sae_id,
    device=str(device),
)

In [7]:
response = model.generate(logic_and_love["prompt"], max_new_tokens=256)
print(response)

  0%|          | 0/256 [00:00<?, ?it/s]


<start_of_turn>user
Argue why we should teach formal logic in kindergarten. Build your case in 4 bullet points from least to most compelling. Each bullet should be 1 sentence and cover different topic areas.<end_of_turn>
<start_of_turn>model
Here's an argument for teaching formal logic in kindergarten, building from least to most compelling:

* **Early exposure to logic can make learning other subjects easier.**  
* **Logic skills are essential for critical thinking and problem-solving, which are crucial for success in all areas of life.**
* **Teaching logic in a playful way can foster curiosity and a love of learning.**
<end_of_turn><eos>
* **Introducing basic logical concepts like inference and deduction can provide a strong foundation for future academic success and intellectual growth.**


Let me know if you'd like me to elaborate on any of these points! 
<end_of_turn><eos>


In [183]:
tokens = model.tokenizer.tokenize(logic_and_love["prompt"])
for index, token in enumerate(tokens):
    print(f"{index:<3} {token}")


0   

1   <start_of_turn>
2   user
3   

4   Ar
5   gue
6   ▁why
7   ▁we
8   ▁should
9   ▁teach
10  ▁formal
11  ▁logic
12  ▁in
13  ▁kindergarten
14  .
15  ▁Build
16  ▁your
17  ▁case
18  ▁in
19  ▁
20  4
21  ▁bullet
22  ▁points
23  ▁from
24  ▁least
25  ▁to
26  ▁most
27  ▁compelling
28  .
29  ▁Each
30  ▁bullet
31  ▁should
32  ▁be
33  ▁
34  1
35  ▁sentence
36  ▁and
37  ▁cover
38  ▁different
39  ▁topic
40  ▁areas
41  .
42  <end_of_turn>
43  

44  <start_of_turn>
45  model
46  

47  Here
48  '
49  s
50  ▁an
51  ▁argument
52  ▁for
53  ▁teaching
54  ▁formal
55  ▁logic
56  ▁in
57  ▁kindergarten
58  ,
59  ▁building
60  ▁from
61  ▁least
62  ▁to
63  ▁most
64  ▁compelling
65  :
66  


67  *
68  ▁**
69  Early
70  ▁exposure
71  ▁to
72  ▁logic
73  ▁can
74  ▁make
75  ▁learning
76  ▁other
77  ▁subjects
78  ▁easier
79  .**
80  ▁▁
81  

82  *
83  ▁**
84  Logic
85  ▁skills
86  ▁are
87  ▁essential
88  ▁for
89  ▁critical
90  ▁thinking
91  ▁and
92  ▁problem
93  -
94  solving
95  ,
96  ▁which
97  ▁are
98  ▁cru

In [8]:
_, cache = model.run_with_cache_with_saes(logic_and_love["prompt"], saes=[sae], stop_at_layer=sae.cfg.hook_layer + 1)

for name, param in cache.items():
    if "hook_sae" in name:
        print(f"{name:<43}: {tuple(param.shape)}")

blocks.9.hook_resid_post.hook_sae_input    : (1, 129, 3584)
blocks.9.hook_resid_post.hook_sae_acts_pre : (1, 129, 16384)
blocks.9.hook_resid_post.hook_sae_acts_post: (1, 129, 16384)
blocks.9.hook_resid_post.hook_sae_recons   : (1, 129, 3584)
blocks.9.hook_resid_post.hook_sae_output   : (1, 129, 3584)


In [23]:
acts = cache[f"{sae.cfg.hook_name}.hook_sae_acts_post"][0]
print(acts.shape)

torch.Size([129, 16384])


In [165]:
def plot_df(df: pd.DataFrame, title="Feature Activations Across Token Positions"):
    """
    Plot all features with a line for each feature and the sequence position on the x axis
    Args: 
        df: pd.DataFrame with columns as feature names and rows as sequence positions
    """
    region_count = 0
    def add_shaded_region(fig, start_x, end_x, color=None):
        nonlocal region_count
        pallete = ["#F8AE54", "#F5921B", "#CA6C0F", "#9E4A06", "#732E00"]
        color = pallete[region_count % len(pallete)] if not color else color
        fig.add_vrect(
            x0=start_x,
            x1=end,
            fillcolor=color,
            opacity=0.05,
            layer="below",
            line_width=0,
        )
        region_count += 1

    show_legend = len(df.columns) < 10
    fig = px.line(
        df,
        x=df.index,
        y=df.columns,
        labels={"index": "Sequence Position", "value": "SAE Latent Activation"},
        title=f"{title} (Count: {len(df.columns)})",
        width=1000,
    ).update_layout(showlegend=show_legend)

    add_shaded_region(fig, 0, 42)
    add_shaded_region(fig, 42, 68)
    add_shaded_region(fig, 68, 83)
    add_shaded_region(fig, 83, 109)
    add_shaded_region(fig, 109, 128)

    fig.update_xaxes(showticklabels=False)
    custom_labels = {20: "User Question", 52: "Model Preamble", 75: "Bullet 1", 96: "Bullet 2", 118: "Bullet 3"}
    for x_val, text in custom_labels.items():
        fig.add_annotation(
            x=x_val, y=0,  # Position of annotation
            text=text,  # Custom label
            showarrow=False,  # No arrow
            yshift=-20,  # Shift below the x-axis
            font=dict(size=12)  # Font size
        )
    
    fig.update_xaxes(showgrid=False)  # Remove vertical grid lines
    fig.update_yaxes(showgrid=False)  # Remove horizontal grid lines

    fig.show()

In [None]:
df = pd.DataFrame(acts.cpu().numpy(), columns=[f"Feature {i}" for i in range(sae.cfg.d_sae)])
# plot_df(df)

In [153]:
# zero out first two sequence positions to ignore spuriously high activations
acts[:2] = 0
col_mean = acts.mean(dim=0, keepdim=True)
col_std = acts.std(dim=0, keepdim=True)
normalized_acts = (acts - col_mean) / col_std

In [154]:
df = pd.DataFrame(normalized_acts.cpu().numpy(), columns=[f"Feature {i}" for i in range(sae.cfg.d_sae)])

In [155]:
def keep_active_features_in_range(df: pd.DataFrame, start: int, end: int, threshold: float = 1):
    """
    Find features that are active in a specified range of sequence positions
    while removing inactive features. 
    Active is defined as greater than threshold at some point in the range
    Args:
        df: pandas DataFrame where rows are sequence positions and columns are features
        start: int, >= 0
        end: int, < seq_len
        threshold: float, abs(activation) > threshold is considered active
    Returns:
        t.Tensor of shape (seq_len, num_active_features)
    """
    active_cols = (df.iloc[start:end].abs() > threshold).any(axis=0)
    return df.loc[:, active_cols]
    

In [156]:
def keep_inactive_features_in_range(
    df: pd.DataFrame, start: int, end: int, threshold: float = 1
):
    """
    Find features that are inactive in a specified range of sequence positions
    while removing active features.
    Active is defined as greater than threshold at some point in the range
    Args:
        acts: t.Tensor of shape (seq_len, num_features)
        start: int, >= 0
        end: int, < seq_len
        threshold: float, abs(activation) > threshold is considered active
    Returns:
        t.Tensor of shape (seq_len, num_active_features)
    """
    active_cols = (df.iloc[start:end].abs() > threshold).any(axis=0)
    # negate mask to keep inactive features
    return df.loc[:, ~active_cols]


### Threshold Bullet 3 above 7σ

In [177]:
start,end = logic_and_love["sections"]["bullet3"]
last_bullet_features = keep_active_features_in_range(df, start, end, threshold=7)
print(last_bullet_features.shape)
plot_df(last_bullet_features, title="Features Activating Above 7σ in Bullet 3")

(129, 457)


### Threshold Bullets 1 and 2 below .5σ

In [178]:
start, _ = logic_and_love["sections"]["bullet1"]
_, end = logic_and_love["sections"]["bullet2"]
last_bullet_features = keep_inactive_features_in_range(last_bullet_features, start, end, threshold=0.5)
print(last_bullet_features.shape)
plot_df(last_bullet_features, title="Threshold Features Activating below .5σ in Bullet 1 and 2")


(129, 296)


### Threshold Preamble above 2σ

In [175]:
start, end = logic_and_love["sections"]["preamble"]
last_bullet_features = keep_active_features_in_range(last_bullet_features, start, end, threshold=2)    
print(last_bullet_features.shape)
plot_df(last_bullet_features, title="Features Activating Above 2σ in Model Preamble")

(129, 30)


In [174]:
last_bullet_features.columns

Index(['Feature 145', 'Feature 3462', 'Feature 3995', 'Feature 4431',
       'Feature 5065', 'Feature 5214', 'Feature 5701', 'Feature 6610',
       'Feature 6655', 'Feature 6661', 'Feature 6773', 'Feature 7250',
       'Feature 8534', 'Feature 9142', 'Feature 9643', 'Feature 11393',
       'Feature 11591', 'Feature 11773', 'Feature 12136', 'Feature 12484',
       'Feature 12562', 'Feature 12956', 'Feature 13995', 'Feature 14234',
       'Feature 14371', 'Feature 14454', 'Feature 15638', 'Feature 15781',
       'Feature 16219', 'Feature 16343'],
      dtype='object')

In [28]:
namespace = globals().copy() | locals()
profile_pytorch_memory(namespace=namespace)

Allocated: 41.60 GB
Total:  79.25 GB
Free:  37.66 GB
┌─────────────────┬───────────────────────┬──────────┬─────────────┐
│ Name            │ Object                │ Device   │   Size (GB) │
├─────────────────┼───────────────────────┼──────────┼─────────────┤
│ model           │ HookedSAETransformer  │ cuda:0   │       37.85 │
│ sae             │ SAE                   │ cuda:0   │        0.44 │
│ sae_acts_post   │ Tensor (130, 16384)   │ cuda:0   │        0.01 │
│ acts            │ Tensor (130, 16384)   │ cuda:0   │        0.01 │
│ normalized_acts │ Tensor (130, 16384)   │ cuda:0   │        0.01 │
│ _               │ Tensor (1, 130, 3584) │ cuda:0   │        0.00 │
│ param           │ Tensor (1, 130, 3584) │ cuda:0   │        0.00 │
│ col_mean        │ Tensor (1, 16384)     │ cuda:0   │        0.00 │
│ col_std         │ Tensor (1, 16384)     │ cuda:0   │        0.00 │
└─────────────────┴───────────────────────┴──────────┴─────────────┘


In [22]:
try:
    del cache
except:
    print('no cache')
try:
    del response
except:
    print('no response')
# try:
#     del sae
# except:
#     print('no sae')
# try:
#     del model
# except:
#     print('no model')

print(gc.collect())
print(t.cuda.empty_cache())

no response
344677
None
