# SAE Feature Steering
Two approaches to steering using Sparse Autoencoder features:
1. **SingleFeature** — Use a single SAE decoder direction (manually chosen or by highest probe correlation)
2. **TopKFeatures** — Weighted sum of the top-k most class-correlated SAE features

These methods use the SAE to decompose activations into interpretable features, then construct steering vectors from decoder directions of task-relevant features.

In [None]:
# @title Setup
import sys, os

IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
    # Clone repo and install dependencies
    if not os.path.exists('/content/optimal_sparse_steering'):
        !git clone https://github.com/tgautam23/optimal_sparse_steering.git /content/optimal_sparse_steering
    !pip install -q torch transformer-lens sae-lens cvxpy datasets scikit-learn matplotlib seaborn tqdm transformers
    PROJECT_ROOT = '/content/optimal_sparse_steering'
else:
    PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), '..'))

sys.path.insert(0, PROJECT_ROOT)

import logging
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns

from configs.base import ExperimentConfig, ModelConfig
from src.models.wrapper import ModelWrapper
from src.models.sae_utils import get_decoder_matrix, get_feature_direction
from src.data.loader import load_dataset_splits
from src.data.preprocessing import extract_activations, extract_sae_features
from src.data.prompts import get_neutral_queries
from src.probes.linear_probe import LinearProbe
from src.steering.single_feature import SingleFeature
from src.steering.topk_features import TopKFeatures
from src.evaluation.metrics import compute_probe_score
from src.evaluation.generation import steered_generation

logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(name)s: %(message)s')
sns.set_theme(style="whitegrid")

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## Helper Functions

In [None]:
def evaluate_steering(model_wrapper, probe, test_acts_np, test_labels,
                      method, layer, alpha_values, target_class=1):
    """Evaluate a steering method across alpha values."""
    results = []
    sv = method.steering_vector.numpy()
    for alpha in alpha_values:
        steered_acts = test_acts_np + alpha * sv[np.newaxis, :]
        base_score = compute_probe_score(probe, test_acts_np, target_class)
        steered_score = compute_probe_score(probe, steered_acts, target_class)
        results.append({
            'alpha': alpha,
            'probe_score_base': base_score,
            'probe_score_steered': steered_score,
            'probe_score_delta': steered_score - base_score,
            'steering_norm': float(method.steering_vector.norm().item()),
        })
    return results


def find_top_correlated_features(sae_features_np, labels_np, topk=20):
    """Find the most correlated SAE features with binary labels."""
    feat_centered = sae_features_np - sae_features_np.mean(axis=0, keepdims=True)
    labels_centered = labels_np - labels_np.mean()
    numerator = feat_centered.T @ labels_centered
    feat_std = np.sqrt((feat_centered ** 2).sum(axis=0) + 1e-10)
    label_std = np.sqrt((labels_centered ** 2).sum() + 1e-10)
    correlations = numerator / (feat_std * label_std)
    top_indices = np.argsort(np.abs(correlations))[::-1][:topk]
    return top_indices, correlations


def plot_feature_correlations(correlations, top_indices, title="Top Feature Correlations"):
    """Bar plot of top correlated features."""
    fig, ax = plt.subplots(1, 1, figsize=(10, 5))
    top_corrs = correlations[top_indices]
    colors = ['#2ecc71' if c > 0 else '#e74c3c' for c in top_corrs]
    ax.bar(range(len(top_indices)), top_corrs, color=colors)
    ax.set_xticks(range(len(top_indices)))
    ax.set_xticklabels([str(i) for i in top_indices], rotation=45, fontsize=8)
    ax.set_xlabel('Feature Index')
    ax.set_ylabel('Pearson Correlation with Label')
    ax.set_title(title)
    ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
    plt.tight_layout()
    plt.show()
    return fig

---
## GPT-2 Small + SST-2

In [None]:
# @title Load GPT-2 Small + SAE
config = ExperimentConfig()
config.model.device = device

model_wrapper = ModelWrapper(config.model)
layer = config.model.steering_layer

# Pre-load SAE
sae = model_wrapper.get_sae(layer)
print(f"Model: {config.model.name}, d_model: {config.model.d_model}")
print(f"SAE loaded at layer {layer}, d_sae: {sae.W_dec.shape[0]}")

In [None]:
# @title Data + Probe + SAE Features
data = load_dataset_splits(config.data)
print(f"Train: {len(data['train_texts'])}, Test: {len(data['test_texts'])}")

# Extract residual stream activations
train_acts = extract_activations(data['train_texts'], model_wrapper, layer,
                                  batch_size=config.model.batch_size)
test_acts = extract_activations(data['test_texts'], model_wrapper, layer,
                                 batch_size=config.model.batch_size)

# Extract SAE feature activations
train_sae = extract_sae_features(data['train_texts'], model_wrapper, layer,
                                  batch_size=config.model.batch_size)
test_sae = extract_sae_features(data['test_texts'], model_wrapper, layer,
                                 batch_size=config.model.batch_size)

train_acts_np = train_acts.numpy()
test_acts_np = test_acts.numpy()
train_labels = np.array(data['train_labels'])
test_labels = np.array(data['test_labels'])
train_sae_np = train_sae.numpy()

# Train probe
probe = LinearProbe(d_model=config.model.d_model)
probe.fit(train_acts_np, train_labels)
print(f"Probe accuracy \u2014 train: {probe.score(train_acts_np, train_labels):.4f}, "
      f"test: {probe.score(test_acts_np, test_labels):.4f}")
print(f"SAE features shape: {train_sae.shape}")

### Feature Correlation Analysis
Before steering, let's identify which SAE features are most correlated with the sentiment label.

In [None]:
# @title Feature Correlation Analysis
top_indices, correlations = find_top_correlated_features(train_sae_np, train_labels, topk=20)

print("Top 20 most correlated SAE features:")
for i, idx in enumerate(top_indices):
    sign = "+" if correlations[idx] > 0 else "-"
    print(f"  {i+1:2d}. Feature {idx:5d}  correlation = {correlations[idx]:+.4f}  ({sign} sentiment)")

fig = plot_feature_correlations(correlations, top_indices,
                                title="GPT-2 Small: Top SAE Features Correlated with Sentiment")

### Method 1: SingleFeature
Use the decoder direction of the single most positively-correlated SAE feature.

In [None]:
# @title SingleFeature Steering
# Use the feature most positively correlated with target class
best_pos_idx = top_indices[0] if correlations[top_indices[0]] > 0 else top_indices[1]
print(f"Selected feature: {best_pos_idx} (correlation: {correlations[best_pos_idx]:+.4f})")

sf = SingleFeature(feature_idx=int(best_pos_idx))
sf.compute_steering_vector(model_wrapper=model_wrapper, layer=layer)

print(f"Steering vector norm: {sf.steering_vector.norm():.4f}")

# Also check: cosine similarity between this feature's direction and the probe direction
probe_dir = torch.tensor(probe.weight_vector, dtype=torch.float32)
probe_dir = probe_dir / probe_dir.norm()
cos_sim = torch.nn.functional.cosine_similarity(
    sf.steering_vector.unsqueeze(0), probe_dir.unsqueeze(0)
).item()
print(f"Cosine similarity with probe direction: {cos_sim:.4f}")

alpha_values = [0.0, 1.0, 3.0, 5.0, 10.0, 20.0, 50.0]
results_sf = evaluate_steering(model_wrapper, probe, test_acts_np, test_labels,
                                sf, layer, alpha_values)

for r in results_sf:
    print(f"  alpha={r['alpha']:5.1f}  probe_delta={r['probe_score_delta']:+.4f}")

### Method 2: TopKFeatures
Weighted sum of top-k features by Pearson correlation with the label.

In [None]:
# @title TopKFeatures — Sweep over k
k_values = [1, 3, 5, 10, 20, 50]
topk_results = {}

for k in k_values:
    tk = TopKFeatures(topk=k)
    tk.compute_steering_vector(
        sae_features=train_sae, labels=train_labels,
        model_wrapper=model_wrapper, layer=layer,
    )
    
    results = evaluate_steering(model_wrapper, probe, test_acts_np, test_labels,
                                 tk, layer, [5.0])
    delta = results[0]['probe_score_delta']
    topk_results[k] = {
        'method': tk,
        'results': results,
        'selected_features': tk.selected_features[:5],
        'probe_delta_at_5': delta,
    }
    print(f"k={k:3d}  probe_delta(alpha=5)={delta:+.4f}  "
          f"top features: {tk.selected_features[:5]}")

# Plot k vs effectiveness
fig, ax = plt.subplots(1, 1, figsize=(8, 5))
ks = list(topk_results.keys())
deltas = [topk_results[k]['probe_delta_at_5'] for k in ks]
ax.plot(ks, deltas, marker='o', color='#3498db')
ax.set_xlabel('k (number of features)')
ax.set_ylabel('Probe Score Delta (alpha=5)')
ax.set_title('GPT-2 Small: TopKFeatures \u2014 Effectiveness vs k')
plt.tight_layout()
plt.show()

In [None]:
# @title TopKFeatures — Alpha sweep (best k)
best_k = max(topk_results, key=lambda k: topk_results[k]['probe_delta_at_5'])
print(f"Best k: {best_k}")

tk_best = topk_results[best_k]['method']
results_tk = evaluate_steering(model_wrapper, probe, test_acts_np, test_labels,
                                tk_best, layer, [0.0, 1.0, 3.0, 5.0, 10.0, 20.0])

for r in results_tk:
    print(f"  alpha={r['alpha']:5.1f}  probe_delta={r['probe_score_delta']:+.4f}")

### GPT-2 Small: Comparison

In [None]:
# @title GPT-2 Small — SingleFeature vs TopKFeatures
alpha_values_cmp = [0.0, 1.0, 3.0, 5.0, 10.0, 20.0]

results_sf_cmp = evaluate_steering(model_wrapper, probe, test_acts_np, test_labels,
                                    sf, layer, alpha_values_cmp)
results_tk_cmp = evaluate_steering(model_wrapper, probe, test_acts_np, test_labels,
                                    tk_best, layer, alpha_values_cmp)

fig, ax = plt.subplots(1, 1, figsize=(8, 5))
for name, results in [('SingleFeature', results_sf_cmp), (f'TopK (k={best_k})', results_tk_cmp)]:
    alphas = [r['alpha'] for r in results]
    deltas = [r['probe_score_delta'] for r in results]
    ax.plot(alphas, deltas, marker='o', label=name)
ax.set_xlabel('Alpha')
ax.set_ylabel('Probe Score Delta')
ax.set_title('GPT-2 Small: SAE Feature Steering Comparison')
ax.legend()
ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
plt.tight_layout()
plt.show()

# Generated samples
queries = get_neutral_queries("sst2")[:3]
for name, method in [('SingleFeature', sf), (f'TopK (k={best_k})', tk_best)]:
    print(f"\n--- {name} (alpha=5.0) ---")
    try:
        gens = steered_generation(model_wrapper, queries, method, layer,
                                   alpha=5.0, max_new_tokens=50, temperature=0.7)
        for i, g in enumerate(gens):
            print(f"  [{i+1}] {g[:150]}...")
    except Exception as e:
        print(f"  Generation error: {e}")

---
## Gemma-2-2B Pretrained + SST-2

In [None]:
# @title Load Gemma-2-2B + SAE
del model_wrapper
torch.cuda.empty_cache()

gemma_config = ExperimentConfig(
    model=ModelConfig(
        name="gemma-2-2b-pt",
        tl_name="google/gemma-2-2b",
        sae_release="gemma-scope-2b-pt-res-canonical",
        sae_id_template="layer_{layer}/width_16k/canonical",
        hook_template="blocks.{layer}.hook_resid_post",
        d_model=2304, n_layers=26, steering_layer=15,
        dtype="float16", device=device, batch_size=4,
    ),
    experiment_name="gemma-2-2b-pt_sst2",
)

model_wrapper_g = ModelWrapper(gemma_config.model)
layer_g = gemma_config.model.steering_layer
sae_g = model_wrapper_g.get_sae(layer_g)
print(f"Model: {gemma_config.model.name}, d_sae: {sae_g.W_dec.shape[0]}")

In [None]:
# @title Gemma: Data + SAE Features + Probe
data_g = load_dataset_splits(gemma_config.data)

train_acts_g = extract_activations(data_g['train_texts'], model_wrapper_g, layer_g,
                                    batch_size=gemma_config.model.batch_size)
test_acts_g = extract_activations(data_g['test_texts'], model_wrapper_g, layer_g,
                                   batch_size=gemma_config.model.batch_size)
train_sae_g = extract_sae_features(data_g['train_texts'], model_wrapper_g, layer_g,
                                    batch_size=gemma_config.model.batch_size)

train_acts_g_np = train_acts_g.numpy()
test_acts_g_np = test_acts_g.numpy()
train_labels_g = np.array(data_g['train_labels'])
test_labels_g = np.array(data_g['test_labels'])

probe_g = LinearProbe(d_model=gemma_config.model.d_model)
probe_g.fit(train_acts_g_np, train_labels_g)
print(f"Probe accuracy \u2014 train: {probe_g.score(train_acts_g_np, train_labels_g):.4f}, "
      f"test: {probe_g.score(test_acts_g_np, test_labels_g):.4f}")

In [None]:
# @title Gemma: Feature Analysis + Steering
train_sae_g_np = train_sae_g.numpy()
top_idx_g, corrs_g = find_top_correlated_features(train_sae_g_np, train_labels_g, topk=20)

print("Top 10 correlated Gemma SAE features:")
for i in range(10):
    print(f"  Feature {top_idx_g[i]:5d}  corr = {corrs_g[top_idx_g[i]]:+.4f}")

fig = plot_feature_correlations(corrs_g, top_idx_g,
                                title="Gemma-2-2B: Top SAE Features Correlated with Sentiment")

# SingleFeature
best_pos_g = top_idx_g[0] if corrs_g[top_idx_g[0]] > 0 else top_idx_g[1]
sf_g = SingleFeature(feature_idx=int(best_pos_g))
sf_g.compute_steering_vector(model_wrapper=model_wrapper_g, layer=layer_g)

# TopK (k=10)
tk_g = TopKFeatures(topk=10)
tk_g.compute_steering_vector(sae_features=train_sae_g, labels=train_labels_g,
                              model_wrapper=model_wrapper_g, layer=layer_g)

alpha_vals_g = [0.0, 1.0, 3.0, 5.0, 10.0]
for name, method in [('SingleFeature', sf_g), ('TopK(k=10)', tk_g)]:
    results = evaluate_steering(model_wrapper_g, probe_g, test_acts_g_np, test_labels_g,
                                 method, layer_g, alpha_vals_g)
    print(f"\n{name}:")
    for r in results:
        print(f"  alpha={r['alpha']:5.1f}  probe_delta={r['probe_score_delta']:+.4f}")

## Key Takeaways
- **SingleFeature** is interpretable (one decoder direction) but limited in effectiveness
- **TopKFeatures** scales better: weighted combination of correlated features captures more of the concept
- Feature selection quality depends on the SAE \u2014 Gemma Scope features may be more/less disentangled than JB SAEs
- These heuristic methods motivate the convex optimization approach: instead of picking features by correlation, let the optimizer find the minimal intervention