In [1]:
import pandas as pd
from collections import Counter

df_name = "results/imagenet-1k-256x256_gemma-2b-it_layer12.csv"
# df_name = 'results/cifar100-enriched_gemma-2b_layer12.csv'
dataset = df_name.split("/")[1].split("_gemma")[0]
layer = df_name.split("layer")[1].split(".csv")[0]
release = df_name.split("/")[1].split("_")[1].split("_layer")[0]
df = pd.read_csv(df_name)
df.head()

Unnamed: 0,image_id,fine_label,retrieved_features
0,0,65,"{""284"": ""programming codes and technical instr..."
1,1,970,"{""284"": ""programming codes and technical instr..."
2,2,230,"{""284"": ""programming codes and technical instr..."
3,3,809,"{""312"": ""phrases related to physical appearanc..."
4,4,516,"{""284"": ""programming codes and technical instr..."


In [2]:
dataset, layer, release

('imagenet-1k-256x256', '12', 'gemma-2b-it')

In [3]:
# load image_net_labels.txt as a dictionary
with open("image_net_labels.txt", "r") as f:
    image_net_labels = eval(f.read())


def map_label(df):
    label = [image_net_labels[i] for i in df["fine_label"]]
    df["label"] = label
    return df


if "image_net" in df_name:
    df = map_label(df)

In [4]:
def get_feature_counts(df):

    all_keys = []
    for features in df["retrieved_features"]:
        feature_dict = eval(features)  # Convert string to dictionary
        all_keys.extend(feature_dict.keys())

    x = Counter(all_keys).most_common()
    print("before:", len(x))
    # now i want to remove key value > 1000 from feature_dict
    x = dict(x)
    for key in list(x.keys()):
        if x[key] > 1000:
            del x[key]
    print("after:", len(x))
    print("=====")
    return x


def clean_features(df, feature_counts):
    new_features = []
    for features in df["retrieved_features"]:
        features = eval(features)
        features = {k: v for k, v in features.items() if k in feature_counts}
        new_features.append(features)
    df["cleaned_features"] = new_features
    print(f"layer{layer}:", sum([i != {} for i in df["cleaned_features"]]) / len(df))
    return df


feature_counts = get_feature_counts(df)
df = clean_features(df, feature_counts)

before: 4844
after: 4490
=====
layer12: 0.93208


In [5]:
df.head()

Unnamed: 0,image_id,fine_label,retrieved_features,cleaned_features
0,0,65,"{""284"": ""programming codes and technical instr...",{'7695': 'phrases describing visual actions li...
1,1,970,"{""284"": ""programming codes and technical instr...",{'4201': 'references to arctic-related topics ...
2,2,230,"{""284"": ""programming codes and technical instr...",{}
3,3,809,"{""312"": ""phrases related to physical appearanc...",{'3701': 'technical terms related to data anal...
4,4,516,"{""284"": ""programming codes and technical instr...","{'1668': 'references to rabbits', '4653': 'wor..."


In [6]:
import requests
import torch
from PIL import Image
from transformers import LlavaForConditionalGeneration, AutoProcessor
import pandas as pd
from sae_lens import SAE
from datasets import load_dataset
from tqdm import tqdm
import json
from torch.utils.data import DataLoader

# Standard imports
import os
import torch
from tqdm import tqdm

# Set the environment variable to use only GPU 1 # jack might need to change this
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

if torch.cuda.is_available():
    device = "cuda:1"
else:
    device = "cpu"

print(f"Device: {device}")


def fetch_and_process_feature_descriptions(layer, sae):
    url = "https://www.neuronpedia.org/api/explanation/export"
    payload = {"modelId": sae, "saeId": f"{layer}-res-jb"}  # "gemma-2b",
    headers = {"X-Api-Key": "YOUR_TOKEN"}
    response = requests.get(url, params=payload, headers=headers)

    if response.status_code != 200:
        raise ValueError(f"Failed to fetch feature descriptions: {response.text}")

    data = response.json()
    explanations_df = pd.DataFrame(data)
    explanations_df.rename(columns={"index": "feature"}, inplace=True)
    explanations_df["description"] = explanations_df["description"].apply(
        lambda x: x.lower()
    )
    feature_to_description = explanations_df.set_index("feature")[
        "description"
    ].to_dict()
    return feature_to_description


explainations = fetch_and_process_feature_descriptions(layer, release)
explainations

  from .autonotebook import tqdm as notebook_tqdm


Device: cpu


{'3130': 'government-related phrases and terms',
 '13141': 'code snippets related to mining cryptocurrencies',
 '3579': 'code snippets relating to memory allocation and management',
 '27': 'references to specific events or occurrences',
 '367': 'phrases related to research and analysis',
 '1107': 'phrases related to skepticism or disbelief',
 '2502': 'phrases related to mathematical instructions and procedures',
 '3091': 'phrases related to policy and legislation',
 '3138': 'verbs related to caring and concern',
 '3211': 'the beginning of a text',
 '3585': 'phrases related to politics and government',
 '4322': 'phrases related to issues or problems',
 '4387': 'words related to ancient mythology and historical events',
 '4455': 'mentions of subsidies and related economic terms',
 '5463': 'mentions of a specific name, "tig" or "tiger"',
 '5741': 'mentions of natural disasters and public safety concerns',
 '6720': 'keywords related to financial transactions and blockchain technology',
 '6

In [7]:
len(explainations)

15824

In [8]:
def create_vector(size, indices):
    """
    Create a vector of given size where the elements at the specified indices are 1 and the rest are 0.

    Parameters:
    size (int): The size of the vector.
    indices (list of int): The indices to be set to 1.

    Returns:
    torch.Tensor: The resulting vector.
    """
    vector = torch.zeros(1, size)
    for index in indices:
        vector[0, index] = 1
    return vector.to(device)


def get_activated_sae(release, layer, df, explainations):
    """
    Get the activated SAE and create a dictionary containing counts, explanations, and vectors for each SAE ID.

    Parameters:
    - release (str): The release version.
    - layer (int): The layer number.
    - df (pd.DataFrame): DataFrame containing 'cleaned_features'.
    - explainations (dict): Dictionary mapping SAE IDs to explanations.

    Returns:
    - dict: A dictionary with SAE IDs as keys and their corresponding data.
    """
    sae, cfg_dict, sparsity = SAE.from_pretrained(
        release=f"{release}-res-jb",
        sae_id=f"blocks.{layer}.hook_resid_post",
        device=device,
    )

    # Initialize sae_dict with explanations
    sae_dict = {}
    for i in range(cfg_dict["d_sae"]):
        # Convert i to string to match keys in explainations
        sae_id_str = str(i)
        explanation = explainations.get(sae_id_str, "UNK")
        sae_dict[i] = {
            "count": 0,
            "explanation": explanation,
            "vector": create_vector(cfg_dict["d_sae"], [i]) @ sae.W_dec,
        }

    # Update counts based on df['cleaned_features']
    for feature_dict in df["cleaned_features"]:
        if feature_dict != {}:
            for sae_id_str in feature_dict.keys():
                sae_id = int(sae_id_str)  # Convert sae_id_str to integer
                sae_dict[sae_id]["count"] += 1

    return sae_dict


# Call the function with explainations passed as an argument
dict_result = get_activated_sae(release, layer, df, explainations)
dict_result

{0: {'count': 0,
  'explanation': 'references to historical events or political situations that involve conflict or challenge to authority',
  'vector': tensor([[-0.0289,  0.0290,  0.0271,  ..., -0.0452, -0.0041,  0.0211]],
         grad_fn=<MmBackward0>)},
 1: {'count': 0,
  'explanation': ' mentions of dates and chronological order',
  'vector': tensor([[-0.0143,  0.0438,  0.0210,  ...,  0.0070, -0.0185,  0.0136]],
         grad_fn=<MmBackward0>)},
 2: {'count': 0,
  'explanation': 'historical information related to cultural genocide and major errors in scientific studies, with a possible focus on legal cases and investigations',
  'vector': tensor([[ 0.0324,  0.0426,  0.0010,  ..., -0.0086, -0.0125, -0.0035]],
         grad_fn=<MmBackward0>)},
 3: {'count': 0,
  'explanation': 'references to staying tuned for updates or additional information',
  'vector': tensor([[-0.0226,  0.0059,  0.0279,  ...,  0.0143,  0.0272, -0.0243]],
         grad_fn=<MmBackward0>)},
 4: {'count': 0,
  'exp

In [9]:
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.decomposition import PCA
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import networkx as nx

# Convert dict_result to a DataFrame for easier manipulation
sae_df = pd.DataFrame(
    [
        {
            "sae_id": sae_id,
            "count": data["count"],
            "explanation": data["explanation"],
            "vector": data["vector"].cpu().detach().numpy().flatten(),
        }
        for sae_id, data in dict_result.items()
    ]
)

# Sort SAEs by activation count
sae_df = sae_df.sort_values("count", ascending=False).reset_index(drop=True)

In [33]:
import pandas as pd
import numpy as np
import ast
import plotly.graph_objects as go
import plotly.express as px
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import umap


# Load ImageNet labels
with open("image_net_labels.txt", "r") as f:
    image_net_labels = eval(f.read())

# Assuming df is already loaded
# df = pd.read_csv('your_preprocessed_data.csv')


# Map fine labels to ImageNet labels
def map_label(df):
    df["label"] = df["fine_label"].map(image_net_labels)
    return df


df = map_label(df)


# Convert cleaned_features to the correct format
def parse_features(features):
    if isinstance(features, str):
        features_dict = ast.literal_eval(features)
    else:
        features_dict = features
    return features_dict  # Keep as is, with string keys and values


df["cleaned_features"] = df["cleaned_features"].apply(parse_features)

# Print a sample of the cleaned features to verify
print("Sample of cleaned features:")
print(df[["fine_label", "label", "cleaned_features"]].head())

# Get unique SAE IDs from all rows
all_sae_ids = set()
for features in df["cleaned_features"]:
    all_sae_ids.update(features.keys())

# Convert SAE IDs to integers and sort
sae_id_list = sorted(all_sae_ids, key=int)

# Create a mapping of SAE ID to column index
sae_id_to_col = {sae_id: idx for idx, sae_id in enumerate(sae_id_list)}

# Initialize the matrix for label and SAE ID occurrences
num_labels = df["fine_label"].nunique()
num_sae_ids = len(sae_id_list)
label_sae_matrix = np.zeros((num_labels, num_sae_ids), dtype=int)

# Create a mapping of label to row index
label_to_row = {
    label: idx for idx, label in enumerate(sorted(df["fine_label"].unique()))
}
row_to_label = {v: k for k, v in label_to_row.items()}

# Populate the matrix
for _, row in df.iterrows():
    label = row["fine_label"]
    features = row["cleaned_features"]
    for sae_id in features:
        label_sae_matrix[label_to_row[label], sae_id_to_col[sae_id]] += 1

# Print matrix statistics
print("\nMatrix shape:", label_sae_matrix.shape)
print("Total activations:", label_sae_matrix.sum())
print("Max value in matrix:", label_sae_matrix.max())
print("Number of non-zero elements:", np.count_nonzero(label_sae_matrix))

# Calculate label frequencies
label_freq = label_sae_matrix.sum(axis=1)
label_freq_dict = dict(zip(sorted(df["fine_label"].unique()), label_freq))
label_freq_sorted = dict(
    sorted(label_freq_dict.items(), key=lambda x: x[1], reverse=True)
)

# Print top 20 labels
print("\nTop 20 labels:")
print(dict(list(label_freq_sorted.items())[:20]))

# Select top N labels and features for visualization
N = 50  # Adjust this value to show more or fewer labels/features
top_labels = list(label_freq_sorted.keys())[:N]
top_features = np.argsort(label_sae_matrix.sum(axis=0))[-N:]

# Create a submatrix with top labels and features
submatrix = label_sae_matrix[
    np.ix_([label_to_row[label] for label in top_labels], top_features)
]

# Get feature explanations
feature_explanations = {}
for _, row in df.iterrows():
    for sae_id, explanation in row["cleaned_features"].items():
        feature_explanations[sae_id] = explanation


# Create heatmap using Plotly
# Function to truncate text
def truncate_text(text, max_length=30):
    return text[:max_length] + "..." if len(text) > max_length else text


# Create label frequencies bar plot
label_freq_df = pd.DataFrame(
    list(label_freq_sorted.items()), columns=["Label", "Frequency"]
)
label_freq_df["ImageNet_Label"] = label_freq_df["Label"].map(image_net_labels)
fig_label = px.bar(
    label_freq_df,
    x="Label",
    y="Frequency",
    title="Label Frequencies",
    labels={"Label": "Label ID", "Frequency": "Frequency"},
    height=600,
    hover_data=["ImageNet_Label"],
)
fig_label.update_xaxes(type="category")  # Treat x-axis as categorical
fig_label.show()

# Create feature frequencies bar plot
feature_freq = label_sae_matrix.sum(axis=0)
feature_freq_sorted = sorted(enumerate(feature_freq), key=lambda x: x[1], reverse=True)
feature_freq_df = pd.DataFrame(feature_freq_sorted, columns=["SAE_ID", "Frequency"])
feature_freq_df["SAE_ID"] = feature_freq_df["SAE_ID"].apply(lambda x: sae_id_list[x])
feature_freq_df["Explanation"] = feature_freq_df["SAE_ID"].map(feature_explanations)

fig_feature = px.bar(
    feature_freq_df,
    x="SAE_ID",
    y="Frequency",
    title="Feature Frequencies",
    labels={"SAE_ID": "SAE ID", "Frequency": "Frequency"},
    height=600,
    hover_data=["Explanation"],
)
fig_feature.update_xaxes(type="category")  # Treat x-axis as categorical
fig_feature.show()

# Create heatmap using Plotly
fig = go.Figure(
    data=go.Heatmap(
        z=submatrix,
        x=[f"SAE {sae_id_list[i]}" for i in top_features],
        y=[
            f"{truncate_text(image_net_labels[label])} (Label {label})"
            for label in top_labels
        ],
        colorscale="YlOrRd",
        colorbar=dict(title="Activation Count"),
        hovertemplate="Label: %{customdata}<br>SAE: %{x}<br>Count: %{z}<br>Explanation: %{text}<extra></extra>",
        text=[
            [feature_explanations[sae_id_list[i]] for i in top_features]
            for _ in top_labels
        ],
        customdata=[
            [f"{image_net_labels[label]} (Label {label})" for _ in top_features]
            for label in top_labels
        ],
    )
)

fig.update_layout(
    title=f"Top {N} SAE Activations Across Top {N} Labels",
    xaxis_title="Top SAE IDs",
    yaxis_title="Top Labels (truncated)",
    height=800,
    width=1000,
    yaxis=dict(tickfont=dict(size=10)),
)

fig.show()

# Prepare data for PCA and UMAP
X = label_sae_matrix
y = np.array([image_net_labels[label] for label in sorted(df["fine_label"].unique())])

# Calculate activation frequency for each label
activation_frequency = X.sum(axis=1)

# Standardize the features
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# Perform PCA
pca = PCA(n_components=2)
X_pca = pca.fit_transform(X_scaled)

# Create PCA plot
fig_pca = px.scatter(
    x=X_pca[:, 0],
    y=X_pca[:, 1],
    color=activation_frequency,
    labels={
        "x": "First Principal Component",
        "y": "Second Principal Component",
        "color": "Activation Frequency",
    },
    title="PCA of SAE Activations",
    hover_data={"label": y, "frequency": activation_frequency},
    color_continuous_scale="Viridis",
)
fig_pca.update_traces(
    hovertemplate="<b>%{customdata[0]}</b><br>Frequency: %{customdata[1]}<br>PC1: %{x:.2f}<br>PC2: %{y:.2f}"
)
fig_pca.show()

# Perform UMAP
umap_reducer = umap.UMAP(random_state=42)
X_umap = umap_reducer.fit_transform(X_scaled)

# Create UMAP plot
fig_umap = px.scatter(
    x=X_umap[:, 0],
    y=X_umap[:, 1],
    color=activation_frequency,
    labels={"x": "UMAP1", "y": "UMAP2", "color": "Activation Frequency"},
    title="UMAP of SAE Activations",
    hover_data={"label": y, "frequency": activation_frequency},
    color_continuous_scale="Viridis",
)
fig_umap.update_traces(
    hovertemplate="<b>%{customdata[0]}</b><br>Frequency: %{customdata[1]}<br>UMAP1: %{x:.2f}<br>UMAP2: %{y:.2f}"
)
fig_umap.show()

# Print explained variance ratio for PCA
print("PCA explained variance ratio:", pca.explained_variance_ratio_)

Sample of cleaned features:
   fine_label                                            label  \
0          65                                        sea snake   
1         970                                              alp   
2         230  Shetland sheepdog, Shetland sheep dog, Shetland   
3         809                                        soup bowl   
4         516                                           cradle   

                                    cleaned_features  
0  {7695: 'phrases describing visual actions like...  
1  {4201: 'references to arctic-related topics su...  
2                                                 {}  
3  {3701: 'technical terms related to data analys...  
4  {1668: 'references to rabbits', 4653: 'words a...  

Matrix shape: (1000, 4490)
Total activations: 271974
Max value in matrix: 50
Number of non-zero elements: 153904

Top 20 labels:
{np.int64(916): np.int64(2138), np.int64(458): np.int64(1540), np.int64(922): np.int64(1532), np.int64(918): np.int


n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.



PCA explained variance ratio: [0.03441682 0.01782704]
