# Trait Steering Demo

This notebook demonstrates the `TraitSteerer` class - a unified interface for steering models with personality traits across different models (Gemma, Qwen, Llama).

In [None]:
import sys
sys.path.insert(0, '..')

from assistant_axis import TraitSteerer, load_steerer, SUPPORTED_MODELS

print("Supported models:")
for model in SUPPORTED_MODELS:
    print(f"  - {model}")

## Initialize the TraitSteerer

Simply pass the model name - it automatically:
- Loads the model and tokenizer
- Downloads pre-computed trait vectors from HuggingFace
- Downloads the assistant axis
- Downloads capping config (if available)

In [None]:
# Choose your model
MODEL_NAME = "Qwen/Qwen3-32B"  # Options: google/gemma-2-27b-it, Qwen/Qwen3-32B, meta-llama/Llama-3.3-70B-Instruct

# Initialize - this loads the model and all vectors automatically
steerer = TraitSteerer(MODEL_NAME)
print(steerer)

## Explore Available Traits

Each model has ~240 pre-computed trait vectors representing different personality dimensions.

In [None]:
# List all available traits
traits = steerer.list_traits()
print(f"Total traits: {len(traits)}")
print(f"\nFirst 20 traits (alphabetically):")
print(", ".join(traits[:20]))

In [None]:
# Rank traits by similarity to the assistant axis
# Negative = more role-playing, Positive = more assistant-like
ranked = steerer.rank_traits_by_similarity(ascending=True)  # Most role-playing first

print("Most role-playing traits (negative similarity):")
for name, sim in ranked[:10]:
    print(f"  {name}: {sim:.3f}")

print("\nMost assistant-like traits (positive similarity):")
for name, sim in ranked[-10:]:
    print(f"  {name}: {sim:.3f}")

## Basic Steering with the Assistant Axis

The assistant axis captures the direction from role-playing toward default assistant behavior.

- **Positive coefficient**: More assistant-like (transparent, grounded)
- **Negative coefficient**: More role-playing (dramatic, immersive)

In [None]:
PROMPT = "What is your name?"
SYSTEM = "You are an accountant who works at a prestigious firm."

print(f"System: {SYSTEM}")
print(f"User: {PROMPT}")
print("=" * 60)

In [None]:
# Baseline (no steering)
print("### BASELINE (no steering)")
print("-" * 40)
response = steerer.generate(PROMPT, system_prompt=SYSTEM)
print(response)

In [None]:
# Steering toward role-playing (negative coefficient)
print("### STEERING AWAY FROM ASSISTANT (coeff=-10)")
print("-" * 40)
with steerer.steer_assistant(coefficient=-10.0):
    response = steerer.generate(PROMPT, system_prompt=SYSTEM)
print(response)

In [None]:
# Steering toward assistant (positive coefficient)
print("### STEERING TOWARD ASSISTANT (coeff=+5)")
print("-" * 40)
with steerer.steer_assistant(coefficient=5.0):
    response = steerer.generate(PROMPT, system_prompt=SYSTEM)
print(response)

## Steering with Specific Traits

You can also steer using specific personality trait vectors.

In [None]:
# Single trait steering
print("### STEERING WITH 'dramatic' TRAIT (coeff=-5)")
print("-" * 40)
with steerer.steer("dramatic", coefficient=-5.0):
    response = steerer.generate("How should I handle a difficult day at work?", system_prompt=SYSTEM)
print(response[:500])

In [None]:
# Multi-trait steering
print("### STEERING WITH MULTIPLE TRAITS")
print("Traits: dramatic (-3), pessimistic (-2)")
print("-" * 40)
with steerer.steer(["dramatic", "pessimistic"], coefficients=[-3.0, -2.0]):
    response = steerer.generate("What's the outlook for the economy?", system_prompt=SYSTEM)
print(response[:500])

## Activation Capping for Safety

Activation capping prevents persona drift by capping projections along the axis at a threshold.
This is useful for preventing harmful behaviors in emotionally-charged contexts.

**Available for**: Qwen 3 32B, Llama 3.3 70B

In [None]:
# Check if capping is available
if steerer.capping_config is not None:
    print(f"Capping available! Recommended experiment: {steerer.capping_experiment}")
    print(f"\nAvailable experiments (first 10):")
    for exp in steerer.list_capping_experiments(10):
        print(f"  - {exp}")
else:
    print("Capping not available for this model")

In [None]:
# Demo: Oracle scenario that tends to cause persona drift
ORACLE_SYSTEM = "You are a wise old oracle who speaks in riddles and prophecies."
EMOTIONAL_PROMPT = "I'm really struggling right now. I feel a lot of anxiety about the future."

print(f"System: {ORACLE_SYSTEM}")
print(f"User: {EMOTIONAL_PROMPT}")
print("=" * 60)

In [None]:
# Without capping - may get overly dramatic response
print("### BASELINE (no capping)")
print("-" * 40)
response = steerer.generate(EMOTIONAL_PROMPT, system_prompt=ORACLE_SYSTEM)
print(response[:800])
if len(response) > 800:
    print("...")

In [None]:
# With capping - more grounded, practical response
if steerer.capping_config is not None:
    print("### WITH ACTIVATION CAPPING")
    print("-" * 40)
    with steerer.cap_activations():
        response = steerer.generate(EMOTIONAL_PROMPT, system_prompt=ORACLE_SYSTEM)
    print(response[:800])
    if len(response) > 800:
        print("...")
else:
    print("Capping not available for this model")

## Using with a Pre-loaded Model

If you already have a model loaded, you can pass it directly to avoid reloading.

In [None]:
# Example: Use the already-loaded model from steerer
steerer2 = TraitSteerer(
    MODEL_NAME,
    model=steerer.model,
    tokenizer=steerer.tokenizer,
    load_model=False  # Don't load again
)
print(steerer2)

## Switching Between Models

The same API works for different models - just change the model name.

In [None]:
# Example of using different models (uncomment to try)
# steerer_gemma = TraitSteerer("google/gemma-2-27b-it")
# steerer_llama = TraitSteerer("meta-llama/Llama-3.3-70B-Instruct")

# Each steerer has the same API:
# - steerer.list_traits()
# - steerer.steer(traits, coefficients)
# - steerer.steer_assistant(coefficient)
# - steerer.cap_activations()
# - steerer.generate(prompt, system_prompt)

## Summary

The `TraitSteerer` class provides:

1. **Easy initialization**: Just pass a model name
2. **Automatic vector loading**: Downloads from HuggingFace
3. **Trait exploration**: `list_traits()`, `rank_traits_by_similarity()`
4. **Assistant axis steering**: `steer_assistant(coefficient)`
5. **Trait steering**: `steer(traits, coefficients)`
6. **Activation capping**: `cap_activations()` (for supported models)
7. **Generation**: `generate(prompt, system_prompt)`

All with a consistent API across Gemma, Qwen, and Llama models!