In [2]:
import torch

from sae_lens import SAE
from transformers import AutoModelForCausalLM, AutoTokenizer

In [3]:
torch.manual_seed(16)
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x31b496ea0>

In [4]:

model_name = "google/gemma-2-2b-it"

model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [14]:
prompt = "Who are you?"
inputs = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=True).to("mps")

sae, cfg_dict, sparsity = SAE.from_pretrained(
    release = "gemma-scope-2b-pt-res-canonical",
    sae_id = "layer_20/width_16k/canonical",
    device = "mps"
)

def generate_with_sae(model,
                      sae: SAE,
                      layer: int,
                      feature_index: int,
                      alpha: float,
                      do_sample: bool,
                      temperature: float,
                      max_len: int,
                      repetition_penalty:float=1.0,
                      normalize_new_hidden=True):

    def gather_act_hook(mod, inputs, outputs):

        out = outputs[0]

        try:
            sparse_out = sae.encode_jumprelu(out)
            to_add = torch.zeros_like(sparse_out)
            to_add[:, :, feature_index] = alpha
            new_out = sparse_out + to_add
            out = sae.decode(new_out)
        except:
            print("error")
            out = outputs[0]

        return (out,)

    def gather_act_hook_normalized(mod, inputs, outputs):

        out = outputs[0]

        try:
            sparse_out = sae.encode_jumprelu(out)
            to_add = torch.zeros_like(sparse_out)
            to_add[:, :, feature_index] = alpha

            new_out = sparse_out + to_add
            new_out = (new_out / new_out.norm(dim=-1, keepdim=True)) * sparse_out.norm(dim=-1, keepdim=True)

            out = sae.decode(new_out)
        except:
            print("error")
            out = outputs[0]

        return (out,)

    if normalize_new_hidden:
        hook = model.model.layers[layer].register_forward_hook(gather_act_hook_normalized)
    else:
        hook = model.model.layers[layer].register_forward_hook(gather_act_hook)

    outputs = model.generate(inputs,
                             max_length=max_len,
                             temperature=temperature,
                             do_sample=do_sample,
                             repetition_penalty=repetition_penalty,
                             use_cache=False)
    hook.remove()

    return outputs


In [32]:
outputs = generate_with_sae(model,
                            sae,
                            layer=20,
                            feature_index=12082,
                            alpha=40.0,
                            do_sample=False,
                            temperature=0,
                            max_len=50,
                            repetition_penalty=1.0)

In [33]:
tokenizer.decode(outputs[0])

'<bos>Who are you?\n \n I am a large-scale language model, trained by Google.\n \n What is your purpose?\n \n My purpose is to help people understand and interact with information. I can translate languages, write'