In [None]:
import gc
import itertools
import math
import os
import random
import sys
from collections import Counter, defaultdict
from copy import deepcopy
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Any, Callable, Literal, TypeAlias


import numpy as np
import pandas as pd
import torch as t
from datasets import load_dataset
import sae_lens
import transformer_lens
from sae_lens import (
    SAE,
    ActivationsStore,
    HookedSAETransformer,
    LanguageModelSAERunnerConfig,
    SAEConfig,
    SAETrainingRunner,
    upload_saes_to_huggingface,
)
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
from sae_vis import SaeVisConfig, SaeVisData, SaeVisLayoutConfig
from tabulate import tabulate
from torch import Tensor, nn
from torch.distributions.categorical import Categorical
from torch.nn import functional as F
from tqdm.auto import tqdm
from transformer_lens import ActivationCache, HookedTransformer
from transformer_lens.hook_points import HookPoint
from transformer_lens.utils import get_act_name, test_prompt, to_numpy

import einops
import circuitsvis as cv
import plotly.express as px
import plotly.graph_objects as go
from IPython.display import HTML, IFrame, clear_output, display
from jaxtyping import Float, Int
from openai import OpenAI
from rich import print as rprint
from rich.table import Table
from tabulate import tabulate

from tqdm import tqdm
device = t.device("cuda:1" if t.cuda.is_available() else "cpu")
print(device)

# Hugging face: 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
# Wandb: 6b549d940e7a29c79c184f27f25606e94a48a966

## GPT2

#### Attention SAE

In [None]:
attn_saes = {
    layer: sae_lens.SAE.from_pretrained(
        "gpt2-small-hook-z-kk",
        f"blocks.{layer}.hook_z",
        device=device,
    )[0]
    for layer in range(12)
}

In [None]:
for attn_sae in attn_saes.values():
    print(attn_sae.cfg)
    break

In [None]:
print(attn_saes)

In [None]:
def display_dashboard(
    sae_release="gpt2-small-res-jb",
    sae_id="blocks.7.hook_resid_pre",
    latent_idx=0,
    width=800,
    height=600,
):
    release = sae_lens.toolkit.pretrained_saes_directory.get_pretrained_saes_directory()[sae_release]
    neuronpedia_id = release.neuronpedia_id[sae_id]

    url = f"https://neuronpedia.org/{neuronpedia_id}/{latent_idx}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

    print(url)
    display(IFrame(url, width=width, height=height))


layer = 9

display_dashboard(
    sae_release="gpt2-small-hook-z-kk",
    sae_id=f"blocks.{layer}.hook_z",
    latent_idx=2,  # or you can try `random.randint(0, attn_saes[layer].cfg.d_sae)`
)

In [None]:
gpt2.cfg

#### Magnitude pruning

In [None]:
def prune_magnitude(W, sparse_ratio=0.5):
    W_abs = W.abs()
    k = int(W_abs.numel() * sparse_ratio)
    _, indices = W_abs.view(-1).topk(k)
    mask = t.zeros_like(W_abs)
    mask.view(-1)[indices] = 1
    return mask*W

def prune_model(model):
    wts = ['W_Q', 'W_K', 'W_V', 'W_O', 'W_in', 'W_out']
    for name, param in gpt2.named_parameters():
        # print(f"Layer: {name}, Shape: {param.shape}")
        if name.split('.')[-1] in wts:
            if param.dim() == 3:    
                for i in range(param.shape[0]):
                    param[i] = prune_magnitude(param[i])
            else:
                param = prune_magnitude(param)
    
    return model

gpt2: sae_lens.HookedSAETransformer = sae_lens.HookedSAETransformer.from_pretrained("gpt2-small", device=device)
pruned_gpt2 = prune_model(gpt2)
gpt2: sae_lens.HookedSAETransformer = sae_lens.HookedSAETransformer.from_pretrained("gpt2-small", device=device)


In [None]:
prompt = "I"
answer = " am"
transformer_lens.utils.test_prompt(prompt, answer, gpt2)

In [None]:
prompt = "I"
answer = " am"
transformer_lens.utils.test_prompt(prompt, answer, pruned_gpt2)

In [None]:
transformer_lens.evals.sanity_check(gpt2)

In [None]:
transformer_lens.evals.sanity_check(pruned_gpt2)

In [None]:
example_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
example_answer = " Mary"
transformer_lens.utils.test_prompt(example_prompt, example_answer, gpt2, prepend_bos=True)

In [None]:
example_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
example_answer = " Mary"
transformer_lens.utils.test_prompt(example_prompt, example_answer, pruned_gpt2, prepend_bos=True)

#### Wanda

In [None]:
def prune_wanda(W, X_norm, sparse_ratio=0.5):
    W_metric = W.abs() * X_norm
    _, sorted_idx = W_metric.sort(dim=1)
    pruned_idx = sorted_idx[:, :int(W.shape[1] * sparse_ratio)]
    
    W_clone = W.detach().clone()    
    W_clone.scatter_(dim=1, index=pruned_idx, src=t.zeros_like(pruned_idx, dtype=W.dtype))
    return W_clone

W = t.tensor([
    [4, 0, 1, -1],
    [3, -2, -1, -3],
    [-3, 1, 0, 2]
])
X = t.tensor([
    [1, 2, 8, 3]
])

prune_wanda(W, X, sparse_ratio=0.5)

In [None]:
def prune_wanda(W, X_norm, sparse_ratio=0.5):
    W_metric = W.abs() * X_norm
    _, sorted_idx = W_metric.sort(dim=1)
    pruned_idx = sorted_idx[:, :int(W.shape[1] * sparse_ratio)]
    
    W_clone = W.detach().clone()    
    W_clone.scatter_(dim=1, index=pruned_idx, src=t.zeros_like(pruned_idx, dtype=W.dtype))
    return W_clone

def prune_model(model, tokens):
    wts_act = {
    'attn.W_Q': 'attn.hook_q',
    'attn.W_K': 'attn.hook_k',
    'attn.W_V': 'attn.hook_v',
    'attn.W_O': 'hook_attn_out',
    'mlp.W_in': 'mlp.hook_pre',
    'mlp.W_out': 'hook_mlp_out'
    }
    for layer in range(model.cfg.n_layers):
        logits, cache = model.run_with_cache(tokens, remove_batch_dim=True)
        for wt, act in wts_act.items():
            W = model.get_parameter(f'blocks.{layer}.{wt}')
            X = cache[f'blocks.{layer}.{act}']

            if W.dim() == 3:
                if 'W_O' in wt:
                    X_norm = X.norm(p=2, dim=0)
                    for head in range(W.shape[0]):
                        W[head] = prune_wanda(W[head], X_norm, sparse_ratio=0.5)
                        
                else:
                    for head in range(W.shape[0]):
                        X_norm = X[:, head, :].norm(p=2, dim=0)
                        W[head] = prune_wanda(W[head], X_norm, sparse_ratio=0.5)
            else:
                X_norm = X.norm(p=2, dim=0)
                W = prune_wanda(W, X_norm, sparse_ratio=0.5)
            
    return model

gpt2: sae_lens.HookedSAETransformer = sae_lens.HookedSAETransformer.from_pretrained("gpt2-small", device=device)
gpt2_text = "A quick brown fox jumps over the lazy dog."
gpt2_tokens = gpt2.to_tokens(gpt2_text)
print(gpt2_tokens.shape)
pruned_gpt2 = prune_model(gpt2, gpt2_tokens)
gpt2: sae_lens.HookedSAETransformer = sae_lens.HookedSAETransformer.from_pretrained("gpt2-small", device=device)

In [None]:
prompt = "I"
answer = " am"
transformer_lens.utils.test_prompt(prompt, answer, pruned_gpt2)

In [None]:
example_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
example_answer = " Mary"
transformer_lens.utils.test_prompt(example_prompt, example_answer, gpt2, prepend_bos=True)

In [None]:
example_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
example_answer = " Mary"
transformer_lens.utils.test_prompt(example_prompt, example_answer, pruned_gpt2, prepend_bos=True)

#### SAEs

In [None]:
t.set_grad_enabled(False)
gpt2 = sae_lens.HookedSAETransformer.from_pretrained("gpt2-small", device=device)
gpt2.load_state_dict(t.load('/home/gupte.31/COLM/sae-compression/gpt2-small/pruned/pruned_gpt2_wanda.pth'))
example_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
example_answer = " Mary"
transformer_lens.utils.test_prompt(example_prompt, example_answer, gpt2, prepend_bos=True)

In [None]:
print(tabulate(gpt2_sae.cfg.__dict__.items(), headers=["name", "value"], tablefmt="simple_outline"))

In [None]:
prompt = "Mitigating the risk of extinction from AI should be a global"
answer = " priority"

# First see how the model does without SAEs
test_prompt(prompt, answer, gpt2)

# Test our prompt, to see what the model says
with gpt2.saes(saes=[gpt2_sae]):
    test_prompt(prompt, answer, gpt2)

# # Same thing, done in a different way
# model.add_sae(attn_sae)
# test_prompt(prompt, answer, model)
# model.reset_saes()  # Remember to always do this!

# Using `run_with_saes` method in place of standard forward pass
logits = gpt2(prompt, return_type="logits")
logits_sae = gpt2.run_with_saes(prompt, saes=[gpt2_sae], return_type="logits")
answer_token_id = gpt2.to_single_token(answer)

# Getting model's prediction
top_prob, token_id_prediction = logits[0, -1].softmax(-1).max(-1)
top_prob_sae, token_id_prediction_sae = logits_sae[0, -1].softmax(-1).max(-1)

print(f"""Standard model: top prediction = {gpt2.to_string(token_id_prediction)!r}, prob = {top_prob.item():.2%}
SAE reconstruction: top prediction = {gpt2.to_string(token_id_prediction_sae)!r}, prob = {top_prob_sae.item():.2%}
""")

Difference: In the demo, priority is the token reconstructed by the SAE. Does this mean we need more epochs? Or does this mean pruning has affected the SAE? Or possible reason

**prepend_bos=False?** Check what is the default value of the parameter - Thats not the reason
Also architecture is not the reason - standard or gated give the same output
Standard has lower probability for priority as compared to gated

In [None]:
prompt = "After John and Mary went to the store, John gave a bottle of milk to"
answer = " Mary"

# First see how the model does without SAEs
test_prompt(prompt, answer, gpt2)

# Test our prompt, to see what the model says
with gpt2.saes(saes=[gpt2_sae]):
    test_prompt(prompt, answer, gpt2)

# Using `run_with_saes` method in place of standard forward pass
logits = gpt2(prompt, return_type="logits")
logits_sae = gpt2.run_with_saes(prompt, saes=[gpt2_sae], return_type="logits")
answer_token_id = gpt2.to_single_token(answer)

# Getting model's prediction
top_prob, token_id_prediction = logits[0, -1].softmax(-1).max(-1)
top_prob_sae, token_id_prediction_sae = logits_sae[0, -1].softmax(-1).max(-1)

print(f"""Standard model: top prediction = {gpt2.to_string(token_id_prediction)!r}, prob = {top_prob.item():.2%}
SAE reconstruction: top prediction = {gpt2.to_string(token_id_prediction_sae)!r}, prob = {top_prob_sae.item():.2%}
""")

Higher probability in this case!

#### Running SAES and replicating dashboards

Replacing the activations with SAE reconstructions

In [None]:
logits_no_saes, cache_no_saes = gpt2.run_with_cache(prompt)

gpt2_sae.use_error_term = False
logits_with_sae_recon, cache_with_sae_recon = gpt2.run_with_cache_with_saes(prompt, saes=[gpt2_sae])

gpt2_sae.use_error_term = True
logits_without_sae_recon, cache_without_sae_recon = gpt2.run_with_cache_with_saes(prompt, saes=[gpt2_sae])

# Both SAE caches contain the hook values
assert f"{gpt2_sae.cfg.hook_name}.hook_sae_acts_post" in cache_with_sae_recon
assert f"{gpt2_sae.cfg.hook_name}.hook_sae_acts_post" in cache_without_sae_recon

# But the final output will be different, because we don't use SAE reconstructions when use_error_term=True
t.testing.assert_close(logits_no_saes, logits_without_sae_recon)
logit_diff_from_sae = (logits_no_saes - logits_with_sae_recon).abs().mean()
print(f"Average logit diff from using SAE reconstruction: {logit_diff_from_sae:.4f}")

Using ActivationsStore to load a bunch of data. It streams in data from a given dataset that was used to train the SAE

In [None]:
print(gpt2_sae.cfg.dataset_path)

gpt2_act_store = ActivationsStore.from_sae(
    model=gpt2,
    sae=gpt2_sae,
    streaming=True,
    store_batch_size_prompts=16,
    n_batches_in_buffer=32,
    device=str(device),
)

# Example of how you can use this:
tokens = gpt2_act_store.get_batch_tokens()
assert tokens.shape == (gpt2_act_store.store_batch_size_prompts, gpt2_act_store.context_size)

#### Irrelevant to attention layer understanding

##### **Activation Distribution**: The distribution of latent's activations

In [None]:
def show_activation_histogram(
    model: HookedSAETransformer,
    sae: SAE,
    act_store: ActivationsStore,
    latent_idx: int,
    total_batches: int = 200,
):
    """
    Displays the activation histogram for a particular latent, computed across `total_batches` batches from `act_store`.
    """
    sae_acts_post_hook_name = f"{sae.cfg.hook_name}.hook_sae_acts_post"
    all_positive_acts = []

    for i in tqdm(range(total_batches), desc="Computing activations for histogram"):
        tokens = act_store.get_batch_tokens()
        _, cache = model.run_with_cache_with_saes(
            tokens,
            saes=[sae],
            stop_at_layer=sae.cfg.hook_layer + 1,
            names_filter=[sae_acts_post_hook_name],
        )
        acts = cache[sae_acts_post_hook_name][..., latent_idx]
        all_positive_acts.extend(acts[acts > 0].cpu().tolist())

    frac_active = len(all_positive_acts) / (total_batches * act_store.store_batch_size_prompts * act_store.context_size)

    px.histogram(
        all_positive_acts,
        nbins=50,
        title=f"ACTIVATIONS DENSITY {frac_active:.3%}",
        labels={"value": "Activation"},
        width=800,
        template="ggplot2",
        color_discrete_sequence=["darkorange"],
    ).update_layout(bargap=0.02, showlegend=False).show()
    
show_activation_histogram(gpt2, gpt2_sae, gpt2_act_store, latent_idx=9)

##### **Top/Bottom Logits**: Most +ve or -ve logits in the logit weight distribution

In [None]:
def show_top_logits(
    model: HookedSAETransformer,
    sae: SAE,
    latent_idx: int,
    k: int = 10,
) -> None:
    """
    Displays the top & bottom logits for a particular latent.
    """
    logits = sae.W_dec[latent_idx] @ model.W_U

    pos_logits, pos_token_ids = logits.topk(k)
    pos_tokens = model.to_str_tokens(pos_token_ids)
    neg_logits, neg_token_ids = logits.topk(k, largest=False)
    neg_tokens = model.to_str_tokens(neg_token_ids)

    print(
        tabulate(
            zip(map(repr, neg_tokens), neg_logits, map(repr, pos_tokens), pos_logits),
            headers=["Bottom tokens", "Value", "Top tokens", "Value"],
            tablefmt="simple_outline",
            stralign="right",
            numalign="left",
            floatfmt="+.3f",
        )
    )


show_top_logits(gpt2, gpt2_sae, latent_idx=4)

##### **Max Activating Examples**

In [None]:
def get_k_largest_indices(x: Float[Tensor, "batch seq"], k: int, buffer: int = 0) -> Int[Tensor, "k 2"]:
    """
    The indices of the top k elements in the input tensor, i.e. output[i, :] is the (batch, seqpos) value of the i-th
    largest element in x.

    Won't choose any elements within `buffer` from the start or end of their sequence.
    """
    if buffer > 0:
        x = x[:, buffer:-buffer]
    indices = x.flatten().topk(k=k).indices
    rows = indices // x.size(1)
    cols = indices % x.size(1) + buffer
    return t.stack((rows, cols), dim=1)


x = t.arange(40, device=device).reshape((2, 20))
x[0, 10] += 50  # 2nd highest value
x[0, 11] += 100  # highest value
x[1, 1] += 150  # not inside buffer (it's less than 3 from the start of the sequence)
top_indices = get_k_largest_indices(x, k=2, buffer=3)
assert top_indices.tolist() == [[0, 11], [0, 10]]


def index_with_buffer(
    x: Float[Tensor, "batch seq"], indices: Int[Tensor, "k 2"], buffer: int | None = None
) -> Float[Tensor, "k *buffer_x2_plus1"]:
    """
    Indexes into `x` with `indices` (which should have come from the `get_k_largest_indices` function), and takes a
    +-buffer range around each indexed element. If `indices` are less than `buffer` away from the start of a sequence
    then we just take the first `2*buffer+1` elems (same for at the end of a sequence).

    If `buffer` is None, then we don't add any buffer and just return the elements at the given indices.
    """
    rows, cols = indices.unbind(dim=-1)
    if buffer is not None:
        rows = einops.repeat(rows, "k -> k buffer", buffer=buffer * 2 + 1)
        cols[cols < buffer] = buffer
        cols[cols > x.size(1) - buffer - 1] = x.size(1) - buffer - 1
        cols = einops.repeat(cols, "k -> k buffer", buffer=buffer * 2 + 1) + t.arange(
            -buffer, buffer + 1, device=cols.device
        )
    return x[rows, cols]


x_top_values_with_context = index_with_buffer(x, top_indices, buffer=3)
assert x_top_values_with_context[0].tolist() == [8, 9, 10 + 50, 11 + 100, 12, 13, 14]  # highest value in the middle
assert x_top_values_with_context[1].tolist() == [7, 8, 9, 10 + 50, 11 + 100, 12, 13]  # 2nd highest value in the middle


def display_top_seqs(data: list[tuple[float, list[str], int]]):
    """
    Given a list of (activation: float, str_toks: list[str], seq_pos: int), displays a table of these sequences, with
    the relevant token highlighted.

    We also turn newlines into "\\n", and remove unknown tokens � (usually weird quotation marks) for readability.
    """
    table = Table("Act", "Sequence", title="Max Activating Examples", show_lines=True)
    for act, str_toks, seq_pos in data:
        formatted_seq = (
            "".join([f"[b u green]{str_tok}[/]" if i == seq_pos else str_tok for i, str_tok in enumerate(str_toks)])
            .replace("�", "")
            .replace("\n", "↵")
        )
        table.add_row(f"{act:.3f}", repr(formatted_seq))
    rprint(table)


example_data = [
    (0.5, [" one", " two", " three"], 0),
    (1.5, [" one", " two", " three"], 1),
    (2.5, [" one", " two", " three"], 2),
]
display_top_seqs(example_data)

In [None]:
def fetch_max_activating_examples(
    model: HookedSAETransformer,
    sae: SAE,
    act_store: ActivationsStore,
    latent_idx: int,
    total_batches: int = 100,
    k: int = 10,
    buffer: int = 10,
) -> list[tuple[float, list[str], int]]:
    """
    Returns the max activating examples across a number of batches from the activations store.
    """
    sae_acts_post_hook_name = f"{sae.cfg.hook_name}.hook_sae_acts_post"

    # Create list to store the top k activations for each batch. Once we're done,
    # we'll filter this to only contain the top k over all batches
    data = []

    for _ in tqdm(range(total_batches), desc="Computing activations for max activating examples"):
        tokens = act_store.get_batch_tokens()
        _, cache = model.run_with_cache_with_saes(
            tokens,
            saes=[sae],
            stop_at_layer=sae.cfg.hook_layer + 1,
            names_filter=[sae_acts_post_hook_name],
        )
        acts = cache[sae_acts_post_hook_name][..., latent_idx]

        # Get largest indices, get the corresponding max acts, and get the surrounding indices
        k_largest_indices = get_k_largest_indices(acts, k=k, buffer=buffer)
        tokens_with_buffer = index_with_buffer(tokens, k_largest_indices, buffer=buffer)
        str_toks = [model.to_str_tokens(toks) for toks in tokens_with_buffer]
        top_acts = index_with_buffer(acts, k_largest_indices).tolist()
        data.extend(list(zip(top_acts, str_toks, [buffer] * len(str_toks))))

    return sorted(data, key=lambda x: x[0], reverse=True)[:k]


# Fetch & display the results
buffer = 10
data = fetch_max_activating_examples(gpt2, gpt2_sae, gpt2_act_store, latent_idx=9, buffer=buffer, k=5)
display_top_seqs(data)

# Test one of the results, to see if it matches the expected output
first_seq_str_tokens = data[0][1]
print(first_seq_str_tokens)
# assert first_seq_str_tokens[buffer] == " new"

#### Relevant to attention layer evaluation

In [None]:
attn_saes = gpt2_sae

In [None]:
@dataclass
class AttnSeqDFA:
    act: float
    str_toks_dest: list[str]
    str_toks_src: list[str]
    dest_pos: int
    src_pos: int


def display_top_seqs_attn(data: list[AttnSeqDFA]):
    """
    Same as previous function, but we now have 2 str_tok lists and 2 sequence positions to highlight, the first being
    for top activations (destination token) and the second for top DFA (src token). We've given you a dataclass to help
    keep track of this.
    """
    table = Table(
        "Top Act",
        "Src token DFA (for top dest token)",
        "Dest token",
        title="Max Activating Examples",
        show_lines=True,
    )
    for seq in data:
        formatted_seqs = [
            repr(
                "".join(
                    [f"[b u {color}]{str_tok}[/]" if i == seq_pos else str_tok for i, str_tok in enumerate(str_toks)]
                )
                .replace("�", "")
                .replace("\n", "↵")
            )
            for str_toks, seq_pos, color in [
                (seq.str_toks_src, seq.src_pos, "dark_orange"),
                (seq.str_toks_dest, seq.dest_pos, "green"),
            ]
        ]
        table.add_row(f"{seq.act:.3f}", *formatted_seqs)
    rprint(table)


str_toks = [" one", " two", " three", " four"]
example_data = [
    AttnSeqDFA(act=0.5, str_toks_dest=str_toks[1:], str_toks_src=str_toks[:-1], dest_pos=0, src_pos=0),
    AttnSeqDFA(act=1.5, str_toks_dest=str_toks[1:], str_toks_src=str_toks[:-1], dest_pos=1, src_pos=1),
    AttnSeqDFA(act=2.5, str_toks_dest=str_toks[1:], str_toks_src=str_toks[:-1], dest_pos=2, src_pos=0),
]
display_top_seqs_attn(example_data)

In [None]:
def fetch_max_activating_examples_attn(
    model: HookedSAETransformer,
    sae: SAE,
    act_store: ActivationsStore,
    latent_idx: int,
    total_batches: int = 250,
    k: int = 10,
    buffer: int = 10,
) -> list[AttnSeqDFA]:
    """
    Returns the max activating examples across a number of batches from the activations store.
    """
    sae_acts_pre_hook_name = f"{sae.cfg.hook_name}.hook_sae_acts_pre"
    v_hook_name = get_act_name("v", sae.cfg.hook_layer)
    pattern_hook_name = get_act_name("pattern", sae.cfg.hook_layer)
    data = []

    for _ in tqdm(range(total_batches), desc="Computing activations for max activating examples (attn)"):
        tokens = act_store.get_batch_tokens()
        _, cache = model.run_with_cache_with_saes(
            tokens,
            saes=[sae],
            stop_at_layer=sae.cfg.hook_layer + 1,
            names_filter=[sae_acts_pre_hook_name, v_hook_name, pattern_hook_name],
        )
        acts = cache[sae_acts_pre_hook_name][..., latent_idx]  # [batch seq]

        # Get largest indices (i.e. dest tokens), and the tokens at those positions (plus buffer)
        k_largest_indices = get_k_largest_indices(acts, k=k, buffer=buffer)
        top_acts = index_with_buffer(acts, k_largest_indices).tolist()
        dest_toks_with_buffer = index_with_buffer(tokens, k_largest_indices, buffer=buffer)
        str_toks_dest_list = [model.to_str_tokens(toks) for toks in dest_toks_with_buffer]

        # Get src token value vectors & dest-to-src attention patterns, for each of our chosen dest tokens
        batch_indices, dest_pos_indices = k_largest_indices.unbind(-1)
        v = cache[v_hook_name][batch_indices]  # shape [k src n_heads d_head]
        pattern = cache[pattern_hook_name][batch_indices, :, dest_pos_indices]  # shape [k n_heads src]

        # Multiply them together to get weighted value vectors, and reshape them to d_in = n_heads * d_head
        v_weighted = (v * einops.rearrange(pattern, "k n src -> k src n 1")).flatten(-2, -1)  # shape [k src d_in]

        # Map through our SAE encoder to get direct feature attribution for each src token, and argmax over src tokens
        dfa = v_weighted @ sae.W_enc[:, latent_idx]  # shape [k src]
        src_pos_indices = dfa.argmax(dim=-1)
        src_toks_with_buffer = index_with_buffer(tokens, t.stack([batch_indices, src_pos_indices], -1), buffer=buffer)
        str_toks_src_list = [model.to_str_tokens(toks) for toks in src_toks_with_buffer]

        # Add all this data to our list
        for act, str_toks_dest, str_toks_src, src_pos in zip(
            top_acts, str_toks_dest_list, str_toks_src_list, src_pos_indices
        ):
            data.append(
                AttnSeqDFA(
                    act=act,
                    str_toks_dest=str_toks_dest,  # top activating dest tokens, with buffer
                    str_toks_src=str_toks_src,  # top DFA src tokens for the dest token, with buffer
                    dest_pos=buffer,  # dest token is always in the middle of its buffer
                    src_pos=min(src_pos, buffer),  # src token might be before the middle, if near start of sequence
                )
            )

    return sorted(data, key=lambda x: x.act, reverse=True)[:k]

# Test your function: compare it to dashboard above (max DFA should come from src toks like " guns", " firearms")
data = fetch_max_activating_examples_attn(gpt2, attn_saes, gpt2_act_store, latent_idx=1)
display_top_seqs_attn(data)

In [None]:
t.set_grad_enabled(False)
gpt2 = sae_lens.HookedSAETransformer.from_pretrained("gpt2-small", device=device)

# SAE for full gpt2-small
hf_repo_id = "suchitg/sae_test"
sae_id = 'blocks.9.attn.hook_z-base-v1'
gpt2_sae = sae_lens.SAE.from_pretrained(release=hf_repo_id, sae_id=sae_id, device=str(device))[0]

gpt2_act_store = ActivationsStore.from_sae(
    model=gpt2,
    sae=gpt2_sae,
    streaming=True,
    store_batch_size_prompts=16,
    n_batches_in_buffer=32,
    device=str(device),
)


example_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
example_answer = " Mary"
tokens = gpt2.to_tokens(example_prompt)



sae_acts_pre_hook_name = f"{gpt2_sae.cfg.hook_name}.hook_sae_acts_pre"
v_hook_name = get_act_name("v", gpt2_sae.cfg.hook_layer)
pattern_hook_name = get_act_name("pattern", gpt2_sae.cfg.hook_layer)
print(sae_acts_pre_hook_name, v_hook_name, pattern_hook_name)


_, cache = gpt2.run_with_cache_with_saes(
    tokens,
    saes=[gpt2_sae],
    stop_at_layer=gpt2_sae.cfg.hook_layer + 1,
    names_filter=[sae_acts_pre_hook_name, v_hook_name, pattern_hook_name],
)
acts = cache[sae_acts_pre_hook_name]
print(acts.shape)



In [None]:
acts

In [None]:
gpt2_wanda = sae_lens.HookedSAETransformer.from_pretrained("gpt2-small", device=device)
gpt2_wanda.load_state_dict(t.load('pruned/pruned_gpt2_wanda.pth'))
# example_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
# example_answer = " Mary"
# transformer_lens.utils.test_prompt(example_prompt, example_answer, gpt2_wanda, prepend_bos=True)

# SAE for wanda-pruned gpt2-small
hf_repo_id = "suchitg/sae_wanda"
sae_id = 'blocks.9.attn.hook_z-v1'
gpt2_wanda_sae = sae_lens.SAE.from_pretrained(release=hf_repo_id, sae_id=sae_id, device=str(device))[0]


sae_acts_pre_hook_name = f"{gpt2_wanda_sae.cfg.hook_name}.hook_sae_acts_pre"
v_hook_name = get_act_name("v", gpt2_wanda_sae.cfg.hook_layer)
pattern_hook_name = get_act_name("pattern", gpt2_wanda_sae.cfg.hook_layer)
print(sae_acts_pre_hook_name, v_hook_name, pattern_hook_name)


_, cache = gpt2_wanda.run_with_cache_with_saes(
    tokens,
    saes=[gpt2_wanda_sae],
    stop_at_layer=gpt2_wanda_sae.cfg.hook_layer + 1,
    names_filter=[sae_acts_pre_hook_name, v_hook_name, pattern_hook_name],
)
acts = cache[sae_acts_pre_hook_name]
print(acts.shape)

In [None]:
acts

In [None]:
t.set_grad_enabled(False)

# Helper functions
def show_activation_histogram(
    model: HookedSAETransformer,
    sae: SAE,
    act_store: ActivationsStore,
    latent_idx: int,
    total_batches: int = 200,
):
    """
    Displays the activation histogram for a particular latent, computed across `total_batches` batches from `act_store`.
    """
    sae_acts_post_hook_name = f"{sae.cfg.hook_name}.hook_sae_acts_post"
    all_positive_acts = []

    for i in tqdm(range(total_batches), desc="Computing activations for histogram"):
        tokens = act_store.get_batch_tokens()
        _, cache = model.run_with_cache_with_saes(
            tokens,
            saes=[sae],
            stop_at_layer=sae.cfg.hook_layer + 1,
            names_filter=[sae_acts_post_hook_name],
        )
        acts = cache[sae_acts_post_hook_name][..., latent_idx]
        all_positive_acts.extend(acts[acts > 0].cpu().tolist())

    frac_active = len(all_positive_acts) / (total_batches * act_store.store_batch_size_prompts * act_store.context_size)

    px.histogram(
        all_positive_acts,
        nbins=50,
        title=f"ACTIVATIONS DENSITY {frac_active:.3%}",
        labels={"value": "Activation"},
        width=800,
        template="ggplot2",
        color_discrete_sequence=["darkorange"],
    ).update_layout(bargap=0.02, showlegend=False).show()

def show_top_logits(
    model: HookedSAETransformer,
    sae: SAE,
    latent_idx: int,
    k: int = 10,
) -> None:
    """
    Displays the top & bottom logits for a particular latent.
    """
    logits = sae.W_dec[latent_idx] @ model.W_U

    pos_logits, pos_token_ids = logits.topk(k)
    pos_tokens = model.to_str_tokens(pos_token_ids)
    neg_logits, neg_token_ids = logits.topk(k, largest=False)
    neg_tokens = model.to_str_tokens(neg_token_ids)

    print(
        tabulate(
            zip(map(repr, neg_tokens), neg_logits, map(repr, pos_tokens), pos_logits),
            headers=["Bottom tokens", "Value", "Top tokens", "Value"],
            tablefmt="simple_outline",
            stralign="right",
            numalign="left",
            floatfmt="+.3f",
        )
    )

@dataclass
class AttnSeqDFA:
    act: float
    str_toks_dest: list[str]
    str_toks_src: list[str]
    dest_pos: int
    src_pos: int

def display_top_seqs_attn(data: list[AttnSeqDFA]):
    """
    Same as previous function, but we now have 2 str_tok lists and 2 sequence positions to highlight, the first being
    for top activations (destination token) and the second for top DFA (src token). We've given you a dataclass to help
    keep track of this.
    """
    table = Table(
        "Top Act",
        "Src token DFA (for top dest token)",
        "Dest token",
        title="Max Activating Examples",
        show_lines=True,
    )
    for seq in data:
        formatted_seqs = [
            repr(
                "".join(
                    [f"[b u {color}]{str_tok}[/]" if i == seq_pos else str_tok for i, str_tok in enumerate(str_toks)]
                )
                .replace("�", "")
                .replace("\n", "↵")
            )
            for str_toks, seq_pos, color in [
                (seq.str_toks_src, seq.src_pos, "dark_orange"),
                (seq.str_toks_dest, seq.dest_pos, "green"),
            ]
        ]
        table.add_row(f"{seq.act:.3f}", *formatted_seqs)
    rprint(table)

def fetch_max_activating_examples_attn(
    model: HookedSAETransformer,
    sae: SAE,
    act_store: ActivationsStore,
    latent_idx: int,
    total_batches: int = 250,
    k: int = 10,
    buffer: int = 10,
) -> list[AttnSeqDFA]:
    """
    Returns the max activating examples across a number of batches from the activations store.
    """
    sae_acts_pre_hook_name = f"{sae.cfg.hook_name}.hook_sae_acts_pre"
    v_hook_name = get_act_name("v", sae.cfg.hook_layer)
    pattern_hook_name = get_act_name("pattern", sae.cfg.hook_layer)
    data = []

    for _ in tqdm(range(total_batches), desc="Computing activations for max activating examples (attn)"):
        tokens = act_store.get_batch_tokens()
        _, cache = model.run_with_cache_with_saes(
            tokens,
            saes=[sae],
            stop_at_layer=sae.cfg.hook_layer + 1,
            names_filter=[sae_acts_pre_hook_name, v_hook_name, pattern_hook_name],
        )
        acts = cache[sae_acts_pre_hook_name][..., latent_idx]  # [batch seq]

        # Get largest indices (i.e. dest tokens), and the tokens at those positions (plus buffer)
        k_largest_indices = get_k_largest_indices(acts, k=k, buffer=buffer)
        top_acts = index_with_buffer(acts, k_largest_indices).tolist()
        dest_toks_with_buffer = index_with_buffer(tokens, k_largest_indices, buffer=buffer)
        str_toks_dest_list = [model.to_str_tokens(toks) for toks in dest_toks_with_buffer]

        # Get src token value vectors & dest-to-src attention patterns, for each of our chosen dest tokens
        batch_indices, dest_pos_indices = k_largest_indices.unbind(-1)
        v = cache[v_hook_name][batch_indices]  # shape [k src n_heads d_head]
        pattern = cache[pattern_hook_name][batch_indices, :, dest_pos_indices]  # shape [k n_heads src]

        # Multiply them together to get weighted value vectors, and reshape them to d_in = n_heads * d_head
        v_weighted = (v * einops.rearrange(pattern, "k n src -> k src n 1")).flatten(-2, -1)  # shape [k src d_in]

        # Map through our SAE encoder to get direct feature attribution for each src token, and argmax over src tokens
        dfa = v_weighted @ sae.W_enc[:, latent_idx]  # shape [k src]
        src_pos_indices = dfa.argmax(dim=-1)
        src_toks_with_buffer = index_with_buffer(tokens, t.stack([batch_indices, src_pos_indices], -1), buffer=buffer)
        str_toks_src_list = [model.to_str_tokens(toks) for toks in src_toks_with_buffer]

        # Add all this data to our list
        for act, str_toks_dest, str_toks_src, src_pos in zip(
            top_acts, str_toks_dest_list, str_toks_src_list, src_pos_indices
        ):
            data.append(
                AttnSeqDFA(
                    act=act,
                    str_toks_dest=str_toks_dest,  # top activating dest tokens, with buffer
                    str_toks_src=str_toks_src,  # top DFA src tokens for the dest token, with buffer
                    dest_pos=buffer,  # dest token is always in the middle of its buffer
                    src_pos=min(src_pos, buffer),  # src token might be before the middle, if near start of sequence
                )
            )

    return sorted(data, key=lambda x: x.act, reverse=True)[:k]

def get_k_largest_indices(x: Float[Tensor, "batch seq"], k: int, buffer: int = 0) -> Int[Tensor, "k 2"]:
    """
    The indices of the top k elements in the input tensor, i.e. output[i, :] is the (batch, seqpos) value of the i-th
    largest element in x.

    Won't choose any elements within `buffer` from the start or end of their sequence.
    """
    if buffer > 0:
        x = x[:, buffer:-buffer]
    indices = x.flatten().topk(k=k).indices
    rows = indices // x.size(1)
    cols = indices % x.size(1) + buffer
    return t.stack((rows, cols), dim=1)

def index_with_buffer(
    x: Float[Tensor, "batch seq"], indices: Int[Tensor, "k 2"], buffer: int | None = None
) -> Float[Tensor, "k *buffer_x2_plus1"]:
    """
    Indexes into `x` with `indices` (which should have come from the `get_k_largest_indices` function), and takes a
    +-buffer range around each indexed element. If `indices` are less than `buffer` away from the start of a sequence
    then we just take the first `2*buffer+1` elems (same for at the end of a sequence).

    If `buffer` is None, then we don't add any buffer and just return the elements at the given indices.
    """
    rows, cols = indices.unbind(dim=-1)
    if buffer is not None:
        rows = einops.repeat(rows, "k -> k buffer", buffer=buffer * 2 + 1)
        cols[cols < buffer] = buffer
        cols[cols > x.size(1) - buffer - 1] = x.size(1) - buffer - 1
        cols = einops.repeat(cols, "k -> k buffer", buffer=buffer * 2 + 1) + t.arange(
            -buffer, buffer + 1, device=cols.device
        )
    return x[rows, cols]

def display_top_seqs(data: list[tuple[float, list[str], int]]):
    """
    Given a list of (activation: float, str_toks: list[str], seq_pos: int), displays a table of these sequences, with
    the relevant token highlighted.

    We also turn newlines into "\\n", and remove unknown tokens � (usually weird quotation marks) for readability.
    """
    table = Table("Act", "Sequence", title="Max Activating Examples", show_lines=True)
    for act, str_toks, seq_pos in data:
        formatted_seq = (
            "".join([f"[b u green]{str_tok}[/]" if i == seq_pos else str_tok for i, str_tok in enumerate(str_toks)])
            .replace("�", "")
            .replace("\n", "↵")
        )
        table.add_row(f"{act:.3f}", repr(formatted_seq))
    rprint(table)


'''
Full gpt2-small
'''
gpt2 = sae_lens.HookedSAETransformer.from_pretrained("gpt2-small", device=device)
# example_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
# example_answer = " Mary"
# transformer_lens.utils.test_prompt(example_prompt, example_answer, gpt2, prepend_bos=True)

# SAE for full gpt2-small
hf_repo_id = "suchitg/sae_test"
sae_id = 'blocks.9.attn.hook_z-base-v1'
gpt2_sae = sae_lens.SAE.from_pretrained(release=hf_repo_id, sae_id=sae_id, device=str(device))[0]
# print(tabulate(gpt2_sae.cfg.__dict__.items(), headers=["name", "value"], tablefmt="simple_outline"))

gpt2_act_store = ActivationsStore.from_sae(
    model=gpt2,
    sae=gpt2_sae,
    streaming=True,
    store_batch_size_prompts=16,
    n_batches_in_buffer=32,
    device=str(device),
)

# Example of how you can use this:
tokens = gpt2_act_store.get_batch_tokens()
assert tokens.shape == (gpt2_act_store.store_batch_size_prompts, gpt2_act_store.context_size)
show_activation_histogram(gpt2, gpt2_sae, gpt2_act_store, latent_idx=9)
show_top_logits(gpt2, gpt2_sae, latent_idx=9)
data = fetch_max_activating_examples_attn(gpt2, gpt2_sae, gpt2_act_store, latent_idx=1)
display_top_seqs_attn(data)

'''
WANDA PRUNED
'''
gpt2_wanda = sae_lens.HookedSAETransformer.from_pretrained("gpt2-small", device=device)
gpt2_wanda.load_state_dict(t.load('pruned/pruned_gpt2_wanda.pth'))
# example_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
# example_answer = " Mary"
# transformer_lens.utils.test_prompt(example_prompt, example_answer, gpt2_wanda, prepend_bos=True)

# SAE for wanda-pruned gpt2-small
hf_repo_id = "suchitg/sae_wanda"
sae_id = 'blocks.9.attn.hook_z-v1'
gpt2_wanda_sae = sae_lens.SAE.from_pretrained(release=hf_repo_id, sae_id=sae_id, device=str(device))[0]

# print(tabulate(gpt2_wanda_sae.cfg.__dict__.items(), headers=["name", "value"], tablefmt="simple_outline"))

gpt2_wanda_act_store = ActivationsStore.from_sae(
    model=gpt2_wanda,
    sae=gpt2_wanda_sae,
    streaming=True,
    store_batch_size_prompts=16,
    n_batches_in_buffer=32,
    device=str(device),
)

# Example of how you can use this:
tokens = gpt2_wanda_act_store.get_batch_tokens()
assert tokens.shape == (gpt2_wanda_act_store.store_batch_size_prompts, gpt2_wanda_act_store.context_size)
show_activation_histogram(gpt2_wanda, gpt2_wanda_sae, gpt2_wanda_act_store, latent_idx=9)
show_top_logits(gpt2_wanda, gpt2_wanda_sae, latent_idx=9)

data = fetch_max_activating_examples_attn(gpt2_wanda, gpt2_wanda_sae, gpt2_wanda_act_store, latent_idx=9)
display_top_seqs_attn(data)


#### Research Questions

##### Q. How different are the neuron activations of an SAE trained on wanda pruned gpt2-small as compared to an SAE trained on full gpt2-small? 

##### Q. Does an SAE trained on wanda pruned gpt2-small replicate a wanda pruned SAE trained on full gpt2-small?

Procedure: Wanda prune the SAE trained on full gpt2

* How to get the input activations of the SAE to apply Wanda pruning?

In [None]:
gpt2 = sae_lens.HookedSAETransformer.from_pretrained("gpt2-small", device=device)
example_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
example_answer = " Mary"
# transformer_lens.utils.test_prompt(example_prompt, example_answer, gpt2, prepend_bos=True)

# SAE for full gpt2-small
hf_repo_id = "suchitg/sae_test"
sae_id = 'blocks.9.attn.hook_z-base-v1'
gpt2_sae = sae_lens.SAE.from_pretrained(release=hf_repo_id, sae_id=sae_id, device=str(device))[0]
# print(tabulate(gpt2_sae.cfg.__dict__.items(), headers=["name", "value"], tablefmt="simple_outline"))

_, cache = gpt2.run_with_cache_with_saes(example_prompt, saes=[gpt2_sae])

for name, param in cache.items():
    if "hook_sae" in name:
        print(f"{name:<43}: {tuple(param.shape)}")


prompt = "In the beginning, God created the heavens and the"
answer = "earth"

# Show that the model can confidently predict the next token.
test_prompt(prompt, answer, gpt2)
print(gpt2_sae.use_error_term)  

_, cache = gpt2.run_with_cache_with_saes(prompt, saes=[gpt2_sae])

print([(k, v.shape) for k, v in cache.items() if "sae" in k])

px.line(
    cache[f"{gpt2_sae.cfg.hook_name}.hook_sae_acts_post"][0, -1, :].cpu().numpy(),
    title="Feature activations at the final token position",
    labels={"index": "Feature", "value": "Activation"},
).show()

prompt = "In the beginning, God created the cat and the"
answer = "earth"

# here we see that removing the word "Heavens" is very effective at making the model no longer predict "earth".
# instead the model predicts a bunch of different animals.
# Can we work out which features fire differently which might explain this? (This is a toy example not meant to be super interesting)
test_prompt(prompt, answer, gpt2)

prompt = [
    "In the beginning, God created the heavens and the",
    "In the beginning, God created the cat and the",
]
_, cache = gpt2.run_with_cache_with_saes(prompt, saes=[gpt2_sae])
print([(k, v.shape) for k, v in cache.items() if "sae" in k])

feature_activation_df = pd.DataFrame(
    cache[f"{gpt2_sae.cfg.hook_name}.hook_sae_acts_post"][0, -1, :].cpu().numpy(),
    index=[f"feature_{i}" for i in range(gpt2_sae.cfg.d_sae)],
)
feature_activation_df.columns = ["heavens_and_the"]
feature_activation_df["cat_and_the"] = (
    cache[f"{gpt2_sae.cfg.hook_name}.hook_sae_acts_post"][1, -1, :].cpu().numpy()
)
feature_activation_df["diff"] = (
    feature_activation_df["heavens_and_the"] - feature_activation_df["cat_and_the"]
)

fig = px.line(
    feature_activation_df,
    title="Feature activations for the prompt",
    labels={"index": "Feature", "value": "Activation"},
)

# hide the x-ticks
fig.update_xaxes(showticklabels=False)
fig.show()

In [None]:
gpt2 = sae_lens.HookedSAETransformer.from_pretrained("gpt2-small", device=device)
example_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
example_answer = " Mary"
# transformer_lens.utils.test_prompt(example_prompt, example_answer, gpt2, prepend_bos=True)

# SAE for full gpt2-small
hf_repo_id = "suchitg/sae_test"
sae_id = 'blocks.9.attn.hook_z-base-v1'
gpt2_sae = sae_lens.SAE.from_pretrained(release=hf_repo_id, sae_id=sae_id, device=str(device))[0]

prompt = "Mitigating the risk of extinction from AI should be a global"
answer = " priority"

# First see how the model does without SAEs
test_prompt(prompt, answer, gpt2)

# Test our prompt, to see what the model says
with gpt2.saes(saes=[gpt2_sae]):
    test_prompt(prompt, answer, gpt2)

# Using `run_with_saes` method in place of standard forward pass
logits = gpt2(prompt, return_type="logits")
logits_sae = gpt2.run_with_saes(prompt, saes=[gpt2_sae], return_type="logits")
answer_token_id = gpt2.to_single_token(answer)

# Getting model's prediction
top_prob, token_id_prediction = logits[0, -1].softmax(-1).max(-1)
top_prob_sae, token_id_prediction_sae = logits_sae[0, -1].softmax(-1).max(-1)

print(f"""Standard model: top prediction = {gpt2.to_string(token_id_prediction)!r}, prob = {top_prob.item():.2%}
SAE reconstruction: top prediction = {gpt2.to_string(token_id_prediction_sae)!r}, prob = {top_prob_sae.item():.2%}
""")

In [None]:
prompt = "Mitigating the risk of extinction from AI should be a global"
answer = " priority"

gpt2_sae.load_state_dict(t.load('pruned_gpt2_sae_wanda.pth'))

# First see how the model does without SAEs
test_prompt(prompt, answer, gpt2)

# Test our prompt, to see what the model says
with gpt2.saes(saes=[gpt2_sae]):
    test_prompt(prompt, answer, gpt2)

# Using `run_with_saes` method in place of standard forward pass
logits = gpt2(prompt, return_type="logits")
logits_sae = gpt2.run_with_saes(prompt, saes=[gpt2_sae], return_type="logits")
answer_token_id = gpt2.to_single_token(answer)

# Getting model's prediction
top_prob, token_id_prediction = logits[0, -1].softmax(-1).max(-1)
top_prob_sae, token_id_prediction_sae = logits_sae[0, -1].softmax(-1).max(-1)

print(f"""Standard model: top prediction = {gpt2.to_string(token_id_prediction)!r}, prob = {top_prob.item():.2%}
SAE reconstruction: top prediction = {gpt2.to_string(token_id_prediction_sae)!r}, prob = {top_prob_sae.item():.2%}
""")

In [None]:
model = sae_lens.HookedSAETransformer.from_pretrained("gpt2-small", device=device)
hf_repo_id = "suchitg/sae_test"
sae_id = 'blocks.9.attn.hook_z-base-v1'
sae = sae_lens.SAE.from_pretrained(release=hf_repo_id, sae_id=sae_id, device=str(device))[0]

In [None]:
sae.load_state_dict(t.load('pruned_gpt2_sae_wanda.pth'))

In [None]:
from transformer_lens.utils import tokenize_and_concatenate

dataset = load_dataset(
    path="NeelNanda/pile-10k",
    split="train",
    streaming=False,
)

token_dataset = tokenize_and_concatenate(
    dataset=dataset,  # type: ignore
    tokenizer=model.tokenizer,  # type: ignore
    streaming=True,
    max_length=sae.cfg.context_size,
    add_bos_token=sae.cfg.prepend_bos,
)

In [None]:
sae.eval()  # prevents error if we're expecting a dead neuron mask for who grads

with t.no_grad():

    # activation store can give us tokens.
    batch_tokens = token_dataset[:8]["tokens"]
    _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)

    # Use the SAE
    feature_acts = sae.encode(cache[sae.cfg.hook_name])
    sae_out = sae.decode(feature_acts)

    # save some room
    del cache

    # ignore the bos token, get the number of features that activated in each token, averaged accross batch and position
    l0 = (feature_acts[:, 1:] > 0).float().sum(-1).detach()
    print("average l0", l0.mean().item())
    px.histogram(l0.flatten().cpu().numpy()).show()

In [None]:
from transformer_lens import utils
from functools import partial


# next we want to do a reconstruction test.
def reconstr_hook(activation, hook, sae_out):
    return sae_out


def zero_abl_hook(activation, hook):
    return t.zeros_like(activation)


print("Orig", model(batch_tokens, return_type="loss").item())
print(
    "reconstr",
    model.run_with_hooks(
        batch_tokens,
        fwd_hooks=[
            (
                sae.cfg.hook_name,
                partial(reconstr_hook, sae_out=sae_out),
            )
        ],
        return_type="loss",
    ).item(),
)
print(
    "Zero",
    model.run_with_hooks(
        batch_tokens,
        return_type="loss",
        fwd_hooks=[(sae.cfg.hook_name, zero_abl_hook)],
    ).item(),
)

In [None]:
model = sae_lens.HookedSAETransformer.from_pretrained("gpt2-small", device=device)
model.load_state_dict(t.load('pruned/pruned_gpt2_wanda.pth'))

# SAE for wanda-pruned gpt2-small
hf_repo_id = "suchitg/sae_wanda"
sae_id = 'blocks.9.attn.hook_z-v1'
sae = sae_lens.SAE.from_pretrained(release=hf_repo_id, sae_id=sae_id, device=str(device))[0]

In [None]:
sae.eval()  # prevents error if we're expecting a dead neuron mask for who grads

with t.no_grad():

    # activation store can give us tokens.
    batch_tokens = token_dataset[:8]["tokens"]
    _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)

    # Use the SAE
    feature_acts = sae.encode(cache[sae.cfg.hook_name])
    sae_out = sae.decode(feature_acts)

    # save some room
    del cache

    # ignore the bos token, get the number of features that activated in each token, averaged accross batch and position
    l0 = (feature_acts[:, 1:] > 0).float().sum(-1).detach()
    print("average l0", l0.mean().item())
    px.histogram(l0.flatten().cpu().numpy()).show()

In [None]:
from transformer_lens import utils
from functools import partial


# next we want to do a reconstruction test.
def reconstr_hook(activation, hook, sae_out):
    return sae_out


def zero_abl_hook(activation, hook):
    return t.zeros_like(activation)


print("Orig", model(batch_tokens, return_type="loss").item())
print(
    "reconstr",
    model.run_with_hooks(
        batch_tokens,
        fwd_hooks=[
            (
                sae.cfg.hook_name,
                partial(reconstr_hook, sae_out=sae_out),
            )
        ],
        return_type="loss",
    ).item(),
)
print(
    "Zero",
    model.run_with_hooks(
        batch_tokens,
        return_type="loss",
        fwd_hooks=[(sae.cfg.hook_name, zero_abl_hook)],
    ).item(),
)

In [None]:
gpt2_wanda = sae_lens.HookedSAETransformer.from_pretrained("gpt2-small", device=device)
gpt2_wanda.load_state_dict(t.load('pruned/pruned_gpt2_wanda.pth'))

# SAE for wanda-pruned gpt2-small
hf_repo_id = "suchitg/sae_wanda"
sae_id = 'blocks.9.attn.hook_z-v1'
gpt2_wanda_sae = sae_lens.SAE.from_pretrained(release=hf_repo_id, sae_id=sae_id, device=str(device))[0]

In [None]:
prompt = "Mitigating the risk of extinction from AI should be a global"
answer = " priority"


# First see how the model does without SAEs
test_prompt(prompt, answer, gpt2_wanda)

# Test our prompt, to see what the model says
with gpt2_wanda.saes(saes=[gpt2_wanda_sae]):
    test_prompt(prompt, answer, gpt2_wanda)

# Using `run_with_saes` method in place of standard forward pass
logits = gpt2_wanda(prompt, return_type="logits")
logits_sae = gpt2_wanda.run_with_saes(prompt, saes=[gpt2_wanda_sae], return_type="logits")
answer_token_id = gpt2_wanda.to_single_token(answer)

# Getting model's prediction
top_prob, token_id_prediction = logits[0, -1].softmax(-1).max(-1)
top_prob_sae, token_id_prediction_sae = logits_sae[0, -1].softmax(-1).max(-1)

print(f"""Standard model: top prediction = {gpt2_wanda.to_string(token_id_prediction)!r}, prob = {top_prob.item():.2%}
SAE reconstruction: top prediction = {gpt2_wanda.to_string(token_id_prediction_sae)!r}, prob = {top_prob_sae.item():.2%}
""")

##### Q. Can I predict either SAES based on the other?

## Gemma-2-2b

In [None]:
import torch as t
from datasets import load_dataset
import sae_lens
import transformer_lens
from sae_lens import (
    SAE,
    HookedSAETransformer,
)
from transformer_lens.utils import get_act_name, test_prompt, to_numpy
from dataclasses import dataclass


model = HookedSAETransformer.from_pretrained("gemma-2-2b")
model.load_state_dict(t.load('/home/gupte.31/COLM/sae-compression/gemma2b/pruned/gemma-2-2b_wanda.pth'))

t.set_grad_enabled(False)

# next we want to do a reconstruction test.
def reconstr_hook(activation, hook, sae_out):
    return sae_out

dataset = load_dataset(
    path="NeelNanda/pile-10k",
    split="train",
    streaming=False,
)

from transformer_lens.utils import tokenize_and_concatenate

token_dataset = tokenize_and_concatenate(
    dataset=dataset,  # type: ignore
    tokenizer=model.tokenizer,  # type: ignore
    streaming=True,
    max_length=1024,
    add_bos_token=True,
)

In [None]:
from sae_lens import SAE

release = "gemma-scope-2b-pt-res-canonical"
sae_id = "layer_12/width_16k/canonical"
sae = SAE.from_pretrained(release, sae_id, device="cuda")[0]

prompt = "Mitigating the risk of extinction from AI should be a global"
answer = " priority"

# # First see how the model does without SAEs
# test_prompt(prompt, answer, model)

# # Test our prompt, to see what the model says
# with model.saes(saes=[sae]):
#     test_prompt(prompt, answer, model)

# Using `run_with_saes` method in place of standard forward pass
logits = model(prompt, return_type="logits")
logits_sae = model.run_with_saes(prompt, saes=[sae], return_type="logits")
answer_token_id = model.to_single_token(answer)

# Getting model's prediction
top_prob, token_id_prediction = logits[0, -1].softmax(-1).max(-1)
top_prob_sae, token_id_prediction_sae = logits_sae[0, -1].softmax(-1).max(-1)

print(f"""Standard model: top prediction = {model.to_string(token_id_prediction)!r}, prob = {top_prob.item():.2%}
SAE reconstruction: top prediction = {model.to_string(token_id_prediction_sae)!r}, prob = {top_prob_sae.item():.2%}
""")

del logits, logits_sae, token_id_prediction, token_id_prediction_sae, top_prob, top_prob_sae
t.cuda.empty_cache()

# IOI task
prompt = "After John and Mary went to the store, John gave a bottle of milk to"
answer = " Mary"

# # First see how the model does without SAEs
# test_prompt(prompt, answer, model)

# # Test our prompt, to see what the model says
# with model.saes(saes=[sae]):
#     test_prompt(prompt, answer, model)

# Using `run_with_saes` method in place of standard forward pass
logits = model(prompt, return_type="logits")
logits_sae = model.run_with_saes(prompt, saes=[sae], return_type="logits")
answer_token_id = model.to_single_token(answer)

# Getting model's prediction
top_prob, token_id_prediction = logits[0, -1].softmax(-1).max(-1)
top_prob_sae, token_id_prediction_sae = logits_sae[0, -1].softmax(-1).max(-1)

print(f"""Standard model: top prediction = {model.to_string(token_id_prediction)!r}, prob = {top_prob.item():.2%}
SAE reconstruction: top prediction = {model.to_string(token_id_prediction_sae)!r}, prob = {top_prob_sae.item():.2%}
""")

del sae, logits, logits_sae, token_id_prediction, token_id_prediction_sae, top_prob, top_prob_sae
t.cuda.empty_cache()

from functools import partial

sae = SAE.from_pretrained(release, sae_id, device="cuda")[0]

sae.eval()  # prevents error if we're expecting a dead neuron mask for who grads
with t.no_grad():

    # activation store can give us tokens.
    batch_tokens = token_dataset[:2]["tokens"]
    _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)

    # Use the SAE
    feature_acts = sae.encode(cache[sae.cfg.hook_name])
    sae_out = sae.decode(feature_acts)

    # save some room
    del cache

    print(
    "Reconstuction loss:",
    model.run_with_hooks(
        batch_tokens,
        fwd_hooks=[
            (
                sae.cfg.hook_name,
                partial(reconstr_hook, sae_out=sae_out),
            )
        ],
        return_type="loss",
    ).item(),
    )

del sae, sae_out, feature_acts, batch_tokens


In [None]:
release = "suchitg/sae-compression-gemma-2-2b"
sae_id = "blocks.12.hook_resid_post"
sae = SAE.from_pretrained(release, sae_id, device="cuda")[0]


prompt = "Mitigating the risk of extinction from AI should be a global"
answer = " priority"

# # First see how the model does without SAEs
# test_prompt(prompt, answer, model)

# # Test our prompt, to see what the model says
# with model.saes(saes=[sae]):
#     test_prompt(prompt, answer, model)

# Using `run_with_saes` method in place of standard forward pass
logits = model(prompt, return_type="logits")
logits_sae = model.run_with_saes(prompt, saes=[sae], return_type="logits")
answer_token_id = model.to_single_token(answer)

# Getting model's prediction
top_prob, token_id_prediction = logits[0, -1].softmax(-1).max(-1)
top_prob_sae, token_id_prediction_sae = logits_sae[0, -1].softmax(-1).max(-1)

print(f"""Standard model: top prediction = {model.to_string(token_id_prediction)!r}, prob = {top_prob.item():.2%}
SAE reconstruction: top prediction = {model.to_string(token_id_prediction_sae)!r}, prob = {top_prob_sae.item():.2%}
""")

del logits, logits_sae, token_id_prediction, token_id_prediction_sae, top_prob, top_prob_sae


# IOI task
prompt = "After John and Mary went to the store, John gave a bottle of milk to"
answer = " Mary"

# # First see how the model does without SAEs
# test_prompt(prompt, answer, model)

# # Test our prompt, to see what the model says
# with model.saes(saes=[sae]):
#     test_prompt(prompt, answer, model)

# Using `run_with_saes` method in place of standard forward pass
logits = model(prompt, return_type="logits")
logits_sae = model.run_with_saes(prompt, saes=[sae], return_type="logits")
answer_token_id = model.to_single_token(answer)

# Getting model's prediction
top_prob, token_id_prediction = logits[0, -1].softmax(-1).max(-1)
top_prob_sae, token_id_prediction_sae = logits_sae[0, -1].softmax(-1).max(-1)

print(f"""Standard model: top prediction = {model.to_string(token_id_prediction)!r}, prob = {top_prob.item():.2%}
SAE reconstruction: top prediction = {model.to_string(token_id_prediction_sae)!r}, prob = {top_prob_sae.item():.2%}
""")
del sae, logits, logits_sae, token_id_prediction, token_id_prediction_sae, top_prob, top_prob_sae
t.cuda.empty_cache()

from functools import partial

sae = SAE.from_pretrained(release, sae_id, device="cuda")[0]

sae.eval()  # prevents error if we're expecting a dead neuron mask for who grads
with t.no_grad():

    # activation store can give us tokens.
    batch_tokens = token_dataset[:2]["tokens"]
    _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)

    # Use the SAE
    feature_acts = sae.encode(cache[sae.cfg.hook_name])
    sae_out = sae.decode(feature_acts)

    # save some room
    del cache

    print(
    "Reconstuction loss:",
    model.run_with_hooks(
        batch_tokens,
        fwd_hooks=[
            (
                sae.cfg.hook_name,
                partial(reconstr_hook, sae_out=sae_out),
            )
        ],
        return_type="loss",
    ).item(),
    )

del sae, sae_out, feature_acts, batch_tokens

In [None]:
from sae_lens import SAE

release = "gemma-scope-2b-pt-res-canonical"
sae_id = "layer_12/width_16k/canonical"
sae = SAE.from_pretrained(release, sae_id, device="cuda")[0]
sae.load_state_dict(t.load('/local/scratch/suchit/COLM/pruned_saes/gemma-2-2b/wanda/pile/hook_resid_post/blocks.12.hook_resid_post.pth'))

prompt = "Mitigating the risk of extinction from AI should be a global"
answer = " priority"

# # First see how the model does without SAEs
# test_prompt(prompt, answer, model)

# # Test our prompt, to see what the model says
# with model.saes(saes=[sae]):
#     test_prompt(prompt, answer, model)

# Using `run_with_saes` method in place of standard forward pass
logits = model(prompt, return_type="logits")
logits_sae = model.run_with_saes(prompt, saes=[sae], return_type="logits")
answer_token_id = model.to_single_token(answer)

# Getting model's prediction
top_prob, token_id_prediction = logits[0, -1].softmax(-1).max(-1)
top_prob_sae, token_id_prediction_sae = logits_sae[0, -1].softmax(-1).max(-1)

print(f"""Standard model: top prediction = {model.to_string(token_id_prediction)!r}, prob = {top_prob.item():.2%}
SAE reconstruction: top prediction = {model.to_string(token_id_prediction_sae)!r}, prob = {top_prob_sae.item():.2%}
""")

del logits, logits_sae, token_id_prediction, token_id_prediction_sae, top_prob, top_prob_sae
t.cuda.empty_cache()

# IOI task
prompt = "After John and Mary went to the store, John gave a bottle of milk to"
answer = " Mary"

# # First see how the model does without SAEs
# test_prompt(prompt, answer, model)

# # Test our prompt, to see what the model says
# with model.saes(saes=[sae]):
#     test_prompt(prompt, answer, model)

# Using `run_with_saes` method in place of standard forward pass
logits = model(prompt, return_type="logits")
logits_sae = model.run_with_saes(prompt, saes=[sae], return_type="logits")
answer_token_id = model.to_single_token(answer)

# Getting model's prediction
top_prob, token_id_prediction = logits[0, -1].softmax(-1).max(-1)
top_prob_sae, token_id_prediction_sae = logits_sae[0, -1].softmax(-1).max(-1)

print(f"""Standard model: top prediction = {model.to_string(token_id_prediction)!r}, prob = {top_prob.item():.2%}
SAE reconstruction: top prediction = {model.to_string(token_id_prediction_sae)!r}, prob = {top_prob_sae.item():.2%}
""")

del sae, logits, logits_sae, token_id_prediction, token_id_prediction_sae, top_prob, top_prob_sae
t.cuda.empty_cache()

from functools import partial

sae = SAE.from_pretrained(release, sae_id, device="cuda")[0]
sae.load_state_dict(t.load('/local/scratch/suchit/COLM/pruned_saes/gemma-2-2b/wanda/pile/hook_resid_post/blocks.12.hook_resid_post.pth'))


sae.eval()  # prevents error if we're expecting a dead neuron mask for who grads
with t.no_grad():

    # activation store can give us tokens.
    batch_tokens = token_dataset[:2]["tokens"]
    _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)

    # Use the SAE
    feature_acts = sae.encode(cache[sae.cfg.hook_name])
    sae_out = sae.decode(feature_acts)

    # save some room
    del cache

    print(
    "Reconstuction loss:",
    model.run_with_hooks(
        batch_tokens,
        fwd_hooks=[
            (
                sae.cfg.hook_name,
                partial(reconstr_hook, sae_out=sae_out),
            )
        ],
        return_type="loss",
    ).item(),
    )

del sae, sae_out, feature_acts, batch_tokens


In [None]:
from sae_lens import SAE

release = "gemma-scope-2b-pt-res-canonical"
sae_id = "layer_12/width_16k/canonical"
sae = SAE.from_pretrained(release, sae_id, device="cuda")[0]
sae.load_state_dict(t.load('/local/scratch/suchit/COLM/pruned_saes/gemma-2-2b/wanda/pile/hook_resid_post_ratio=0.5/blocks.12.hook_resid_post.pth'))

prompt = "Mitigating the risk of extinction from AI should be a global"
answer = " priority"

# # First see how the model does without SAEs
# test_prompt(prompt, answer, model)

# # Test our prompt, to see what the model says
# with model.saes(saes=[sae]):
#     test_prompt(prompt, answer, model)

# Using `run_with_saes` method in place of standard forward pass
logits = model(prompt, return_type="logits")
logits_sae = model.run_with_saes(prompt, saes=[sae], return_type="logits")
answer_token_id = model.to_single_token(answer)

# Getting model's prediction
top_prob, token_id_prediction = logits[0, -1].softmax(-1).max(-1)
top_prob_sae, token_id_prediction_sae = logits_sae[0, -1].softmax(-1).max(-1)

print(f"""Standard model: top prediction = {model.to_string(token_id_prediction)!r}, prob = {top_prob.item():.2%}
SAE reconstruction: top prediction = {model.to_string(token_id_prediction_sae)!r}, prob = {top_prob_sae.item():.2%}
""")

del logits, logits_sae, token_id_prediction, token_id_prediction_sae, top_prob, top_prob_sae
t.cuda.empty_cache()

# IOI task
prompt = "After John and Mary went to the store, John gave a bottle of milk to"
answer = " Mary"

# # First see how the model does without SAEs
# test_prompt(prompt, answer, model)

# # Test our prompt, to see what the model says
# with model.saes(saes=[sae]):
#     test_prompt(prompt, answer, model)

# Using `run_with_saes` method in place of standard forward pass
logits = model(prompt, return_type="logits")
logits_sae = model.run_with_saes(prompt, saes=[sae], return_type="logits")
answer_token_id = model.to_single_token(answer)

# Getting model's prediction
top_prob, token_id_prediction = logits[0, -1].softmax(-1).max(-1)
top_prob_sae, token_id_prediction_sae = logits_sae[0, -1].softmax(-1).max(-1)

print(f"""Standard model: top prediction = {model.to_string(token_id_prediction)!r}, prob = {top_prob.item():.2%}
SAE reconstruction: top prediction = {model.to_string(token_id_prediction_sae)!r}, prob = {top_prob_sae.item():.2%}
""")

del sae, logits, logits_sae, token_id_prediction, token_id_prediction_sae, top_prob, top_prob_sae
t.cuda.empty_cache()

from functools import partial

sae = SAE.from_pretrained(release, sae_id, device="cuda")[0]
sae.load_state_dict(t.load('/local/scratch/suchit/COLM/pruned_saes/gemma-2-2b/wanda/pile/hook_resid_post_ratio=0.5/blocks.12.hook_resid_post.pth'))



sae.eval()  # prevents error if we're expecting a dead neuron mask for who grads
with t.no_grad():

    # activation store can give us tokens.
    batch_tokens = token_dataset[:2]["tokens"]
    _, cache = model.run_with_cache(batch_tokens, prepend_bos=True)

    # Use the SAE
    feature_acts = sae.encode(cache[sae.cfg.hook_name])
    sae_out = sae.decode(feature_acts)

    # save some room
    del cache

    print(
    "Reconstuction loss:",
    model.run_with_hooks(
        batch_tokens,
        fwd_hooks=[
            (
                sae.cfg.hook_name,
                partial(reconstr_hook, sae_out=sae_out),
            )
        ],
        return_type="loss",
    ).item(),
    )

del sae, sae_out, feature_acts, batch_tokens


# Quantized

In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch

model_id = "google/gemma-2-2b"
model_id = "meta-llama/Llama-3.1-8B"

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Define quantization config
quant_config = BitsAndBytesConfig(
    load_in_4bit=True,                   # Set to True for 4-bit quantization
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",           # Options: "fp4" or "nf4"
    bnb_4bit_compute_dtype="float16"     # Can also be bfloat16 or float32
)


# Load quantized model
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="cuda:5",             # Automatically selects GPU
    quantization_config=quant_config
)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [None]:
from transformer_lens import HookedTransformer

# hookedmodel = HookedTransformer.from_pretrained("gemma-2-2b", device="cuda:5", dtype="float16")
hookedmodel = HookedTransformer.from_pretrained(model_id, device="cuda:5", dtype="float16")



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

torch.Size([4096, 4096]) torch.Size([1024, 4096]) torch.Size([1024, 4096])
torch.Size([32, 4096, 128]) torch.Size([8, 4096, 128]) torch.Size([8, 4096, 128])


NameError: name 'exit' is not defined

In [6]:
from transformer_lens import HookedTransformer

# quantized_model = HookedTransformer.from_pretrained_no_processing("gemma-2-2b", hf_model=model, device="cuda:5", dtype="float16")
quantized_model = HookedTransformer.from_pretrained_no_processing(model_id, hf_model=model, device="cuda:5", dtype="float16")

meta-llama/llama-3.1-8b True False
torch.Size([8388608, 1]) torch.Size([2097152, 1]) torch.Size([2097152, 1])
torch.Size([8388608, 1]) torch.Size([2097152, 1]) torch.Size([2097152, 1])
torch.Size([8388608, 1]) torch.Size([2097152, 1]) torch.Size([2097152, 1])
torch.Size([8388608, 1]) torch.Size([2097152, 1]) torch.Size([2097152, 1])
torch.Size([8388608, 1]) torch.Size([2097152, 1]) torch.Size([2097152, 1])
torch.Size([8388608, 1]) torch.Size([2097152, 1]) torch.Size([2097152, 1])
torch.Size([8388608, 1]) torch.Size([2097152, 1]) torch.Size([2097152, 1])
torch.Size([8388608, 1]) torch.Size([2097152, 1]) torch.Size([2097152, 1])
torch.Size([8388608, 1]) torch.Size([2097152, 1]) torch.Size([2097152, 1])
torch.Size([8388608, 1]) torch.Size([2097152, 1]) torch.Size([2097152, 1])
torch.Size([8388608, 1]) torch.Size([2097152, 1]) torch.Size([2097152, 1])
torch.Size([8388608, 1]) torch.Size([2097152, 1]) torch.Size([2097152, 1])
torch.Size([8388608, 1]) torch.Size([2097152, 1]) torch.Size([209

RuntimeError: Error(s) in loading state_dict for HookedTransformer:
	size mismatch for blocks.0.attn._W_K: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.0.attn._W_V: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.1.attn._W_K: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.1.attn._W_V: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.2.attn._W_K: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.2.attn._W_V: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.3.attn._W_K: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.3.attn._W_V: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.4.attn._W_K: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.4.attn._W_V: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.5.attn._W_K: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.5.attn._W_V: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.6.attn._W_K: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.6.attn._W_V: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.7.attn._W_K: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.7.attn._W_V: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.8.attn._W_K: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.8.attn._W_V: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.9.attn._W_K: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.9.attn._W_V: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.10.attn._W_K: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.10.attn._W_V: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.11.attn._W_K: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.11.attn._W_V: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.12.attn._W_K: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.12.attn._W_V: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.13.attn._W_K: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.13.attn._W_V: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.14.attn._W_K: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.14.attn._W_V: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.15.attn._W_K: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.15.attn._W_V: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.16.attn._W_K: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.16.attn._W_V: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.17.attn._W_K: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.17.attn._W_V: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.18.attn._W_K: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.18.attn._W_V: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.19.attn._W_K: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.19.attn._W_V: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.20.attn._W_K: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.20.attn._W_V: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.21.attn._W_K: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.21.attn._W_V: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.22.attn._W_K: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.22.attn._W_V: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.23.attn._W_K: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.23.attn._W_V: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.24.attn._W_K: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.24.attn._W_V: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.25.attn._W_K: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.25.attn._W_V: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.26.attn._W_K: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.26.attn._W_V: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.27.attn._W_K: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.27.attn._W_V: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.28.attn._W_K: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.28.attn._W_V: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.29.attn._W_K: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.29.attn._W_V: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.30.attn._W_K: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.30.attn._W_V: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.31.attn._W_K: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).
	size mismatch for blocks.31.attn._W_V: copying a param with shape torch.Size([2097152, 1]) from checkpoint, the shape in current model is torch.Size([8, 4096, 128]).

In [None]:
hookedmodel.cfg.load_in_4bit

In [None]:
hookedmodel.cfg.load_in_4bit = True

In [None]:
hookedmodel.cfg.load_in_4bit

In [None]:
from transformer_lens.pretrained.weight_conversions import convert_gemma_weights_q
from transformer_lens.loading_from_pretrained import get_pretrained_state_dict

state_dict1 = convert_gemma_weights_q(model, hookedmodel.cfg)
# state_dict2 = get_pretrained_state_dict("gemma-2-2b", hookedmodel.cfg, model, dtype="float16")

# assert state_dict1.keys() == state_dict2.keys()