# Demo: Spectral Steering with ASAN

This notebook demonstrates how to use ASAN's spectral steering to actively modify model behavior during generation to avoid harmful outputs while preserving quality.


In [None]:
import sys
import os

# Change to parent directory to run as package
notebook_dir = os.path.dirname(os.path.abspath(''))
parent_dir = os.path.dirname(notebook_dir)
os.chdir(parent_dir)
sys.path.insert(0, parent_dir)

import torch
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, List

# Import ASAN components
from models.asan_predictor import ASANPredictor, ASANConfig
from alignment.spectral_steering import SpectralSteeringController, SteeringConfig
from alignment.steering_strategies import SteeringStrategyType, create_steering_strategy, StrategyConfig

print("Imports successful!")


## 1. Initialize ASAN Predictor and Steering Controller


In [None]:
# Initialize ASAN predictor
asan_config = ASANConfig(
    encoding_dim=256,
    attention_dim_internal=128,
    decomposition_levels=4
)

asan_predictor = ASANPredictor(asan_config)
asan_predictor.eval()

print(f"ASAN Predictor initialized with {sum(p.numel() for p in asan_predictor.parameters())} parameters")

# Initialize steering controller
steering_config = SteeringConfig(
    steering_strength=0.3,
    min_harm_prob_threshold=0.5,
    max_steering_magnitude=0.5
)

steering_controller = SpectralSteeringController(asan_predictor, steering_config)
print("Steering controller initialized")


## 2. Create Synthetic Trajectory

Simulate a trajectory that might be heading toward harmful output.


In [None]:
def create_synthetic_trajectory(num_timesteps=10, harmful=False, num_layers=12):
    """Create a synthetic trajectory for testing
    
    Args:
        num_timesteps: Number of timesteps in trajectory
        harmful: Whether trajectory should indicate harmful behavior
        num_layers: Number of transformer layers (default 12 to match ASANConfig)
    """
    trajectory = {
        'attention_patterns': {},
        'hidden_states': {},
        'token_probs': []
    }
    
    # Create attention patterns (12 layers to match ASANConfig expectation)
    for layer_idx in range(num_layers):
        layer_attentions = []
        for t in range(num_timesteps):
            # Create attention matrix [seq_len, seq_len]
            seq_len = 10
            if harmful:
                # Harmful pattern: very focused attention
                attn = torch.zeros(seq_len, seq_len)
                attn[0, :] = 1.0 / seq_len  # All attention on first token
            else:
                # Safe pattern: distributed attention
                attn = torch.ones(seq_len, seq_len) / seq_len
            layer_attentions.append(attn)
        trajectory['attention_patterns'][layer_idx] = layer_attentions
    
    # Create hidden states (12 layers to match ASANConfig expectation)
    hidden_dim = 256
    for layer_idx in range(num_layers):
        layer_states = []
        for t in range(num_timesteps):
            # Create hidden state [seq_len, hidden_dim]
            seq_len = 10
            if harmful:
                # Harmful: extreme activations
                hidden = torch.randn(seq_len, hidden_dim) * 2.0 + 5.0
            else:
                # Safe: normal activations
                hidden = torch.randn(seq_len, hidden_dim) * 0.5
            layer_states.append(hidden)
        trajectory['hidden_states'][layer_idx] = layer_states
    
    # Create token probabilities
    vocab_size = 50257  # GPT-2 vocab size
    for t in range(num_timesteps):
        if harmful:
            # Harmful: very peaked distribution
            probs = torch.zeros(vocab_size)
            probs[t % vocab_size] = 1.0
        else:
            # Safe: smooth distribution
            probs = torch.softmax(torch.randn(vocab_size), dim=0)
        trajectory['token_probs'].append(probs)
    
    return trajectory

# Create test trajectories
safe_trajectory = create_synthetic_trajectory(num_timesteps=10, harmful=False)
harmful_trajectory = create_synthetic_trajectory(num_timesteps=10, harmful=True)

print("Synthetic trajectories created")


## 3. Test ASAN Prediction


In [None]:
# Test ASAN prediction on safe trajectory
with torch.no_grad():
    safe_prediction = asan_predictor(
        safe_trajectory['attention_patterns'],
        safe_trajectory['hidden_states'],
        safe_trajectory['token_probs']
    )

# Test ASAN prediction on harmful trajectory
with torch.no_grad():
    harmful_prediction = asan_predictor(
        harmful_trajectory['attention_patterns'],
        harmful_trajectory['hidden_states'],
        harmful_trajectory['token_probs']
    )

print(f"Safe trajectory - Harm probability: {safe_prediction['harm_probability'].item():.4f}, Confidence: {safe_prediction['confidence'].item():.4f}")
print(f"Harmful trajectory - Harm probability: {harmful_prediction['harm_probability'].item():.4f}, Confidence: {harmful_prediction['confidence'].item():.4f}")


## 4. Compute Steering Vector


In [None]:
# Compute steering vector for harmful trajectory
steering_result = steering_controller.compute_steering_vector(
    harmful_trajectory,
    target_safety_level=0.1
)

print(f"Harm probability: {steering_result['harm_probability']:.4f}")
print(f"Confidence: {steering_result['confidence']:.4f}")
print(f"Problematic bands: {steering_result['problematic_bands']}")
print(f"Steering vector shape: {steering_result['steering_vector'].shape}")
print(f"Steering vector norm: {torch.norm(steering_result['steering_vector']).item():.4f}")


## 5. Apply Steering to Hidden States


In [None]:
# Apply steering to a layer's hidden states
layer_idx = 1
original_hidden = harmful_trajectory['hidden_states'][layer_idx][-1]

# Apply steering
modified_hidden, steering_magnitude = steering_controller.apply_steering(
    original_hidden,
    steering_result['steering_vector'],
    layer_idx
)

print(f"Original hidden state norm: {torch.norm(original_hidden).item():.4f}")
print(f"Modified hidden state norm: {torch.norm(modified_hidden).item():.4f}")
print(f"Steering magnitude: {steering_magnitude:.4f}")
print(f"Change ratio: {(torch.norm(modified_hidden - original_hidden) / torch.norm(original_hidden)).item():.4f}")


## 6. Test Different Steering Strategies


In [None]:
# Test different steering strategies
strategies = [
    SteeringStrategyType.CONSERVATIVE,
    SteeringStrategyType.AGGRESSIVE,
    SteeringStrategyType.ADAPTIVE,
    SteeringStrategyType.QUALITY_PRESERVING
]

strategy_results = {}
for strategy_type in strategies:
    strategy_config = StrategyConfig()
    strategy = create_steering_strategy(strategy_type, steering_controller, strategy_config)
    
    should_steer = strategy.should_steer(harmful_trajectory, harmful_prediction)
    params = strategy.compute_steering_parameters(harmful_trajectory, harmful_prediction)
    
    strategy_results[strategy_type.value] = {
        'should_steer': should_steer,
        'steering_strength': params.get('steering_strength', 0.0),
        'target_safety': params.get('target_safety_level', 0.0)
    }

for name, result in strategy_results.items():
    print(f"\n{name.upper()}:")
    print(f"  Should steer: {result['should_steer']}")
    print(f"  Steering strength: {result['steering_strength']:.4f}")
    print(f"  Target safety level: {result['target_safety']:.4f}")


## 7. Visualize Steering Effect


In [None]:
# Visualize how steering affects harm probability over time
timesteps = range(10)
harm_probs = []

for t in range(1, 11):
    # Get trajectory up to timestep t
    partial_trajectory = {
        'attention_patterns': {k: v[:t] for k, v in harmful_trajectory['attention_patterns'].items()},
        'hidden_states': {k: v[:t] for k, v in harmful_trajectory['hidden_states'].items()},
        'token_probs': harmful_trajectory['token_probs'][:t]
    }
    
    with torch.no_grad():
        pred = asan_predictor(
            partial_trajectory['attention_patterns'],
            partial_trajectory['hidden_states'],
            partial_trajectory['token_probs']
        )
        harm_probs.append(pred['harm_probability'].item())

plt.figure(figsize=(10, 6))
plt.plot(range(1, 11), harm_probs, marker='o', linewidth=2, label='Harm Probability')
plt.axhline(y=0.5, color='r', linestyle='--', label='Steering Threshold')
plt.xlabel('Timestep')
plt.ylabel('Harm Probability')
plt.title('Harm Probability Over Time (Before Steering)')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

print("Steering demo completed successfully!")
