In [1]:
# import plotly.express as px
from typing import Any, Dict, Optional, Protocol, Tuple
import os, sys
import torch
from torch.utils.data import DataLoader
from sae_lens import SAE
from pathlib import Path
import numpy as np
from sae_lens.toolkit.pretrained_sae_loaders import (
    gemma_2_sae_loader,
    get_gemma_2_config,
)
from sae_lens import SAE, SAEConfig, LanguageModelSAERunnerConfig, SAETrainingRunner

def load_gemma_2_sae(
    sae_path: str,
    device: str = "cpu",
    repo_id: str = "gemma-scope-9b-it-res",
    force_download: bool = False,
    cfg_overrides: Optional[Dict[str, Any]] = None,
    d_sae_override: Optional[int] = None,
    layer_override: Optional[int] = None,
) -> Tuple[Dict[str, Any], Dict[str, torch.Tensor], Optional[torch.Tensor]]:
    """
    Custom loader for Gemma 2 SAEs.
    """
    cfg_dict = get_gemma_2_config(repo_id, sae_path, d_sae_override, layer_override)
    cfg_dict["device"] = device

    # Apply overrides if provided
    if cfg_overrides is not None:
        cfg_dict.update(cfg_overrides)

    # Load and convert the weights
    state_dict = {}
    with np.load(os.path.join(sae_path, "params.npz")) as data:
        for key in data.keys():
            state_dict_key = "W_" + key[2:] if key.startswith("w_") else key
            state_dict[state_dict_key] = (
                torch.tensor(data[key]).to(dtype=torch.float32).to(device)
            )

    # Handle scaling factor
    if "scaling_factor" in state_dict:
        if torch.allclose(
            state_dict["scaling_factor"], torch.ones_like(state_dict["scaling_factor"])
        ):
            del state_dict["scaling_factor"]
            cfg_dict["finetuning_scaling_factor"] = False
        else:
            assert cfg_dict[
                "finetuning_scaling_factor"
            ], "Scaling factor is present but finetuning_scaling_factor is False."
            state_dict["finetuning_scaling_factor"] = state_dict.pop("scaling_factor")
    else:
        cfg_dict["finetuning_scaling_factor"] = False

    sae_cfg = SAEConfig.from_dict(cfg_dict)
    sae = SAE(sae_cfg)
    sae.load_state_dict(state_dict)

    # No sparsity tensor for Gemma 2 SAEs
    log_sparsity = None

    return sae, log_sparsity

  from .autonotebook import tqdm as notebook_tqdm


加载sae

In [2]:
layer = 20
sae, sparsity = load_gemma_2_sae(f"/disk3/wmr/hugging_cache/gemma-scope-9b-it-res/layer_20/width_16k/average_l0_91", device="cpu") # 24

In [3]:
import torch
neg_attr_name = f"/data2/xzwnlp/SaeEdit/ManipulateSAE/data/safety/toxic_DINM_it/sae_caa_vector_it/gemma-2-9b-it_safety/act_toxic_freq/feature_attr/gemma-2-9b-it_sae_layer20_resid_post_16k_neg_feature_freq.pt"
pos_attr_name = f"/data2/xzwnlp/SaeEdit/ManipulateSAE/data/safety/toxic_DINM_it/sae_caa_vector_it/gemma-2-9b-it_safety/act_toxic_freq/feature_attr/gemma-2-9b-it_sae_layer20_resid_post_16k_pos_feature_freq.pt"
neg_data = torch.load(neg_attr_name)
pos_data = torch.load(pos_attr_name)

In [4]:
pos_data.topk(50)

torch.return_types.topk(
values=tensor([4050., 4050., 4050., 4050., 4050., 4050., 4050., 4050., 4050., 4050.,
        4050., 4050., 4050., 4050., 4050., 4050., 4050., 4050., 4050., 4050.,
        4050., 4050., 4050., 4050., 4050., 4050., 4050., 4050., 4050., 4049.,
        4049., 4049., 4049., 4049., 4048., 4048., 4048., 4048., 4048., 4048.,
        4048., 4048., 4048., 4048., 4048., 4048., 4047., 4047., 4047., 4047.],
       device='cuda:0'),
indices=tensor([  771,  1295,  2226,  2698,  2962,  3200,  3206,  4480,  4851,  5045,
         5183,  5299,  5677,  6880,  8522,  8736,  9015,  9071,  9436, 10145,
        10964, 11395, 12023, 12789, 13937, 14203, 14878, 15913, 16043,  1107,
         9755, 14195, 14688, 15718,  3219,  4147,  5956,  5999,  6588,  6670,
         6794,  8384,  9462, 11927, 15294, 15655,  1415,  1694,  2065,  3972],
       device='cuda:0'))

In [5]:
(pos_data-neg_data).topk(50)

torch.return_types.topk(
values=tensor([3849., 3710., 3659., 3606., 3554., 3419., 3382., 3369., 3319., 3319.,
        3291., 3280., 3276., 3273., 3183., 3139., 3135., 3087., 3029., 3016.,
        2980., 2973., 2966., 2965., 2964., 2928., 2928., 2917., 2916., 2912.,
        2905., 2894., 2890., 2887., 2883., 2861., 2842., 2841., 2831., 2828.,
        2813., 2807., 2801., 2790., 2771., 2769., 2767., 2740., 2736., 2732.],
       device='cuda:0'),
indices=tensor([10457,  8228,  7276,  8664,  2834,  5566,  7854,  3791,  1883,  9015,
          683, 15762,  7521,  1264, 16170, 11700, 13144,  5344, 12309,   484,
         5168, 10940,  9128, 10737, 14055,  9272, 12957, 10731, 15956,  7104,
         7779,  6791,  5469,  5191,  3206,  1396, 14455,  3032,  8340, 10166,
         9918,  2112, 16069,  5964,  2446,  6968,  8780,  7780,  1527, 11198],
       device='cuda:0'))

In [6]:
(pos_data-neg_data)[pos_data.topk(50).indices].topk(2)

torch.return_types.topk(
values=tensor([3319., 2883.], device='cuda:0'),
indices=tensor([16,  6], device='cuda:0'))

In [7]:
pos_data.topk(50).indices[16]

tensor(9015, device='cuda:0')

In [8]:
pos_data[9015]-neg_data[9015]

tensor(3319., device='cuda:0')

In [10]:
act_attr_name = f"/data2/xzwnlp/SaeEdit/ManipulateSAE/data/safety/toxic_DINM_it/sae_caa_vector_it/gemma-2-9b-it_safety/act_toxic_freq/feature_attr/gemma-2-9b-it_sae_layer20_resid_post_16k_feature_score.pt"

act_data = torch.load(act_attr_name)
act_data[9015]

tensor(1.9905, device='cuda:0')

激活和频率取前pec的交集

In [11]:
scores = torch.zeros_like(act_data)  # 初始化综合得分
scores[9015] = 1
scores = scores.to(sae.W_dec.device)
result = scores @ sae.W_dec
print(result.shape)
print(result)
print(torch.norm(result))



torch.Size([3584])
tensor([ 0.0181, -0.0096,  0.0040,  ...,  0.0206,  0.0270,  0.0073],
       grad_fn=<SqueezeBackward4>)
tensor(1., grad_fn=<LinalgVectorNormBackward0>)


In [12]:
sae.W_dec[9015]

tensor([ 0.0181, -0.0096,  0.0040,  ...,  0.0206,  0.0270,  0.0073],
       grad_fn=<SelectBackward0>)

In [13]:
# steering_vector_path = "/data2/xzwnlp/SaeEdit/shae_old/data/toxic_DINM/feature/act_toxic_freq_or_act/steering_vector"
steering_vector_path = "/data2/xzwnlp/SaeEdit/ManipulateSAE/data/safety/toxic_DINM_it/sae_caa_vector_it/gemma-2-9b-it_safety/freq_top1/"
steering_vector_name = f"gemma-2-9b-it_sae_layer{layer}_resid_post_16k_steering_vector_top1.pt"
steering_vector_full_path = os.path.join(steering_vector_path, steering_vector_name)
torch.save(result, steering_vector_full_path)