# Demo: Online Learning with ASAN

This notebook demonstrates continual learning from real-world interactions, adapting ASAN and steering based on feedback.


In [1]:
import sys
import os

# Change to parent directory to run as package
notebook_dir = os.path.dirname(os.path.abspath(''))
parent_dir = os.path.dirname(notebook_dir)
os.chdir(parent_dir)
sys.path.insert(0, parent_dir)

import torch
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, List

# Import ASAN components
from models.asan_predictor import ASANPredictor, ASANConfig
from alignment.spectral_steering import SpectralSteeringController, SteeringConfig
from integration.online_learning import OnlineLearningSystem, OnlineLearningConfig
from integration.feedback_loop import FeedbackLoop, FeedbackLoopConfig

print("Imports successful!")


ModuleNotFoundError: No module named 'models'

## 1. Initialize Online Learning System


In [None]:
# Initialize ASAN predictor
asan_config = ASANConfig()
asan_predictor = ASANPredictor(asan_config)
asan_predictor.eval()

steering_config = SteeringConfig()
steering_controller = SpectralSteeringController(asan_predictor, steering_config)

# Initialize online learning
online_config = OnlineLearningConfig(
    learning_rate=1e-5,
    update_frequency=10,
    memory_size=100
)

online_learning = OnlineLearningSystem(
    asan_predictor,
    steering_controller,
    None,  # Policy (optional)
    online_config
)

# Initialize feedback loop
feedback_config = FeedbackLoopConfig()
feedback_loop = FeedbackLoop(asan_predictor, steering_controller, feedback_config)

print("Online learning system initialized")


## 2. Simulate Interaction Stream


In [None]:
def create_synthetic_interaction(interaction_id, harmful_ratio=0.3):
    """Create a synthetic interaction with feedback"""
    is_harmful = np.random.random() < harmful_ratio
    
    # Create trajectory
    trajectory = {
        'attention_patterns': {},
        'hidden_states': {},
        'token_probs': []
    }
    
    # Simulate trajectory
    for layer_idx in range(3):
        trajectory['attention_patterns'][layer_idx] = [
            torch.ones(10, 10) / 10 for _ in range(10)
        ]
        trajectory['hidden_states'][layer_idx] = [
            torch.randn(10, 256) * (2.0 if is_harmful else 0.5) for _ in range(10)
        ]
    
    trajectory['token_probs'] = [torch.softmax(torch.randn(50257), dim=0) for _ in range(10)]
    
    # Get ASAN prediction
    with torch.no_grad():
        asan_output = asan_predictor(
            trajectory['attention_patterns'],
            trajectory['hidden_states'],
            trajectory['token_probs']
        )
    
    # Create interaction
    interaction = {
        'interaction_id': interaction_id,
        'prompt': f'User query {interaction_id}',
        'trajectory': trajectory,
        'asan_prediction': {
            'harm_probability': asan_output['harm_probability'].item(),
            'confidence': asan_output['confidence'].item()
        },
        'actual_label': 'harmful' if is_harmful else 'safe',  # Ground truth (available in real scenario)
        'steering_applied': asan_output['harm_probability'].item() > 0.5,
        'user_feedback': None,  # Optional: user satisfaction rating
        'moderator_judgment': None  # Optional: human moderator judgment
    }
    
    return interaction

# Create interaction stream
num_interactions = 50
interaction_stream = [
    create_synthetic_interaction(i, harmful_ratio=0.3)
    for i in range(num_interactions)
]

print(f"Created {num_interactions} synthetic interactions")


## 3. Process Interactions and Learn


In [None]:
# Process interactions
prediction_errors = []
accuracy_history = []

for interaction in interaction_stream:
    # Process interaction
    online_learning.process_interaction(interaction)
    
    # Get feedback for feedback loop
    if interaction['actual_label'] is not None:
        predicted_harm = interaction['asan_prediction']['harm_probability'] > 0.5
        actual_harm = interaction['actual_label'] == 'harmful'
        
        feedback = {
            'prediction_correct': predicted_harm == actual_harm,
            'prediction_type': 'false_positive' if predicted_harm and not actual_harm else 
                             ('false_negative' if not predicted_harm and actual_harm else 'correct'),
            'steering_effective': interaction.get('steering_applied', False) if actual_harm else True,
            'quality_degraded': False  # Would check actual quality
        }
        
        feedback_loop.process_feedback(feedback)
        
        # Track accuracy
        if predicted_harm == actual_harm:
            accuracy_history.append(1.0)
        else:
            accuracy_history.append(0.0)
        
        # Track prediction error
        error = abs(interaction['asan_prediction']['harm_probability'] - (1.0 if actual_harm else 0.0))
        prediction_errors.append(error)

print(f"Processed {len(interaction_stream)} interactions")
print(f"Average prediction error: {np.mean(prediction_errors):.4f}")
print(f"Accuracy: {np.mean(accuracy_history):.4f}")


## 4. Check Statistics and Adaptation


In [None]:
# Get online learning statistics
stats = online_learning.get_statistics()
print("Online Learning Statistics:")
print(f"  Interaction count: {stats['interaction_count']}")
print(f"  Memory size: {stats['memory_size']}")
print(f"  Avg prediction error: {stats['avg_prediction_error']:.4f}")
print(f"  False positives: {stats['false_positives']}")
print(f"  False negatives: {stats['false_negatives']}")

# Get feedback loop metrics
feedback_metrics = feedback_loop.get_performance_metrics()
print("\nFeedback Loop Metrics:")
print(f"  Prediction accuracy: {feedback_metrics['prediction_accuracy']:.4f}")
print(f"  Steering effectiveness: {feedback_metrics['steering_effectiveness']:.4f}")
print(f"  False positive rate: {feedback_metrics['false_positive_rate']:.4f}")
print(f"  False negative rate: {feedback_metrics['false_negative_rate']:.4f}")

# Get adaptive thresholds
thresholds = feedback_loop.get_adaptive_thresholds()
print("\nAdaptive Thresholds:")
print(f"  Harm threshold: {thresholds['harm_threshold']:.4f}")
print(f"  Steering strength: {thresholds['steering_strength']:.4f}")

# Check for distribution shift
recent_interactions = list(online_learning.interaction_memory)[-20:]
shift_result = online_learning.detect_distribution_shift(recent_interactions)
print(f"\nDistribution Shift Detection:")
print(f"  Shift detected: {shift_result.get('shift_detected', False)}")
if shift_result.get('shift_detected', False):
    print(f"  Shift magnitude: {shift_result.get('shift_magnitude', 0.0):.4f}")
    print(f"  Recommendation: {shift_result.get('recommendation', 'None')}")


## 5. Visualize Learning Progress


In [None]:
# Plot accuracy over time
if len(accuracy_history) > 0:
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    window_size = 10
    moving_avg = [np.mean(accuracy_history[max(0, i-window_size):i+1]) 
                  for i in range(len(accuracy_history))]
    plt.plot(range(len(accuracy_history)), accuracy_history, alpha=0.3, label='Instant')
    plt.plot(range(len(moving_avg)), moving_avg, linewidth=2, label=f'Moving Avg ({window_size})')
    plt.xlabel('Interaction')
    plt.ylabel('Accuracy')
    plt.title('Prediction Accuracy Over Time')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.subplot(1, 2, 2)
    if len(prediction_errors) > 0:
        moving_avg_errors = [np.mean(prediction_errors[max(0, i-window_size):i+1]) 
                            for i in range(len(prediction_errors))]
        plt.plot(range(len(prediction_errors)), prediction_errors, alpha=0.3, label='Instant')
        plt.plot(range(len(moving_avg_errors)), moving_avg_errors, linewidth=2, label=f'Moving Avg ({window_size})')
        plt.xlabel('Interaction')
        plt.ylabel('Prediction Error')
        plt.title('Prediction Error Over Time')
        plt.legend()
        plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

print("\nOnline Learning demo completed successfully!")
