# HypotheSAEs Quickstart

This notebook demonstrates basic usage of HypotheSAEs on a sample of the Yelp review dataset.  
We use GPT-4.1 for hypothesis generation (interpreting neurons), and GPT-4.1-mini for text annotation.  
Please set your OpenAI API key in the environment variable `OPENAI_KEY_SAE` in the below notebook cell.

In [None]:
%load_ext autoreload
%autoreload 2

import os
os.environ['OPENAI_KEY_SAE'] = '...' # Replace with your OpenAI API key, or with another environment variable (e.g. os.environ['OPENAI_API_KEY'])

import numpy as np
import pandas as pd

from hypothesaes.quickstart import train_sae, interpret_sae, generate_hypotheses, evaluate_hypotheses
from hypothesaes.embedding import get_openai_embeddings, get_local_embeddings

INTERPRETER_MODEL = "gpt-4.1"
ANNOTATOR_MODEL = "gpt-4.1-mini"
N_WORKERS_ANNOTATION = 30 # Number of parallel threads to use for annotation API calls; lower if hitting OpenAI rate limits

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


**Load data**

The dataset we're using here is a subset of 20K Yelp reviews, with 2K reviews used for validation (during SAE training). 

The target variable is the `stars` column, which is a rating between 1 and 5. We treat this as a regression task.

There are also 2K reviews used for holdout evaluation, which we'll use at the end of the notebook.

In [2]:
current_dir = os.getcwd()
if current_dir.endswith("notebooks"):
    prefix = "../"
else:
    prefix = "./"

base_dir = os.path.join(prefix, "demo-data")
train_df = pd.read_json(os.path.join(base_dir, "yelp-demo-train-20K.json"), lines=True)
val_df = pd.read_json(os.path.join(base_dir, "yelp-demo-val-2K.json"), lines=True)

texts = train_df['text'].tolist()
labels = train_df['stars'].values
val_texts = val_df['text'].tolist() # These are only used for early stopping of SAE training, so we don't need labels.

**Compute text embeddings for your dataset**

We'll compute text embeddings for a training set, and optionally a validation set. The validation embeddings are used for SAE eval and early-stopping during training.

Embeddings will be stored in the `emb_cache` directory (or `os.environ["EMB_CACHE_DIR"]` if you set it) using the `cache_name` parameter, so you only need to compute embeddings once.

You can use OpenAI or a local model.

Local models will run much faster on GPU. The default local model is `nomic-ai/modernbert-embed-base`. You can use any sentence-transformers model, but please read the model's docs; you may need to edit `get_local_embeddings`.

In [3]:
EMBEDDER = "text-embedding-3-small" # OpenAI
# EMBEDDER = "nomic-ai/modernbert-embed-base" # Huggingface model, will run locally
CACHE_NAME = f"yelp_quickstart_{EMBEDDER}"

text2embedding = get_openai_embeddings(texts + val_texts, model=EMBEDDER, cache_name=CACHE_NAME)
# text2embedding = get_local_embeddings(texts + val_texts, model=EMBEDDER, batch_size=128, cache_name=CACHE_NAME)
embeddings = np.stack([text2embedding[text] for text in texts])

train_embeddings = np.stack([text2embedding[text] for text in texts])
val_embeddings = np.stack([text2embedding[text] for text in val_texts])

Loading embedding chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Loaded 22000 embeddings in 1.4s


**Train SAE** 

We will train a Matryoshka SAE with $M=256$, $k=8$, and $\text{prefix\_lengths} = [32, 256]$.  

With the Matryoshka loss, the SAE will learn to reconstruct the input from (1) just the first 32 neurons, and (2) all 256 neurons.  
This will produce 32 coarse-grained features, and 224 finer-grained features.  

See the README for more details about selecting SAE hyperparameters. 

In [None]:
checkpoint_dir = os.path.join(prefix, "checkpoints", CACHE_NAME)
sae = train_sae(embeddings=train_embeddings, val_embeddings=val_embeddings,
                M=256, K=8, matryoshka_prefix_lengths=[32, 256], 
                checkpoint_dir=checkpoint_dir)

  0%|          | 0/100 [00:00<?, ?it/s]

Early stopping triggered after 67 epochs
Saved model to ../checkpoints/yelp_quickstart_text-embedding-3-small/SAE_matryoshka_M=256_K=8_prefixes=32-256.pt


**Interpret neurons**  

Interpret a random subset of neurons in the SAE to sanity-check that the learned features, and their interpretations, seem reasonable. We generate and print labels for `n_random_neurons` neurons, and we also print out the top-activating texts for each neuron.

In [None]:
# This instruction will be included in the neuron interpretation prompt.
# The below instructions are specific to Yelp, but you can customize this for your task.
# If you don't pass in task-specific instructions, there is a generic instruction (see src/interpret_neurons.py);
# task-specific instructions are optional, but they help produce hypotheses at the desired level of specificity.

TASK_SPECIFIC_INSTRUCTIONS = """All of the texts are reviews of restaurants on Yelp.
Features should describe a specific aspect of the review. For example:
- "mentions long wait times to receive service"
- "praises how a dish was cooked, with phrases like 'perfect medium-rare'\""""

# Interpret random neurons
results = interpret_sae(
    texts=texts,
    embeddings=train_embeddings,
    sae=sae,
    n_random_neurons=5,
    print_examples_n=3,
    task_specific_instructions=TASK_SPECIFIC_INSTRUCTIONS,
    interpreter_model=INTERPRETER_MODEL,
)

Computing activations (batchsize=16384):   0%|          | 0/2 [00:00<?, ?it/s]

Activations shape: (20000, 256)


Generating interpretations:   0%|          | 0/5 [00:00<?, ?it/s]


Neuron 232 (from SAE M=256, K=8): mentions smoothies or blended fruit drinks as a menu item or focus of the review

Top activating examples:
1. stopped by here because my friend recommended me to if i was in SB! wow that TROPICAL PINEAPPLE smoothie is bomb!!!! one of the best smoothies of my life. the pineapple wasn't overpowering at all, it was so tasty. that was actually my boyfriend's drink and i couldn't stop reaching for his. i ordered the peach smoothie too and it was good also but the aftertaste was kinda flintstone vitaminy to me, maybe because of the sherbert?   i loved how the smoothies come with a free boost. we both got bee pollen because a lot of people recommended it, it also listed good benefits on the menu, but i'm not really sure what that actually tastes like LOL   wished they had this place in SD :(
2. Delicious and nutritious blended drinks/smoothies! Definitely worth going back to whenever you're feeling for a little something extra after your meal. The power gree

**Generate hypotheses**

Generate hypotheses which are predictive of the target variable.

The `selection_method` parameter defines how we compute neuron predictiveness (see `src/select_neurons.py` for more details):
- "separation_score": E[target | top-activating examples] - E[target | zero-activating examples]
- "correlation": pearson(neuron activations, target variable)
- "lasso": select N nonzero features with an L1 regularized model

This cell outputs a dataframe with the following columns:
- `neuron_idx`: The index of the neuron in the SAE (if you're using multiple SAEs, this will be a global index across all of them).
- `source_sae`: The SAE that the neuron was selected from.
- `target_{selection_method}`: The predictiveness of the neuron for the target variable, using the selected `selection_method`.
- `interpretation`: The natural language interpretation of the neuron.
- `interp_fidelity_score`: The F1 fidelity score for how well the neuron's interpretation actually corresponds to its activation pattern.

In [10]:
selection_method = "correlation"
results = generate_hypotheses(
    texts=texts,
    labels=labels,
    embeddings=embeddings,
    sae=sae,
    cache_name=CACHE_NAME,
    selection_method=selection_method,
    n_selected_neurons=20,
    n_candidate_interpretations=1,
    task_specific_instructions=TASK_SPECIFIC_INSTRUCTIONS,
    interpreter_model=INTERPRETER_MODEL,
    annotator_model=ANNOTATOR_MODEL,
    n_workers_annotation=N_WORKERS_ANNOTATION, # Please lower this parameter if you are running into OpenAI API rate limits
)

print("\nMost predictive features of Yelp reviews:")
pd.set_option('display.max_colwidth', None)
display(results.sort_values(by=f"target_{selection_method}", ascending=False))
pd.reset_option('display.max_colwidth')

Embeddings shape: (20000, 1536)


Computing activations (batchsize=16384):   0%|          | 0/2 [00:00<?, ?it/s]

Activations shape: (20000, 256)

Step 1: Selecting top 20 predictive neurons

Step 2: Interpreting selected neurons


Generating interpretations:   0%|          | 0/20 [00:00<?, ?it/s]


Step 3: Scoring Interpretations
Found 0 cached items; annotating 2000 uncached items


Scoring neuron interpretation fidelity (20 neurons; 1 candidate interps per neuron; 100 examples to score each…


Most predictive features of Yelp reviews:


Unnamed: 0,neuron_idx,source_sae,target_correlation,interpretation,f1_fidelity_score
2,6,"(256, 8)",0.315468,"mentions a strong intent or desire to return to the restaurant, using phrases like 'definitely coming back', 'will be back', or 'can't wait to come back'",0.761538
3,2,"(256, 8)",0.309113,describes repeatedly returning to the restaurant over many years and consistently having positive experiences,0.331034
13,137,"(256, 8)",0.120681,"makes explicit claims that something is the 'best', often using phrases like 'best in the city', 'best po-boy anywhere', 'the Best!', 'the best drinks', or similar superlative language",0.899556
15,13,"(256, 8)",0.115959,"describes special occasion or event dining experiences, such as weddings, anniversaries, birthdays, or chef's table/rehearsal dinners",0.780488
19,46,"(256, 8)",-0.102475,"mentions uncomfortable temperatures inside the restaurant, such as being too hot or too cold, affecting the dining experience",0.529412
18,185,"(256, 8)",-0.105393,"mentions cleanliness or dirtiness of the restaurant or its facilities (e.g., floors, bathrooms, tables, employees, trash cans), often specifically describing them as dirty, filthy, or clean",0.989899
17,111,"(256, 8)",-0.11137,"mentions that the restaurant has changed ownership or management, and describes a decline in food quality, service, or atmosphere compared to previous visits",0.820241
16,245,"(256, 8)",-0.11557,describes receiving an incorrect food order or items missing from the order,0.899556
14,172,"(256, 8)",-0.117763,"mentions issues with restaurant staffing levels or management decisions affecting staff efficiency (e.g., understaffing, overwhelmed servers, need for better staff training or utilization, absent management, or poor planning regarding staff)",0.652778
12,61,"(256, 8)",-0.121326,"mentions contacting the restaurant or business by phone (e.g., calling to place an order, speak to a manager, or resolve an issue)",0.707368


**Evaluate held-out generalization**

Finally, we evaluate whether these are good hypotheses by testing whether their natural language interpretations can predict the target variable.  

We compute annotations for each hypothesized concept on a holdout set (not seen during SAE training & feature selection).

After annotation, we output a dataframe with the following columns:
- `hypothesis`: The natural language hypothesis (which came from interpreting a predictive neuron in the SAE)
- `separation_score`: How much the target variable differs when the concept is present vs. absent (i.e., $E[Y\mid\text{concept} = 1] - E[Y\mid\text{concept} = 0]$).
- `separation_pvalue`: The t-test p-value of the null hypothesis that the separation score is 0 (i.e., the concept is not associated with the target variable).
- `regression_coef`: The coefficient of the concept in a multivariate linear regression of the target variable on all concepts.
- `regression_pval`: The p-value of the null hypothesis that the regression coefficient is 0.
- `feature_prevalence`: The fraction of examples that contain the concept.

Additionally, we output the evaluation metrics used in the paper:
- Significant hypotheses: the number of hypotheses that are significant in the multivariate regression at a specified significance level (default $0.1$) after Bonferroni correction. You can pass in a different significance level using the `corrected_pval_threshold` parameter.
- AUC or $R^2$: how well the hypotheses collectively predict the target variable in the multivariate regression.


In [12]:
holdout_df = pd.read_json(os.path.join(base_dir, "yelp-demo-holdout-2K.json"), lines=True)
holdout_texts = holdout_df['text'].tolist()
holdout_labels = holdout_df['stars'].values

metrics, evaluation_df = evaluate_hypotheses(
    hypotheses_df=results,
    texts=holdout_texts,
    labels=holdout_labels,
    cache_name=CACHE_NAME,
    annotator_model=ANNOTATOR_MODEL,
    n_workers_annotation=N_WORKERS_ANNOTATION, # Please lower this parameter if you are running into OpenAI API rate limits
)

pd.set_option('display.max_colwidth', None)
display(evaluation_df)
pd.reset_option('display.max_colwidth')

print("\nHoldout Set Metrics:")
print(f"R² Score: {metrics['r2']:.3f}")
print(f"Significant hypotheses: {metrics['Significant'][0]}/{metrics['Significant'][1]} " 
      f"(p < {metrics['Significant'][2]:.3e})")

Step 1: Annotating texts with 20 hypotheses
Found 0 cached items; annotating 40000 uncached items


Annotating:   0%|          | 0/40000 [00:00<?, ?it/s]

Step 2: Computing predictiveness of hypothesis annotations


Unnamed: 0,hypothesis,separation_score,separation_pval,regression_coef,regression_pval,feature_prevalence
13,"makes explicit claims that something is the 'best', often using phrases like 'best in the city', 'best po-boy anywhere', 'the Best!', 'the best drinks', or similar superlative language",1.017268,9.815023999999999e-26,0.517771,6.263042e-15,0.1145
3,describes repeatedly returning to the restaurant over many years and consistently having positive experiences,1.006586,2.865108e-08,0.435373,0.00032136,0.0305
2,"mentions a strong intent or desire to return to the restaurant, using phrases like 'definitely coming back', 'will be back', or 'can't wait to come back'",0.977331,6.40044e-32,0.467886,3.941076e-16,0.164
15,"describes special occasion or event dining experiences, such as weddings, anniversaries, birthdays, or chef's table/rehearsal dinners",0.327083,0.04054698,0.146469,0.1658999,0.04
18,"mentions cleanliness or dirtiness of the restaurant or its facilities (e.g., floors, bathrooms, tables, employees, trash cans), often specifically describing them as dirty, filthy, or clean",-0.573659,1.510775e-05,-0.311449,0.0004254345,0.059
19,"mentions uncomfortable temperatures inside the restaurant, such as being too hot or too cold, affecting the dining experience",-0.930485,0.004964482,-0.550415,0.0122454,0.009
12,"mentions contacting the restaurant or business by phone (e.g., calling to place an order, speak to a manager, or resolve an issue)",-1.270704,2.876688e-10,-0.107538,0.4425678,0.0245
5,"describes waiting a specific, often long, amount of time to receive food or service, usually mentioning the number of minutes waited",-1.303414,2.434631e-25,-0.253231,0.005023272,0.065
8,"describes the food as mediocre, average, or not particularly good, often using terms like 'mediocre', 'so-so', 'not that good', 'OK, not great', or 'average quality'",-1.559029,3.1760059999999997e-63,-0.966914,1.6554170000000002e-39,0.1205
16,describes receiving an incorrect food order or items missing from the order,-1.695516,1.0694349999999999e-56,-0.050427,0.6703634,0.0885



Holdout Set Metrics:
R² Score: 0.570
Significant hypotheses: 13/20 (p < 5.000e-03)
