# HypotheSAEs Quickstart: Local

This notebook demonstrates basic usage of HypotheSAEs on a sample of the Yelp review dataset, entirely using local compute. You will need a GPU*.
- We use a sentence-transformers model for text embeddings. This can run on CPU or GPU (if a cuda device is available, it will use GPU).
- We use a large language model loaded in vLLM for hypothesis generation and text annotation. This requires GPU.
- This notebook uses `Qwen/Qwen3-32B-AWQ`. (If your GPU doesn't support this model or have enough memory, you can use a different model.)
- Takes about 80 minutes to run using an NVIDIA A6000. The slowest step is hypothesis evaluation on the heldout set, which requires lots of LLM annotation.

*It's possible you can get some smaller LLMs to run on CPU, perhaps with some modifications to the vLLM config, but it's not tested.

In [30]:
%load_ext autoreload
%autoreload 2

import os
import sys
os.environ['CUDA_VISIBLE_DEVICES'] = '0' # Set to the index of the GPU you want to use; see visible GPUs with `nvidia-smi` on command line

import numpy as np
import pandas as pd

current_dir = os.getcwd()
if current_dir.endswith("notebooks"):
    parent_dir = os.path.dirname(current_dir)
    sys.path.insert(0, parent_dir)

from hypothesaes.quickstart import train_sae, interpret_sae, generate_hypotheses, evaluate_hypotheses
from hypothesaes.embedding import get_local_embeddings
from hypothesaes.llm_local import get_vllm_engine

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


**0. Load data**

The dataset we're using here is a subset of 20K Yelp reviews, with 2K reviews used for validation (during SAE training).

The target variable is the `stars` column, which is a rating between 1 and 5. We treat this as a regression task.

There are also 2K reviews used for holdout evaluation, which we'll use at the end of the notebook.

In [2]:
current_dir = os.getcwd()
if current_dir.endswith("notebooks"):
    prefix = "../"
else:
    prefix = "./"

base_dir = os.path.join(prefix, "demo_data")
train_df = pd.read_json(os.path.join(base_dir, "yelp-demo-train-20K.json"), lines=True)
val_df = pd.read_json(os.path.join(base_dir, "yelp-demo-val-2K.json"), lines=True)

texts = train_df['text'].tolist()
labels = train_df['stars'].values
val_texts = val_df['text'].tolist() # These are only used for early stopping of SAE training, so we don't need labels.

**1. Compute text embeddings for your dataset**

We'll compute text embeddings for a training set, and optionally a validation set. The validation embeddings are used for SAE eval and early-stopping during training.

Embeddings will be stored in the `emb_cache` directory (or `os.environ["EMB_CACHE_DIR"]` if you set it) using the `cache_name` parameter, so you only need to compute embeddings once.

We will compute embeddings using a good sentence-transformers model, `nomic-ai/modernbert-embed-base` (smaller models like `sentence-transformers/all-roberta-large-v1` should also work well, and may fit on GPUs with less memory).

In [None]:
EMBEDDER = "nomic-ai/modernbert-embed-base"
CACHE_NAME = f"yelp_quickstart_local_{EMBEDDER}"

text2embedding = get_local_embeddings(texts + val_texts, model=EMBEDDER, layer_idx=-2, component="hidden_states", batch_size=128, cache_name=CACHE_NAME)
embeddings = np.stack([text2embedding[text] for text in texts])

train_embeddings = np.stack([text2embedding[text] for text in texts])
val_embeddings = np.stack([text2embedding[text] for text in val_texts])

**2. Train SAE**

We will train a Matryoshka SAE with $M=256$, $k=8$, and $\text{prefix\_lengths} = [32, 256]$.  

With the Matryoshka loss, the SAE will learn to reconstruct the input from (1) just the first 32 neurons, and (2) all 256 neurons.  
This will produce 32 coarse-grained features, and 224 finer-grained features.  

See the README for more details about selecting SAE hyperparameters. 





In [4]:
checkpoint_dir = os.path.join(prefix, "checkpoints", CACHE_NAME)
sae = train_sae(embeddings=train_embeddings, M=256, K=8, matryoshka_prefix_lengths=[32, 256], checkpoint_dir=checkpoint_dir, val_embeddings=val_embeddings)

Loaded model from ../checkpoints/yelp_quickstart_local_nomic-ai/modernbert-embed-base/SAE_matryoshka_M=256_K=8_prefixes=32-256.pt onto device cuda


**3. Interpret neurons**

This step will load a local LLM in vLLM, and use it to interpret a few random neurons in the SAE.  
This is to ensure that the local LLM is working and the interpretations look reasonable, if not perfect.

In [None]:
'''
Some vLLM notes:
- When running in notebook, it's best to use the same interpreter and annotator model;
  this is because there is some memory overhead when switching models (ie we can't fully free GPU memory).
- We set `gpu_memory_utilization=0.85`, which avoids OOMs with Qwen3-32B-AWQ on an A6000.
  You may want to adjust this depending on your hardware/model.
- You can increase `tensor_parallel_size` to use multiple GPUs on the same node.
  See vLLM docs: https://docs.vllm.ai/en/latest/serving/distributed_serving.html
'''

INTERPRETER_MODEL = ANNOTATOR_MODEL = "Qwen/Qwen3-32B-AWQ"

engine = get_vllm_engine(
    INTERPRETER_MODEL, 
    gpu_memory_utilization=0.85,
    tensor_parallel_size=1,
)

Loading Qwen/Qwen3-32B-AWQ in vLLM...


Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:02<00:06,  2.22s/it]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:04<00:04,  2.12s/it]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:06<00:02,  2.39s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:09<00:00,  2.53s/it]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:09<00:00,  2.43s/it]

Capturing CUDA graph shapes: 100%|██████████| 67/67 [01:05<00:00,  1.03it/s]


Loaded Qwen/Qwen3-32B-AWQ with dtype: torch.float16 (took 146.8s)


In [6]:
# This instruction will be included in the neuron interpretation prompt.
# The below instructions are specific to Yelp, but you can customize this for your task.
# If you don't pass in task-specific instructions, there is a generic instruction (see src/interpret_neurons.py);
# task-specific instructions are optional, but they help produce hypotheses at the desired level of specificity.

TASK_SPECIFIC_INSTRUCTIONS = """All of the texts are reviews of restaurants on Yelp.
Features should describe a specific aspect of the review. For example:
- "mentions long wait times to receive service"
- "praises how a dish was cooked, with phrases like 'perfect medium-rare'\""""

# Interpret random neurons
results = interpret_sae(
    texts=texts,
    embeddings=train_embeddings,
    sae=sae,
    interpreter_model=INTERPRETER_MODEL,
    n_random_neurons=5,
    print_examples_n=3,
    task_specific_instructions=TASK_SPECIFIC_INSTRUCTIONS
)

Computing activations (batchsize=16384):   0%|          | 0/2 [00:00<?, ?it/s]

Activations shape: (20000, 256)


Adding requests:   0%|          | 0/5 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/5 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]


Neuron 208 (from SAE M=256, K=8): mentions specific food preparation techniques or textures, such as 'crunchy on the outside with a creamier center,' 'wood-burning oven,' 'hot and fresh,' or 'crispy'

Top activating examples:
1. Nice thin crust crunchy pizza.  Walking the French Quarter needing something to eat and drink and stumbled upon Louisiana Kitchen which was away from the hustle and bustle of the quarter.  We had an Abita amber and IPA  along with a cheese pizza.  The service was prompt and professional from a very friendly server.  The pizza was made in a wood burning oven and arrived quickly, hot and crispy.  After a nice relaxing break we left with our batteries recharged and ready to head back to hustle and bustle.
2. Improved flow with high-tops added and table service all the way. MidiCi does not need to be as complicated as it was at first.   I just got a Margherita pizza with half pepperoni and      it was both light and filling. Flavorful; a hint of char. A Coke.   Th

**4. Generate hypotheses**

Generate hypotheses which are predictive of the target variable.  

We use the local LLM to generate interpretations and then estimate each interpretation's fidelity by annotating texts with the interpretation.

Here, we select neurons after ranking them by correlation with the target variable. See `src/select_neurons.py` for more details and other selection methods (e.g., Lasso).

This cell outputs a dataframe with the following columns:
- `neuron_idx`: The index of the neuron in the SAE (if you're using multiple SAEs, this will be a global index across all of them).
- `source_sae`: The SAE that the neuron was selected from.
- `target_{selection_method}`: The predictiveness of the neuron for the target variable, using the selected `selection_method`.
- `interpretation`: The natural language interpretation of the neuron.
- `interp_fidelity_score`: The F1 fidelity score for how well the neuron's interpretation actually corresponds to its activation pattern.

In [7]:
selection_method = "correlation"
results = generate_hypotheses(
    texts=texts,
    labels=labels,
    embeddings=embeddings,
    sae=sae,
    interpreter_model=INTERPRETER_MODEL,
    annotator_model=ANNOTATOR_MODEL,
    selection_method=selection_method,
    n_selected_neurons=20,
    n_candidate_interpretations=1,
    task_specific_instructions=TASK_SPECIFIC_INSTRUCTIONS
)

print("\nMost predictive features of Yelp reviews:")
pd.set_option('display.max_colwidth', None)
display(results.sort_values(by=f"target_{selection_method}", ascending=False))
pd.reset_option('display.max_colwidth')

Embeddings shape: (20000, 768)


Computing activations (batchsize=16384):   0%|          | 0/2 [00:00<?, ?it/s]

Activations shape: (20000, 256)

Step 1: Selecting top 20 predictive neurons

Step 2: Interpreting selected neurons


Adding requests:   0%|          | 0/20 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/20 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]


Step 3: Scoring Interpretations
Found 0 cached items; annotating 2000 uncached items


Adding requests:   0%|          | 0/2000 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/2000 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s…


Most predictive features of Yelp reviews:


Unnamed: 0,neuron_idx,source_sae,target_correlation,interpretation,f1_fidelity_score
5,15,"(256, 8)",0.200716,"uses hyperbolic superlatives (e.g., 'BEYOND AMAZING', 'PERFECT', 'FABULOUS') paired with multiple exclamation points to express extreme satisfaction",0.795181
6,254,"(256, 8)",0.163757,"expresses that the food is the best they've ever had, using superlatives like 'best ever,' 'epic,' 'unbelievable,' or 'out of this world' to describe the food quality",0.582154
11,34,"(256, 8)",0.122538,"mentions specific dishes and praises their preparation, taste, or texture with detailed descriptions",0.444255
15,227,"(256, 8)",0.094623,"mentions exceptional service quality and attention to detail, using phrases like 'outstanding service,' 'professional service,' or 'culture of hospitality'",0.427119
16,2,"(256, 8)",0.093676,mentions friendly and attentive staff who check on customers or ensure they are taken care of,0.630303
17,214,"(256, 8)",0.090071,"mentions frequent visits and strong cravings for the food, using phrases like 'at least once a week,' 'crave this place,' or 'every time I visit'",0.499375
18,27,"(256, 8)",0.08998,mentions that the restaurant is considered the best in a specific geographic area or category and compares it to other restaurants,0.872727
19,61,"(256, 8)",-0.089931,"mentions problems with online ordering or delivery services, including incorrect orders, lack of communication, or poor customer service resolution",0.802439
14,127,"(256, 8)",-0.099267,"mentions receiving food that was different from what was ordered, including incorrect items, wrong preparation, or insufficient portions, or the restaurant refusing to accommodate specific dietary or preparation requests",0.779747
13,58,"(256, 8)",-0.099812,mentions being denied seating despite available tables or encountering unfair seating policies and uncomfortable seating arrangements,0.888889


**5. Evaluate held-out generalization**

Finally, we evaluate whether these are good hypotheses by testing whether their natural language interpretations can predict the target variable.  

We compute annotations for each hypothesized concept on a holdout set (not seen during SAE training & feature selection).  
This step uses the same local LLM as the one used to generate interpretations (though, for Qwen models, we set `enable_thinking=False` for annotation to improve speed).

After annotation, we output a dataframe with the following columns:
- `hypothesis`: The natural language hypothesis (which came from interpreting a predictive neuron in the SAE)
- `separation_score`: How much the target variable differs when the concept is present vs. absent (i.e., $E[Y\mid\text{concept} = 1] - E[Y\mid\text{concept} = 0]$).
- `separation_pvalue`: The t-test p-value of the null hypothesis that the separation score is 0 (i.e., the concept is not associated with the target variable).
- `regression_coef`: The coefficient of the concept in a multivariate linear regression of the target variable on all concepts.
- `regression_pval`: The p-value of the null hypothesis that the regression coefficient is 0.
- `feature_prevalence`: The fraction of examples that contain the concept.

Additionally, we output the evaluation metrics used in the paper:
- Significant hypotheses: the number of hypotheses that are significant in the multivariate regression at a specified significance level (default $0.1$) after Bonferroni correction. You can pass in a different significance level using the `corrected_pval_threshold` parameter.
- AUC or $R^2$: how well the hypotheses collectively predict the target variable in the multivariate regression.


In [8]:
holdout_df = pd.read_json(os.path.join(base_dir, "yelp-demo-holdout-2K.json"), lines=True)
holdout_texts = holdout_df['text'].tolist()
holdout_labels = holdout_df['stars'].values

metrics, evaluation_df = evaluate_hypotheses(
    hypotheses_df=results,
    texts=holdout_texts,
    labels=holdout_labels,
    annotator_model=ANNOTATOR_MODEL,
)

pd.set_option('display.max_colwidth', None)
display(evaluation_df)
pd.reset_option('display.max_colwidth')

print("\nHoldout Set Metrics:")
print(f"R² Score: {metrics['r2']:.3f}")
print(f"Significant hypotheses: {metrics['Significant'][0]}/{metrics['Significant'][1]} " 
      f"(p < {metrics['Significant'][2]:.3e})")

Step 1: Annotating texts with 20 hypotheses
Found 0 cached items; annotating 40000 uncached items


Adding requests:   0%|          | 0/40000 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/40000 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/…

Step 2: Computing predictiveness of hypothesis annotations


Unnamed: 0,hypothesis,separation_score,separation_pval,regression_coef,regression_pval,feature_prevalence
6,"expresses that the food is the best they've ever had, using superlatives like 'best ever,' 'epic,' 'unbelievable,' or 'out of this world' to describe the food quality",1.100535,1.972704e-34,0.307047,8.053908e-07,0.135
11,"mentions specific dishes and praises their preparation, taste, or texture with detailed descriptions",1.086586,6.844723000000001e-73,0.420917,5.429902e-24,0.4865
5,"uses hyperbolic superlatives (e.g., 'BEYOND AMAZING', 'PERFECT', 'FABULOUS') paired with multiple exclamation points to express extreme satisfaction",1.058716,7.385799e-16,0.201603,0.01638637,0.0595
15,"mentions exceptional service quality and attention to detail, using phrases like 'outstanding service,' 'professional service,' or 'culture of hospitality'",0.986522,2.255149e-23,0.266751,7.200962e-05,0.1095
16,mentions friendly and attentive staff who check on customers or ensure they are taken care of,0.891304,7.406501e-37,0.222345,5.317346e-06,0.2565
17,"mentions frequent visits and strong cravings for the food, using phrases like 'at least once a week,' 'crave this place,' or 'every time I visit'",0.631096,2.969104e-07,0.235836,0.002015711,0.069
18,mentions that the restaurant is considered the best in a specific geographic area or category and compares it to other restaurants,0.6125,1.665475e-10,0.086565,0.1632675,0.12
12,mentions automatic gratuity or hidden fees added to the bill without prior notice,-1.048585,0.007066972,0.002354,0.9921679,0.0065
2,mentions long wait times to receive service or food,-1.104529,1.829027e-34,0.089186,0.1712972,0.134
9,"mentions specific criticisms of food quality, such as dishes being bland, under-seasoned, or lacking flavor",-1.770692,1.791816e-108,-0.576484,5.839361e-12,0.162



Holdout Set Metrics:
R² Score: 0.635
Significant hypotheses: 11/20 (p < 5.000e-03)
