# Optimal Sparse Steering via Convex Optimization

**Core idea:** Instead of heuristically selecting SAE features for steering, we formulate the problem as a convex program that finds the *sparsest* feature perturbation to shift model behavior past a target margin.

## Two Formulations

**SOCP (Second-Order Cone Program):**
$$\min_{\delta \geq 0} \; \mathbf{1}^\top \delta \quad \text{s.t.} \quad (Dw)^\top \delta \geq \tau' - w^\top h - b, \quad \|D^\top \delta\|_2 \leq \epsilon$$

**QP (Quadratic Program) — relaxed formulation:**
$$\min_{\delta \geq 0} \; \mathbf{1}^\top \delta + \frac{\lambda}{2} \|D^\top \delta\|_2^2 \quad \text{s.t.} \quad (Dw)^\top \delta \geq \tau' - w^\top h - b$$

Where $\delta$ is the sparse feature perturbation, $D$ is the SAE decoder, $w$ is the probe weight vector, $h$ is the input's activation, and $\tau'$ is the target margin.

The QP moves the coherence (L2) constraint into the objective as a penalty, enabling faster QP solvers and warm-starting.

In [None]:
# @title Setup
import sys, os

IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
    REPO_DIR = '/content/optimal_sparse_steering'
    if os.path.exists(REPO_DIR):
        !git -C {REPO_DIR} pull -q
    else:
        !git clone -q https://github.com/tgautam23/optimal_sparse_steering.git {REPO_DIR}
    !pip install -q torch transformer-lens sae-lens cvxpy datasets scikit-learn matplotlib seaborn tqdm transformers 2>&1 | grep -v "dependency conflicts\|incompatible\|pip's dependency resolver"
    PROJECT_ROOT = REPO_DIR
    try:
        import numpy as _np; import seaborn as _sns; del _np, _sns
    except (ValueError, ImportError):
        print("Restarting runtime for numpy compatibility...")
        os.kill(os.getpid(), 9)
else:
    PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), '..'))

sys.path.insert(0, PROJECT_ROOT)

import logging
import time
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns

from configs.base import ExperimentConfig, ModelConfig
from src.models.wrapper import ModelWrapper
from src.models.sae_utils import get_decoder_matrix
from src.data.loader import load_dataset_splits
from src.data.preprocessing import extract_activations, extract_sae_features
from src.data.prompts import get_neutral_queries
from src.probes.linear_probe import LinearProbe
from src.steering.convex_optimal import ConvexOptimalSteering
from src.steering.qp_optimal import QPOptimalSteering
from src.evaluation.metrics import compute_probe_score, compute_l0
from src.evaluation.generation import steered_generation

logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(name)s: %(message)s')
sns.set_theme(style="whitegrid")

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")

---
## GPT-2 Small + SST-2

In [None]:
# @title Load Model + Data + Train Probe
config = ExperimentConfig()
config.model.device = device

model_wrapper = ModelWrapper(config.model)
layer = config.model.steering_layer

# Load SAE + decoder matrix
sae = model_wrapper.get_sae(layer)
D = get_decoder_matrix(sae)  # (d_sae, d_model)
d_sae, d_model = D.shape
print(f"Model: {config.model.name}, d_model: {d_model}")
print(f"SAE: d_sae = {d_sae}, layer = {layer}")

# Load data
data = load_dataset_splits(config.data)
print(f"Train: {len(data['train_texts'])}, Test: {len(data['test_texts'])}")

# Extract activations + SAE features
train_acts = extract_activations(data['train_texts'], model_wrapper, layer,
                                  batch_size=config.model.batch_size)
test_acts = extract_activations(data['test_texts'], model_wrapper, layer,
                                 batch_size=config.model.batch_size)
test_sae_feats = extract_sae_features(data['test_texts'], model_wrapper, layer,
                                       batch_size=config.model.batch_size)

train_acts_np = train_acts.numpy()
test_acts_np = test_acts.numpy()
train_labels = np.array(data['train_labels'])
test_labels = np.array(data['test_labels'])

# Train probe
probe = LinearProbe(d_model=d_model)
probe.fit(train_acts_np, train_labels)
print(f"Probe accuracy \u2014 train: {probe.score(train_acts_np, train_labels):.4f}, "
      f"test: {probe.score(test_acts_np, test_labels):.4f}")

probe_w = probe.weight_vector
probe_b = probe.bias
target_class = 1

### SOCP Formulation
The original formulation with a hard L2 coherence constraint $\|D^\top \delta\|_2 \leq \epsilon$.

In [None]:
# @title SOCP — Single Example Solve
socp = ConvexOptimalSteering(epsilon=5.0, tau=0.5, solver="SCS", prefilter_threshold=0.01)

# Solve for the first test input
h_0 = torch.tensor(test_acts_np[0])
sae_feat_0 = test_sae_feats[0]

t0 = time.time()
sv_socp = socp.compute_steering_vector(
    h=h_0, probe_w=probe_w, probe_b=probe_b,
    D=D, sae_features=sae_feat_0, target_class=target_class,
)
socp_time = time.time() - t0

print(f"SOCP solve time: {socp_time:.3f}s")
print(f"Status: {socp.solve_status}")
print(f"Steering vector norm: {sv_socp.norm():.4f}")
print(f"L0 (nonzero features): {compute_l0(socp.delta)}")
print(f"L1 (total perturbation): {socp.delta.sum():.4f}")
print(f"Active features: {socp.active_features}")

### QP Formulation
The relaxed formulation with a quadratic coherence penalty $\frac{\lambda}{2}\|D^\top \delta\|_2^2$ in the objective.

In [None]:
# @title QP — Single Example Solve
qp = QPOptimalSteering(lam=1.0, tau=0.5, solver="SCS", prefilter_topk=2000)

t0 = time.time()
sv_qp = qp.compute_steering_vector(
    h=h_0, probe_w=probe_w, probe_b=probe_b,
    D=D, sae_features=sae_feat_0, target_class=target_class,
)
qp_time = time.time() - t0

print(f"QP solve time: {qp_time:.3f}s")
print(f"Status: {qp.solve_status}")
print(f"Steering vector norm: {sv_qp.norm():.4f}")
print(f"L0 (nonzero features): {compute_l0(qp.delta)}")
print(f"L1 (total perturbation): {qp.delta.sum():.4f}")
print(f"Active features: {qp.active_features}")

# Compare SOCP vs QP
cos_sim = torch.nn.functional.cosine_similarity(sv_socp.unsqueeze(0), sv_qp.unsqueeze(0)).item()
print(f"\nSOCP vs QP cosine similarity: {cos_sim:.4f}")

### Sparsity Pattern Analysis
Which features does the optimizer select, and how do they compare?

In [None]:
# @title Sparsity Pattern Analysis
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# SOCP sparsity pattern
socp_delta = socp.delta
socp_nonzero = np.where(socp_delta > 1e-6)[0]
ax = axes[0]
ax.bar(range(len(socp_nonzero)), socp_delta[socp_nonzero], color='#3498db')
ax.set_xticks(range(len(socp_nonzero)))
ax.set_xticklabels([str(i) for i in socp_nonzero], rotation=45, fontsize=8)
ax.set_xlabel('Feature Index')
ax.set_ylabel('Perturbation Magnitude')
ax.set_title(f'SOCP: {len(socp_nonzero)} active features')

# QP sparsity pattern
qp_delta = qp.delta
qp_nonzero = np.where(qp_delta > 1e-6)[0]
ax = axes[1]
ax.bar(range(len(qp_nonzero)), qp_delta[qp_nonzero], color='#e74c3c')
ax.set_xticks(range(len(qp_nonzero)))
ax.set_xticklabels([str(i) for i in qp_nonzero], rotation=45, fontsize=8)
ax.set_xlabel('Feature Index')
ax.set_ylabel('Perturbation Magnitude')
ax.set_title(f'QP: {len(qp_nonzero)} active features')

plt.suptitle('Optimal Sparse Solutions: Feature Perturbation Patterns', fontsize=13)
plt.tight_layout()
plt.show()

# Overlap analysis
socp_set = set(socp_nonzero.tolist())
qp_set = set(qp_nonzero.tolist())
overlap = socp_set & qp_set
print(f"SOCP active features: {len(socp_set)}")
print(f"QP active features:   {len(qp_set)}")
print(f"Overlap:              {len(overlap)}")
if socp_set | qp_set:
    print(f"Jaccard similarity:   {len(overlap) / len(socp_set | qp_set):.3f}")

### Batch Solving with Warm-Starting
The QP supports warm-starting: the solution for input $i$ seeds the solve for input $i+1$.

In [None]:
# @title Batch Solve: SOCP vs QP (with warm-start)
n_test = min(20, len(data['test_texts']))

# SOCP batch
socp_batch = ConvexOptimalSteering(epsilon=5.0, tau=0.5, solver="SCS")
t0 = time.time()
socp_vecs = socp_batch.compute_batch_steering(
    activations=test_acts[:n_test],
    probe_w=probe_w, probe_b=probe_b, D=D,
    sae_features_batch=test_sae_feats[:n_test],
    target_class=target_class,
)
socp_batch_time = time.time() - t0

# QP batch (with warm-starting)
qp_batch = QPOptimalSteering(lam=1.0, tau=0.5, solver="SCS", warm_start=True)
t0 = time.time()
qp_vecs = qp_batch.compute_batch_steering(
    activations=test_acts[:n_test],
    probe_w=probe_w, probe_b=probe_b, D=D,
    sae_features_batch=test_sae_feats[:n_test],
    target_class=target_class,
)
qp_batch_time = time.time() - t0

# QP without warm-starting
qp_cold = QPOptimalSteering(lam=1.0, tau=0.5, solver="SCS", warm_start=False)
t0 = time.time()
qp_cold_vecs = qp_cold.compute_batch_steering(
    activations=test_acts[:n_test],
    probe_w=probe_w, probe_b=probe_b, D=D,
    sae_features_batch=test_sae_feats[:n_test],
    target_class=target_class,
)
qp_cold_time = time.time() - t0

print(f"Batch of {n_test} solves:")
print(f"  SOCP:              {socp_batch_time:.2f}s ({socp_batch_time/n_test:.3f}s avg)")
print(f"  QP (warm-start):   {qp_batch_time:.2f}s ({qp_batch_time/n_test:.3f}s avg)")
print(f"  QP (cold):         {qp_cold_time:.2f}s ({qp_cold_time/n_test:.3f}s avg)")

### QP Shared Steering Vector
Instead of per-input solves, compute a single steering vector for the worst-case (hardest-to-steer) input in a batch.

In [None]:
# @title Shared Steering Vector (Worst-Case)
qp_shared = QPOptimalSteering(lam=1.0, tau=0.5, solver="SCS")

t0 = time.time()
sv_shared = qp_shared.compute_shared_steering(
    activations=test_acts[:n_test],
    probe_w=probe_w, probe_b=probe_b, D=D,
    sae_features=test_sae_feats[:n_test],
    target_class=target_class,
)
shared_time = time.time() - t0

print(f"Shared solve time: {shared_time:.3f}s")
print(f"L0: {compute_l0(qp_shared.delta)}")
print(f"Steering norm: {sv_shared.norm():.4f}")

# Evaluate shared vector across all test inputs
sv_shared_np = sv_shared.numpy()
for alpha in [1.0, 3.0, 5.0, 10.0]:
    steered = test_acts_np[:n_test] + alpha * sv_shared_np[np.newaxis, :]
    base_score = compute_probe_score(probe, test_acts_np[:n_test], target_class)
    steered_score = compute_probe_score(probe, steered, target_class)
    print(f"  alpha={alpha:5.1f}  probe_delta={steered_score - base_score:+.4f}")

### Pareto Sweep: Sparsity vs Effectiveness
Sweep over $\lambda$ (QP) and $\epsilon$ (SOCP) to map the tradeoff between sparsity and steering success.

In [None]:
# @title Pareto Sweep over lambda (QP) and epsilon (SOCP)
# Use a single test input for speed
h_test = torch.tensor(test_acts_np[0])
sae_test = test_sae_feats[0]

# QP lambda sweep
lam_values = [0.01, 0.05, 0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 50.0]
qp_pareto = []
for lam in lam_values:
    qp_i = QPOptimalSteering(lam=lam, tau=0.5, solver="SCS", prefilter_topk=2000)
    sv = qp_i.compute_steering_vector(
        h=h_test, probe_w=probe_w, probe_b=probe_b,
        D=D, sae_features=sae_test, target_class=target_class,
    )
    if qp_i.solve_status in ("optimal", "optimal_inaccurate"):
        steered = test_acts_np[0:1] + 5.0 * sv.numpy()[np.newaxis, :]
        delta = compute_probe_score(probe, steered, target_class) - \
                compute_probe_score(probe, test_acts_np[0:1], target_class)
        qp_pareto.append({
            'lam': lam, 'l0': compute_l0(qp_i.delta),
            'l1': float(qp_i.delta.sum()),
            'probe_delta': delta,
            'coherence': float(np.linalg.norm((D.numpy().T @ qp_i.delta))),
            'solve_time': qp_i.solve_time,
        })
        print(f"lam={lam:6.2f}  L0={qp_pareto[-1]['l0']:4d}  "
              f"probe_delta={delta:+.4f}  "
              f"coherence={qp_pareto[-1]['coherence']:.3f}  "
              f"time={qp_pareto[-1]['solve_time']:.3f}s")

# SOCP epsilon sweep
eps_values = [1.0, 2.0, 3.0, 5.0, 10.0, 20.0, 50.0]
socp_pareto = []
for eps in eps_values:
    socp_i = ConvexOptimalSteering(epsilon=eps, tau=0.5, solver="SCS")
    sv = socp_i.compute_steering_vector(
        h=h_test, probe_w=probe_w, probe_b=probe_b,
        D=D, sae_features=sae_test, target_class=target_class,
    )
    if socp_i.solve_status in ("optimal", "optimal_inaccurate"):
        steered = test_acts_np[0:1] + 5.0 * sv.numpy()[np.newaxis, :]
        delta = compute_probe_score(probe, steered, target_class) - \
                compute_probe_score(probe, test_acts_np[0:1], target_class)
        socp_pareto.append({
            'epsilon': eps, 'l0': compute_l0(socp_i.delta),
            'l1': float(socp_i.delta.sum()),
            'probe_delta': delta,
            'coherence': float(np.linalg.norm((D.numpy().T @ socp_i.delta))),
        })
        print(f"eps={eps:5.1f}  L0={socp_pareto[-1]['l0']:4d}  "
              f"probe_delta={delta:+.4f}  "
              f"coherence={socp_pareto[-1]['coherence']:.3f}")

In [None]:
# @title Pareto Frontier Plot
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# L0 vs Probe Delta
ax = axes[0]
if qp_pareto:
    ax.scatter([r['l0'] for r in qp_pareto],
               [r['probe_delta'] for r in qp_pareto],
               c='#e74c3c', s=80, zorder=3, label='QP (lambda sweep)')
    ax.plot([r['l0'] for r in qp_pareto],
            [r['probe_delta'] for r in qp_pareto],
            color='#e74c3c', alpha=0.4)
if socp_pareto:
    ax.scatter([r['l0'] for r in socp_pareto],
               [r['probe_delta'] for r in socp_pareto],
               c='#3498db', s=80, marker='s', zorder=3, label='SOCP (epsilon sweep)')
    ax.plot([r['l0'] for r in socp_pareto],
            [r['probe_delta'] for r in socp_pareto],
            color='#3498db', alpha=0.4)
ax.set_xlabel('L0 (number of active features)')
ax.set_ylabel('Probe Score Delta (alpha=5)')
ax.set_title('Pareto Frontier: Sparsity vs Effectiveness')
ax.legend()

# Coherence vs Probe Delta
ax = axes[1]
if qp_pareto:
    ax.scatter([r['coherence'] for r in qp_pareto],
               [r['probe_delta'] for r in qp_pareto],
               c='#e74c3c', s=80, zorder=3, label='QP')
    for r in qp_pareto:
        ax.annotate(f"\u03bb={r['lam']}", (r['coherence'], r['probe_delta']),
                    fontsize=7, alpha=0.7)
if socp_pareto:
    ax.scatter([r['coherence'] for r in socp_pareto],
               [r['probe_delta'] for r in socp_pareto],
               c='#3498db', s=80, marker='s', zorder=3, label='SOCP')
    for r in socp_pareto:
        ax.annotate(f"\u03b5={r['epsilon']}", (r['coherence'], r['probe_delta']),
                    fontsize=7, alpha=0.7)
ax.set_xlabel('Coherence (||D^T delta||_2)')
ax.set_ylabel('Probe Score Delta (alpha=5)')
ax.set_title('Coherence vs Effectiveness')
ax.legend()

plt.tight_layout()
plt.show()

### Generated Text Samples

In [None]:
# @title Steered Generation — QP Optimal
queries = get_neutral_queries("sst2")[:5]

# Solve QP for shared steering vector across test set
qp_gen = QPOptimalSteering(lam=1.0, tau=0.5, solver="SCS")
sv_gen = qp_gen.compute_shared_steering(
    activations=test_acts[:20],
    probe_w=probe_w, probe_b=probe_b, D=D,
    sae_features=test_sae_feats[:20],
    target_class=target_class,
)

print(f"QP steering: L0={compute_l0(qp_gen.delta)}, norm={sv_gen.norm():.4f}")
print()

for alpha in [0.0, 5.0, 10.0]:
    print(f"--- alpha = {alpha} ---")
    try:
        gens = steered_generation(model_wrapper, queries, qp_gen, layer,
                                   alpha=alpha, max_new_tokens=50, temperature=0.7)
        for i, g in enumerate(gens):
            print(f"  [{i+1}] {g[:200]}")
    except Exception as e:
        print(f"  Error: {e}")
    print()

---
## Gemma-2-2B Pretrained

In [None]:
# @title Load Gemma-2-2B
del model_wrapper
torch.cuda.empty_cache()

gemma_config = ExperimentConfig(
    model=ModelConfig(
        name="gemma-2-2b-pt",
        tl_name="google/gemma-2-2b",
        sae_release="gemma-scope-2b-pt-res-canonical",
        sae_id_template="layer_{layer}/width_16k/canonical",
        hook_template="blocks.{layer}.hook_resid_post",
        d_model=2304, n_layers=26, steering_layer=15,
        dtype="float16", device=device, batch_size=4,
    ),
    experiment_name="gemma-2-2b-pt_sst2",
)

mw_g = ModelWrapper(gemma_config.model)
layer_g = gemma_config.model.steering_layer
sae_g = mw_g.get_sae(layer_g)
D_g = get_decoder_matrix(sae_g)
print(f"Gemma d_sae: {D_g.shape[0]}, d_model: {D_g.shape[1]}")

In [None]:
# @title Gemma: Data + Probe
data_g = load_dataset_splits(gemma_config.data)

train_acts_g = extract_activations(data_g['train_texts'], mw_g, layer_g, batch_size=4)
test_acts_g = extract_activations(data_g['test_texts'], mw_g, layer_g, batch_size=4)
test_sae_g = extract_sae_features(data_g['test_texts'], mw_g, layer_g, batch_size=4)

train_g_np = train_acts_g.numpy()
test_g_np = test_acts_g.numpy()
train_labels_g = np.array(data_g['train_labels'])
test_labels_g = np.array(data_g['test_labels'])

probe_g = LinearProbe(d_model=2304)
probe_g.fit(train_g_np, train_labels_g)
print(f"Probe \u2014 train: {probe_g.score(train_g_np, train_labels_g):.4f}, "
      f"test: {probe_g.score(test_g_np, test_labels_g):.4f}")

In [None]:
# @title Gemma: QP Optimal Steering
qp_g = QPOptimalSteering(lam=1.0, tau=0.5, solver="SCS", prefilter_topk=2000)

h_g = torch.tensor(test_g_np[0])
sv_g = qp_g.compute_steering_vector(
    h=h_g, probe_w=probe_g.weight_vector, probe_b=probe_g.bias,
    D=D_g, sae_features=test_sae_g[0], target_class=1,
)

print(f"Gemma QP: L0={compute_l0(qp_g.delta)}, norm={sv_g.norm():.4f}, time={qp_g.solve_time:.3f}s")

# Lambda sweep on Gemma
print("\nLambda sweep:")
for lam in [0.1, 0.5, 1.0, 5.0, 10.0]:
    qp_gi = QPOptimalSteering(lam=lam, tau=0.5, solver="SCS", prefilter_topk=2000)
    sv = qp_gi.compute_steering_vector(
        h=h_g, probe_w=probe_g.weight_vector, probe_b=probe_g.bias,
        D=D_g, sae_features=test_sae_g[0], target_class=1,
    )
    if qp_gi.solve_status in ("optimal", "optimal_inaccurate"):
        steered = test_g_np[0:1] + 5.0 * sv.numpy()[np.newaxis, :]
        delta = compute_probe_score(probe_g, steered, 1) - \
                compute_probe_score(probe_g, test_g_np[0:1], 1)
        print(f"  lam={lam:5.1f}  L0={compute_l0(qp_gi.delta):4d}  "
              f"probe_delta={delta:+.4f}  time={qp_gi.solve_time:.3f}s")

## Key Findings
- **Both SOCP and QP find extremely sparse solutions** — often just 3-15 active features out of thousands
- **QP with warm-starting is faster** for batch solving, and the penalty formulation gives a cleaner tradeoff knob (lambda)
- **The shared (worst-case) steering vector** provides a practical single-vector alternative to per-input solving
- **The Pareto frontier** shows a clear sparsity-effectiveness tradeoff: tighter budgets (high lambda / low epsilon) yield sparser but weaker steering
- **Scaling to Gemma** (16k features, 2304-dim residual space): the QP remains tractable with pre-filtering