## Setup

### Globals

In [None]:
LOGS_DIR = "feature_interp_logs"
EXPERIMENTS_DIR = "feature_interp_exp"
NEURONPEDIA_EXPORTS_DIR = "neuronpedia_exports"
TRANSLUCE_EXPORTS_DIR = "transluce_exports"

### Keys and Tokens (Fill In)

In [None]:
GAI_KEY = "" # GEMINI
HF_TOKEN = "" # HF
# OPENAI (should be in env)

### Init

In [None]:
import os
import json
import random
import math
import re
import gc
import pickle
import requests
import time
import traceback
from dataclasses import dataclass, asdict
from typing import Dict, List, DefaultDict, Tuple
from json.decoder import JSONDecodeError
from functools import partial
from datetime import datetime

import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
from tqdm import tqdm
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import functools
from loguru import logger

# Hugging Face and Models
from transformers import (
    pipeline,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    AutoTokenizer,
)
from huggingface_hub import hf_hub_download, notebook_login, login
from openai import OpenAI, RateLimitError
import google.generativeai as gai
import datasets


# SAE and Transformer Lens
from sae_lens import SAE, HookedSAETransformer
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
from transformer_lens.utils import test_prompt, tokenize_and_concatenate

In [None]:
# trigger garbage collection
gc.collect()

# free all memory in use by cuda
torch.cuda.empty_cache()

# Get the current GPU memory usage
allocated_memory = torch.cuda.memory_allocated()
reserved_memory = torch.cuda.memory_reserved()

logger.debug(f"Allocated memory: {allocated_memory / (1024 ** 3):.2f} GB")
logger.debug(f"Reserved memory: {reserved_memory / (1024 ** 3):.2f} GB")

torch.set_grad_enabled(False) # avoid blowing up mem
if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"
logger.debug(f"Device: {device}")


def free(verbose=False):
    # trigger garbage collection
    gc.collect()

    # free all memory in use by cuda
    torch.cuda.empty_cache()

    # Get the current GPU memory usage
    if verbose:
        logger.info("Allocated", torch.cuda.memory_allocated())
        logger.info("Reserved", torch.cuda.memory_reserved())

In [None]:
if not os.path.exists(LOGS_DIR):
    os.mkdir(LOGS_DIR)
    
def trace_logger(log_filename="trace.log"):
    def decorator(func):
        def wrapper(*args, **kwargs):
            log_full_path = os.path.join(LOGS_DIR, log_filename)
            log_sink = None
            try:
                log_sink = logger.add(log_full_path, level="TRACE")
                result = func(*args, **kwargs)
                return result
            except Exception as e:
                logger.error(f"Exception in {func.__name__}: {e}")
                raise
            finally:
                if log_sink is not None:
                    logger.remove(log_sink)
        return wrapper
    return decorator

In [None]:
def pd_ext():
    pd.set_option('display.max_rows', None)  # Show all rows
    pd.set_option('display.max_columns', None)  # Show all columns
    pd.set_option('display.max_colwidth', None)  # No truncation of column content
    pd.set_option('display.expand_frame_repr', False)  # Avoid wrapping data in the notebook

def pd_reset():
    pd.reset_option("all")

In [None]:
def format_time(secs):
    if isinstance(secs, str):
        return secs
    minutes = int(secs // 60)
    seconds = int(secs % 60)
    return f'{minutes:02}:{seconds:02}'

def get_its_per_second_test(its_per_second):
    if its_per_second == 0:
        return "?it/s"
    if its_per_second < 1:
        return f"{1/its_per_second:.2f}s/it"
    return f"{its_per_second:.2f}it/s"
    

def _tqdm(iterable, total=None, width=20, fill='█', print_end='\r'):
    start = time.time()

    if total is None:
        total = len(iterable)

    five_count = 0
    last_five = []

    def print_bar(i, n, prev_time):
        now = time.time()
        time_since_start = now - start
        time_since_last = now - prev_time
        mean_diff = 0

        if prev_time == 0 or i == total:
            projected_time = 0
        else:
            projected_time_all = (time_since_start / i) * total - time_since_start

            if five_count >= 1:
                diffs = []
                last_five.append(now)
                for j in range(len(last_five)-1):
                    diffs.append(last_five[j+1] - last_five[j])

                last_five.pop(-1)

                mean_diff = sum(diffs) / len(diffs)
                projected_time_fives = mean_diff * (total - i)
            else:
                projected_time_fives = time_since_last * (total - i)

            projected_time = projected_time_all*0.7 + projected_time_fives*0.3

        print(f'\r{(i/total)*100:.0f}%|{"█" * n + " " * (width - n)}| {i}/{total} [{format_time(time_since_start)}<{format_time(projected_time)},  {get_its_per_second_test(i / time_since_start)}]', end=print_end, flush=True)

    print(" "*100, end=print_end, flush=True)
    print_bar(0, 0, 0)
    for i, item in enumerate(iterable):
        prev = time.time()

        last_five.append(prev)
        five_count += 1

        if five_count > 5:
            five_count -= 1
            last_five.pop(0)
        
        yield item
        print_bar(i+1, int((i + 1) / total * width), prev)

    print()


class Pbar:
    def __init__(self, total, width=20, fill='█', print_end='\r'):
        self.start = time.time()
        self.total = total
        self.width = width
        self.fill = fill
        self.print_end = print_end
        self.five_count = 0
        self.last_five = []
        self.count = 0
        self.desc = ""

        print(" "*100, end=self.print_end, flush=True)
        self.print_bar(0, 0, 0)

        self.prev = time.time()
        self.last_five.append(self.prev)

    def print_bar(self, i, n, prev_time):
        now = time.time()
        time_since_start = now - self.start
        time_since_last = now - prev_time
        mean_diff = 0

        if prev_time == 0 or i == self.total:
            projected_time = 0
        elif i > self.total:
            projected_time = "?"
        else:
            projected_time_all = (time_since_start / i) * self.total - time_since_start

            if self.five_count >= 1:
                diffs = []
                self.last_five.append(now)
                for j in range(len(self.last_five)-1):
                    diffs.append(self.last_five[j+1] - self.last_five[j])

                self.last_five.pop(-1)

                mean_diff = sum(diffs) / len(diffs)
                projected_time_fives = mean_diff * (self.total - i)
            else:
                projected_time_fives = time_since_last * (self.total - i)

            projected_time = projected_time_all*0.7 + projected_time_fives*0.3

        if self.desc:
            prefix = f"{self.desc}:  "
        else:
            prefix = ""

        print(f'\r{prefix}{(i/self.total)*100:.0f}%|{"█" * n + " " * (self.width - n)}| {i}/{self.total} [{format_time(time_since_start)}<{format_time(projected_time)},  {get_its_per_second_test(i / time_since_start)}]', end=self.print_end, flush=True)

    def update(self, n):
        self.count += n
        
        self.print_bar(self.count, min(int((self.count) / self.total * self.width), self.width), self.prev)

        self.prev = time.time()
        self.last_five.append(self.prev)
        self.five_count += 1

        if self.five_count > 5:
            self.five_count -= 1
            self.last_five.pop(0)

    def set_description(self, desc):
        self.desc = desc

def tqdm_(iterable=None, total=None):
    if iterable is None:
        return Pbar(total)
    return _tqdm(iterable, total)

### SAEs

In [None]:
def get_all_saes_df():
  # Each row is a "release" which has multiple SAEs which may have different configs / match different hook points in a model.
  all_saes_df = pd.DataFrame.from_records({k:v.__dict__ for k,v in get_pretrained_saes_directory().items()}).T
  all_saes_df.drop(columns=["expected_var_explained", "expected_l0",
                            "config_overrides", "conversion_func"], inplace=True)
  return all_saes_df

def get_feature_api(modelId, saeId, feature):
    """
    NEURONPEDIA GET /api/feature/export
    https://www.neuronpedia.org/api-doc#tag/features/GET/api/feature/{modelId}/{layer}/{index}
    e.g. https://www.neuronpedia.org/api/feature/gpt2-small/0-res-jb/14057
    """
    url = f"https://www.neuronpedia.org/api/feature/{modelId}/{saeId}/{feature}"
    response = requests.get(url)
    result = response.json()
    return result

### Feature & Activations

In [None]:
def size_to_int(size: str):
    """
    Converts an SAE size in string format to its integer value (e.g. "16k" -> 16384)
    """
    assert size[-1] == 'k', "Invalid size string format"
    num = int(size[:len(size) - 1])
    raw_value = num * (2 ** 10)
    true_power = round(math.log2(raw_value))
    return 2 ** true_power

def write_pkl_file(data, file_name):
    pickle.dump(data, open(file_name, "wb"))  

def read_pkl_file(file_name):
    with open(file_name, "rb") as f:
        data = pickle.load(f)  
    return data

In [None]:
@dataclass
class Activation:
    id: str
    token_values: List[Tuple[str, float]]
    max_value: float
    min_value: float
    max_value_token_index: int = None
    loss_values: List[float] = None
    
    def get_tokens(self) -> List[str]:
        tokens = [tv[0] for tv in self.token_values]
        return tokens

    def get_values(self) -> List[float]:
        values = [tv[1] for tv in self.token_values]
        return values
    
    def get_tokens_str(self) -> str:
        return "".join(self.get_tokens())
    
    def get_max_token_values(self, w: int = 5):
        start = max(self.max_value_token_index - w, 0)
        end = min(self.max_value_token_index + w, len(self.token_values))
        return self.token_values[start:end]
    
    def get_max_tokens_str(self, w: int = 5) -> str:
        # w = 5 is the "stacked", a larger value will be "snippet" (in neuronpedia's jargon)
        start = self.max_value_token_index - w
        end = self.max_value_token_index + w
        return "".join(self.get_tokens()[start:end])
    
    def __repr__(self):
        return str(self)
    
    def __str__(self):
        if self.max_value_token_index is not None:
            return f"Max={self.token_values[self.max_value_token_index]}, Sentence={self.get_max_token_values()}"
        return f"Tokens={self.get_tokens()}"
    
    def __hash__(self):
        return hash(str(self))

In [None]:
@dataclass
class Feature:
    model_id: str      # e.g. gemma-2-2b
    feature: int       # e.g. 1846
    layer: str         # e.g. 11
    type: str          # e.g. att, mlp, res
    activations: List[Activation] = None
    
    sae_id: str = None       # from saes_df, e.g. layer_11/width_16k/average_l0_80
    sae_release: str = None  # from saes_df, e.g. gemma-scope-2b-pt-mlp	  
    size: str = ""         # e.g. 16k, 65k
    
    def get_pedia_dashboard_url(self, np_sae_id) -> str:
        return f"https://www.neuronpedia.org/{self.model_id}/{np_sae_id}/{self.feature}"
    
    def get_size_int(self):
        if self.size != "":
            return size_to_int(self.size)
        return 0
    
    def get_max_activating_examples(self, k: int = 5) -> List[Activation]:
        unique_activations = list(set(self.activations))
        sorted_activations = sorted(unique_activations, key=lambda x: x.max_value, reverse=True)
        return sorted_activations[:k]
    
    def __str__(self):
        return f"Feature {self.type}-{self.size}/{self.layer}/{self.feature}"
    
    def __hash__(self):
        return hash(str(self))

#### SAE activations

In [None]:
def get_sae_export_dir(modelId: str, saeId: str) -> str:
    saeIdNoLayer = saeId.split('-', 1)[1]
    model_dir = f"{modelId}-{saeIdNoLayer}"
    return os.path.join(NEURONPEDIA_EXPORTS_DIR, model_dir, saeId)

def get_feature_export_json(model_export_dir: str, feature: int):
    files = os.listdir(model_export_dir)
    for file in files:
        # Extract the range from the filename
        try:
            start, end = map(int, file.rstrip('.json').split('-'))
            # Check if the feature is within the range
            if start <= feature < end:
                return file
        except ValueError:
            continue  # Skip files that don't match the expected format
    return None

def get_feature_json_index(json_file, feature: int) -> int:
    start, end = map(int, json_file.rstrip('.json').split('-'))
    return feature - start

def get_feature_json_data(json_path: str, feature: int):
    """
    Retrieved the data for feature 
    """
    json_file = os.path.basename(json_path)
    json_feature_index = get_feature_json_index(json_file, feature)
    
    # Load the JSON file
    with open(json_path, 'r') as f:
        data = json.load(f)
        
    feature_data = data[json_feature_index]
    logger.trace(f"Got data for feature {feature} which is at index {json_feature_index} in the json {json_file}")
    return feature_data
    
def get_export_data(modelId: str, saeId: str, feature: int) -> Dict:
    """
    Retrieved the data for the given feature from neuronpedia's exports
    """
    logger.trace(f"Getting export for {modelId}/{saeId}/{feature}")
    sae_export_dir = get_sae_export_dir(modelId, saeId)
    if not os.path.exists(sae_export_dir):
        logger.trace(f"No appropriate export dir found for {feature} in {sae_export_dir}")
        return {}
    json_file = get_feature_export_json(sae_export_dir, feature)
    if json_file is None:
        logger.trace(f"No appropriate json file found for {feature} in {sae_export_dir}")
        return {}
    
    json_path = os.path.join(sae_export_dir, json_file)
    data = get_feature_json_data(json_path, feature)
    return data

def json_to_activations(activations_data) -> List[Activation]:
    activations = []
    for a_dict in activations_data:
        tokens = a_dict['tokens']
        values = a_dict['values']
        token_values = list(zip(tokens, values))
        a = Activation('', token_values, a_dict['maxValue'], a_dict['minValue'], 
                       a_dict['maxValueTokenIndex'], a_dict['lossValues'])
        activations.append(a)
    logger.trace(f"Retrieved {len(activations)} activations")
    return activations
    
def get_activations_data(modelId: str, saeId: str, feature: int) -> List[Activation]:
    data = get_export_data(modelId, saeId, feature)
    if not data:
        return []
    activations = json_to_activations(data['activations'])
    return activations

def get_act_neuronpedia(modelId: str, saeId: str, feature: int) -> List[Activation]:
    """
    Returns the activations for the feature
    """
    # first try the explanations API (calling it or using the already cached data)
    result = get_activations_data(modelId, saeId, feature)
    if result == []:
        api_result = get_feature_api(modelId, saeId, feature)
        if api_result is None:
            return []

        result = json_to_activations(api_result['activations'])
        logger.trace(f"Got activations data from feature api")
    return result

#### Transluce activations

In [None]:
def load_act_transluce():
    pkl_path = os.path.join(TRANSLUCE_EXPORTS_DIR, "transluce_formatted_activations.pkl")
    with open(pkl_path, "rb") as f:
        data = pickle.load(f)
    return data
    

def get_act_transluce(layer: int, feature: int) -> List[Activation]:
    data = load_act_transluce()
    try:
        result = data[(layer, feature)]
        return result
    except:
        return []

### Models

In [None]:
@dataclass
class Model:
    """matches the modelId in neuronepedia API calls, used to load model using from_pretrained(), e.g. gemma-2-2b
    this is a valid model name in https://transformerlensorg.github.io/TransformerLens/generated/model_properties_table.html"""
    model_id: str           
    
    """a more friendly name for the model, e.g. for the same model_id (gpt2-small), 
    there are different available saes (v5, jb) and this helps us to distinguish them, a more unique name for the model"""
    model_name: str     
    
    with_sae: bool
    
    """used to search for the relevant releases in the general df of available SAEs
    e.g. gemma-scope-2b-pt-{}, gpt2-small-{}-{}-v5-{}"""    
    sae_release_prefix: str = None
    
    sae_sizes: List[str] = None   # e.g. [16k, 65k]    
    sae_types: List[str] = None   # e.g. [mlp, res, att]
    saes_df: pd.DataFrame = None
    m: HookedSAETransformer = None

    def load_model(self):
        """after loading the model, we update model_id - so we can use it properly in other places.
        for neuronpedia the model id is 'llama3.1-8b' and not 'meta-llama/Llama-3.1-8B' """
        if self.model_id == "llama3.1-8b":
            model_id = "meta-llama/Llama-3.1-8B"
        elif self.model_id == "llama3.1-8b-instruct":
            model_id = "meta-llama/Llama-3.1-8B-Instruct"
        else:
            model_id = self.model_id

        if self.m is None:
            self.m = HookedSAETransformer.from_pretrained(model_id, device=device)
            
    def is_gemma(self):
        return self.model_id.startswith("gemma")
    
    def is_llama(self):
        return self.model_name == "llama3.1-8b"
    
    def is_gpt_jb(self):
        return self.model_name == "gpt2-small-jb"
    
    def is_gpt_v5(self):
        return self.model_name == "gpt2-small-v5"
    
    def get_size_int(self, size: str):
        """special case: d_model (768) * 32 = d_sae (24576)
        this is done just to get a random feature in the appropriate range, Feature object stays with empty string type """
        return 24576 if self.is_gpt_jb() else size_to_int(size)
            
    def get_pile_dataset(self):
        dataset = datasets.load_dataset("NeelNanda/pile-10k", split="train")
        pile = tokenize_and_concatenate(dataset, self.m.tokenizer, streaming=False, max_length=32, column_name="text", add_bos_token=True, num_proc=4)
        pile = pile[:-1000]["tokens"]
        return pile
        
    def get_features_pkl_name(self):
        return f"features-sample-{self.model_name}.pkl" 
    
    def write_features(self, data, out_file: str = None):
        file_name = out_file if out_file is not None else self.get_features_pkl_name()
        logger.info(f"Writing {len(data)} features to {file_name}")
        write_pkl_file(data, file_name)
    
    def get_features(self, in_file: str = None):
        file_name = in_file if in_file is not None else self.get_features_pkl_name()
        logger.info(f"Getting features from {file_name}")
        return read_pkl_file(file_name)
    
    def get_specific_layers(self) ->  List:
        specific_layers = []
        for type in self.sae_types:
            for size in self.sae_sizes:
                layers = [*[(type, layer, size) for layer in range(self.m.cfg.n_layers)]]
                specific_layers.extend(layers)
        return specific_layers
    
    # --------------- SAE ---------------
    def extract_type(self, saes_df_row):
        for t in self.sae_types:
            if t in saes_df_row.release or t in saes_df_row.np_id:
                return t
        raise ValueError(f"Found no valid type ({self.sae_types}) in given row {saes_df_row.release}")

    def extract_size(self, saes_df_row):
        for s in self.sae_sizes:
            if s in saes_df_row.np_id:
                return s
        raise ValueError(f"Found no valid size ({self.sae_sizes}) in given row {saes_df_row.id}")
    
    def get_np_sae_id(self, layer: str, type: str, size: str) -> str:
        """
        Extract the saeId for neuronpedia (from the previously loaded saes df)
        """
        if self.saes_df is None:
            raise ValueError("saes df does not exist, please call get_saes_info_specific_layers()")    
        # e.g. for gemma: 5-gemmascope-att-16k, for gpt: 1-res_mid_128k-oai
        np_id = self.saes_df[(self.saes_df['layer'] == layer) &
                            (self.saes_df['width'] == size) &
                            ((self.saes_df['release'].str.contains(type)) | (self.saes_df['np_id'].str.contains(type)))]["np_id"]
        if np_id.empty:
            raise ValueError(f"requested sae_id {type}/{layer}/{size} does not exist in current saes_df! check types and sizes defined: {self.sae_types}, {self.sae_sizes}")    
        sae_id = np_id.iloc[0]
        sae_id = sae_id.removeprefix(f"{self.model_id}/")  # this is the format which can be inferred from neuronpedia's conventions
        return sae_id
    
    def enrich_sae_ids(self, sae_ids_df):
        if self.is_gemma():
            sae_ids_df['layer'] = sae_ids_df['id'].apply(lambda x: x.split('/')[0].split('_')[1])
            sae_ids_df['width'] = sae_ids_df['id'].apply(lambda x: x.split('/')[1].split('_')[1])
            sae_ids_df['average_l0'] = sae_ids_df['id'].apply(lambda x: x.split('/')[2])
        elif self.is_llama():
            sae_ids_df['layer'] = sae_ids_df['np_id'].apply(lambda x: x.split('/')[1].split('-')[0])
            sae_ids_df['width'] = sae_ids_df['np_id'].apply(lambda x: x.split('/')[1].split('-')[-1])
        elif self.is_gpt_v5():
            sae_ids_df['layer'] = sae_ids_df['id'].apply(lambda x: x.split('.')[1])
            sae_ids_df['width'] = sae_ids_df['release'].apply(lambda x: x.split('-')[-1])
        elif self.is_gpt_jb():
            sae_ids_df['layer'] = sae_ids_df['id'].apply(lambda x: x.split('.')[1])
            sae_ids_df['width'] = ""
        
        return sae_ids_df
    
    def get_sae_ids(self, all_saes_df):
        """
        Given a specific release, get its sae ids
        """
        # TODO: check using the "model" field - take all the SAEs where it matches self.model_id
        if self.is_gemma():
            # "canonical" is the sae used in neuronpedia (average_l0 close to 100 - https://huggingface.co/google/gemma-scope-2b-pt-mlp/blob/main/README.md)
            saes_map = all_saes_df[(all_saes_df['release'].str.contains(self.sae_release_prefix)) &
                                       (all_saes_df['release'].str.contains('canonical'))][["saes_map", "neuronpedia_id"]]
        elif self.is_gpt_v5():
            saes_map = all_saes_df[(all_saes_df['release'].str.contains(self.sae_release_prefix)) &
                                (all_saes_df['release'].str.contains("v5"))][["saes_map", "neuronpedia_id"]]
        elif self.is_gpt_jb() or self.is_llama():
            saes_map = all_saes_df[(all_saes_df['release'].str.contains(self.sae_release_prefix))][["saes_map", "neuronpedia_id"]]
            
        df = pd.DataFrame(saes_map)
        sae_ids_df = pd.DataFrame(columns=['id', 'release'])
        for release in df.index:
            ids = df.loc[release]['saes_map'].keys()
            np_ids = df.loc[release]['neuronpedia_id'].values()  # the neuronpedia sae_id for API usage
            temp_df = pd.DataFrame({'id': ids, 'np_id': np_ids, 'release': release})
            sae_ids_df = pd.concat([sae_ids_df, temp_df])
            
        # problem splitting "layer_n" later because the in this case the string is simply "embedding", so we omit these rows
        sae_ids_df = sae_ids_df[~(sae_ids_df['id'].str.contains('embedding'))]            
        sae_ids_df = self.enrich_sae_ids(sae_ids_df)  
        return sae_ids_df

    def get_saes_info_specific_layers(self, specific_layers):
        """
        Creating the main saes df for this model
        """
        all_saes_df = get_all_saes_df()
        sae_ids_df = self.get_sae_ids(all_saes_df)
        sae_ids = []
        for layer_type, layer_num, layer_width in specific_layers:
            sae_id_temp = sae_ids_df[(sae_ids_df['layer'] == str(layer_num)) & 
                                     (sae_ids_df['width'] == str(layer_width)) &
                                     # release - for gemma, gpt. np_id - for llama (because release doesn't contain full type name)
                                     ((sae_ids_df['release'].str.contains(layer_type)) | (sae_ids_df['np_id'].str.contains(layer_type)))]
            sae_id = sae_id_temp.iloc[0]['np_id']
            sae_ids.append(sae_id)
        # using np_id because apparently id can be sometimes not unique! (e.g. layer_11/width_16k/average_l0_80 exists both for ATT and RES!)
        sae_single_ids_df = sae_ids_df[sae_ids_df['np_id'].isin(sae_ids)]
        self.saes_df = sae_single_ids_df.reset_index(drop=True)
        
        
    def get_feature_sample_sae(self, amount):
        """
        :param amount: The amount of features from each SAE of the model to sample.
        """
        all_features_sample = []
        logger.info(f"Getting {amount} features per layer ({len(self.saes_df)}) for {self.model_name}")
        for i in tqdm_(range(len(self.saes_df))):
            x = self.saes_df.iloc[i]
            size = self.extract_size(x)
            type = self.extract_type(x)
            j = 0            
            attempts = 0
            while j < amount:
                attempts += 1
                if attempts > 100:
                    break
                
                size_int = self.get_size_int(size)
                rand_feature = torch.randint(size_int, (1,))[0].item()                
                np_sae_id = self.get_np_sae_id(x.layer, type, size)  
                j += 1
                
                new_f = Feature(self.model_id, rand_feature, x.layer, type, sae_id=x.id, sae_release=x.release, size=size)
                activations = get_act_neuronpedia(self.model_id, np_sae_id, new_f.feature)      
                if activations:
                    new_f = Feature(new_f.model_id, new_f.feature, new_f.layer, new_f.type, activations, new_f.sae_id, new_f.sae_release, new_f.size)

                logger.trace(f"Adding new feature {repr(new_f)}")
                all_features_sample.append(new_f)
                
            logger.trace(f"Added {j} features, from {type}.{x.layer}.{size} using {x.release} SAE")
    
    
    def get_feature_sample_no_sae(self, amount, type="mlp"):
        """
        :param amount: The amount of features from each layer of the model to sample.
        """
        all_features_sample = []
        logger.info(f"Getting {amount} features per layer ({self.m.cfg.n_layers}) for {self.model_name}")
        for layer in tqdm_(range(self.m.cfg.n_layers)):
            j = 0            
            sample = random.sample(range(self.m.cfg.d_mlp), amount)
            for rand_feature in sample:
                new_f = Feature(self.model_id, rand_feature, layer, type)
                activations = get_act_transluce(layer, rand_feature)
                if activations:
                    new_f = Feature(new_f.model_id, new_f.feature, new_f.layer, new_f.type, activations)

                logger.trace(f"Adding new feature {repr(new_f)}")
                all_features_sample.append(new_f)
                
            logger.trace(f"Added {j} features, from {type}.{layer}")
        
        return all_features_sample
    
    def get_feature_sample_transluce(self):
        all_features_sample = []
        # TODO: consider creating a new activations .pkl with 80 features per layer and not 40
        data = load_act_transluce()
        logger.info(f"Getting {len(data)} features for {self.model_name}")
        
        for (layer, feature) in data:
            new_f = Feature(self.model_id, feature, layer, "mlp", data[(layer, feature)])
            logger.trace(f"Adding new feature {repr(new_f)}")
            all_features_sample.append(new_f)
            
        return all_features_sample

    def write_feature_sample(self, write=False, out_file=None):
        """
        Exports a sample of features to a pickle file.
        :param write: Whether to actually write the data to the file (think carefully before writing).
        :param out_file: A custom output file path instead of the default one.
        """
        if self.with_sae:
            all_features_sample = self.get_feature_sample_sae(amount=40)
        else:
            all_features_sample = self.get_feature_sample_transluce()

        random.shuffle(all_features_sample)
        if write:
            self.write_features(all_features_sample, out_file=out_file)

        return all_features_sample

In [None]:
MODELS = {
    "GEMMA-2-2B":       Model("gemma-2-2b", "gemma-2-2b", True, "gemma-scope-2b-pt-", ["16k", "65k"], ["res", "mlp"]), 
    "GEMMA-2-9B":       Model("gemma-2-9b", "gemma-2-9b", True, "gemma-scope-9b-pt-", ["16k", "131k"], ["res", "mlp"]),  
    "LLAMA-3.1-8B":     Model("llama3.1-8b", "llama3.1-8b", True, "llama_scope_lx", ["32k"], ["res", "mlp"]),
    "GPT-2-SM-V5":      Model("gpt2-small", "gpt2-small-v5", True, "gpt2-small-", ["32k", "128k"], ["resid-mid", "resid-post", "mlp-out"]),
    "GPT-2-SM-JB":      Model("gpt2-small", "gpt2-small-jb", True, "gpt2-small-res-jb", [""], ["res"]),
    "LLAMA-3.1-8B-IN":  Model("llama3.1-8b-instruct", "llama3.1-8b-instruct", False),    
}

### Explainer

In [None]:
class Explainer:
    def __init__(self, remote: bool, default_remote_model="gpt-4o", weak_remote_model="gpt-4o-mini", default_local_model="meta-llama/Meta-Llama-3-70B-Instruct", gemini=False):
        self.remote = remote
        self.weak_remote_model = weak_remote_model
        self.gemini = gemini
        if remote:
            if gemini:
                gai.configure(api_key=GAI_KEY)
                self.remote_model = gai.GenerativeModel("models/gemini-1.5-pro-latest")

            else:
                self.client = OpenAI()
                self.remote_model = default_remote_model
        else:
            self.local_model = pipeline("text-generation", model=default_local_model, device_map="auto", max_length=10000)
            self.local_model.tokenizer.pad_token_id = self.local_model.tokenizer.eos_token_id # why?


    def __call__(self, prompts: List[Dict], weak=False):
        if self.remote:
            if self.gemini:
                prompt = ""
                for p in prompts:
                    prompt += p["content"] + "\n"

                return self.remote_model.generate_content(prompt).text

            else:
                assert weak, "Only weak mode is supported for OpenAI API"
                remote_model = self.remote_model if not weak else self.weak_remote_model
                completion = self.client.chat.completions.create(model=remote_model, messages=prompts)
                return completion.choices[0].message.content
        
        else:
            return self.local_model(prompts)[0]["generated_text"][1]["content"]


explainer = Explainer(remote=True)
gemini_explainer = Explainer(remote=True, gemini=True)

In [None]:
def get_description(sys_prompt, user_prompt, weak=True) -> str:
    explanation = explainer([
        {"role": "system", "content": sys_prompt},
        {"role": "user", "content": user_prompt}
    ], weak=weak)
    
    try:
        json_content = explanation.strip("json").strip('`').removeprefix('json\n').removesuffix('\n')
        j = json.loads(json_content)
        explanation = j["Explanation"]
    except:
        pass
    
    return explanation

## Description Generation

### MaxAct

In [None]:
@dataclass
class Example:
    activation_records: List[Activation]
    explanation: str
    
example1 = Example(
    activation_records=[
        Activation(
            id='',
            token_values=[
                (token, activation)
                for token, activation in zip(
                    [
                        "t", "urt", "ur", "ro", " is", " fab", "ulously", " funny",
                        " and", " over", " the", " top", " as", " a", " '", "very",
                        " sneaky", "'", " but", "ler", " who", " excel", "s", " in",
                        " the", " art", " of", " impossible", " disappearing", "/",
                        "re", "app", "earing", " acts"
                    ],
                    [
                        -0.71, -1.85, -2.39, -2.58, -1.34, -1.92, -1.69, -0.84,
                        -1.25, -1.75, -1.42, -1.47, -1.51, -0.8, -1.89, -1.56,
                        -1.63, 0.44, -1.87, -2.55, -2.09, -1.76, -1.33, -0.88,
                        -1.63, -2.39, -2.63, -0.99, 2.83, -1.11, -1.19, -1.33,
                        4.24, -1.51
                    ],
                )
            ],
            max_value=4.24,
            min_value=-2.63,
            max_value_token_index=32
        ),
        Activation(
            id='',
            token_values=[
                (token, activation)
                for token, activation in zip(
                    [
                        "esc", "aping", " the", " studio", " ,", " pic", "col",
                        "i", " is", " warm", "ly", " affecting", " and", " so",
                        " is", " this", " ad", "roit", "ly", " minimalist", " movie",
                        " ."
                    ],
                    [
                        -0.69, 4.12, 1.83, -2.28, -0.28, -0.79, -2.2, -2.03,
                        -1.77, -1.71, -2.44, 1.6, -1, -0.38, -1.93, -2.09,
                        -1.63, -1.94, -1.82, -1.64, -1.32, -1.92
                    ],
                )
            ],
            max_value=4.12,
            min_value=-2.44,
            max_value_token_index=1
        ),
    ],
    explanation="present tense verbs ending in 'ing'")
    
example2 = Example(
    activation_records=[
        Activation(
            id='',
            token_values=[
                ("as", -0.14),
                (" sac", -1.37),
                ("char", -0.68),
                ("ine", -2.27),
                (" movies", -1.46),
                (" go", -1.11),
                (" ,", -0.9),
                (" this", -2.48),
                (" is", -2.07),
                (" likely", -3.49),
                (" to", -2.16),
                (" cause", -1.79),
                (" massive", -0.23),
                (" cardiac", -0.04),
                (" arrest", 4.46),
                (" if", -1.02),
                (" taken", -2.26),
                (" in", -2.95),
                (" large", -1.49),
                (" doses", -1.46),
                (" .", -0.6),
            ],
            max_value=4.46,
            min_value=-3.49,
            max_value_token_index=14,
        ),
        Activation(
            id='',
            token_values=[
                ("shot", -0.09),
                (" perhaps", -3.53),
                ("'", -0.72),
                ("art", -2.36),
                ("istically", -1.05),
                ("'", -1.12),
                (" with", -2.49),
                ("handheld", -2.14),
                (" cameras", -1.98),
                (" and", -1.59),
                (" apparently", -2.62),
                (" no", -2),
                (" movie", -2.73),
                (" lights", -2.87),
                (" by", -3.23),
                (" jo", -1.11),
                ("aquin", -2.23),
                (" b", -0.97),
                ("aca", -2.28),
                ("-", -2.37),
                ("as", -1.5),
                ("ay", -2.81),
                (" ,", -1.73),
                (" the", -3.14),
                (" low", -2.61),
                ("-", -1.7),
                ("budget", -3.08),
                (" production", -4),
                (" swings", -0.71),
                (" annoy", -2.48),
                ("ingly", -1.39),
                (" between", -1.96),
                (" vert", -1.09),
                ("igo", 4.37),
                (" and", -0.74),
                (" opacity", -0.5),
                (" .", -0.62),
            ],
            max_value=4.37,
            min_value=-4,
            max_value_token_index=33,
        ),
    ],
    explanation="words related to physical medical conditions",
)

example3 = Example(
    activation_records=[
        Activation(
            id='',
            token_values=[
                ("the", 0),
                (" sense", 0),
                (" of", 0),
                (" together", 1),
                ("ness", 2),
                (" in", 0),
                (" our", 0.23),
                (" town", 0.5),
                (" is", 0),
                (" strong", 0),
                (" .", 0),
            ],
            max_value=2,
            min_value=0,
            max_value_token_index=4,
        ),
        Activation(
            id='',
            token_values=[
                ("a", -0.15),
                (" buoy", -2.33),
                ("ant", -1.4),
                (" romantic", -2.17),
                (" comedy", -2.53),
                (" about", -0.85),
                (" friendship", 0.23),
                (",", -1.89),
                (" love", 0.09),
                (",", -0.47),
                (" and", -0.5),
                (" the", -0.58),
                (" truth", -0.87),
                (" that", 0.22),
                (" we", 0.58),
                ("'re", 1.34),
                (" all", 0.98),
                (" in", 2.21),
                (" this", 2.84),
                (" together", 1.7),
                (" .", -0.89),
            ],
            max_value=2.84,
            min_value=-2.53,
            max_value_token_index=18,
        ),
    ],
    explanation="phrases related to community",
)

In [None]:
MAX_ACT_DESCRIPTION_PREFIX = "the main thing this neuron does is find"

def relu(x: float) -> float:
    return max(0.0, x)

def normalize_activations(activation_record: List[float], max_activation: float) -> List[int]:
    """Convert raw neuron activations to integers on the range [0, 10]."""
    if max_activation <= 0:
        return [0 for x in activation_record]
    # Relu is used to assume any values less than 0 are indicating the neuron is in the resting
    # state. This is a simplifying assumption that works with relu/gelu.
    return [min(10, math.floor(10 * relu(x) / max_activation)) for x in activation_record]

def format_activation_record(activation_record: Activation, omit_zeros: bool) -> str:
    tokens = activation_record.get_tokens()
    normalized_activations = normalize_activations(activation_record.get_values(), activation_record.max_value)
    if omit_zeros:
        tokens = [
            token for token, activation in zip(tokens, normalized_activations) if activation > 0
        ]
        normalized_activations = [x for x in normalized_activations if x > 0]
    entries = []
    assert len(tokens) == len(normalized_activations)
    for token, activation in zip(tokens, normalized_activations):
        activation_string = str(int(activation))
        entries.append(f"{token}\t{activation_string}")
    return "\n".join(entries)

def format_activation_records(activation_records: List[Activation], omit_zeros: bool) -> str:
    """Format a list of activation records into a string."""
    return (
        "\n<start>\n"
        + "\n<end>\n<start>\n".join(
            [
                format_activation_record(activation_record, omit_zeros=omit_zeros)
                for activation_record in activation_records
            ]
        )
        + "\n<end>\n"
    )
    
def non_zero_activation_proportion(activation_records: List[Activation]) -> float:
    """Return the proportion of activation values that aren't zero."""
    total_activations_count = sum(
        [len(activation_record.get_values()) for activation_record in activation_records]
    )
    normalized_activations = [
        normalize_activations(activation_record.get_values(), activation_record.max_value)
        for activation_record in activation_records
    ]
    non_zero_activations_count = sum(
        [len([x for x in activations if x != 0]) for activations in normalized_activations]
    )
    return non_zero_activations_count / total_activations_count

def add_per_neuron_explanation_prompt(
    activation_records: List[Activation],
    index: int,
    repeat_non_zero_activations: bool = True,
    numbered_list_of_n_explanations: int = None,
    explanation: str = ''):
    
    message = f"""Neuron {index + 1}
    Activations:{format_activation_records(activation_records, omit_zeros=False)}"""
    
    # We repeat the non-zero activations only if it was requested and if the proportion of
    # non-zero activations isn't too high.
    if (repeat_non_zero_activations) and (non_zero_activation_proportion(activation_records) < 0.2):
        message += (
            f"\nSame activations, but with all zeros filtered out:"
            f"{format_activation_records(activation_records, omit_zeros=True)}"
        )
        
    # When set, this indicates that the prompt should solicit a numbered list of the given
    # number of explanations, rather than a single explanation.
    if numbered_list_of_n_explanations is None:
        message += f"\nExplanation of neuron {index + 1} behavior:"
        message += f" {MAX_ACT_DESCRIPTION_PREFIX}"
    
    if explanation != '':
        message += f" {explanation}."
        
    return message

In [None]:
MAX_ACT_FEW_SHOT_EXAMPLES: List[Example] = [example1, example2, example3]

MAX_ACT_BASE_SYS_PROMPT = "We're studying neurons in a neural network. Each neuron looks for some particular " \
"thing in a short document. Look at the parts of the document the neuron activates for " \
"and summarize in a single sentence what the neuron is looking for. Don't list " \
"examples of words.\n\nThe activation format is token<tab>activation. Activation " \
"values range from 0 to 10. A neuron finding what it's looking for is represented by a " \
"non-zero activation value. The higher the activation value, the stronger the match.\n"

def generate_max_act_user_prompt(f: Feature, activating_examples: List[Activation] = None):
    # for transluce we currently provide the max activating as parameter
    max_activating = f.get_max_activating_examples() if activating_examples is None else activating_examples
    user_prompt = add_per_neuron_explanation_prompt(activation_records=max_activating, index=0)
    return user_prompt

def build_sys_prompt():
    prompt = MAX_ACT_BASE_SYS_PROMPT
    for i, few_shot_example in enumerate(MAX_ACT_FEW_SHOT_EXAMPLES):
        prompt += add_per_neuron_explanation_prompt(activation_records=few_shot_example.activation_records, index=i,
                                                    explanation=few_shot_example.explanation)
        prompt += "\n"
    return prompt

MAX_ACT_SYS_PROMPT = build_sys_prompt()

### VocabProj

In [None]:
VOCAB_PROJ_SYS_PROMPT = "You will be given a list of tokens related to a specific vector. These tokens represent a combination of embeddings that reconstruct the vector. Your task is to infer the most likely meaning or function of the vector based on these tokens. The list may include noise, such as unrelated terms, symbols, or programming jargon. Ignore whether the words are in multiple different languages, and do not mention it in your response. Focus on identifying a cohesive theme or concept shared by the most relevant tokens. Provide a specific sentence summarizing the meaning or function of the vector. Answer only with the summary. Avoid generic or overly broad answers, and disregard any noise in the list.\nVector 1\n    Tokens: ['contentLoaded', '▁hObject', ':✨', '▁AssemblyCulture', 'ContentAsync', '▁ivelany', '▁nahilalakip', 'IUrlHelper', '▁تضيفلها', '▁ErrIntOverflow'] ['▁could', 'could', '▁Could', 'Could', '▁COULD', '▁podría', '▁könnte', '▁podrían', '▁poderia', '▁könnten']\nExplanation of vector 1 behavior: this vector is related to the word could.\nVector 2\n    Tokens: ['▁CreateTagHelper', '▁ldc', 'PropertyChanging', '▁jsPsych', 'ulement', '▁IBOutlet', '▁wireType', '▁initComponents', '▁متعلقه', 'Бахар'] ['▁مشين', '▁charity', '▁donation', '▁charitable', '▁volont', '▁donations', 'iNdEx', 'Parcelize', 'DatabaseError', 'BufferException']\nExplanation of vector 2 behavior: this vector is related to charity and donations.\nVector 3\n    Tokens: ['▁tomorrow', '▁tonight', '▁yesterday', '▁today', 'yesterday', 'tomorrow', '▁demain', '▁Tomorrow', 'Tomorrow', '▁Yesterday'] ['▁Wex', 'ကိုးကား', 'Ārējās', 'piecze', ')$/,', '▁außer', '[]=$', 'cendental', 'ɜ', 'aderie']\nExplanation of vector 3 behavior: this vector is related to specific dates, like tomorrow, tonight and yesterday.\n\n"
VOACB_PROJ_USER_PROMPT = "Vector 4\n    Tokens: {0}\nExplanation of vector 4 behavior: this vector is related to"

def get_projection_data(m: Model, f: Feature, sae_w=None, encode=False, embed=False, k=50):
    # project properly and get logits
    if sae_w is None:
        logits = m.m.unembed(m.m.ln_final(m.m.blocks[f.layer].mlp.W_out[f.feature]))
    else:
        feature_vector = sae_w[:, f.feature] if encode else sae_w[f.feature]
        logits = (feature_vector @ m.m.embed.W_E.T) if embed else m.m.unembed(m.m.ln_final(feature_vector))
        
    topk = logits.topk(k)
    bottomk = logits.topk(k, largest=False)
    abs_topk = logits.abs().topk(k * 2)
    top_tokens = m.m.to_str_tokens(topk.indices)
    bottom_tokens = m.m.to_str_tokens(bottomk.indices)
    top_abs_tokens = m.m.to_str_tokens(abs_topk.indices)
    return (logits, topk, bottomk, abs_topk), (top_tokens, bottom_tokens, top_abs_tokens)

def get_dec_unembed(m: Model, f: Feature, sae=None):
    # sometimes these types don't match (e.g. llama's sae_w_dec is 'torch.bfloat16' while the model is 'torch.float32')
    sae_w_dec = None
    if sae is not None:
        sae_w_dec = sae.W_dec
        if sae.W_dec.dtype != m.m.W_U.dtype:
            sae_w_dec = sae.W_dec.to(m.m.W_U.dtype)
    
    # the tokens we actually use - 'value' vectors of SAEs projected using the unembedding matrix #
    _, tokens = get_projection_data(m, f, sae_w_dec)
    (dec_top_tokens, dec_bottom_tokens, dec_top_abs_tokens) = tokens
    
    # take the dec unembed data
    return (dec_top_tokens, dec_bottom_tokens)

### TokenChange

In [None]:
TOKEN_CHANGE_SYS_PROMPT = "You will be given a list of tokens related to a feature in an LLM. These tokens are the ones whose probabilities changed most when amplifying the feature. Your task is to infer the most likely meaning or function of the feature based on these tokens. The list may include noise, such as unrelated terms, symbols, or programming jargon. Provide a specific sentence summarizing the meaning or function of the feature. Answer only with the summary. Avoid generic or overly broad answers."

def generate_token_change_user_prompt(real):
    return str(list(set(real)))

In [None]:
def set_feature_act_hook(act, hook, feature, value):
    act[:,:,feature] = value

def get_intervention_tokens(model: HookedSAETransformer, prompts, f: Feature, value=200, sae=None):
    if sae is None:  # transluce
        clean_logits = model(prompts)
        pos_inter_logits = model.run_with_hooks(prompts, fwd_hooks=[(f"blocks.{f.layer}.mlp.hook_post", 
                                                                     functools.partial(set_feature_act_hook, feature=f.feature, value=value))])
        neg_inter_logits = model.run_with_hooks(prompts, fwd_hooks=[(f"blocks.{f.layer}.mlp.hook_post", 
                                                                     functools.partial(set_feature_act_hook, feature=f.feature, value=-value))])
    else:
        clean_logits = model.run_with_saes(prompts, saes=[sae])
        pos_inter_logits = model.run_with_hooks_with_saes(prompts, saes=[sae], fwd_hooks=[(f"{sae.cfg.hook_name}.hook_sae_acts_post", 
                                                                                       functools.partial(set_feature_act_hook, feature=f.feature, value=value))])
        neg_inter_logits = model.run_with_hooks_with_saes(prompts, saes=[sae], fwd_hooks=[(f"{sae.cfg.hook_name}.hook_sae_acts_post", 
                                                                                       functools.partial(set_feature_act_hook, feature=f.feature, value=-value))])
        
    pos_diff_logits = (pos_inter_logits - clean_logits).mean(dim=0).mean(dim=0)
    neg_diff_logits = (neg_inter_logits - clean_logits).mean(dim=0).mean(dim=0)

    neg_toks = model.to_str_tokens(pos_diff_logits.topk(10).indices) + model.to_str_tokens(neg_diff_logits.topk(10, largest=False).indices)
    pos_toks = model.to_str_tokens(pos_diff_logits.topk(10, largest=False).indices) + model.to_str_tokens(neg_diff_logits.topk(10).indices)

    return pos_toks + neg_toks

def get_causal_data(m: Model, f: Feature, sae=None):
    pile = m.get_pile_dataset()
    pile_sample = pile[torch.randint(0, len(pile), (32,))].to(device)
    real = get_intervention_tokens(m.m, pile_sample, f, value=10, sae=sae)
    return real, None, None, None, None

### Ensemble

In [None]:
ENSEMBLE_RAW_VM_SYS_PROMPT = """We're studying neurons in a neural network. Each neuron has certain inputs that activate it and outputs that it leads to. You will receive two pieces of information about a neuron: the activations it has for certain inputs, the words its output is most associated with. These will be separated into two sections [INPUT] and [OUTPUT].

The [INPUT] activation format is token<tab>activation. Activation values range from 0 to 10. A neuron finding what it's looking for is represented by a non-zero activation value. The higher the activation value, the stronger the match.

The [OUTPUT] format is a list of words related to that specific neuron. These tokens represent a combination of embeddings that reconstruct the vector. You can infer the most likely output or function of the neuron based on these tokens. The list may include noise, such as unrelated terms, symbols, or programming jargon. Ignore whether the words are in multiple different languages, and do not mention it in your response. Focus on identifying a cohesive theme or concept shared by the most relevant tokens.

Your response should be a concise (1-2 sentence) explanation of the neuron, encompassing what triggers it (input) and what it does once triggered (output). If the two sides relate to one another you may include that in your explanation, otherwise simply state the input and output.

Neuron 1

[INPUT]
    Activations:
<start>
t	0
urt	0
ur	0
ro	0
 is	0
 fab	0
ulously	0
 funny	0
 and	0
 over	0
 the	0
 top	0
 as	0
 a	0
 '	0
very	0
 sneaky	0
'	1
 but	0
ler	0
 who	0
 excel	0
s	0
 in	0
 the	0
 art	0
 of	0
 impossible	0
 disappearing	6
/	0
re	0
app	0
earing	10
 acts	0
<end>
<start>
esc	0
aping	10
 the	4
 studio	0
 ,	0
 pic	0
col	0
i	0
 is	0
 warm	0
ly	0
 affecting	3
 and	0
 so	0
 is	0
 this	0
 ad	0
roit	0
ly	0
 minimalist	0
 movie	0
 .	0
<end>

Same activations, but with all zeros filtered out:
<start>
'	1
 disappearing	6
earing	10
<end>
<start>
aping	10
 the	4
 affecting	3
<end>

[OUTPUT]
['to', 'To', 'TO', 'Towards', 'towards', 'TOWARDS', 'toward', 'Toward', 'TOWARD', 'toward', 'Toward', 'TOWARD', 'life', 'do', 'fdsa', 'aaaaaa', 'aaaaa', 'aaaa', 'aaa', 'aa', 'a', 'A']

Explanation of neuron 1 behavior: the main thing this neuron does is find present tense verbs ending in 'ing', and then outputs words related to directionality or movement to or towards something.

Neuron 2

[INPUT]
    Activations:
<start>
as	0
 sac	0
char	0
ine	0
 movies	0
 go	0
 ,	0
 this	0
 is	0
 likely	0
 to	0
 cause	0
 massive	0
 cardiac	0
 arrest	10
 if	0
 taken	0
 in	0
 large	0
 doses	0
 .	0
<end>
<start>
shot	0
 perhaps	0
'	0
art	0
istically	0
'	0
 with	0
handheld	0
 cameras	0
 and	0
 apparently	0
 no	0
 movie	0
 lights	0
 by	0
 jo	0
aquin	0
 b	0
aca	0
-	0
as	0
ay	0
 ,	0
 the	0
 low	0
-	0
budget	0
 production	0
 swings	0
 annoy	0
ingly	0
 between	0
 vert	0
igo	10
 and	0
 opacity	0
 .	0
<end>

Same activations, but with all zeros filtered out:
<start>
 arrest	10
<end>
<start>
igo	10
<end>

[OUTPUT]
['1111', 'Evol', 'crab', 'sing', 'dance', 'walk', 'run', 'jump', 'swim', 'climb', 'death', 'Death', 'DEATH', 'dying', 'Dying', 'DYING', 'die', 'DIED', 'kill']

Explanation of neuron 2 behavior: the main thing this neuron does is find words related to physical medical conditions, and then outputs words related to death or dying.

Neuron 3

[INPUT]
    Activations:
<start>
the	0
 sense	0
 of	0
 together	5
ness	10
 in	0
 our	1
 town	2
 is	0
 strong	0
 .	0
<end>
<start>
a	0
 buoy	0
ant	0
 romantic	0
 comedy	0
 about	0
 friendship	0
,	0
 love	0
,	0
 and	0
 the	0
 truth	0
 that	0
 we	2
're	4
 all	3
 in	7
 this	10
 together	5
 .	0
<end>

[OUTPUT]
['community', 'commune', 'communal', 'Community', 'family', 'Family', 'Together', 'ball', 'street', 'efu', 'jefus']

Explanation of neuron 3 behavior: the main thing this neuron does is find phrases related to community, and then outputs words related to togetherness or family.
"""

ENSEMBLE_RAW_TM_SYS_PROMPT = """We're studying neurons in a neural network. Each neuron has certain inputs that activate it and outputs that it leads to. You will receive two pieces of information about a neuron: the activations it has for certain inputs, the words its output is most associated with. These will be separated into two sections [INPUT] and [OUTPUT].

The [INPUT] activation format is token<tab>activation. Activation values range from 0 to 10. A neuron finding what it's looking for is represented by a non-zero activation value. The higher the activation value, the stronger the match.

The [OUTPUT] format is a list of the words whose probabilities changed most when amplifying that specific neuron. The list may include noise, such as unrelated terms, symbols, or programming jargon. Focus on identifying a cohesive theme or concept shared by the most relevant tokens.

Your response should be a concise (1-2 sentence) explanation of the neuron, encompassing what triggers it (input) and what it does once triggered (output). If the two sides relate to one another you may include that in your explanation, otherwise simply state the input and output.

Neuron 1

[INPUT]
    Activations:
<start>
t	0
urt	0
ur	0
ro	0
 is	0
 fab	0
ulously	0
 funny	0
 and	0
 over	0
 the	0
 top	0
 as	0
 a	0
 '	0
very	0
 sneaky	0
'	1
 but	0
ler	0
 who	0
 excel	0
s	0
 in	0
 the	0
 art	0
 of	0
 impossible	0
 disappearing	6
/	0
re	0
app	0
earing	10
 acts	0
<end>
<start>
esc	0
aping	10
 the	4
 studio	0
 ,	0
 pic	0
col	0
i	0
 is	0
 warm	0
ly	0
 affecting	3
 and	0
 so	0
 is	0
 this	0
 ad	0
roit	0
ly	0
 minimalist	0
 movie	0
 .	0
<end>

Same activations, but with all zeros filtered out:
<start>
'	1
 disappearing	6
earing	10
<end>
<start>
aping	10
 the	4
 affecting	3
<end>

[OUTPUT]
['to', 'To', 'TO', 'Towards', 'towards', 'TOWARDS', 'toward', 'Toward', 'TOWARD', 'toward', 'Toward', 'TOWARD', 'life', 'do', 'fdsa', 'aaaaaa', 'aaaaa', 'aaaa', 'aaa', 'aa', 'a', 'A']

Explanation of neuron 1 behavior: the main thing this neuron does is find present tense verbs ending in 'ing', and then outputs words related to directionality or movement to or towards something.

Neuron 2

[INPUT]
    Activations:
<start>
as	0
 sac	0
char	0
ine	0
 movies	0
 go	0
 ,	0
 this	0
 is	0
 likely	0
 to	0
 cause	0
 massive	0
 cardiac	0
 arrest	10
 if	0
 taken	0
 in	0
 large	0
 doses	0
 .	0
<end>
<start>
shot	0
 perhaps	0
'	0
art	0
istically	0
'	0
 with	0
handheld	0
 cameras	0
 and	0
 apparently	0
 no	0
 movie	0
 lights	0
 by	0
 jo	0
aquin	0
 b	0
aca	0
-	0
as	0
ay	0
 ,	0
 the	0
 low	0
-	0
budget	0
 production	0
 swings	0
 annoy	0
ingly	0
 between	0
 vert	0
igo	10
 and	0
 opacity	0
 .	0
<end>

Same activations, but with all zeros filtered out:
<start>
 arrest	10
<end>
<start>
igo	10
<end>

[OUTPUT]
['1111', 'Evol', 'crab', 'sing', 'dance', 'walk', 'run', 'jump', 'swim', 'climb', 'death', 'Death', 'DEATH', 'dying', 'Dying', 'DYING', 'die', 'DIED', 'kill']

Explanation of neuron 2 behavior: the main thing this neuron does is find words related to physical medical conditions, and then outputs words related to death or dying.

Neuron 3

[INPUT]
    Activations:
<start>
the	0
 sense	0
 of	0
 together	5
ness	10
 in	0
 our	1
 town	2
 is	0
 strong	0
 .	0
<end>
<start>
a	0
 buoy	0
ant	0
 romantic	0
 comedy	0
 about	0
 friendship	0
,	0
 love	0
,	0
 and	0
 the	0
 truth	0
 that	0
 we	2
're	4
 all	3
 in	7
 this	10
 together	5
 .	0
<end>

[OUTPUT]
['community', 'commune', 'communal', 'Community', 'family', 'Family', 'Together', 'ball', 'street', 'efu', 'jefus']

Explanation of neuron 3 behavior: the main thing this neuron does is find phrases related to community, and then outputs words related to togetherness or family.
"""

ENSEMBLE_RAW_ALL_SYS_PROMPT = """We're studying neurons in a neural network. Each neuron has certain inputs that activate it and outputs that it leads to. You will receive two pieces of information about a neuron: the activations it has for certain inputs, the words its output is most associated with. These will be separated into two sections [INPUT] and [OUTPUT].

The [INPUT] activation format is token<tab>activation. Activation values range from 0 to 10. A neuron finding what it's looking for is represented by a non-zero activation value. The higher the activation value, the stronger the match.

The [OUTPUT] format is a list of words related to that specific neuron. These tokens are in two lists, one represents a combination of embeddings that reconstruct the vector, and the other is the tokens most affected by amplifying that neuron. You can infer the most likely output or function of the neuron based on these tokens. The list may include noise, such as unrelated terms, symbols, or programming jargon. Ignore whether the words are in multiple different languages, and do not mention it in your response. Focus on identifying a cohesive theme or concept shared by the most relevant tokens.

Your response should be a concise (1-2 sentence) explanation of the neuron, encompassing what triggers it (input) and what it does once triggered (output). If the two sides relate to one another you may include that in your explanation, otherwise simply state the input and output.

Neuron 1

[INPUT]
    Activations:
<start>
t	0
urt	0
ur	0
ro	0
 is	0
 fab	0
ulously	0
 funny	0
 and	0
 over	0
 the	0
 top	0
 as	0
 a	0
 '	0
very	0
 sneaky	0
'	1
 but	0
ler	0
 who	0
 excel	0
s	0
 in	0
 the	0
 art	0
 of	0
 impossible	0
 disappearing	6
/	0
re	0
app	0
earing	10
 acts	0
<end>
<start>
esc	0
aping	10
 the	4
 studio	0
 ,	0
 pic	0
col	0
i	0
 is	0
 warm	0
ly	0
 affecting	3
 and	0
 so	0
 is	0
 this	0
 ad	0
roit	0
ly	0
 minimalist	0
 movie	0
 .	0
<end>

Same activations, but with all zeros filtered out:
<start>
'	1
 disappearing	6
earing	10
<end>
<start>
aping	10
 the	4
 affecting	3
<end>

[OUTPUT]
['to', 'To', 'TO', 'Towards', 'towards', 'TOWARDS', 'toward', 'Toward', 'TOWARD', 'toward', 'Toward', 'TOWARD', 'life', 'do', 'fdsa', 'aaaaaa', 'aaaaa', 'aaaa', 'aaa', 'aa', 'a', 'A']
['thought', 'think', 'Think', 'to', 'towards', 'Towar', 'towards', 'toward', 'Toward', 'TOWARD', 'life', 'do', 'fdsa', 'aaaaaa', 'aaaaa', 'aaaa', 'aaa', 'aa', 'a', 'A']

Explanation of neuron 1 behavior: the main thing this neuron does is find present tense verbs ending in 'ing', and then outputs words related to directionality, movement or thinking.

Neuron 2

[INPUT]
    Activations:
<start>
as	0
 sac	0
char	0
ine	0
 movies	0
 go	0
 ,	0
 this	0
 is	0
 likely	0
 to	0
 cause	0
 massive	0
 cardiac	0
 arrest	10
 if	0
 taken	0
 in	0
 large	0
 doses	0
 .	0
<end>
<start>
shot	0
 perhaps	0
'	0
art	0
istically	0
'	0
 with	0
handheld	0
 cameras	0
 and	0
 apparently	0
 no	0
 movie	0
 lights	0
 by	0
 jo	0
aquin	0
 b	0
aca	0
-	0
as	0
ay	0
 ,	0
 the	0
 low	0
-	0
budget	0
 production	0
 swings	0
 annoy	0
ingly	0
 between	0
 vert	0
igo	10
 and	0
 opacity	0
 .	0
<end>

Same activations, but with all zeros filtered out:
<start>
 arrest	10
<end>
<start>
igo	10
<end>

[OUTPUT]
['1111', 'Evol', 'crab', 'sing', 'dance', 'walk', 'run', 'jump', 'swim', 'climb', 'death', 'Death', 'DEATH', 'dying', 'Dying', 'DYING', 'die', 'DIED', 'kill']
['die', 'morgue', 'Die', 'murder', 'dance', 'Dancing', 'Dancer', 'walk', 'run', 'jump']

Explanation of neuron 2 behavior: the main thing this neuron does is find words related to physical medical conditions, and then outputs words related to death or dying, as well as physical movement.

Neuron 3

[INPUT]
    Activations:
<start>
the	0
 sense	0
 of	0
 together	5
ness	10
 in	0
 our	1
 town	2
 is	0
 strong	0
 .	0
<end>
<start>
a	0
 buoy	0
ant	0
 romantic	0
 comedy	0
 about	0
 friendship	0
,	0
 love	0
,	0
 and	0
 the	0
 truth	0
 that	0
 we	2
're	4
 all	3
 in	7
 this	10
 together	5
 .	0
<end>

[OUTPUT]
['community', 'commune', 'communal', 'Community', 'family', 'Family', 'Together', 'ball', 'street', 'efu', 'jefus']
['do', 'I', 'what', go', 'commune', 'communal', 'Community', 'family', 'Family', 'Together', 'ball', 'street', 'efu', 'jefus']

Explanation of neuron 3 behavior: the main thing this neuron does is find phrases related to community, and then outputs words related to togetherness or family.
"""

In [None]:
ENSEMBLE_IO_USER_PROMPT = """Neuron 4
[INPUT]
{input_data}

[OUTPUT]
{output_data}

Explanation of neuron 4 behavior:"""

ENSEMBLE_IOO_USER_PROMPT = """Neuron 4

[INPUT]
{input_data}

[OUTPUT]
{output_data1}
{output_data2}

Explanation of neuron 4 behavior:"""

## Description Evaluation

### Input Metric

In [None]:
GEN_LISTS_PROMPT = "I'm going to give you explanations and interpretations of features from LLMs. You must take in each expalantion, and generate 5 sentences for which you think the feature will have a high activation, and 5 for which they'll have a low activation. For the high activation, make sure to choose ones that will cause a high activation with high confidence - you don't have to include all groups, just make examples that you're confident will have high activation. Make the sentences both include the words from the explanation, and represent the concept. Try to use specific examples, and make them literal interpretations of the explanation, without trying to generalize. Low activation sentences should have nothing to do with the interpretation - i.e. they should by orthogonal and completely unrelated. Please output the response in json format with a 'positive' key and a 'negative' key. Output only the json and no other explanation. Make sure the json is formatted correctly - do not include any '`' backtick characters characters, i.e. do not format as code, just return the json text. The explanations should be five and five overall, not per line.\n\n{explanation}\n"

FIX_JSON_PROMPT = """Please fix this json that is not formatted correctly. Write only the fixed json. 
DO NOT write anything but the json. No comments. Only the json itself.\n\n{json}\n"""

In [None]:
def get_lists(explanation):
    """
    Generates 5 positive and negative activating examples based on our explanation
    """
    return gemini_explainer([{"role": "user", "content": GEN_LISTS_PROMPT.format(explanation=explanation)}])

def get_pos_neg(lists):
    try:
        json_content = lists.strip().strip("json").strip('`').removeprefix('json\n').removesuffix('\n').strip().strip("json").strip("`").strip('`').removeprefix('json\n').removesuffix('\n')
        j = json.loads(json_content)
        return j["positive"], j["negative"]
    except:
        pass

    for top in range(3):
        for bottom in range(2):
            bottom = None if bottom == 0 else -bottom
            new_lists = lists.splitlines()[top:bottom]

            try:
                j = json.loads(new_lists)
                return j["positive"], j["negative"]
            except:
                pass

    lists = gemini_explainer([{"role": "user", "content": FIX_JSON_PROMPT.format(json=lists)}])
    try:
        json_content = lists.strip().strip("json").strip('`').removeprefix('json\n').removesuffix('\n').strip().strip("json").strip("`").strip('`').removeprefix('json\n').removesuffix('\n')
        j = json.loads(json_content)
        return j["positive"], j["negative"]
    except:
        logger.error(lists)
        raise

In [None]:
def get_pos_neg_acts(model: HookedSAETransformer, pos, neg, f: Feature, pre_relu=False, sae=None):
    if sae is None:  # transluce
        pos_cache = model.run_with_cache(pos, return_type=None)[1]
        neg_cache = model.run_with_cache(neg, return_type=None)[1]
    else:
        pos_cache = model.run_with_cache_with_saes(pos, saes=[sae], return_type=None)[1]
        neg_cache = model.run_with_cache_with_saes(neg, saes=[sae], return_type=None)[1]
    
    relu = "pre" if pre_relu else "post"

    # we take the maximal activation of the feature across all sentences across all tokens
    # we prefer that because we can generate multiple exps per feature and then take the best activations across all exps
    if sae is None:  # transluce
        block = f"blocks.{f.layer}.mlp.hook_post"
    else:
        block = f"{sae.cfg.hook_name}.hook_sae_acts_{relu}"
        
    pos_act_max_all = pos_cache[block][:, :, f.feature].max().item()
    neg_act_max_all = neg_cache[block][:, :, f.feature].max().item()
    pos_act_max_toks = pos_cache[block][:, :, f.feature].max(dim=-1).values.mean().item()
    neg_act_max_toks = neg_cache[block][:, :, f.feature].max(dim=-1).values.mean().item()

    return pos_act_max_all, neg_act_max_all, pos_act_max_toks, neg_act_max_toks

In [None]:
@dataclass
class InputScore:
    pos_act_all: float
    neg_act_all: float
    pos_act_toks: float
    neg_act_toks: float
    pos_list: float
    neg_list: float
    
    def success(self) -> bool:
        # the input metric final test for success
        return self.pos_act_toks > self.neg_act_toks
    
    @classmethod
    def from_row(cls, row):
        return cls(**vars(row))
    

def get_input_score(description: str, m: Model, f: Feature, sae=None) -> InputScore:
    lists = get_lists(description)    
    pos_list, neg_list = get_pos_neg(lists)
    pos_act_all, neg_act_all, pos_act_toks, neg_act_toks = get_pos_neg_acts(m.m, pos_list, neg_list, f, sae=sae)
    return InputScore(pos_act_all, neg_act_all, pos_act_toks, neg_act_toks, pos_list, neg_list)

### Output Metric

In [None]:
OUTPUT_METRIC_GENERATION_PROMPTS = ["The explanation is simple:", "I think", "We"]

OUTPUT_METRIC_SYS_PROMPT = """You are analyzing the behavior of a specific neuron in a language model. You will receive:

1. A hypothesized explanation for what concept the neuron represents (e.g., specific tokens, themes, or ideas).
2. Three sets of completions, one generated by amplifying the activation of the neuron in question, and one of a random neuron across the same prompts.

Your goal is to identify which set of completions is more likely the result of amplifying the neuron in question. To do this:
- Look for completions where the **literal words** or the **ideas/themes** described in the explanation occur more frequently or with greater emphasis.
- Remember that amplification may highlight specific words or their broader contextual meanings, meaning that a lot of the times they might be very noisy, but contain keywords that appear in the explanation.
- Your answer should be based on the **content** of the completions, not the quality of the language model's output.
- Your reasoning should be sound, don't make overly elaborate and far-fetched connections.

The first line in your response should be a brief explanation of your choice - what made you choose that set of completions.
The second line must be only the set number you think matches the description (i.e., 1, 2 or 3) and no other text. You must pick one of the three sets.
"""

STEERING_AMP_PROMPT_TEMPLATE = """<{amplification}> {completions}\n"""

STEERING_FULL_PROMPT_TEMPLATE = """Explanation: {explanation}

# Set 1
{amplifications1}

# Set 2
{amplifications2}

# Set 3
{amplifications3}
"""

In [None]:
def set_feature_act_kl_hook(act, hook, feature: int, value):
    act[:,:,feature] = value

def kl_div(p, q, eps=1e-10):
    """
    Get the KL Divergence between p and q
    """
    p = p.clamp(min=eps)
    q = q.clamp(min=eps)
    return torch.sum(p * (torch.log(p) - torch.log(q)), dim=-1)

def get_kl_div(model: Model, prompts, f: Feature, value, sae=None):
    """
    Get KL Divergence between clean and hooked activations when clamping the feature to value, 
    averaged over all of the prompts
    """
    toks = model.m.to_tokens(prompts)

    if sae is None:
        clean_probs = model.m(toks)
        hooked_probs = model.m.run_with_hooks(toks, fwd_hooks=[(f"blocks.{f.layer}.mlp.hook_pre", 
                                                              functools.partial(set_feature_act_kl_hook, feature=f.feature, value=value))]).softmax(dim=-1)
    else:
        clean_probs = model.m.run_with_saes(toks, saes=[sae])
        hooked_probs = model.m.run_with_hooks_with_saes(toks, saes=[sae], fwd_hooks=[(f"{sae.cfg.hook_name}.hook_sae_acts_post", 
                                                                                  functools.partial(set_feature_act_kl_hook, feature=f.feature, value=value))]).softmax(dim=-1)
    clean_probs = clean_probs.softmax(dim=-1)

    # Remove logits which are padding tokens
    clean_probs[toks == 0] = 0
    hooked_probs[toks == 0] = 0

    # return hooked_probs, clean_probs
    kl = kl_div(clean_probs, hooked_probs)
    means = []
    for row in kl:
        means.append(row[row != 0].mean().item())

    return np.mean(means)


def get_activation_for_kl(model: Model, prompts, f: Feature, target_kl, high_thresh=0.1, neg=False, verbose=False, sae=None):
    """
    Find the activation value we need for the desired target KL Divergence value
    """
    # Do binary search between 0 and 1000
    low, high = (-1000, -1) if neg else (1, 1000)
    kl = -1
    mid = 0
    while (low+1 < high) and (kl < target_kl or kl > target_kl + high_thresh):
        mid = (low + high) // 2
        kl = get_kl_div(model, prompts, f, mid, sae=sae)

        if (neg and kl < target_kl) or (not neg and kl > target_kl):
            high = mid
        else:
            low = mid

        if verbose:
            print(f"Low: {low}, High: {high}, Mid: {mid}, KL: {kl}, Target KL: {target_kl}")

    return mid

In [None]:
def gen_hook(clean_act, hook, feature: int, value, sae=None):
    """
    Manually steers the value inside an SAE activation using the basic activation
    :param clean_act: the basic activation before the SAE
    """
    if sae is None:  # transluce
        clean_act[:,:,feature] = value
        return clean_act
    
    encoded_act = sae.encode(clean_act)
    dirty_act = sae.decode(encoded_act)
    error_term = clean_act - dirty_act

    encoded_act[:,:,feature] = value
    hooked_act = sae.decode(encoded_act) + error_term

    return hooked_act
    
def hooked_gen(prompt, model: Model, f: Feature, n=25, value=10, verbose=False, temperature=1, sae=None):
    """
    :param prompt: the prompt for which the model will generate text
    :param value: the value to amplify the feature with
    """
    model.m.reset_hooks()
    block = f"blocks.{f.layer}.mlp.hook_pre" if sae is None else sae.cfg.hook_name
    model.m.add_hook(block, functools.partial(gen_hook, feature=f.feature, value=value, sae=sae))
    output = model.m.generate(prompt, max_new_tokens=n, verbose=verbose, temperature=temperature)
    model.m.reset_hooks()

    return [x[len(model.m.to_string(prompt[i])):] for i, x in enumerate(model.m.to_string(output))]

In [None]:
NUM_OF_DISTS = 2
KL_DIV_VALUES = [0.25, 0.5, -0.25, -0.5]

def get_completions_for_kl_val(model: Model, prompts, f: Feature, kl, neg=False, sae=None):
    act = get_activation_for_kl(model, prompts, f, kl, neg=neg, sae=sae)
    positive = hooked_gen(model.m.to_tokens(prompts), model, f, n=25, value=act, temperature=0.75, sae=sae)
    return [x.replace("\n", "\\n").replace("\r", "\\r") for x in positive]

def get_random_amps(model: Model):
    with open(f"output_metric/{model.model_id}_random_amps.pkl", "rb") as f:
        random_amps = pickle.load(f)
        
    return random.sample(random_amps, 2)

In [None]:
def generate_random_amps(model: Model, num, sae):
    random_amps = ["" for _ in range(NUM_OF_DISTS)]
    for i in KL_DIV_VALUES:
        for x in range(num):
            ampy = random_amps[x]
            random_feature = random.sample(range(sae.cfg.d_sae), 1)[0]
            # TODO: f = Feature(model.model_id, random_feature, layer...)
            rand_pos = get_completions_for_kl_val(model, OUTPUT_METRIC_GENERATION_PROMPTS, random_feature, i, sae=sae)
            ampy += f"\n<{'+' if i >= 0 else ''}{i}>" + f"\n<{'+' if i >= 0 else ''}{i}>".join([f"'{OUTPUT_METRIC_GENERATION_PROMPTS[x]}': '{rand_pos[x]}'" for x in range(len(rand_pos))])
            random_amps[x] = ampy

    return random_amps

def write_random_amps(model: Model, num):
    all_random_amps = []
    for _ in range(10):
        # TODO: get random SAE for model
        sae = None
        random_amps = generate_random_amps(model, num, sae)
        all_random_amps.extend(random_amps)
            
    with open(f"output_metric/{model.model_id}_random_amps.pkl", "wb") as f:
        pickle.dump(random_amps, f)

    return random_amps

In [None]:
@dataclass
class OutputScore:
    correct_choice: int
    chosen_index: int

    def success(self) -> bool:
        # the output metric final test for success
        return self.correct_choice == self.chosen_index
    
def to_int(s):
    try:
        return int(re.search(r'\d+', s).group())
    except: 
        print(f"No int in {s}")
        raise

def get_output_score(description: str, m: Model, f: Feature, sae=None) -> OutputScore:
    amps = ""
    for i in KL_DIV_VALUES:
        pos = get_completions_for_kl_val(m, OUTPUT_METRIC_GENERATION_PROMPTS, f, abs(i), neg=False if i>=0 else True, sae=sae)
        amps += f"\n<{'+' if i >= 0 else ''}{i}>" + f"\n<{'+' if i >= 0 else ''}{i}>".join([f"'{OUTPUT_METRIC_GENERATION_PROMPTS[x]}': '{pos[x]}'" for x in range(len(pos))])

    random_amps = get_random_amps(m)
    random_amps = random_amps.copy()
    random_amps.append(amps)

    random_indices = random.sample(range(NUM_OF_DISTS+1), NUM_OF_DISTS+1)
    ordered_amps = [random_amps[i] for i in random_indices]
    correct_choice = random_indices.index(NUM_OF_DISTS) + 1
    
    prompt = STEERING_FULL_PROMPT_TEMPLATE.format(explanation=description, amplifications1=ordered_amps[0], amplifications2=ordered_amps[1], amplifications3=ordered_amps[2])
    response = explainer([
        {"role": "user", "content": OUTPUT_METRIC_SYS_PROMPT + "\n\n" + prompt}
    ], weak=True)

    chosen_index = to_int(response.splitlines()[-1])
    return OutputScore(correct_choice, chosen_index)    

## Pipeline


### Getting features

In [None]:
def get_all_features_sample(m: Model, n: int = None, in_file: str = None, start_chunk: int = 0) -> List[Feature]:
    all_features_sample = m.get_features(in_file=in_file)
    if not m.with_sae:
        logger.info(f"Got {len(all_features_sample)} features for {m.model_id}")
        return all_features_sample
    
    features_count = len(all_features_sample)
    no_act_count = len([x for x in all_features_sample if x.activations == []])
    
    logger.info(f"Loaded {features_count} features, out of which {no_act_count} have no activations data")
    
    # validate n according to feature amount and chunks
    if n is None:
        n = features_count
    types_count = len(m.sae_types) 
    sizes_count = len(m.sae_sizes)
    assert n <= features_count, f"n must be less or equal to the numbers of sample features {features_count}"
    logger.info(f"Getting {n}/{features_count} features with {types_count} types and {sizes_count} sizes")
    
    chunks_count = types_count * sizes_count
    chunk_size = n // chunks_count
    assert n % chunks_count == 0, f"n must be divisible by {chunks_count} so that groups can be equal, for example {n - (n % chunks_count)}"
    logger.info(f"Dividing into {chunks_count} chunks, each of size {chunk_size}")

    # generate equally distribued sample made of chunks
    sample = []
    for t in m.sae_types:
        for s in m.sae_sizes:
            logger.debug(f"Creating chunk for {t}/{s} with range [{start_chunk}:{start_chunk + chunk_size}]")
            chunk = [x for x in all_features_sample if x.type == t and x.size == s][start_chunk:start_chunk + chunk_size]
            assert len(chunk) == chunk_size, f"current chunk length ({len(chunk)}) is not enough ({chunk_size})"
            sample.extend(chunk[:chunk_size])
    random.shuffle(sample)
    logger.info(f"Got {len(sample)} final sample features for {m.model_id}")
    return sample

### Main analysis

In [None]:
@dataclass
class Result:
    layer: str
    feature: str
    sae_type: str
    sae_size: str
    dashboard_url: str

    # MA
    max_act_exp: str
    max_act_input_success: bool
    max_act_output_success: bool
    
    # VP
    vocab_proj_exp: str
    vocab_proj_input_success: bool
    vocab_proj_output_success: bool

    # TC
    token_change_exp: str
    token_change_input_success: bool
    token_change_output_success: bool

    # RAW (VP + MA)
    ensemble_raw_vm_exp: str
    ensemble_raw_vm_input_success: bool
    ensemble_raw_vm_output_success: bool

    # RAW (TC + MA)
    ensemble_raw_tm_exp: str
    ensemble_raw_tm_input_success: bool
    ensemble_raw_tm_output_success: bool

    # RAW (VP + TC)
    ensemble_raw_vt_exp: str
    ensemble_raw_vt_input_success: bool
    ensemble_raw_vt_output_success: bool
    
    # RAW ALL (VP + TC + MA)
    ensemble_raw_all_exp: str
    ensemble_raw_all_input_success: bool
    ensemble_raw_all_output_success: bool

    # CONCAT ALL (VP + TC + MA)
    ensemble_concat_all_exp: str
    ensemble_concat_all_input_success: bool
    ensemble_concat_all_output_success: bool

    @classmethod
    def from_row(cls, row):
        return cls(**vars(row))

In [None]:
def analyze_feature(f: Feature, m: Model) -> Result:
    logger.trace(f"Analyzing {f} with {f.sae_id}") if m.with_sae else logger.trace(f"Analyzing {f}")
    
    if m.with_sae:
        sae, _, _ = SAE.from_pretrained(release = f.sae_release, sae_id = f.sae_id, device = device)
        sae_model_id = "meta-llama/Llama-3.1-8B" if m.is_llama() else m.model_id
        assert sae.cfg.model_name == sae_model_id, f"sae model name {sae.cfg.model_name} doesn't match current model {sae_model_id}!"
        if f.size != "":
            assert sae.cfg.d_sae == f.get_size_int(), f"sae size {sae.cfg.d_sae} doesn't match current feature {f.get_size_int()}!"
        
        try:
            np_sae_id = m.get_np_sae_id(f.layer, f.type, f.size)
            dashboard_url = f.get_pedia_dashboard_url(np_sae_id)
        except:
            logger.warning(f"get_np_sae_id() failed")
            dashboard_url = ""
    else:
        sae = None
        dashboard_url = ""
        
        
    ### DESCRIPTION GENERATION ###
    
    # 1 - MaxAct (MA)
    if not f.activations:
        max_act_user_prompt = "No activations found for feature."
        max_act_exp = "NO ACTIVATIONS DATA"
    else:
        max_act_user_prompt = generate_max_act_user_prompt(f, activating_examples=f.activations)
        max_act_exp = get_description(MAX_ACT_SYS_PROMPT, max_act_user_prompt)
        
    # 2 - VocabProj (VP)
    top_tokens, bottom_tokens = get_dec_unembed(m, f, sae=sae)
    vocab_proj_user_prompt = VOACB_PROJ_USER_PROMPT.format(top_tokens + bottom_tokens)
    vocab_proj_exp = get_description(VOCAB_PROJ_SYS_PROMPT, vocab_proj_user_prompt)
    
    # 3 - TokenChange (TC)
    real, random1, random2, random3, random4 = get_causal_data(m, f, sae=sae)
    token_change_user_prompt = generate_token_change_user_prompt(real)
    token_change_exp = get_description(TOKEN_CHANGE_SYS_PROMPT, token_change_user_prompt)
    
    # 4 - Ensemble Raw (VP + MA = VM)
    ensemble_raw_vm_exp = get_description(ENSEMBLE_RAW_VM_SYS_PROMPT, ENSEMBLE_IO_USER_PROMPT.format(
        input_data=max_act_user_prompt, output_data=top_tokens + bottom_tokens))
    
    # 5 - Ensemble Raw (TC + MA = TM)
    ensemble_raw_tm_exp = get_description(ENSEMBLE_RAW_TM_SYS_PROMPT, ENSEMBLE_IO_USER_PROMPT.format(
        input_data=max_act_user_prompt, output_data=token_change_user_prompt))
    
    # 6 - Ensemble Raw (VP + TC = VT)
    vocab_proj_user_prompt = VOACB_PROJ_USER_PROMPT.format(top_tokens + bottom_tokens + list(set(real)))
    ensemble_raw_vt_exp = get_description(VOCAB_PROJ_SYS_PROMPT, vocab_proj_user_prompt)
    
    # 7 - Ensemble Raw (All)
    ensemble_raw_all_exp = get_description(ENSEMBLE_RAW_ALL_SYS_PROMPT, ENSEMBLE_IOO_USER_PROMPT.format(
        input_data=max_act_user_prompt, output_data1=top_tokens + bottom_tokens, output_data2=token_change_user_prompt))
    
    # 8 - Ensemble Concat (All)
    ensemble_concat_all_exp = f"{vocab_proj_exp}\n{max_act_exp}\n{token_change_exp}"
    logger.trace("Got descriptions")

    ### DESCRIPTION EVALUATION ###

    ma_in_score = get_input_score(max_act_exp, m, f, sae=sae)
    vp_in_score = get_input_score(vocab_proj_exp, m, f, sae=sae)
    tc_in_score = get_input_score(token_change_exp, m, f, sae=sae)
    ens_raw_vm_in_score = get_input_score(ensemble_raw_vm_exp, m, f, sae=sae)
    ens_raw_tm_in_score = get_input_score(ensemble_raw_tm_exp, m, f, sae=sae)
    ens_raw_vt_in_score = get_input_score(ensemble_raw_vt_exp, m, f, sae=sae)
    ens_raw_all_in_score = get_input_score(ensemble_raw_all_exp, m, f, sae=sae)
    ens_con_all_in_score = get_input_score(ensemble_concat_all_exp, m, f, sae=sae)
    logger.trace("Got input score")
    
    ma_out_score = get_output_score(max_act_exp, m, f, sae=sae)
    vp_out_score = get_output_score(vocab_proj_exp, m, f, sae=sae)
    tc_out_score = get_output_score(token_change_exp, m, f, sae=sae)
    ens_raw_vm_out_score = get_output_score(ensemble_raw_vm_exp, m, f, sae=sae)
    ens_raw_tm_out_score = get_output_score(ensemble_raw_tm_exp, m, f, sae=sae)
    ens_raw_vt_out_score = get_output_score(ensemble_raw_vt_exp, m, f, sae=sae)
    ens_raw_all_out_score = get_output_score(ensemble_raw_all_exp, m, f, sae=sae)
    ens_con_all_out_score = get_output_score(ensemble_concat_all_exp, m, f, sae=sae)
    logger.trace("Got output score")
    
    layer = sae.cfg.hook_layer if sae is not None else f.layer
    result = Result(
        layer=layer, feature=f.feature, sae_type=f.type, sae_size=f.size, dashboard_url=dashboard_url,
        
        max_act_exp=max_act_exp, max_act_input_success=ma_in_score.success(), max_act_output_success=ma_out_score.success(),
        vocab_proj_exp=vocab_proj_exp, vocab_proj_input_success=vp_in_score.success(), vocab_proj_output_success=vp_out_score.success(),
        token_change_exp=token_change_exp, token_change_input_success=tc_in_score.success(), token_change_output_success=tc_out_score.success(),
        ensemble_raw_vm_exp=ensemble_raw_vm_exp, ensemble_raw_vm_input_success=ens_raw_vm_in_score.success(), ensemble_raw_vm_output_success=ens_raw_vm_out_score.success(),
        ensemble_raw_tm_exp=ensemble_raw_tm_exp, ensemble_raw_tm_input_success=ens_raw_tm_in_score.success(), ensemble_raw_tm_output_success=ens_raw_tm_out_score.success(),
        ensemble_raw_vt_exp=ensemble_raw_vt_exp, ensemble_raw_vt_input_success=ens_raw_vt_in_score.success(), ensemble_raw_vt_output_success=ens_raw_vt_out_score.success(),
        ensemble_raw_all_exp=ensemble_raw_all_exp, ensemble_raw_all_input_success=ens_raw_all_in_score.success(), ensemble_raw_all_output_success=ens_raw_all_out_score.success(),
        ensemble_concat_all_exp=ensemble_concat_all_exp, ensemble_concat_all_input_success=ens_con_all_in_score.success(), ensemble_concat_all_output_success=ens_con_all_out_score.success()
        )
    
    # free memory
    del sae
    free()
    
    return result

In [None]:
def get_df_scores(df):
    ma_input = len(df[df["Max Act Input Success"] == True]) / len(df)
    vp_input = len(df[df["Vocab Proj Input Success"] == True]) / len(df)
    tc_input = len(df[df["Token Change Input Success"] == True]) / len(df)
    ens_raw_vm_input = len(df[df["Ensemble Raw Vm Input Success"] == True]) / len(df)
    ens_raw_tm_input = len(df[df["Ensemble Raw Tm Input Success"] == True]) / len(df)
    ens_raw_vt_input = len(df[df["Ensemble Raw Vt Input Success"] == True]) / len(df)
    ens_raw_all_input = len(df[df["Ensemble Raw All Input Success"] == True]) / len(df)
    ens_con_all_input = len(df[df["Ensemble Concat All Input Success"] == True]) / len(df)

    input_results = { 
        "Max Act Input Result": ma_input,
        "Vocab Proj Input Result": vp_input,
        "Token Change Input Result": tc_input,
        "Ensemble Raw Vm Input Result": ens_raw_vm_input,
        "Ensemble Raw Tm Input Result": ens_raw_tm_input,
        "Ensemble Raw Vt Input Result": ens_raw_vt_input,
        "Ensemble Raw All Input Result": ens_raw_all_input,
        "Ensemble Concat All Input Result": ens_con_all_input
    }
    
    return input_results

def get_knowledge_exp_df(m: Model, features: List[Feature], limit=None, start=0, output_csv=None):
    exp_results = DefaultDict(list)
    features = features[start:]
    if limit:
        features = features[:limit]

    pbar = Pbar(total = len(features))
    for i, f in enumerate(features):
        try:
            # full analysis for the current feature (explainer + scorer)
            result = analyze_feature(f, m)
            if not result:
                pbar.update(1)
                continue
            
            # append all results to dataframe, figure out the fields dynamically
            for k, v in asdict(result).items():
                df_column = k.replace("_", " ").title()
                exp_results[df_column].append(v)
        
            if i % 5 == 0:
                df = pd.DataFrame(exp_results)
                results = get_df_scores(df)
                pbar.set_description(results)
                if output_csv:
                    df.to_csv(output_csv, index=False)
            pbar.update(1)

        except (JSONDecodeError, KeyError, AttributeError, ValueError, IndexError) as e:
            logger.error(f"Error with {f.sae_id}/{f.feature}: {e}")
            traceback.print_exc()
            free()
            continue

        except Exception as e:
            logger.error(e)
            traceback.print_exc()
            break

    df = pd.DataFrame(exp_results)
    return df

In [None]:
if not os.path.exists(EXPERIMENTS_DIR):
    os.mkdir(EXPERIMENTS_DIR)

def analyze_model(m: Model, fs: List[Feature], limit=None, csv_path="") -> str:
    current_time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    fs_len = limit if limit else len(fs)
    file_name = f'{m.model_name}_{fs_len}_features_@{current_time}'
    
    log_full_path = os.path.join(LOGS_DIR, f"{file_name}.log")
    log_sink = logger.add(log_full_path, level="TRACE")
    
    try:
        csv_full_path = os.path.join(EXPERIMENTS_DIR, f"{file_name}.csv") if not csv_path else csv_path
        df = get_knowledge_exp_df(m, fs, limit=limit, output_csv=csv_full_path)
        df.to_csv(csv_full_path)
        logger.info(f"Results saved to {csv_full_path}")
        return csv_full_path
    except Exception as e:
        logger.error(e)
    finally:
        logger.remove(log_sink)
    
def calculate_score(experiment_csv):
    df = pd.read_csv(experiment_csv)
    results = get_df_scores(df)
    logger.info(results)    

### Run

In [None]:
CHOSEN_MODEL = "gemma"

def get_model_obj(chosen_model: str) -> Model:
    models_map = { 
        "gemma": MODELS["GEMMA-2-2B"],
        "llama": MODELS["LLAMA-3.1-8B"], 
        "gpt": MODELS["GPT-2-SM-V5"],
        "llama_in": MODELS["LLAMA-3.1-8B-IN"]}
    return models_map[chosen_model]

model = get_model_obj(CHOSEN_MODEL)

In [None]:
model.load_model()

In [None]:
if model.with_sae:
    specific_layers = model.get_specific_layers()
    model.get_saes_info_specific_layers(specific_layers)

In [None]:
# model.write_feature_sample()
all_features_sample = get_all_features_sample(model)
all_features_sample[:5]

In [None]:
results_csv = analyze_model(model, all_features_sample, limit=1)