<a href="https://colab.research.google.com/github/tomonari-masada/course2025-nlp/blob/main/11_transcoders.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Transcodersによる解釈

* Transcodersとは・・・
  * 元のモデルの中間層の表現を、より解釈しやすい表現に変換するためのニューラルネットワークである。
* 関連する論文
  * https://arxiv.org/abs/2406.11944
  * https://arxiv.org/abs/2408.05147

* 今回使ってみるtranscoders
  * https://huggingface.co/google/gemma-scope-2b-pt-transcoders

In [None]:
import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download, login

access_token = "" # ここには自分のアクセストークンを書き込む
# あるいは、次のセルをaccess_tokenなしで実行して、そのつど入力してもよい。

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
login(access_token)

In [None]:
torch.set_grad_enabled(False) # avoid blowing up mem

model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-2b",
    device_map='auto',
)

In [None]:
tokenizer =  AutoTokenizer.from_pretrained("google/gemma-2-2b")

In [None]:
# The input text
prompt = "Would you be able to travel through time using a wormhole?"

# Use the tokenizer to convert it to tokens. Note that this implicitly adds a special "Beginning of Sequence" or <bos> token to the start
inputs = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=True).to(device)
print(inputs)

# Pass it in to the model and generate text
outputs = model.generate(input_ids=inputs, max_new_tokens=2)
print(tokenizer.decode(outputs[0]))

* `l0_`の後の数値は、スパース性の度合いを表している。
  * ざっくり言えば、反応するニューロンの個数の期待値。
* この値が小さいtranscoderは、再構成の正確さを犠牲にして、より少数のニューロンしか発火しないようにしている。

In [None]:
path_to_params = hf_hub_download(
    repo_id="google/gemma-scope-2b-pt-transcoders",
    filename="layer_20/width_16k/average_l0_11/params.npz",
    force_download=False,
)

In [None]:
params = np.load(path_to_params)
pt_params = {k: torch.from_numpy(v).to(device) for k, v in params.items()}

In [None]:
pt_params['threshold'].shape

In [None]:
pt_params["W_enc"].norm(dim=0)

In [None]:
import torch.nn as nn
class JumpReLUSAE(nn.Module):
  def __init__(self, d_model, d_sae):
    # Note that we initialise these to zeros because we're loading in pre-trained weights.
    # If you want to train your own SAEs then we recommend using blah
    super().__init__()
    self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae))
    self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model))
    self.threshold = nn.Parameter(torch.zeros(d_sae))
    self.b_enc = nn.Parameter(torch.zeros(d_sae))
    self.b_dec = nn.Parameter(torch.zeros(d_model))

  def encode(self, input_acts):
    pre_acts = input_acts @ self.W_enc + self.b_enc
    mask = (pre_acts > self.threshold)
    acts = mask * torch.nn.functional.relu(pre_acts)
    return acts

  def decode(self, acts):
    return acts @ self.W_dec + self.b_dec

  def forward(self, acts):
    acts = self.encode(acts)
    recon = self.decode(acts)
    return recon

In [None]:
transcoder = JumpReLUSAE(params['W_enc'].shape[0], params['W_enc'].shape[1])
transcoder.load_state_dict(pt_params)

In [None]:
model.model.layers[20]

In [None]:
def gather_residual_activations(model, target_layer, inputs):
    target_act = []
    def gather_target_act_hook(mod, inputs, outputs):
        nonlocal target_act # make sure we can modify the target_act from the outer scope
        target_act.append((inputs[0].cpu(), outputs.cpu()))
        return outputs
    layer = model.model.layers[target_layer]
    handle_pre = layer.pre_feedforward_layernorm.register_forward_hook(gather_target_act_hook)
    handle_post = layer.post_feedforward_layernorm.register_forward_hook(gather_target_act_hook)
    model.forward(inputs)
    handle_pre.remove()
    handle_post.remove()
    return target_act

In [None]:
target_act = gather_residual_activations(model, 20, inputs)

In [None]:
target_act[0][0].shape, target_act[0][1].shape

In [None]:
target_act[1][0].shape, target_act[1][1].shape

In [None]:
target_act[0][1]

In [None]:
target_act[1][1]

In [None]:
transcoder_acts = transcoder.encode(target_act[0][1].to(torch.float32))
recon = transcoder.decode(transcoder_acts)

In [None]:
recon

In [None]:
1 - torch.mean((recon[0, :, 1:] - target_act[1][1][0, :, 1:].to(torch.float32)) **2) / (target_act[1][1][0, :, 1:].to(torch.float32).var())

In [None]:
(transcoder_acts > 1).sum(-1)