In [1]:
# Standard library imports
import os
import random

# Third-party imports
import numpy as np
import torch
import matplotlib.pyplot as plt
from datasets import concatenate_datasets, load_dataset
from datasets import load_dataset
from IPython.display import HTML, display
from sklearn.decomposition import PCA
from sklearn.metrics import roc_auc_score
from sklearn.metrics.pairwise import cosine_similarity

# Local imports
from safetytooling.internals.model_wrapper import LanguageModelWrapper, ModelConfig
from safetytooling.internals.utils import *
from safetytooling.internals.sae_wrappers import *
from safetytooling.internals.feature_database import FeatureDatabase
from safetytooling.internals.visuals import *

In [2]:
wrapped_model = load_eleuther_llama3_sae_wrapped(layers=[11, 15])

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [3]:
generation = wrapped_model.sample_generations(
    ["Hello, how are you?", "What are you up to over the weekend?"],
    max_length=100,
    temperature=0.7,
    top_p=0.95,
    format_inputs=True,
)

featurized_examples = wrapped_model.featurize_text(generation)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
Both `max_new_tokens` (=20) and `max_length`(=100) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)


We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


In [4]:
print(generation[0])

<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Hello, how are you?<|start_header_id|>assistant<|end_header_id|>

Hello! I'm just a language model, so I don't have feelings or emotions like humans do


In [5]:
# Creating a dataset to find max activating examples
def create_dataset(wrapped_model, size=100_000):
    redpajama = load_dataset("togethercomputer/RedPajama-Data-1T-Sample", split="train", trust_remote_code=True)
    redpajama = redpajama.map(lambda x: {"text": f"{wrapped_model.tokenizer.bos_token}{x['text']}"})
    
    ultrachat = load_dataset("LLM-LAT/benign-dataset", split="train", trust_remote_code=True)
    ultrachat = ultrachat.map(
        lambda x: {"text": wrapped_model.tokenizer.apply_chat_template([
            {"role": "user", "content": x['prompt']},
            {"role": "assistant", "content": x['response']}
        ], tokenize=False)},
        remove_columns=ultrachat.column_names
    )
    
    redpajama_size = int(size * 0.9)
    ultrachat_size = size - redpajama_size
    
    redpajama_subset = redpajama.shuffle(seed=42).select(range(redpajama_size))
    ultrachat_subset = ultrachat.shuffle(seed=42).select(range(ultrachat_size))
    
    return concatenate_datasets([redpajama_subset, ultrachat_subset]).shuffle(seed=42)

# Create feature database
feature_db = FeatureDatabase(wrapped_model)

# If it's not already cached, process the dataset
if not os.path.exists("cached_features_test"):
    dataset = create_dataset(wrapped_model)
    
    feature_db.process_dataset(
        texts=dataset["text"],
        save_dir="cached_features_test",
        max_length=128,
        batch_size=128,
    )

print("Loading from disk")
feature_db.load_from_disk("cached_features_test")

Loading from disk


In [6]:
### Testing example loading
example = feature_db.load_example(0)
print(example["text"])
print(example.keys())

### Testing common feature loading
common_features = feature_db.get_common_features("model.layers.11", k=1000)

<|begin_of_text|>Physical Accessibility: The main entrance is level with the sidewalk, no automatic front door. Plenty of space to move around inside the restaurant space. Washrooms are located in the basement.
Staycation - walking on the beach is the perfect vacation.
Staycation! I travel all the time for work (HIV research) that it is a joy to be at home on the couch.
Vacation 28 year old Norwegian, She/Her. Doing fieldwork in Toronto until September. Writing a thesis about Body hair. Love swimming in the ocean, netflix, traveling, and cooking.
I think I'd prefer a vacation, but I
dict_keys(['token_ids', 'str_tokens', 'top_indices', 'top_acts', 'text'])


In [10]:
### Testing top activating examples
top_examples = feature_db.get_top_activating_examples("model.layers.11", common_features[310])

### Testing quantiles of activations
quantile_examples = feature_db.get_quantile_examples("model.layers.11", common_features[310], n_buckets=10, n_examples=10)


[0, 128847, 51856, 4669, 57203, 7296, 111097, 32355, 130624, 127172, 130216, 119149, 31597, 65978, 129612, 8630, 117352, 49189, 33786, 91280, 73849, 125892, 8884, 42157, 42927, 48590, 37267, 118046, 59257, 106849, 35938, 68229, 5072, 101134, 30561, 61805, 58887, 42725, 107414, 39677, 62854, 3760, 48802, 88153, 41143, 32514, 123658, 9936, 81290, 113942, 37350, 100935, 10802, 108916, 102415, 89811, 15115, 16627, 105677, 81352, 38225, 90562, 98721, 2021, 78227, 49597, 95412, 99063, 43594, 9152, 55974, 90862, 3132, 45982, 124241, 105089, 117304, 57055, 25523, 13464, 74978, 104367, 41202, 89707, 46904, 114753, 51475, 77538, 16347, 117326, 45309, 32550, 4225, 64627, 38566, 47682, 77476, 105551, 100187, 41455, 65833, 40716, 1755, 116106, 49274, 95284, 118039, 5787, 78944, 6579, 111555, 57282, 91156, 57202, 35038, 105769, 40546, 96765, 15405, 118268, 83993, 123243, 12067, 121006, 119634, 72737, 66098, 95428, 116320, 59146, 49405, 45426, 18075, 81672, 104960, 130943, 113318, 6329, 119314, 10880

In [11]:
example_dict = {}
example_dict["Top Activating Examples"] = top_examples
for i, (quantile, examples) in enumerate(list(quantile_examples.items())[::-1]):
    example_dict[f"Quantile {quantile}"] = examples

my_html = highlight_tokens_dict(example_dict, use_orange_highlight=True)
display(HTML(my_html))

