In [1]:
# Standard library imports
import os
import random

# Third-party imports
import numpy as np
import torch
import matplotlib.pyplot as plt
from datasets import concatenate_datasets, load_dataset
from datasets import load_dataset
from IPython.display import HTML, display
from sklearn.decomposition import PCA
from sklearn.metrics import roc_auc_score
from sklearn.metrics.pairwise import cosine_similarity

# Local imports
from safetytooling.internals.model_wrapper import LanguageModelWrapper, ModelConfig
from safetytooling.internals.utils import *
from safetytooling.internals.sae_wrappers import *
from safetytooling.internals.feature_database import FeatureDatabase
from safetytooling.internals.visuals import *

In [2]:
wrapped_model = load_goodfire_llama3_8b_sae_wrapped()

  state_dict = torch.load(file_path, map_location=device)


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [3]:
generation = wrapped_model.sample_generations(
    ["Hello, how are you?", "What are you up to over the weekend?"],
    max_length=100,
    temperature=0.7,
    top_p=0.95,
    format_inputs=True,
)

featurized_examples = wrapped_model.featurize_text(generation)

Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
Both `max_new_tokens` (=20) and `max_length`(=100) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


In [4]:
print(generation[0])

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

<|start_header_id|>user<|end_header_id|>

Hello, how are you?<|start_header_id|>assistant<|end_header_id|>

I'm functioning properly, thank you for asking. How can I assist you today?


In [5]:
# Creating a dataset to find max activating examples
def create_dataset(wrapped_model, size=100_000):
    redpajama = load_dataset("togethercomputer/RedPajama-Data-1T-Sample", split="train", trust_remote_code=True)
    redpajama = redpajama.map(lambda x: {"text": f"{wrapped_model.tokenizer.bos_token}{x['text']}"})
    
    ultrachat = load_dataset("LLM-LAT/benign-dataset", split="train", trust_remote_code=True)
    ultrachat = ultrachat.map(
        lambda x: {"text": wrapped_model.tokenizer.apply_chat_template([
            {"role": "user", "content": x['prompt']},
            {"role": "assistant", "content": x['response']}
        ], tokenize=False)},
        remove_columns=ultrachat.column_names
    )
    
    redpajama_size = int(size * 0.9)
    ultrachat_size = size - redpajama_size
    
    redpajama_subset = redpajama.shuffle(seed=42).select(range(redpajama_size))
    ultrachat_subset = ultrachat.shuffle(seed=42).select(range(ultrachat_size))
    
    return concatenate_datasets([redpajama_subset, ultrachat_subset]).shuffle(seed=42)

# Create feature database
feature_db = FeatureDatabase(wrapped_model)

# If it's not already cached, process the dataset
if not os.path.exists("cached_features_test_goodfire"):
    dataset = create_dataset(wrapped_model)
    
    feature_db.process_dataset(
        texts=dataset["text"],
        save_dir="cached_features_test_goodfire",
        max_length=128,
        batch_size=128,
    )

print("Loading from disk")
feature_db.load_from_disk("cached_features_test_goodfire")

Loading from disk


In [6]:
### Testing example loading
example = feature_db.load_example(0)
print(example["text"])
print(example.keys())

### Testing common feature loading
common_features = feature_db.get_common_features("model.layers.19", k=1000)

<|begin_of_text|>Q: Number of copies of irreducible unitary representation in $L^2(G)$ for compact group $G$? Peter-Weyl Theorem is concerned with expressing $L^2(G)$ as closure of direct sum of subspace generated by irreducible unitary representation. Every irreducible representation on a compact group is finite dimensional. Why $L^2(G)$ has exactly subspace with dimension number of copies of corresponding irreducible unitary representation? I have problem with investigating why there are no more copies!! 

A: If not, consider the restriction of regular representation to the complement of the direct sum of the sub
dict_keys(['token_ids', 'str_tokens', 'top_indices', 'top_acts', 'text'])


In [9]:
print(common_features)

### Testing top activating examples
feature_idx = common_features[800]
top_examples = feature_db.get_top_activating_examples("model.layers.19", feature_idx)

### Testing quantiles of activations
quantile_examples = feature_db.get_quantile_examples("model.layers.19", feature_idx, n_buckets=10, n_examples=10)

### Testing feature explanations
feature_exp = wrapped_model.get_feature_description("model.layers.19", feature_idx)


[15, 33, 18, 22, 52, 4, 10, 57, 17, 34, 45, 79, 19, 48, 7, 63, 13, 77, 80, 71, 41, 38, 68, 2, 5, 27, 66, 65, 87, 29, 35, 56, 94, 51, 102, 53, 36, 70, 1, 88, 49, 81, 96, 25, 40, 16, 91, 64, 31, 37, 46, 54, 24, 21, 6, 55, 62, 73, 105, 23, 99, 12, 97, 82, 107, 9, 85, 83, 78, 50, 43, 110, 89, 76, 108, 69, 106, 112, 61, 125, 109, 126, 75, 113, 26, 92, 67, 72, 11, 117, 115, 101, 104, 100, 14, 3, 59, 111, 122, 74, 131, 120, 132, 124, 138, 156, 127, 134, 130, 119, 42733, 93, 121, 47, 30963, 4361, 162, 149, 137, 39079, 24255, 22504, 51786, 42472, 11087, 140, 147, 145, 103, 52256, 18045, 148, 143, 63404, 56619, 59149, 52212, 8707, 56211, 53770, 49844, 163, 25315, 128, 37869, 10474, 62389, 187, 174, 144, 22761, 173, 13365, 49505, 47640, 184, 33688, 123, 36050, 152, 2515, 141, 53839, 46908, 172, 33310, 164, 177, 41183, 20469, 27230, 168, 64607, 13757, 12549, 32561, 44479, 43309, 24356, 181, 135, 20078, 41290, 61727, 153, 36917, 57326, 25784, 154, 44291, 26091, 32750, 23665, 37677, 56873, 40915, 23

In [12]:
example_dict = {}
example_dict["Top Activating Examples"] = top_examples
for i, (quantile, examples) in enumerate(list(quantile_examples.items())[::-1]):
    example_dict[f"Quantile {quantile}"] = examples

my_html = highlight_tokens_dict(example_dict, use_orange_highlight=True, title=feature_exp)
display(HTML(my_html))

