In [1]:
from dictionary_learning import CrossCoder
from torch.nn.functional import cosine_similarity
import torch as th
import plotly.express as px
from pathlib import Path
from tqdm.notebook import tqdm

th.set_grad_enabled(False)
exp_name = "eval_crosscoder"

In [2]:
crosscoder_path = "/dlabscratch1/jminder/repositories/representation-structure-comparison/checkpoints/l13-mu4.0e-02-lr1e-04/ae_90000.pt"
extra_args = []
exp_id = "test"
device = "cuda:1"
seed = 42
base_model = "gemma-2-2b"
instruct_model = "gemma-2-2b-it"
layer = 13
activation_dir = Path(
    "/dlabscratch1/jminder/repositories/representation-structure-comparison/activations"
)
validation_size = 10**6
batch_size = 2048
workers = 12
SEQ_LEN = 1024

In [3]:
crosscoder = CrossCoder.from_pretrained(crosscoder_path)
num_layers, activation_dim, dict_size = crosscoder.encoder.weight.shape

In [4]:
from sae_lens import SAE, SAEConfig
from argparse import ArgumentParser

parser = ArgumentParser()
parser.add_argument("--base-device", type=str, default=device)
parser.add_argument("--instruct-device", type=str, default=device)
args = parser.parse_args(extra_args)

base_config_dict = {
    "architecture": "standard",
    "d_in": activation_dim,
    "d_sae": dict_size,
    "dtype": "float32",
    "model_name": base_model,
    "hook_name": f"blocks.{layer}.hook_resid_post",
    "hook_layer": layer,
    "hook_head_index": None,
    "activation_fn_str": "relu",
    "finetuning_scaling_factor": False,
    "sae_lens_training_version": None,
    "prepend_bos": True,
    "dataset_path": None,
    "context_size": 1024,
    "dataset_trust_remote_code": False,
    "apply_b_dec_to_input": False,
    "normalize_activations": None,
    "device": args.base_device,
}
base_config = SAEConfig.from_dict(base_config_dict)

it_config_dict = base_config_dict.copy()
it_config_dict["model_name"] = instruct_model
it_config_dict["device"] = args.instruct_device
it_config = SAEConfig.from_dict(it_config_dict)


def gen_state_dict(model_idx):
    return {
        'b_enc': crosscoder.encoder.bias,
        'W_enc': crosscoder.encoder.weight[model_idx],
        'b_dec': crosscoder.decoder.bias[model_idx],
        'W_dec': crosscoder.decoder.weight[model_idx],
    }

## Base Visualization

In [5]:
base_sae = SAE(base_config)
base_sae.load_state_dict(gen_state_dict(0))
base_sae.fold_W_dec_norm()


<All keys matched successfully>

In [8]:
from transformer_lens import HookedTransformer, utils
from sae_dashboard.sae_vis_data import SaeVisConfig
from sae_dashboard.sae_vis_runner import SaeVisRunner
from datasets import load_dataset, load_from_disk

# Load model and SAE
model = HookedTransformer.from_pretrained_no_processing(
    base_model, device=args.base_device, dtype="bfloat16"
)


# Configure visualization
config = SaeVisConfig(
    hook_point=base_sae.cfg.hook_name,
    features=list(range(256)),
    minibatch_size_features=64,
    minibatch_size_tokens=256,
    device=args.base_device,
    dtype="bfloat16",
)

fineweb = load_dataset(
    "HuggingFaceFW/fineweb",
    name="sample-10BT",
    split="train",
    cache_dir=Path("/dlabscratch1/cdumas/.cache/huggingface/datasets/"),
)
# select 30 random samples
# fineweb = fineweb[:30]
# tokenized_data = utils.tokenize_and_concatenate(fineweb["text"], model.tokenizer, max_length=SEQ_LEN)  # type: ignore
# lmsys = load_from_disk("/dlabscratch1/public/datasets/lmsys-chat-1m-formatted/")
# lmsys = lmsys.select(range(30))
# tokenized_data = utils.tokenize_and_concatenate(lmsys["text"], model.tokenizer, max_length=SEQ_LEN)  # type: ignore

# tokenized_data = tokenized_data.shuffle(42)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 64.00 MiB. GPU 1 has a total capacity of 11.92 GiB of which 39.19 MiB is free. Including non-PyTorch memory, this process has 11.88 GiB memory in use. Of the allocated memory 11.51 GiB is allocated by PyTorch, and 270.78 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [7]:
type(fineweb[:30])

TypeError: type() takes 1 or 3 arguments

## IT visualization

In [None]:
it_sae = SAE(it_config)
it_sae.load_state_dict(gen_state_dict(1))
