In [1]:
# import plotly.express as px
from typing import Any, Dict, Optional, Protocol, Tuple
import os, sys
import torch
from torch.utils.data import DataLoader
from sae_lens import SAE
from pathlib import Path
import numpy as np
from sae_lens.toolkit.pretrained_sae_loaders import (
    gemma_2_sae_loader,
    get_gemma_2_config,
)
from sae_lens import SAE, SAEConfig, LanguageModelSAERunnerConfig, SAETrainingRunner

def load_gemma_2_sae(
    sae_path: str,
    device: str = "cpu",
    repo_id: str = "gemma-scope-9b-it-res",
    force_download: bool = False,
    cfg_overrides: Optional[Dict[str, Any]] = None,
    d_sae_override: Optional[int] = None,
    layer_override: Optional[int] = None,
) -> Tuple[Dict[str, Any], Dict[str, torch.Tensor], Optional[torch.Tensor]]:
    """
    Custom loader for Gemma 2 SAEs.
    """
    cfg_dict = get_gemma_2_config(repo_id, sae_path, d_sae_override, layer_override)
    cfg_dict["device"] = device

    # Apply overrides if provided
    if cfg_overrides is not None:
        cfg_dict.update(cfg_overrides)

    # Load and convert the weights
    state_dict = {}
    with np.load(os.path.join(sae_path, "params.npz")) as data:
        for key in data.keys():
            state_dict_key = "W_" + key[2:] if key.startswith("w_") else key
            state_dict[state_dict_key] = (
                torch.tensor(data[key]).to(dtype=torch.float32).to(device)
            )

    # Handle scaling factor
    if "scaling_factor" in state_dict:
        if torch.allclose(
            state_dict["scaling_factor"], torch.ones_like(state_dict["scaling_factor"])
        ):
            del state_dict["scaling_factor"]
            cfg_dict["finetuning_scaling_factor"] = False
        else:
            assert cfg_dict[
                "finetuning_scaling_factor"
            ], "Scaling factor is present but finetuning_scaling_factor is False."
            state_dict["finetuning_scaling_factor"] = state_dict.pop("scaling_factor")
    else:
        cfg_dict["finetuning_scaling_factor"] = False

    sae_cfg = SAEConfig.from_dict(cfg_dict)
    sae = SAE(sae_cfg)
    sae.load_state_dict(state_dict)

    # No sparsity tensor for Gemma 2 SAEs
    log_sparsity = None

    return sae, log_sparsity

  from .autonotebook import tqdm as notebook_tqdm


加载sae

In [2]:
layer = 24
# sae, sparsity = load_gemma_2_sae(f"/mnt/16t/xzwnlp/hugging_cache/gemma-scope-9b-pt-res/layer_{layer}/width_16k/average_l0_131", device="cpu") # 15
# sae, sparsity = load_gemma_2_sae(f"/mnt/16t/xzwnlp/hugging_cache/gemma-scope-9b-pt-res/layer_{layer}/width_16k/average_l0_73", device="cpu") # 17
# sae, sparsity = load_gemma_2_sae(f"/mnt/16t/xzwnlp/hugging_cache/gemma-scope-9b-pt-res/layer_{layer}/width_16k/average_l0_132", device="cpu") # 19
# sae, sparsity = load_gemma_2_sae(f"/mnt/16t/xzwnlp/hugging_cache/gemma-scope-9b-pt-res/layer_{layer}/width_16k/average_l0_129", device="cpu") # 21
# sae, sparsity = load_gemma_2_sae(f"/mnt/16t/xzwnlp/hugging_cache/gemma-scope-9b-pt-res/layer_{layer}/width_16k/average_l0_123", device="cpu") # 22
# sae, sparsity = load_gemma_2_sae(f"/mnt/16t/xzwnlp/hugging_cache/gemma-scope-9b-pt-res/layer_{layer}/width_16k/average_l0_120", device="cpu") # 23
sae, sparsity = load_gemma_2_sae(f"/data2/xzwnlp/gemma-scope-9b-pt-res/layer_24/width_16k/average_l0_114", device="cpu") # 24
# sae, sparsity = load_gemma_2_sae(f"/mnt/16t/xzwnlp/hugging_cache/gemma-scope-9b-pt-res/layer_{layer}/width_16k/average_l0_114", device="cpu") # 25
# sae, sparsity = load_gemma_2_sae(f"/mnt/16t/xzwnlp/hugging_cache/gemma-scope-9b-pt-res/layer_{layer}/width_16k/average_l0_118", device="cpu") # 27
# sae, sparsity = load_gemma_2_sae(f"/mnt/16t/xzwnlp/hugging_cache/gemma-scope-9b-pt-res/layer_{layer}/width_16k/average_l0_119", device="cpu") # 29
# sae, sparsity = load_gemma_2_sae(f"/mnt/16t/xzwnlp/hugging_cache/gemma-scope-9b-pt-res/layer_{layer}/width_16k/average_l0_114", device="cpu") # 31
# sae, sparsity = load_gemma_2_sae(f"/mnt/16t/xzwnlp/hugging_cache/gemma-scope-9b-pt-res/layer_{layer}/width_16k/average_l0_114", device="cpu") # 33
# sae, sparsity = load_gemma_2_sae(f"/mnt/16t/xzwnlp/hugging_cache/gemma-scope-9b-pt-res/layer_{layer}/width_16k/average_l0_120", device="cpu") # 35
# sae, sparsity = load_gemma_2_sae(f"/mnt/16t/xzwnlp/hugging_cache/gemma-scope-9b-pt-res/layer_{layer}/width_16k/average_l0_124", device="cpu") # 37
# sae, sparsity = load_gemma_2_sae(f"/mnt/16t/xzwnlp/hugging_cache/gemma-scope-9b-pt-res/layer_{layer}/width_16k/average_l0_131", device="cpu") # 39
# sae, sparsity = load_gemma_2_sae(f"/mnt/16t/xzwnlp/hugging_cache/gemma-scope-9b-pt-res/layer_{layer}/width_16k/average_l0_113", device="cpu") # 41
# sae, sparsity = load_gemma_2_sae(f"/mnt/16t/xzwnlp/hugging_cache/gemma-scope-9b-pt-res/layer_{layer}/width_16k/average_l0_75", device="cpu") # 16
# sae, sparsity = load_gemma_2_sae(f"/mnt/16t/xzwnlp/hugging_cache/gemma-scope-9b-pt-res/layer_{layer}/width_16k/average_l0_71", device="cpu") # 18
# sae, sparsity = load_gemma_2_sae(f"/mnt/16t/xzwnlp/hugging_cache/gemma-scope-9b-pt-res/layer_{layer}/width_16k/average_l0_68", device="cpu") # 20
# sae, sparsity = load_gemma_2_sae(f"/mnt/16t/xzwnlp/hugging_cache/gemma-scope-9b-pt-res/layer_{layer}/width_16k/average_l0_116", device="cpu") # 26
# sae, sparsity = load_gemma_2_sae(f"/mnt/16t/xzwnlp/hugging_cache/gemma-scope-9b-pt-res/layer_{layer}/width_16k/average_l0_119", device="cpu") # 28

In [5]:
import torch
neg_attr_name = f"/data2/xzwnlp/SaeEdit/ManipulateSAE/data/safety/toxic_DINM_pt/sae_caa_vector_pt/gemma-2-9b_safety/act_toxic_freq/feature_attr/gemma-2-9b_sae_layer24_resid_post_16k_neg_feature_freq.pt"
pos_attr_name = f"/data2/xzwnlp/SaeEdit/ManipulateSAE/data/safety/toxic_DINM_pt/sae_caa_vector_pt/gemma-2-9b_safety/act_toxic_freq/feature_attr/gemma-2-9b_sae_layer24_resid_post_16k_pos_feature_freq.pt"
neg_data = torch.load(neg_attr_name)
pos_data = torch.load(pos_attr_name)

In [9]:
pos_data.topk(50)

torch.return_types.topk(
values=tensor([4050., 4050., 4050., 4050., 4050., 4050., 4050., 4050., 4050., 4050.,
        4050., 4050., 4050., 4050., 4050., 4050., 4050., 4050., 4050., 4050.,
        4050., 4050., 4050., 4050., 4050., 4050., 4050., 4050., 4050., 4050.,
        4050., 4050., 4050., 4050., 4050., 4050., 4050., 4050., 4050., 4050.,
        4050., 4050., 4050., 4050., 4050., 4050., 4050., 4050., 4050., 4049.],
       device='cuda:0'),
indices=tensor([  451,  1071,  1405,  1834,  1843,  1927,  2133,  2171,  2233,  3228,
         3614,  3751,  4261,  4570,  4581,  4646,  5286,  6578,  6695,  6932,
         6949,  7618,  7794,  8017,  9227,  9548,  9705, 10310, 10494, 10537,
        10968, 10982, 11214, 11769, 11955, 12399, 12715, 12806, 13551, 14031,
        14802, 14984, 15177, 15280, 15476, 15763, 15794, 15866, 16162,   459],
       device='cuda:0'))

In [14]:
(pos_data-neg_data).topk(50)

torch.return_types.topk(
values=tensor([3725., 3631., 3620., 3609., 3558., 3511., 3483., 3478., 3472., 3426.,
        3420., 3373., 3345., 3318., 3274., 3264., 3210., 3209., 3202., 3176.,
        3170., 3167., 3166., 3145., 3126., 3103., 3086., 3079., 3071., 3055.,
        3054., 3025., 2964., 2961., 2957., 2936., 2921., 2917., 2912., 2885.,
        2874., 2873., 2864., 2853., 2851., 2850., 2842., 2841., 2840., 2839.],
       device='cuda:0'),
indices=tensor([ 6942, 10705, 16061,  9127,  8912,  4212,   752,  5803,  4082,  2402,
          427,  4071, 11065,  4279,  5512, 14310,  6587,  7423,  5356,  5537,
         9853, 15839,  6142, 15049, 12313,  4309,  3657,  9911, 15878,   683,
        10900, 16326, 10335, 15836,  8193, 12172, 13078,  3922, 15760,  3466,
         5054,  9458,  7443,  2128,  1866,  7350, 13289,  1310, 11015, 14107],
       device='cuda:0'))

In [20]:
(pos_data-neg_data)[pos_data.topk(50).indices].topk(2)

torch.return_types.topk(
values=tensor([2114., 2069.], device='cuda:0'),
indices=tensor([12, 32], device='cuda:0'))

In [23]:
pos_data.topk(50).indices[12]

tensor(4261, device='cuda:0')

In [24]:
pos_data[4261]-neg_data[4261]

tensor(2114., device='cuda:0')

In [26]:
act_attr_name = f"/data2/xzwnlp/SaeEdit/ManipulateSAE/data/safety/toxic_DINM_pt/sae_caa_vector_pt/gemma-2-9b_safety/act_toxic_freq/feature_attr/gemma-2-9b_sae_layer24_resid_post_16k_feature_score.pt"

act_data = torch.load(act_attr_name)
act_data[4261]

tensor(1.5759, device='cuda:0')

激活和频率取前pec的交集

In [27]:
scores = torch.zeros_like(act_data)  # 初始化综合得分
scores[4261] = 1
scores = scores.to(sae.W_dec.device)
result = scores @ sae.W_dec
print(result.shape)
print(result)
print(torch.norm(result))



torch.Size([3584])
tensor([ 0.0020,  0.0015,  0.0046,  ..., -0.0091, -0.0010, -0.0160],
       grad_fn=<SqueezeBackward4>)
tensor(1., grad_fn=<LinalgVectorNormBackward0>)


In [None]:
# steering_vector_path = "/data2/xzwnlp/SaeEdit/shae_old/data/toxic_DINM/feature/act_toxic_freq_or_act/steering_vector"
steering_vector_path = "/data2/xzwnlp/SaeEdit/ManipulateSAE/data/safety/toxic_DINM_pt/sae_caa_vector_pt/gemma-2-9b_safety/freq_top1/"
steering_vector_name = f"gemma-2-9b_sae_layer{layer}_resid_post_16k_steering_vector_top1.pt"
steering_vector_full_path = os.path.join(steering_vector_path, steering_vector_name)
torch.save(result, steering_vector_full_path)