In [None]:
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import pandas as pd
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

os.environ['OPENAI_KEY_SAE'] = '...' # Replace with your OpenAI API key, or with another environment variable (e.g. os.environ['OPENAI_API_KEY'])

from hypothesaes.embedding import get_openai_embeddings
from hypothesaes.quickstart import train_sae
from hypothesaes.interpret_neurons import NeuronInterpreter, SamplingConfig, LLMConfig, InterpretConfig, ScoringConfig
from hypothesaes.annotate import annotate_texts_with_concepts
from hypothesaes.evaluation import score_hypotheses
from hypothesaes.select_neurons import select_neurons

Using device: cuda


In [None]:
train_df = pd.read_json(f'../demo_data/yelp-demo-train-20K.json', lines=True)
val_df = pd.read_json(f'../demo_data/yelp-demo-val-2K.json', lines=True)

text_col = 'text'
target_col = 'stars'

train_texts = train_df[text_col].tolist()
val_texts = val_df[text_col].tolist()
train_labels = train_df[target_col].values

# Compute embeddings for each split
EMBEDDER = "text-embedding-3-small"
CACHE_NAME = f"yelp_quickstart_{EMBEDDER}"
text2embedding = get_openai_embeddings(
    texts=train_texts + val_texts,
    model=EMBEDDER,
    cache_name=CACHE_NAME,
)

train_embeddings = np.array([text2embedding[text] for text in train_texts])
val_embeddings = np.array([text2embedding[text] for text in val_texts])

Loading embedding chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Loaded 22000 embeddings in 1.4s


In [None]:
# Train SAE with Matryoshka prefix lengths
# Define SAE parameters
M, K = 256, 8
prefix_lengths = [32, 256]
checkpoint_dir = f'./checkpoints/{CACHE_NAME}'

model = train_sae(
    embeddings=train_embeddings,
    M=M,
    K=K,
    matryoshka_prefix_lengths=prefix_lengths,
    checkpoint_dir=checkpoint_dir,
    val_embeddings=val_embeddings,
    n_epochs=100,
    # Optional parameters:
    # aux_k=None,  # Number of neurons for dead neuron revival (None=default)
    # multi_k=None,  # Number of neurons for secondary reconstruction
    # dead_neuron_threshold_steps=256,  # Number of non-firing steps after which a neuron is considered dead
    # batch_size=512,
    # learning_rate=5e-4,
    # aux_coef=1/32,  # Coefficient for auxiliary loss
    # multi_coef=0.0,  # Coefficient for multi-k loss
    # patience=3,     # Early stopping patience
    # clip_grad=1.0,  # Gradient clipping value
)

# Get activations from the model
train_activations = model.get_activations(train_embeddings)
print(f"Neuron activations shape: {train_activations.shape}")

  0%|          | 0/100 [00:00<?, ?it/s]

Early stopping triggered after 60 epochs
Saved model to ./checkpoints/yelp_quickstart_text-embedding-3-small/SAE_matryoshka_M=256_K=8_prefixes=32-256.pt


Computing activations (batchsize=16384):   0%|          | 0/2 [00:00<?, ?it/s]

Neuron activations shape: (20000, 256)


In [5]:
# Select neurons using "lasso", "separation_score", or "correlation"
selection_method = "correlation"
top_neuron_count = 20

selected_neurons, scores = select_neurons(
    activations=train_activations,
    target=train_labels,
    n_select=top_neuron_count,
    method=selection_method,
    # Optional parameters depend on selection method; see select_neurons.py
)

In [6]:
TASK_SPECIFIC_INSTRUCTIONS = """All of the texts are reviews of restaurants on Yelp.
Features should describe a specific aspect of the review. For example:
- "mentions long wait times to receive service"
- "praises how a dish was cooked, with phrases like 'perfect medium-rare'\""""

interpreter = NeuronInterpreter(
    interpreter_model="gpt-4.1",
    annotator_model="gpt-4.1-mini",
    n_workers_interpretation=10,
    n_workers_annotation=50,
    cache_name=CACHE_NAME,
)

interpret_config = InterpretConfig(
    sampling=SamplingConfig(
        n_examples=20,
        max_words_per_example=128,
    ),
    llm=LLMConfig(
        temperature=0.7,
        max_interpretation_tokens=75,
    ),
    n_candidates=3,
    task_specific_instructions=TASK_SPECIFIC_INSTRUCTIONS,
)

interpretations = interpreter.interpret_neurons(
    texts=train_texts,
    activations=train_activations,
    neuron_indices=selected_neurons,
    config=interpret_config,
)

Generating interpretations:   0%|          | 0/60 [00:00<?, ?it/s]

In [7]:
scoring_config = ScoringConfig(
    n_examples=200,
    max_words_per_example=128,
)

all_metrics = interpreter.score_interpretations(
    texts=train_texts,
    activations=train_activations,
    interpretations=interpretations,
    config=scoring_config,
)

Found 0 cached items; annotating 12000 uncached items


Scoring neuron interpretation fidelity (20 neurons; 3 candidate interps per neuron; 200 examples to score each…

In [9]:
# Create DataFrame with best and worst interpretations
interpretations_data = []
for neuron_idx in selected_neurons:
    neuron_metrics = all_metrics[neuron_idx]
    best_interp, best_metrics = max(neuron_metrics.items(), key=lambda x: x[1]['f1'])
    worst_interp, worst_metrics = min(neuron_metrics.items(), key=lambda x: x[1]['f1'])
    
    interpretations_data.append({
        'neuron_idx': neuron_idx,
        f'target_{selection_method}': scores[selected_neurons.index(neuron_idx)],
        'best_interpretation': best_interp,
        'best_f1': best_metrics['f1'],
        'worst_interpretation': worst_interp,
        'worst_f1': worst_metrics['f1']
    })

best_interp_df = pd.DataFrame(interpretations_data).sort_values(by=f'target_{selection_method}', ascending=False)

display(
    best_interp_df.style.format({
        'separation_score': '{:.2f}',
        'best_f1': '{:.2f}', 
        'worst_f1': '{:.2f}'
    })
)

Unnamed: 0,neuron_idx,target_correlation,best_interpretation,best_f1,worst_interpretation,worst_f1
1,14,0.293127,"explicitly praises the server or staff by name, often using enthusiastic language",0.57,"mentions by name or gives special praise to specific staff members (e.g., Kari, Matt, Oliver, James, Ms. Sherry, Tiana, Tristan) for their outstanding service",0.46
2,22,0.291691,"expresses strong personal attachment or loyalty to the restaurant, using phrases like 'my favorite', 'I love this place', or describing frequent, long-term patronage",0.86,"describes the restaurant as a personal or family favorite, emphasizing repeated visits and long-term patronage",0.72
14,140,0.118263,"describes the restaurant or its food as 'the best' or 'best in [area/world/city]', often using superlative phrases like 'the best pizza', 'best pho', 'best restaurant experience', or similar",0.87,"explicitly claims the restaurant or its food is the 'best' in a given area, city, or even the world, using phrases like 'best pizza in town', 'best pho in the area', 'best Italian food', or similar superlative statements",0.8
18,18,0.095312,"describes multi-course or tasting menus (e.g., chef's tasting menu, four or five course meals, blind tastings)",0.5,mentions multi-course tasting menus or chef’s tasting experiences,0.41
19,176,-0.089604,"describes the food as average, mediocre, or nothing special, using words like 'average', 'ok', 'just okay', or 'nothing special'",0.88,"describes the food as average or just okay, using phrases like 'average', 'okay', 'nothing special', 'nothing stood out', or 'just ok'",0.84
17,169,-0.095762,"complains about slow or delayed service, specifically using words like 'slow', 'slowest', or 'takes a long time' to describe wait times for food, drinks, or drive-through orders",0.91,complains about slow service or long wait times to receive food or drinks,0.89
16,186,-0.110237,"mentions that the food or the restaurant environment is uncomfortably cold (e.g., food served cold, room temperature too low, complaints about being cold)",0.81,"mentions food or restaurant being cold, uncomfortably cold, or served at the wrong (cold) temperature",0.69
15,27,-0.110907,"mentions a change in restaurant ownership, management, or chef and compares the food or experience before and after the change",0.74,"mentions a change in ownership, management, or chef and describes a difference in food quality or experience after the change",0.65
13,122,-0.121827,"describes being seated at a table but not being acknowledged or served by any staff for an extended period, leading to leaving the restaurant without receiving service",0.44,"describes being seated at a table in a restaurant but not being approached or acknowledged by any server or staff for an extended period, leading to leaving without receiving service",0.41
12,236,-0.129189,"describes experiencing food poisoning or illness (e.g. nausea, vomiting, diarrhea, stomach pain) after eating at the restaurant",0.79,"describes experiencing foodborne illness symptoms such as nausea, vomiting, diarrhea, or stomach pain shortly after eating at the restaurant",0.72


In [None]:
# Sample 500 random examples from holdout set
np.random.seed(42)
holdout_df = pd.read_json(f'../demo_data/yelp-demo-holdout-2K.json', lines=True)
holdout_texts = holdout_df[text_col].tolist()
holdout_labels = holdout_df[target_col].values

# Annotate texts with best interpretations
holdout_annotations = annotate_texts_with_concepts(
    texts=holdout_texts,
    concepts=best_interp_df['best_interpretation'].tolist(),
    max_words_per_example=128,
    cache_name=CACHE_NAME,
    n_workers=50,
)

# Evaluate on holdout set
metrics, hypothesis_df = score_hypotheses(
    hypothesis_annotations=holdout_annotations,
    y_true=holdout_labels,
    classification=False,
)

pd.set_option('display.max_colwidth', None)
display(hypothesis_df.round(3))
pd.reset_option('display.max_colwidth')

print("\nHoldout Set Metrics:")
print(f"R² Score: {metrics['r2']:.3f}")
print(f"Significant hypotheses: {metrics['Significant'][0]}/{metrics['Significant'][1]} " 
      f"(p < {metrics['Significant'][2]:.3e})")

Found 0 cached items; annotating 40000 uncached items


Annotating:   0%|          | 0/40000 [00:00<?, ?it/s]

Unnamed: 0,hypothesis,separation_score,separation_pval,regression_coef,regression_pval,feature_prevalence
0,"explicitly praises the server or staff by name, often using enthusiastic language",1.096,0.0,0.441,0.0,0.068
1,"expresses strong personal attachment or loyalty to the restaurant, using phrases like 'my favorite', 'I love this place', or describing frequent, long-term patronage",1.034,0.0,0.342,0.0,0.17
2,"describes the restaurant or its food as 'the best' or 'best in [area/world/city]', often using superlative phrases like 'the best pizza', 'best pho', 'best restaurant experience', or similar",1.028,0.0,0.365,0.0,0.099
3,"describes multi-course or tasting menus (e.g., chef's tasting menu, four or five course meals, blind tastings)",0.787,0.003,0.299,0.064,0.014
7,"mentions a change in restaurant ownership, management, or chef and compares the food or experience before and after the change",-0.273,0.253,-0.206,0.162,0.018
4,"describes the food as average, mediocre, or nothing special, using words like 'average', 'ok', 'just okay', or 'nothing special'",-1.18,0.0,0.136,0.18,0.08
10,"describes the food, service, or ambiance as average, mediocre, or just okay, often using phrases like 'you get what you pay for', 'nothing's wonderful', 'not a fan', or 'not bad but not great'",-1.273,0.0,-0.002,0.984,0.108
5,"complains about slow or delayed service, specifically using words like 'slow', 'slowest', or 'takes a long time' to describe wait times for food, drinks, or drive-through orders",-1.477,0.0,0.479,0.0,0.045
16,"describes long wait times to be seated or to receive food/service, often specifying durations such as 30 minutes or more",-1.504,0.0,-0.498,0.0,0.036
6,"mentions that the food or the restaurant environment is uncomfortably cold (e.g., food served cold, room temperature too low, complaints about being cold)",-1.831,0.0,-0.444,0.0,0.03



Holdout Set Metrics:
R² Score: 0.629
Significant hypotheses: 14/20 (p < 5.000e-03)
