In [7]:
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import pandas as pd
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

os.environ['OPENAI_KEY_SAE'] = '...' # Replace with your OpenAI API key, or with another environment variable (e.g. os.environ['OPENAI_API_KEY'])

from hypothesaes.embedding import get_openai_embeddings
from hypothesaes.sae import SparseAutoencoder, load_model
from hypothesaes.interpret_neurons import NeuronInterpreter, SamplingConfig, LLMConfig, InterpretConfig, ScoringConfig
from hypothesaes.annotate import annotate_texts_with_concepts
from hypothesaes.evaluation import score_hypotheses
from hypothesaes.select_neurons import select_neurons

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Using device: cuda


In [4]:
train_df = pd.read_json(f'./demo-data/demo-train-20K.json', lines=True)
val_df = pd.read_json(f'./demo-data/demo-val-2K.json', lines=True)

text_col = 'text'
target_col = 'stars'

train_texts = train_df[text_col].tolist()
val_texts = val_df[text_col].tolist()

# Compute embeddings for each split
EMBEDDER = "text-embedding-3-small"
CACHE_NAME = f"yelp_demo_{EMBEDDER}"
text2embedding = get_openai_embeddings(
    texts=train_texts + val_texts,
    model=EMBEDDER,
    cache_name=CACHE_NAME,
)

train_embeddings = np.array([text2embedding[text] for text in train_texts])
val_embeddings = np.array([text2embedding[text] for text in val_texts])

Processing chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Chunk 0:   0%|          | 0/86 [00:00<?, ?it/s]

Saved 22000 embeddings to /home/rm868/saetools-dev/emb_cache/yelp_demo_text-embedding-3-small/chunk_000.npy


In [27]:
# Train and get activations from two SAEs with different parameters
X_train = torch.tensor(train_embeddings, dtype=torch.float32).to(device)
X_val = torch.tensor(val_embeddings, dtype=torch.float32).to(device)

# Define parameters for two SAEs
sae_params = [
    {"M": 256, "K": 8},
    {"M": 32, "K": 4}
]

models = []
activations_list = []
neuron_source_info = []

for params in sae_params:
    M, K = params["M"], params["K"]
    save_dir = f'./checkpoints/{CACHE_NAME}'
    save_path = f'{save_dir}/SAE_M={M}_K={K}.pt'
    
    # Initialize and train (or load) the SAE model
    if os.path.exists(save_path):
        print(f"Loading existing model: M={M}, K={K}")
        model = load_model(save_path).to(device)
    else:
        print(f"Training new model: M={M}, K={K}")
        model = SparseAutoencoder(
            input_dim=X_train.shape[1],
            m_total_neurons=M,
            k_active_neurons=K,
            # Optional parameters:
            # aux_k=None,  # Number of neurons for dead neuron revival (None=default)
            # multi_k=None,  # Number of neurons for secondary reconstruction
            # dead_neuron_threshold_steps=256,  # Number of non-firing steps after which a neuron is considered dead
        ).to(device)
        
        model.fit(
            X_train=X_train,
            X_val=X_val,
            n_epochs=100,
            save_dir=save_dir,
            # Optional parameters:
            # batch_size=512,
            # learning_rate=5e-4,
            # aux_coef=1/32,  # Coefficient for auxiliary loss
            # multi_coef=0.0,  # Coefficient for multi-k loss
            # patience=3,     # Early stopping patience
            # clip_grad=1.0,  # Gradient clipping value
        )
    
    models.append(model)
    
    # Get activations from this model
    model_activations = model.get_activations(X_train)
    activations_list.append(model_activations)
    
    # Track source information for each neuron
    neuron_source_info.extend([(M, K) for i in range(M)])

# Concatenate activations from both models
train_activations = np.concatenate(activations_list, axis=1)
print(f"Neuron activations shape (from {len(sae_params)} models): {train_activations.shape}")

Loading existing model: M=256, K=8
Loaded model from ./checkpoints/yelp_demo_text-embedding-3-small/SAE_M=256_K=8.pt
Loading existing model: M=32, K=4
Loaded model from ./checkpoints/yelp_demo_text-embedding-3-small/SAE_M=32_K=4.pt
Neuron activations shape (from 2 models): (20000, 288)


In [28]:
# Select neurons using "lasso", "separation_score", or "correlation"
selection_method = "correlation"
top_neuron_count = 20

selected_neurons, scores = select_neurons(
    activations=train_activations,
    target=train_df[target_col].values,
    n_select=top_neuron_count,
    method=selection_method,
    # Optional parameters depend on selection method; see select_neurons.py
)

In [29]:
TASK_SPECIFIC_INSTRUCTIONS = """All of the texts are reviews of restaurants on Yelp.
Features should describe a specific aspect of the review. For example:
- "mentions long wait times to receive service"
- "praises how a dish was cooked, with phrases like 'perfect medium-rare'\""""

interpreter = NeuronInterpreter(
    interpreter_model="gpt-4o",
    annotator_model="gpt-4o-mini",
    n_workers_interpretation=10,
    n_workers_annotation=50,
    cache_name=CACHE_NAME,
)

interpret_config = InterpretConfig(
    sampling=SamplingConfig(
        n_examples=20,
        max_words_per_example=128,
    ),
    llm=LLMConfig(
        temperature=0.7,
        max_interpretation_tokens=75,
    ),
    n_candidates=3,
    task_specific_instructions=TASK_SPECIFIC_INSTRUCTIONS,
)

interpretations = interpreter.interpret_neurons(
    texts=train_texts,
    activations=train_activations,
    neuron_indices=selected_neurons,
    config=interpret_config,
)

Generating 3 interpretation(s) per neuron:   0%|          | 0/60 [00:00<?, ?it/s]

In [30]:
scoring_config = ScoringConfig(
    n_examples=200,
    max_words_per_example=128,
)

all_metrics = interpreter.score_interpretations(
    texts=train_texts,
    activations=train_activations,
    interpretations=interpretations,
    config=scoring_config,
)

Found 800 cached items; annotating 11200 uncached items


Scoring neuron interpretation fidelity (20 neurons; 3 candidate interps per neuron; 200 examples to score each…

In [31]:
# Create DataFrame with best and worst interpretations
interpretations_data = []
for neuron_idx in selected_neurons:
    neuron_metrics = all_metrics[neuron_idx]
    best_interp, best_metrics = max(neuron_metrics.items(), key=lambda x: x[1]['f1'])
    worst_interp, worst_metrics = min(neuron_metrics.items(), key=lambda x: x[1]['f1'])
    
    interpretations_data.append({
        'neuron_idx': neuron_idx,
        'source_sae': neuron_source_info[neuron_idx],
        f'{selection_method}': scores[selected_neurons.index(neuron_idx)],
        'best_interpretation': best_interp,
        'best_f1': best_metrics['f1'],
        'worst_interpretation': worst_interp,
        'worst_f1': worst_metrics['f1']
    })

best_interp_df = pd.DataFrame(interpretations_data).sort_values(by=f'{selection_method}', ascending=False)

display(
    best_interp_df.style.format({
        'separation_score': '{:.2f}',
        'best_f1': '{:.2f}', 
        'worst_f1': '{:.2f}'
    })
)

Unnamed: 0,neuron_idx,source_sae,correlation,best_interpretation,best_f1,worst_interpretation,worst_f1
4,286,"(32, 4)",0.349065,emphasizes the friendliness and warmth of the staff using enthusiastic language,0.79,"mentions friendly and amazing staff with words like 'friendly', 'amazing', 'wonderful', or 'smile'",0.58
13,19,"(256, 8)",0.208405,mentions specific positive interactions with named staff members,0.83,mentions specific staff members by name along with praise for their service,0.8
16,235,"(256, 8)",0.198075,"expresses personal love or favorite status for the restaurant, using phrases like 'I love this place', 'my favorite restaurant', or 'go-to spot'",0.83,"expresses personal attachment or emotional connection to the restaurant, using phrases like 'my favorite', 'love this place', or 'go-to spot'",0.77
17,258,"(32, 4)",0.193979,explicitly mentions a desire or intention to return to the restaurant,0.88,mentions intent or desire to return to the restaurant in the future,0.84
19,146,"(256, 8)",-0.174695,mentions being ignored or not attended to by staff despite being seated or present in the restaurant,0.92,mentions long wait times for initial service or being ignored by staff after being seated,0.9
18,246,"(256, 8)",-0.179931,mentions that the quality of food or service has declined over time,0.92,"mentions a decline or change in quality, service, or menu compared to past experiences",0.88
15,268,"(32, 4)",-0.204537,mentions long wait times or delays in receiving service or food,0.97,"mentions long wait times for food or service, often specifying durations such as 20 minutes, 30 minutes, or more",0.93
14,158,"(256, 8)",-0.207138,mentions repeated issues with the restaurant's service or food quality across multiple visits or orders,0.86,"mentions repeated negative experiences with food portion sizes, missing items, or incorrect orders in multiple visits",0.73
12,83,"(256, 8)",-0.22025,mentions problems caused by understaffing or disorganized staff,0.89,mentions long wait times caused by staffing issues or disorganization,0.67
11,138,"(256, 8)",-0.227552,"describes employees, servers, or bartenders as being rude or disrespectful",0.94,describes rude or unprofessional behavior by staff or employees,0.91


In [32]:
# Sample 500 random examples from holdout set
np.random.seed(42)
holdout_df = pd.read_json(f'./demo-data/demo-holdout-2K.json', lines=True)
holdout_texts = holdout_df[text_col].tolist()
holdout_labels = holdout_df[target_col].values

# Annotate texts with best interpretations
holdout_annotations = annotate_texts_with_concepts(
    texts=holdout_texts,
    concepts=best_interp_df['best_interpretation'].tolist(),
    max_words_per_example=128,
    cache_name=CACHE_NAME,
    n_workers=50,
)

# Evaluate on holdout set
metrics, hypothesis_df = score_hypotheses(
    hypothesis_annotations=holdout_annotations,
    y_true=holdout_labels,
    classification=False,
)

pd.set_option('display.max_colwidth', None)
display(hypothesis_df.round(3))
pd.reset_option('display.max_colwidth')

print("\nHoldout Set Metrics:")
print(f"R² Score: {metrics['r2']:.3f}")
print(f"Significant hypotheses: {metrics['Significant'][0]}/{metrics['Significant'][1]} " 
      f"(p < {metrics['Significant'][2]:.3e})")

Found 0 cached items; annotating 40000 uncached items


Annotating:   0%|          | 0/40000 [00:00<?, ?it/s]

Unnamed: 0,hypothesis,separation_score,separation_pval,regression_coef,regression_pval,feature_prevalence
2,"expresses personal love or favorite status for the restaurant, using phrases like 'I love this place', 'my favorite restaurant', or 'go-to spot'",1.468,0.0,0.378,0.0,0.389
0,emphasizes the friendliness and warmth of the staff using enthusiastic language,1.271,0.0,0.243,0.0,0.301
3,explicitly mentions a desire or intention to return to the restaurant,1.013,0.0,0.123,0.005,0.203
1,mentions specific positive interactions with named staff members,0.886,0.0,0.085,0.153,0.096
6,mentions long wait times or delays in receiving service or food,-1.122,0.0,0.267,0.0,0.138
10,"mentions long wait times for food or service, often with specific durations given (e.g., 20 minutes, 30 minutes, over an hour)",-1.435,0.0,-0.31,0.0,0.06
17,expresses disappointment in the flavor or seasoning of the food,-1.69,0.0,0.177,0.031,0.213
12,"mentions errors or mistakes in the food order, such as missing items, incorrect ingredients, or receiving the wrong dish",-1.805,0.0,0.01,0.868,0.144
16,expresses disappointment with food being bland or lacking flavor,-1.863,0.0,-0.266,0.002,0.166
7,mentions repeated issues with the restaurant's service or food quality across multiple visits or orders,-2.086,0.0,0.128,0.054,0.136



Holdout Set Metrics:
R² Score: 0.738
Significant hypotheses: 11/20 (p < 5.000e-03)
