In [1]:
import pandas as pd
from collections import Counter

df_name = 'results/imagenet-1k-256x256_gemma-2b-it_layer12.csv'
# df_name = 'results/cifar100-enriched_gemma-2b_layer12.csv'
dataset = df_name.split('/')[1].split('_gemma')[0]
layer = df_name.split('layer')[1].split('.csv')[0]
release = df_name.split('/')[1].split('_')[1].split('_layer')[0]
df = pd.read_csv(df_name)

In [2]:
dataset, layer, release

('imagenet-1k-256x256', '12', 'gemma-2b-it')

In [3]:
# load image_net_labels.txt as a dictionary
with open('image_net_labels.txt', 'r') as f:
    image_net_labels = eval(f.read())

def map_label(df):
    label = [image_net_labels[i] for i in df['fine_label']]
    df['label'] = label
    return df

if "image_net" in df_name:
    df = map_label(df)


In [4]:
def get_feature_counts(df):
    
    all_keys = []
    for features in df['retrieved_features']:
        feature_dict = eval(features)  # Convert string to dictionary
        all_keys.extend(feature_dict.keys())
        
    x = Counter(all_keys).most_common()
    print('before:',len(x))
    # now i want to remove key value > 1000 from feature_dict
    x = dict(x)
    for key in list(x.keys()):
        if x[key] > 1000:
            del x[key]
    print('after:',len(x))
    print('=====')
    return x

def clean_features(df, feature_counts):
    new_features = []
    for features in df['retrieved_features']:
        features = eval(features)
        features = {k:v for k,v in features.items() if k in feature_counts}
        new_features.append(features)
    df['cleaned_features'] = new_features
    print(f'layer{layer}:',sum([i!={} for i in df['cleaned_features']])/len(df))
    return df
feature_counts = get_feature_counts(df)
df = clean_features(df, feature_counts)

before: 4844
after: 4490
=====
layer12: 0.93208


In [5]:
import requests
import torch
from PIL import Image
from transformers import LlavaForConditionalGeneration, AutoProcessor
import pandas as pd
from sae_lens import SAE
from datasets import load_dataset
from tqdm import tqdm
import json
from torch.utils.data import DataLoader

# Standard imports
import os
import torch
from tqdm import tqdm

# Set the environment variable to use only GPU 1 # jack might need to change this
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

if torch.cuda.is_available():
    device = "cuda:1"
else:
    device = "cpu"

print(f"Device: {device}")

def fetch_and_process_feature_descriptions(layer, sae):
    url = "https://www.neuronpedia.org/api/explanation/export"
    payload = {
        "modelId": sae,  # "gemma-2b",
        "saeId": f"{layer}-res-jb"
    }
    headers = {"X-Api-Key": "YOUR_TOKEN"}
    response = requests.get(url, params=payload, headers=headers)
    
    if response.status_code != 200:
        raise ValueError(f"Failed to fetch feature descriptions: {response.text}")
    
    data = response.json()
    explanations_df = pd.DataFrame(data)
    explanations_df.rename(columns={"index": "feature"}, inplace=True)
    explanations_df["description"] = explanations_df["description"].apply(lambda x: x.lower())
    feature_to_description = explanations_df.set_index('feature')['description'].to_dict()
    return feature_to_description

explainations = fetch_and_process_feature_descriptions(layer, release)
explainations

  from .autonotebook import tqdm as notebook_tqdm


Device: cuda:1


{'3130': 'government-related phrases and terms',
 '13141': 'code snippets related to mining cryptocurrencies',
 '3579': 'code snippets relating to memory allocation and management',
 '27': 'references to specific events or occurrences',
 '367': 'phrases related to research and analysis',
 '1107': 'phrases related to skepticism or disbelief',
 '2502': 'phrases related to mathematical instructions and procedures',
 '3091': 'phrases related to policy and legislation',
 '3138': 'verbs related to caring and concern',
 '3211': 'the beginning of a text',
 '3585': 'phrases related to politics and government',
 '4322': 'phrases related to issues or problems',
 '4387': 'words related to ancient mythology and historical events',
 '4455': 'mentions of subsidies and related economic terms',
 '5463': 'mentions of a specific name, "tig" or "tiger"',
 '5741': 'mentions of natural disasters and public safety concerns',
 '6720': 'keywords related to financial transactions and blockchain technology',
 '6

In [6]:
len(explainations)

15824

In [14]:

def create_vector(size, indices):
    """
    Create a vector of given size where the elements at the specified indices are 1 and the rest are 0.

    Parameters:
    size (int): The size of the vector.
    indices (list of int): The indices to be set to 1.

    Returns:
    torch.Tensor: The resulting vector.
    """
    vector = torch.zeros(1, size)
    for index in indices:
        vector[0, index] = 1
    return vector.to(device)


def get_activated_sae(release, layer, df, explainations):
    """
    Get the activated SAE and create a dictionary containing counts, explanations, and vectors for each SAE ID.

    Parameters:
    - release (str): The release version.
    - layer (int): The layer number.
    - df (pd.DataFrame): DataFrame containing 'cleaned_features'.
    - explainations (dict): Dictionary mapping SAE IDs to explanations.

    Returns:
    - dict: A dictionary with SAE IDs as keys and their corresponding data.
    """
    sae, cfg_dict, sparsity = SAE.from_pretrained(
        release=f"{release}-res-jb",
        sae_id=f"blocks.{layer}.hook_resid_post",
        device=device
    )
    
    # Initialize sae_dict with explanations
    sae_dict = {}
    for i in range(cfg_dict['d_sae']):
        # Convert i to string to match keys in explainations
        sae_id_str = str(i)
        explanation = explainations.get(sae_id_str, 'UNK')
        sae_dict[i] = {
            'count': 0,
            'explanation': explanation,
            'vector': create_vector(cfg_dict['d_sae'], [i]) @ sae.W_dec
        }
    
    # Update counts based on df['cleaned_features']
    for feature_dict in df['cleaned_features']:
        if feature_dict != {}:
            for sae_id_str in feature_dict.keys():
                sae_id = int(sae_id_str)  # Convert sae_id_str to integer
                sae_dict[sae_id]['count'] += 1

    return sae_dict

# Call the function with explainations passed as an argument
dict_result = get_activated_sae(release, layer, df, explainations)
dict_result


{0: {'count': 0,
  'explanation': 'references to historical events or political situations that involve conflict or challenge to authority',
  'vector': tensor([[-0.0289,  0.0290,  0.0271,  ..., -0.0452, -0.0041,  0.0211]],
         device='cuda:1', grad_fn=<MmBackward0>)},
 1: {'count': 0,
  'explanation': ' mentions of dates and chronological order',
  'vector': tensor([[-0.0143,  0.0438,  0.0210,  ...,  0.0070, -0.0185,  0.0136]],
         device='cuda:1', grad_fn=<MmBackward0>)},
 2: {'count': 0,
  'explanation': 'historical information related to cultural genocide and major errors in scientific studies, with a possible focus on legal cases and investigations',
  'vector': tensor([[ 0.0324,  0.0426,  0.0010,  ..., -0.0086, -0.0125, -0.0035]],
         device='cuda:1', grad_fn=<MmBackward0>)},
 3: {'count': 0,
  'explanation': 'references to staying tuned for updates or additional information',
  'vector': tensor([[-0.0226,  0.0059,  0.0279,  ...,  0.0143,  0.0272, -0.0243]],
      