<a href="https://colab.research.google.com/github/tomonari-masada/course2025-nlp/blob/main/11_transcoders.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TranscodersによるLLMの挙動の解釈

* Transcodersとは・・・
  * 元のモデルの中間層の表現を、より解釈しやすい表現に変換するためのニューラルネットワークである。
* 関連する論文
  * https://arxiv.org/abs/2406.11944
  * https://arxiv.org/abs/2408.05147

* 今回使ってみるtranscoders
  * https://huggingface.co/google/gemma-scope-2b-pt-transcoders

* 以下のGemma Scope Tutorialを参考にした。
  * https://colab.research.google.com/drive/17dQFYUYnuKnP6OwQPH9v_GSYUW5aj-Rp
  * このコードを、transcoder用に書き換えた。

## 準備

In [None]:
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
from datasets import load_dataset, DatasetDict
from transformers import set_seed, AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download, login

access_token = "" # ここには自分のアクセストークンを書き込む
# あるいは、次のセルをaccess_tokenなしで実行して、そのつど入力してもよい。

device = "cuda" if torch.cuda.is_available() else "cpu"

set_seed(0)

In [None]:
login(access_token)

## LLM

In [None]:
torch.set_grad_enabled(False) # avoid blowing up mem

model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-2b",
    device_map='auto',
)
tokenizer =  AutoTokenizer.from_pretrained("google/gemma-2-2b")

## データセット
* AG Newsデータセットを使う。


In [None]:
ag_news_label = { 0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tec" }

ds = load_dataset("ag_news")
train_val = ds["train"].train_test_split(test_size=0.05)
ds = DatasetDict({
    "train": train_val["train"],
    "val": train_val["test"],
    "test": ds["test"],
})

ds

In [None]:
prompt = ds["train"]["text"][5][:53]
prompt

In [None]:
# Use the tokenizer to convert it to tokens. Note that this implicitly adds a special "Beginning of Sequence" or <bos> token to the start
inputs = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=True).to(device)
print(inputs)

# Pass it in to the model and generate text
outputs = model.generate(input_ids=inputs, max_new_tokens=50, do_sample=True, temperature=0.7)
print(tokenizer.decode(outputs[0]))

In [None]:
ds["train"]["text"][5]

## Transcoder

### transcoderの選定
* これは、どれが良いかは、分からない。
  * なので、いろいろ試してみる。

* `l0_`の後の数値は、スパース性の度合いを表している。
  * ざっくり言えば、反応するニューロンの個数の期待値。
* この値が小さいtranscoderは、再構成の正確さを犠牲にして、より少数のニューロンしか発火しないようにしている。

In [None]:
path_to_params = hf_hub_download(
    repo_id="google/gemma-scope-2b-pt-transcoders",
    filename="layer_20/width_16k/average_l0_11/params.npz",
    force_download=False,
)
params = np.load(path_to_params)
pt_params = {k: torch.from_numpy(v).to(device) for k, v in params.items()}

In [None]:
pt_params

In [None]:
num_transcoder_neurons = pt_params['threshold'].shape[0]
num_transcoder_neurons

In [None]:
pt_params["W_enc"].shape

### transcoderのforward pass

* 以下のコードはGemma Scope tutorialに書いてある。
* 以下のSAELensのソースにも同じ内容のことが書いてある。（`JumpReLUTranscoder`クラス）
  * https://github.com/decoderesearch/SAELens/blob/main/sae_lens/saes/transcoder.py

In [None]:
class JumpReLUSAE(nn.Module):
  def __init__(self, d_model, d_sae):
    # Note that we initialise these to zeros because we're loading in pre-trained weights.
    # If you want to train your own SAEs then we recommend using blah
    super().__init__()
    self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae))
    self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model))
    self.threshold = nn.Parameter(torch.zeros(d_sae))
    self.b_enc = nn.Parameter(torch.zeros(d_sae))
    self.b_dec = nn.Parameter(torch.zeros(d_model))

  def encode(self, input_acts):
    pre_acts = input_acts @ self.W_enc + self.b_enc
    mask = (pre_acts > self.threshold)
    acts = mask * torch.nn.functional.relu(pre_acts)
    return acts

  def decode(self, acts):
    return acts @ self.W_dec + self.b_dec

  def forward(self, acts):
    acts = self.encode(acts)
    recon = self.decode(acts)
    return recon

In [None]:
transcoder = JumpReLUSAE(params['W_enc'].shape[0], params['W_enc'].shape[1])
transcoder.load_state_dict(pt_params)

## forward hookの設定

In [None]:
target_layer = 20

model.model.layers[target_layer]

In [None]:
def gather_residual_activations(model, target_layer, inputs):
    target_act = []
    def gather_target_act_hook(mod, inputs, outputs):
        nonlocal target_act # make sure we can modify the target_act from the outer scope
        target_act.append({"input": inputs[0].squeeze().cpu(), "output": outputs.squeeze().cpu()})
        return outputs
    layer = model.model.layers[target_layer]
    handle_pre = layer.pre_feedforward_layernorm.register_forward_hook(gather_target_act_hook)
    handle_post = layer.post_feedforward_layernorm.register_forward_hook(gather_target_act_hook)
    model.forward(inputs)
    handle_pre.remove()
    handle_post.remove()
    return target_act

In [None]:
target_act = gather_residual_activations(model, target_layer=target_layer, inputs=inputs)

In [None]:
target_act[0]

In [None]:
target_act[1]

* MLPの直前のlayer normalizationの出力

In [None]:
target_act[0]["output"]

In [None]:
target_act[0]["output"].shape

* MLPの直後のlayer normalizationの出力

In [None]:
target_act[1]["output"]

In [None]:
transcoder_activations = transcoder.encode(target_act[0]["output"].to(torch.float32))
reconstructed = transcoder.decode(transcoder_activations)

In [None]:
reconstructed

In [None]:
((reconstructed - target_act[1]["output"].to(torch.float32)) ** 2).mean()

In [None]:
target_act[1]["output"].to(torch.float32).var()

In [None]:
1 - torch.mean((reconstructed - target_act[1]["output"].to(torch.float32)) **2) / target_act[1]["output"].to(torch.float32).var()

### transcoderのactivationsを調べる

In [None]:
transcoder_activations.shape

In [None]:
(transcoder_activations > 1).sum(-1)

In [None]:
(transcoder_activations > 0).sum(-1)

## Neuronpedia

### ニューロンの意味の調べ方
* LLMと、transcoderと、ニューロンのIDを指定する。
  * https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-transcoder-16k/0

### ニューロンの発火状況の調べ方

In [None]:
transcoder_activations.max(axis=-1)

In [None]:
torch.where(transcoder_activations > 1)[1]

* 同じテキスト内での複数の発火を別々に数える。

In [None]:
neuron_counts = torch.zeros(num_transcoder_neurons, device=device)
for index in torch.where(transcoder_activations > 1)[1]:
    neuron_counts[index] += 1

* 同じテキスト内で複数発火しても一回と数える。
  * 計算時間短縮のためこちらの方法にする。

In [None]:
neuron_counts = torch.zeros(num_transcoder_neurons, device=device)
neuron_counts[torch.where(transcoder_activations > 1)[1]] += 1

## 分析例: 科学技術関係のニュースの分析
* 最も頻繁に発火するニューロンがどのようなものかを調べる。

In [None]:
texts = []
for sample in ds["val"]:
    if sample["label"] == 3:
        texts.append(sample["text"])

In [None]:
neuron_counts = torch.zeros(num_transcoder_neurons, device=device)
for text in tqdm(texts):
    inputs = tokenizer.encode(text, return_tensors="pt", add_special_tokens=True).to(device)
    target_act = gather_residual_activations(model, target_layer=target_layer, inputs=inputs)
    transcoder_activations = transcoder.encode(target_act[0]["output"].to(torch.float32))
    neuron_counts[torch.where(transcoder_activations > 1)[1]] += 1

In [None]:
neuron_counts.argsort(descending=True)[:20]

  * https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-transcoder-16k/6969

In [None]:
for index in neuron_counts.argsort(descending=True)[:20]:
    print(f"* https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-transcoder-16k/{index.item()} ")