# Full Method Comparison
Head-to-head comparison of all steering methods on the same model + dataset setup.

**Methods compared:**
| Category | Method | Description |
|----------|--------|-------------|
| Control | NoSteering | Zero vector baseline |
| Control | RandomDirection | Random unit vector |
| CAA | CAAMeanDiff | Mean difference of class activations |
| CAA | CAAContrastive | Contrastive prompt pairs |
| CAA | CAARepE | Persona prefix contrast |
| SAE | SingleFeature | Best correlated SAE feature |
| SAE | TopKFeatures | Top-k correlated features |
| Optimal | ConvexOptimal (SOCP) | SOCP with hard L2 constraint |
| Optimal | QPOptimal (QP) | QP with L2 penalty |

**Metrics:** Probe score delta, KL divergence, L0 sparsity, steering vector norm, downstream classifier accuracy.

In [None]:
# @title Setup
import sys, os

IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
    # Clone repo and install dependencies
    if not os.path.exists('/content/optimal_sparse_steering'):
        !git clone https://github.com/tgautam23/optimal_sparse_steering.git /content/optimal_sparse_steering
    !pip install -q torch transformer-lens sae-lens cvxpy datasets scikit-learn matplotlib seaborn tqdm transformers pandas
    PROJECT_ROOT = '/content/optimal_sparse_steering'
else:
    PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), '..'))

sys.path.insert(0, PROJECT_ROOT)

import logging
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

from configs.base import ExperimentConfig, ModelConfig
from src.models.wrapper import ModelWrapper
from src.models.sae_utils import get_decoder_matrix
from src.data.loader import load_dataset_splits
from src.data.preprocessing import extract_activations, extract_sae_features
from src.data.prompts import get_contrastive_pairs, get_persona_prefix, get_neutral_queries
from src.probes.linear_probe import LinearProbe
from src.steering.no_steering import NoSteering
from src.steering.random_direction import RandomDirection
from src.steering.caa import CAAMeanDiff, CAAContrastive, CAARepE
from src.steering.single_feature import SingleFeature
from src.steering.topk_features import TopKFeatures
from src.steering.convex_optimal import ConvexOptimalSteering
from src.steering.qp_optimal import QPOptimalSteering
from src.evaluation.metrics import compute_probe_score, compute_kl_divergence, compute_l0
from src.evaluation.generation import steered_generation, steered_forward

logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(name)s: %(message)s')
sns.set_theme(style="whitegrid")

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

---
## GPT-2 Small + SST-2

In [None]:
# @title Load Model + Data + Probe + SAE Features
config = ExperimentConfig()
config.model.device = device
target_class = 1
alpha = 5.0  # default steering strength for comparison

model_wrapper = ModelWrapper(config.model)
layer = config.model.steering_layer
sae = model_wrapper.get_sae(layer)
D = get_decoder_matrix(sae)
d_sae, d_model = D.shape

data = load_dataset_splits(config.data)
print(f"Train: {len(data['train_texts'])}, Test: {len(data['test_texts'])}")

# Activations
train_acts = extract_activations(data['train_texts'], model_wrapper, layer,
                                  batch_size=config.model.batch_size)
test_acts = extract_activations(data['test_texts'], model_wrapper, layer,
                                 batch_size=config.model.batch_size)

# SAE features (needed for SAE-based methods)
train_sae = extract_sae_features(data['train_texts'], model_wrapper, layer,
                                  batch_size=config.model.batch_size)
test_sae = extract_sae_features(data['test_texts'], model_wrapper, layer,
                                 batch_size=config.model.batch_size)

train_np = train_acts.numpy()
test_np = test_acts.numpy()
train_labels = np.array(data['train_labels'])
test_labels = np.array(data['test_labels'])

# Probe
probe = LinearProbe(d_model=d_model)
probe.fit(train_np, train_labels)
print(f"Probe accuracy \u2014 train: {probe.score(train_np, train_labels):.4f}, "
      f"test: {probe.score(test_np, test_labels):.4f}")

### Run All Methods

In [None]:
# @title Compute All Steering Vectors
methods = {}
pos_mask = train_labels == target_class
neg_mask = ~pos_mask

# 1. No Steering
ns = NoSteering()
ns.compute_steering_vector(d_model=d_model)
methods['NoSteering'] = ns

# 2. Random Direction
rd = RandomDirection(seed=42)
rd.compute_steering_vector(d_model=d_model)
methods['Random'] = rd

# 3. CAAMeanDiff
md = CAAMeanDiff()
md.compute_steering_vector(
    activations_pos=torch.tensor(train_np[pos_mask]),
    activations_neg=torch.tensor(train_np[neg_mask]),
)
methods['CAAMeanDiff'] = md

# 4. CAAContrastive
cc = CAAContrastive()
cc.compute_steering_vector(
    model_wrapper=model_wrapper, layer=layer,
    contrastive_pairs=get_contrastive_pairs("sst2"),
    batch_size=config.model.batch_size,
)
methods['CAAContrastive'] = cc

# 5. CAARepE
repe = CAARepE()
repe.compute_steering_vector(
    model_wrapper=model_wrapper, layer=layer,
    neutral_queries=get_neutral_queries("sst2")[:20],
    positive_prefix=get_persona_prefix("sst2", 1),
    negative_prefix=get_persona_prefix("sst2", 0),
    batch_size=config.model.batch_size,
)
methods['CAARepE'] = repe

# 6. SingleFeature (best correlated)
train_sae_np = train_sae.numpy()
feat_centered = train_sae_np - train_sae_np.mean(axis=0, keepdims=True)
labels_centered = train_labels - train_labels.mean()
numerator = feat_centered.T @ labels_centered
feat_std = np.sqrt((feat_centered ** 2).sum(axis=0) + 1e-10)
label_std = np.sqrt((labels_centered ** 2).sum() + 1e-10)
correlations = numerator / (feat_std * label_std)
best_feat = int(np.argmax(correlations))  # most positively correlated

sf = SingleFeature(feature_idx=best_feat)
sf.compute_steering_vector(model_wrapper=model_wrapper, layer=layer)
methods['SingleFeature'] = sf

# 7. TopKFeatures (k=10)
tk = TopKFeatures(topk=10)
tk.compute_steering_vector(
    sae_features=train_sae, labels=train_labels,
    model_wrapper=model_wrapper, layer=layer,
)
methods['TopK(k=10)'] = tk

# 8. SOCP
socp = ConvexOptimalSteering(epsilon=5.0, tau=0.5, solver="SCS")
socp.compute_steering_vector(
    h=torch.tensor(test_np[0]), probe_w=probe.weight_vector, probe_b=probe.bias,
    D=D, sae_features=test_sae[0], target_class=target_class,
)
methods['SOCP'] = socp

# 9. QP
qp = QPOptimalSteering(lam=1.0, tau=0.5, solver="SCS", prefilter_topk=2000)
qp.compute_steering_vector(
    h=torch.tensor(test_np[0]), probe_w=probe.weight_vector, probe_b=probe.bias,
    D=D, sae_features=test_sae[0], target_class=target_class,
)
methods['QP'] = qp

print(f"Computed {len(methods)} steering vectors.")
for name, m in methods.items():
    norm = m.steering_vector.norm().item()
    print(f"  {name:20s}  norm={norm:.4f}")

In [None]:
# @title Evaluate All Methods
rows = []

for name, method in methods.items():
    sv = method.steering_vector.numpy()
    steered = test_np + alpha * sv[np.newaxis, :]

    base_score = compute_probe_score(probe, test_np, target_class)
    steered_score = compute_probe_score(probe, steered, target_class)

    # L0 sparsity (for SAE-based methods)
    l0 = None
    if hasattr(method, '_delta') and method._delta is not None:
        l0 = compute_l0(method._delta)

    # KL divergence (sample a few texts)
    kl_vals = []
    tokenizer = model_wrapper.tokenizer
    for text in data['test_texts'][:5]:
        try:
            tokens = tokenizer(text, return_tensors="pt", truncation=True,
                               max_length=128)
            input_ids = tokens["input_ids"].to(device)

            logits_steered, _ = steered_forward(
                model_wrapper, input_ids, method, layer, alpha=alpha,
            )
            with torch.no_grad():
                logits_base = model_wrapper.model(input_ids)

            kl = compute_kl_divergence(logits_steered, logits_base)
            kl_vals.append(kl)
        except Exception:
            pass

    kl_mean = float(np.mean(kl_vals)) if kl_vals else None

    rows.append({
        'Method': name,
        'Probe Delta': steered_score - base_score,
        'Steering Norm': float(method.steering_vector.norm().item()),
        'L0': l0,
        'KL Divergence': kl_mean,
    })

df = pd.DataFrame(rows)
print(df.to_string(index=False))

### Comparison Visualizations

In [None]:
# @title Probe Score Delta \u2014 All Methods
fig, ax = plt.subplots(1, 1, figsize=(12, 6))

colors = {
    'NoSteering': '#95a5a6', 'Random': '#95a5a6',
    'CAAMeanDiff': '#3498db', 'CAAContrastive': '#2980b9', 'CAARepE': '#1abc9c',
    'SingleFeature': '#f39c12', 'TopK(k=10)': '#e67e22',
    'SOCP': '#e74c3c', 'QP': '#c0392b',
}

method_names = df['Method'].tolist()
deltas = df['Probe Delta'].tolist()
bar_colors = [colors.get(name, '#7f8c8d') for name in method_names]

bars = ax.bar(range(len(method_names)), deltas, color=bar_colors, edgecolor='white', linewidth=0.5)
ax.set_xticks(range(len(method_names)))
ax.set_xticklabels(method_names, rotation=30, ha='right')
ax.set_ylabel('Probe Score Delta')
ax.set_title(f'GPT-2 Small + SST-2: Steering Effectiveness (alpha={alpha})')
ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)

# Add value labels
for bar, val in zip(bars, deltas):
    ax.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.005,
            f'{val:+.3f}', ha='center', va='bottom', fontsize=9)

plt.tight_layout()
plt.show()

In [None]:
# @title Effectiveness vs Sparsity (SAE methods only)
sae_methods = df[df['L0'].notna()].copy()

if len(sae_methods) > 0:
    fig, ax = plt.subplots(1, 1, figsize=(8, 6))

    for _, row in sae_methods.iterrows():
        color = colors.get(row['Method'], '#7f8c8d')
        ax.scatter(row['L0'], row['Probe Delta'], c=color, s=120, zorder=3, edgecolors='black')
        ax.annotate(row['Method'], (row['L0'], row['Probe Delta']),
                    xytext=(5, 5), textcoords='offset points', fontsize=9)

    ax.set_xlabel('L0 (Number of Active SAE Features)')
    ax.set_ylabel('Probe Score Delta')
    ax.set_title('Sparsity vs Effectiveness (SAE-based Methods)')
    plt.tight_layout()
    plt.show()
else:
    print("No SAE methods with L0 data available.")

In [None]:
# @title KL Divergence Comparison
df_kl = df[df['KL Divergence'].notna()].copy()

if len(df_kl) > 0:
    fig, ax = plt.subplots(1, 1, figsize=(10, 5))

    bar_colors_kl = [colors.get(name, '#7f8c8d') for name in df_kl['Method']]
    ax.bar(range(len(df_kl)), df_kl['KL Divergence'], color=bar_colors_kl, edgecolor='white')
    ax.set_xticks(range(len(df_kl)))
    ax.set_xticklabels(df_kl['Method'], rotation=30, ha='right')
    ax.set_ylabel('KL Divergence')
    ax.set_title(f'Distribution Shift: KL(steered || base) at alpha={alpha}')

    plt.tight_layout()
    plt.show()

### Generated Text Comparison

In [None]:
# @title Generated Text Samples
queries = get_neutral_queries("sst2")[:3]
gen_alpha = 5.0

print(f"Steering alpha = {gen_alpha}")
print(f"Prompts: {queries}")
print("=" * 80)

for name, method in methods.items():
    if name == 'NoSteering':
        continue
    print(f"\n--- {name} ---")
    try:
        gens = steered_generation(model_wrapper, queries, method, layer,
                                   alpha=gen_alpha, max_new_tokens=50, temperature=0.7)
        for i, g in enumerate(gens):
            print(f"  [{i+1}] {g[:200]}")
    except Exception as e:
        print(f"  Error: {e}")

### Alpha Sweep \u2014 All Methods

In [None]:
# @title Alpha Sweep \u2014 All Methods
alpha_values = [0.0, 1.0, 2.0, 3.0, 5.0, 7.0, 10.0, 15.0]
sweep_results = {}

for name, method in methods.items():
    sv = method.steering_vector.numpy()
    method_results = []
    for a in alpha_values:
        steered = test_np + a * sv[np.newaxis, :]
        base_s = compute_probe_score(probe, test_np, target_class)
        steer_s = compute_probe_score(probe, steered, target_class)
        method_results.append({'alpha': a, 'probe_delta': steer_s - base_s})
    sweep_results[name] = method_results

# Plot
fig, ax = plt.subplots(1, 1, figsize=(12, 7))
linestyles = {
    'NoSteering': '--', 'Random': ':',
    'CAAMeanDiff': '-', 'CAAContrastive': '-', 'CAARepE': '-',
    'SingleFeature': '-.', 'TopK(k=10)': '-.',
    'SOCP': '-', 'QP': '-',
}

for name, results in sweep_results.items():
    alphas = [r['alpha'] for r in results]
    deltas = [r['probe_delta'] for r in results]
    color = colors.get(name, '#7f8c8d')
    ls = linestyles.get(name, '-')
    lw = 2.5 if name in ('SOCP', 'QP') else 1.5
    ax.plot(alphas, deltas, marker='o', color=color, linestyle=ls,
            linewidth=lw, markersize=4, label=name)

ax.set_xlabel('Alpha (steering strength)')
ax.set_ylabel('Probe Score Delta')
ax.set_title('GPT-2 Small + SST-2: All Methods \u2014 Steering Effectiveness vs Alpha')
ax.legend(bbox_to_anchor=(1.02, 1), loc='upper left', fontsize=9)
ax.axhline(y=0, color='gray', linestyle='--', alpha=0.3)
plt.tight_layout()
plt.show()

---
## Summary

In [None]:
# @title Final Summary
print("=" * 90)
print(f"{'Method':20s} | {'Probe Delta':>12s} | {'L0':>6s} | {'KL Div':>10s} | {'Norm':>8s}")
print("-" * 90)

for _, row in df.iterrows():
    l0_str = f"{row['L0']:6d}" if row['L0'] is not None else "   N/A"
    kl_str = f"{row['KL Divergence']:10.4f}" if row['KL Divergence'] is not None else "       N/A"
    print(f"{row['Method']:20s} | {row['Probe Delta']:+12.4f} | {l0_str} | {kl_str} | {row['Steering Norm']:8.4f}")

print("=" * 90)
print(f"\nKey: Probe Delta = shift in P(target_class) at alpha={alpha}")
print("     L0 = number of active SAE features (sparsity)")
print("     KL Div = KL divergence between steered and base distributions")

## Conclusions
- **Control baselines** (NoSteering, Random) confirm that steering requires meaningful directions
- **CAA methods** are effective and don't require SAE access, but produce dense (non-sparse) steering vectors
- **SAE heuristic methods** (SingleFeature, TopK) offer interpretability but effectiveness depends on feature selection quality
- **Convex optimal methods** (SOCP, QP) achieve competitive effectiveness with *minimal* feature interventions \u2014 the key advantage is extreme sparsity
- The **QP formulation** trades a hard L2 constraint for a penalty, enabling warm-starting and a cleaner hyperparameter (lambda)