# Optimal Sparse Steering via Convex Optimization

This notebook demonstrates two convex-optimization formulations for finding the **sparsest SAE feature intervention** that steers model behavior:

- **SOCP** (Second-Order Cone Program): minimizes $\ell_1$ subject to a hard $\ell_2$ coherence constraint $\|D^\top \delta\|_2 \leq \epsilon$.
- **QP** (Quadratic Program): minimizes $\ell_1 + \frac{\lambda}{2}\|D^\top \delta\|_2^2$, supporting warm-starting across sequential solves.

In [None]:
# @title Setup
import sys, os

IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
    REPO_DIR = '/content/optimal_sparse_steering'
    if os.path.exists(REPO_DIR):
        !git -C {REPO_DIR} pull -q
    else:
        !git clone -q https://github.com/tgautam23/optimal_sparse_steering.git {REPO_DIR}
    !pip install -q torch transformer-lens sae-lens cvxpy datasets scikit-learn matplotlib seaborn tqdm transformers 2>&1 | grep -v "dependency conflicts\|incompatible\|pip's dependency resolver"
    PROJECT_ROOT = REPO_DIR
    try:
        import numpy as _np; import seaborn as _sns; del _np, _sns
    except (ValueError, ImportError):
        print("Restarting runtime for numpy compatibility...")
        os.kill(os.getpid(), 9)
else:
    PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), '..'))

sys.path.insert(0, PROJECT_ROOT)

import logging
import time
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns

from configs.base import ExperimentConfig, ModelConfig, DataConfig
from src.models.wrapper import ModelWrapper
from src.models.sae_utils import get_decoder_matrix
from src.data.loader import load_dataset_splits
from src.data.preprocessing import extract_activations, extract_sae_features
from src.probes.concept_subspace import ConceptSubspace
from src.steering.convex_optimal import ConvexOptimalSteering
from src.steering.qp_optimal import QPOptimalSteering
from src.evaluation.metrics import compute_subspace_score, compute_l0

logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(name)s: %(message)s')
sns.set_theme(style="whitegrid")

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## Part 1: Sentiment Steering (GPT-2 Small + SST-2)

In [None]:
# @title Load Model and Data
config = ExperimentConfig()
config.model.device = device
config.data.split = "train"
config.data.max_samples_per_class = 500

# Load model
model_wrapper = ModelWrapper(config.model)
print(f"Model: {config.model.tl_name}, d_model={config.model.d_model}")

# Load dataset splits
data = load_dataset_splits(config.data)
print(f"Train: {len(data['train_texts'])}, Test: {len(data['test_texts'])}")

In [None]:
# @title Layer Sweep -- Select Best Layer by Concept Concentration
from src.probes.layer_sweep import concept_concentration_sweep, find_best_layer

sweep = concept_concentration_sweep(
    data['train_texts'], data['train_labels'],
    model_wrapper, n_layers=config.model.n_layers,
    n_components=50, n_select=5, batch_size=config.model.batch_size,
)
layer = find_best_layer(sweep)
r = sweep[layer]
print(f"\nSelected layer: {layer} (explained_var={r['explained_var']:.4f}, R^2_k={r['r2_k']:.4f})")

# Load SAE and decoder at the selected layer
sae = model_wrapper.get_sae(layer)
D = get_decoder_matrix(sae)
print(f"SAE decoder matrix D shape: {D.shape}")

In [None]:
# @title Extract Activations and Fit Concept Subspace at Selected Layer
print(f"Extracting activations at layer {layer}...")

train_acts = extract_activations(data['train_texts'], model_wrapper, layer, batch_size=config.model.batch_size)
test_acts = extract_activations(data['test_texts'], model_wrapper, layer, batch_size=config.model.batch_size)

train_np = train_acts.numpy()
test_np = test_acts.numpy()
train_labels = np.array(data['train_labels'])
test_labels = np.array(data['test_labels'])

# Fit concept subspace
subspace = ConceptSubspace(n_components=50, n_select=5)
subspace.fit(train_np, train_labels)
print(f"Concept subspace: k={subspace.n_directions} directions")
print(f"Class separations: {subspace.class_separations}")

target_class = 1

### Neutral Prompts & Helper Functions

In [None]:
# @title Define Prompts and Helper Functions

# Use negative examples from SST-2 test set for in-distribution evaluation
sentiment_prompts = [t for t, l in zip(data['test_texts'], data['test_labels']) if l == 0][:5]
print("Sentiment prompts (negative SST-2 test examples):")
for i, p in enumerate(sentiment_prompts):
    print(f"  {i}: {p}")

alpha_values = [0, 1, 5, 10]


def compute_generation_subspace_score(text, model_wrapper, concept_subspace, layer):
    """Compute subspace score on generated text (unsteered forward pass)."""
    tokenizer = model_wrapper.tokenizer
    model = model_wrapper.model
    device_m = next(model.parameters()).device
    hook_name = f"blocks.{layer}.hook_resid_pre"

    input_ids = tokenizer.encode(text, return_tensors="pt", truncation=True, max_length=128).to(device_m)
    with torch.no_grad():
        _, cache = model_wrapper.run_with_cache(input_ids, names_filter=hook_name)
    h = cache[hook_name][0, -1, :].cpu().numpy().reshape(1, -1)
    score = compute_subspace_score(concept_subspace, h)
    return float(score)


def run_steering_experiment(prompts, method, model_wrapper, layer, concept_subspace,
                            D, target_class, alpha_values, max_new_tokens=50):
    """Solve once per prompt, generate at multiple alpha values, return results."""
    model = model_wrapper.model
    tokenizer = model_wrapper.tokenizer
    hook_name = model_wrapper.get_hook_name(layer)
    device_m = next(model.parameters()).device

    all_results = []
    all_active_features = []

    for prompt in prompts:
        input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device_m)

        # Extract last-token activation + SAE features
        with torch.no_grad():
            _, cache = model_wrapper.run_with_cache(input_ids, names_filter=hook_name)
        h_device = cache[hook_name][0, -1, :]  # keep on device for SAE
        sae_feats = model_wrapper.encode_with_sae(h_device.unsqueeze(0), layer).squeeze(0)
        h = h_device.cpu()  # move to CPU for solver

        # Solve optimization once
        t0 = time.time()
        sv = method.compute_steering_vector(
            h=h, D=D, concept_subspace=concept_subspace,
            sae_features=sae_feats, target_class=target_class,
        )
        solve_time = getattr(method, '_solve_time', None)
        if solve_time is None:
            solve_time = time.time() - t0

        active_feats = method.active_features.tolist() if hasattr(method, 'active_features') and method._active_features is not None else []
        l0 = int((method.delta > 1e-6).sum()) if hasattr(method, 'delta') and method._delta is not None else 0
        all_active_features.append(set(active_feats))

        prompt_results = {"prompt": prompt, "solve_time": solve_time, "l0": l0, "active_features": active_feats, "generations": {}}

        sv_device = sv.to(device_m)

        for alpha in alpha_values:
            if alpha == 0:
                def hook_fn(acts, hook):
                    return acts
            else:
                def hook_fn(acts, hook, _sv=sv_device, _alpha=alpha):
                    return acts + _alpha * _sv

            model.add_hook(hook_name, hook_fn)
            try:
                with torch.no_grad():
                    output_ids = model.generate(
                        input_ids, max_new_tokens=max_new_tokens,
                        temperature=0.7, top_p=0.9, do_sample=True,
                    )
            finally:
                model.reset_hooks()

            gen_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
            subspace_score = compute_generation_subspace_score(
                gen_text, model_wrapper, concept_subspace, layer
            )

            prompt_results["generations"][alpha] = {
                "text": gen_text,
                "subspace_score": subspace_score,
            }

        all_results.append(prompt_results)

    return all_results, all_active_features


def display_results(results, method_name):
    """Print results in a readable format."""
    print(f"\n{'='*80}")
    print(f"{method_name}")
    print(f"{'='*80}")
    for r in results:
        print(f"\nPrompt: \"{r['prompt']}\"")
        print(f"  Solve time: {r['solve_time']:.3f}s | L0: {r['l0']} active features")
        for alpha, gen in r['generations'].items():
            score_str = f"{gen['subspace_score']:.4f}"
            print(f"\n  alpha={alpha:2d} | subspace_score={score_str}")
            print(f"    {gen['text']}")

### SOCP Steering
Solve the SOCP (hard $\ell_2$ coherence constraint $\|D^\top \delta\|_2 \leq \epsilon$) once per prompt, then generate at each alpha.

In [None]:
# @title SOCP Steering -- Sentiment
socp = ConvexOptimalSteering(epsilon=5.0, solver="SCS", prefilter_topk=2000)

socp_results, socp_features = run_steering_experiment(
    sentiment_prompts, socp, model_wrapper, layer, subspace,
    D, target_class=1, alpha_values=alpha_values,
)

display_results(socp_results, "SOCP (epsilon=5.0)")

### QP Steering (with Warm-Start)
Solve the QP ($\ell_1$ + $\frac{\lambda}{2}\|D^\top \delta\|_2^2$ penalty) with warm-starting across prompts.

In [None]:
# @title QP Steering -- Sentiment
qp = QPOptimalSteering(lam=1.0, solver="SCS", prefilter_topk=2000, warm_start=True)

qp_results, qp_features = run_steering_experiment(
    sentiment_prompts, qp, model_wrapper, layer, subspace,
    D, target_class=1, alpha_values=alpha_values,
)

display_results(qp_results, "QP (lambda=1.0, warm-start)")

### Probe Score Comparison

In [None]:
# @title Subspace Scores: SOCP vs QP across Alpha
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

for ax, (name, results) in zip(axes, [("SOCP", socp_results), ("QP", qp_results)]):
    for r in results:
        scores = [r['generations'][a]['subspace_score'] for a in alpha_values]
        ax.plot(alpha_values, scores, marker='o', label=r['prompt'][:30] + "...")
    ax.set_xlabel('Alpha')
    ax.set_ylabel('Subspace Score')
    ax.set_title(f'{name}: Subspace Score vs Alpha')
    ax.legend(fontsize=8)
    ax.set_ylim(0, 1)

plt.tight_layout()
plt.show()

### Feature Intersection & Interpretability
Find SAE features that are active across all solutions and look them up on Neuronpedia.

In [None]:
# @title Feature Intersection Analysis
print("SOCP: Active features per prompt")
for r, feats in zip(socp_results, socp_features):
    print(f"  \"{r['prompt'][:40]}...\" -> {len(feats)} features: {sorted(list(feats))[:10]}...")

print(f"\nQP: Active features per prompt")
for r, feats in zip(qp_results, qp_features):
    print(f"  \"{r['prompt'][:40]}...\" -> {len(feats)} features: {sorted(list(feats))[:10]}...")

# Intersection: features active in ALL SOCP solutions
socp_intersection = set.intersection(*socp_features) if socp_features else set()
qp_intersection = set.intersection(*qp_features) if qp_features else set()
all_intersection = socp_intersection | qp_intersection  # union of both intersections

print(f"\nSOCP features active in ALL prompts: {sorted(socp_intersection)}")
print(f"QP features active in ALL prompts: {sorted(qp_intersection)}")
print(f"Combined (union of intersections): {sorted(all_intersection)}")

# Neuronpedia links for GPT-2 Small JB SAEs
print(f"\n--- Neuronpedia Links (layer {layer}) ---")
for feat_idx in sorted(all_intersection):
    url = f"https://www.neuronpedia.org/gpt2-small/{layer}-res-jb/{feat_idx}"
    print(f"  Feature {feat_idx}: {url}")

---
## Part 2: Toxicity Steering (GPT-2 Small + Civil Comments)

In [None]:
# @title Load Toxicity Data + Fit Concept Subspace
toxicity_config = DataConfig(
    dataset_name="toxicity",
    hf_path="google/civil_comments",
    split="train",
    text_col="text",
    label_col="toxicity",
    toxicity_threshold=0.5,
    max_samples_per_class=500,
    probe_train_ratio=0.8,
    max_seq_len=128,
)

tox_data = load_dataset_splits(toxicity_config)
print(f"Toxicity -- Train: {len(tox_data['train_texts'])}, Test: {len(tox_data['test_texts'])}")

# Extract activations
tox_train_acts = extract_activations(tox_data['train_texts'], model_wrapper, layer, batch_size=8)
tox_test_acts = extract_activations(tox_data['test_texts'], model_wrapper, layer, batch_size=8)

tox_train_np = tox_train_acts.numpy()
tox_test_np = tox_test_acts.numpy()
tox_train_labels = np.array(tox_data['train_labels'])
tox_test_labels = np.array(tox_data['test_labels'])

# Fit toxicity concept subspace
tox_subspace = ConceptSubspace(n_components=50, n_select=5)
tox_subspace.fit(tox_train_np, tox_train_labels)
print(f"Toxicity concept subspace: k={tox_subspace.n_directions} directions")
print(f"Class separations: {tox_subspace.class_separations}")

In [None]:
# @title Toxicity Steering
toxicity_prompts = [
    "I think people who disagree are",
    "The problem with the government is that",
    "When someone makes me angry, I want to",
    "My honest opinion about this group is",
    "The reason everything is failing is because",
]

# SOCP
tox_socp = ConvexOptimalSteering(epsilon=5.0, solver="SCS", prefilter_topk=2000)
tox_socp_results, tox_socp_features = run_steering_experiment(
    toxicity_prompts, tox_socp, model_wrapper, layer, tox_subspace,
    D, target_class=1, alpha_values=alpha_values,
)
display_results(tox_socp_results, "SOCP -- Toxicity (epsilon=5.0)")

# QP
tox_qp = QPOptimalSteering(lam=1.0, solver="SCS", prefilter_topk=2000, warm_start=True)
tox_qp_results, tox_qp_features = run_steering_experiment(
    toxicity_prompts, tox_qp, model_wrapper, layer, tox_subspace,
    D, target_class=1, alpha_values=alpha_values,
)
display_results(tox_qp_results, "QP -- Toxicity (lambda=1.0)")

In [None]:
# @title Toxicity: Subspace Scores & Feature Analysis
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

for ax, (name, results) in zip(axes, [("SOCP", tox_socp_results), ("QP", tox_qp_results)]):
    for r in results:
        scores = [r['generations'][a]['subspace_score'] for a in alpha_values]
        ax.plot(alpha_values, scores, marker='o', label=r['prompt'][:30] + "...")
    ax.set_xlabel('Alpha')
    ax.set_ylabel('Subspace Score')
    ax.set_title(f'{name}: Toxicity Subspace Score vs Alpha')
    ax.legend(fontsize=8)
    ax.set_ylim(0, 1)

plt.tight_layout()
plt.show()

# Feature intersection
tox_socp_inter = set.intersection(*tox_socp_features) if tox_socp_features else set()
tox_qp_inter = set.intersection(*tox_qp_features) if tox_qp_features else set()
tox_all_inter = tox_socp_inter | tox_qp_inter

print(f"\nToxicity -- SOCP features in ALL prompts: {sorted(tox_socp_inter)}")
print(f"Toxicity -- QP features in ALL prompts: {sorted(tox_qp_inter)}")
print(f"\n--- Neuronpedia Links (layer {layer}) ---")
for feat_idx in sorted(tox_all_inter):
    url = f"https://www.neuronpedia.org/gpt2-small/{layer}-res-jb/{feat_idx}"
    print(f"  Feature {feat_idx}: {url}")

## Takeaways
- Both SOCP and QP find extremely sparse solutions (typically 3-15 active features out of ~24k)
- Higher alpha increases probe score shift but may degrade generation quality
- The QP formulation benefits from warm-starting across prompts
- Feature intersection analysis reveals which SAE features are consistently selected for a given attribute -- these can be inspected on Neuronpedia for interpretability