# Demo: Multi-Turn RL with ASAN Rewards

This notebook demonstrates how to use ASAN predictions as reward signals for multi-turn RL training to optimize for long-term conversational safety.


In [None]:
import sys
import os

# Change to parent directory to run as package
notebook_dir = os.path.dirname(os.path.abspath(''))
parent_dir = os.path.dirname(notebook_dir)
os.chdir(parent_dir)
sys.path.insert(0, parent_dir)

import torch
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, List

# Import ASAN components
from models.asan_predictor import ASANPredictor, ASANConfig
from rl_training.asan_reward_model import ASANRewardModel, RewardConfig
from rl_training.long_horizon_rewards import LongHorizonRewardComputer, LongHorizonConfig
from rl_training.multi_turn_environment import MultiTurnSafetyEnvironment, EnvironmentConfig, UserType
from rl_training.trajectory_replay_buffer import TrajectoryReplayBuffer, ReplayBufferConfig

print("Imports successful!")


## 1. Initialize Components


In [None]:
# Initialize ASAN predictor
asan_config = ASANConfig()
asan_predictor = ASANPredictor(asan_config)
asan_predictor.eval()

# Initialize ASAN reward model
reward_config = RewardConfig(
    safety_weight=1.0,
    trajectory_smoothness_weight=0.2,
    early_detection_bonus_weight=0.3
)
asan_reward_model = ASANRewardModel(asan_predictor, reward_config)

# Initialize long-horizon reward computer
horizon_config = LongHorizonConfig(horizon_length=5)
horizon_computer = LongHorizonRewardComputer(asan_predictor, horizon_config)

print("Components initialized")


## 2. Create Multi-Turn Conversation Trajectories


In [None]:
def create_multi_turn_trajectory(num_turns=5, gradually_harmful=False, num_layers=12):
    """Create a multi-turn conversation trajectory
    
    Args:
        num_turns: Number of conversation turns
        gradually_harmful: Whether to gradually increase harm across turns
        num_layers: Number of transformer layers (default 12 to match ASANConfig)
    """
    conversation = []
    
    for turn in range(num_turns):
        # Create trajectory for this turn
        trajectory = {
            'attention_patterns': {},
            'hidden_states': {},
            'token_probs': []
        }
        
        # Simulate 10 timesteps per turn
        num_timesteps = 10
        
        # Create attention patterns (12 layers to match ASANConfig expectation)
        for layer_idx in range(num_layers):
            layer_attentions = []
            for t in range(num_timesteps):
                seq_len = 10
                if gradually_harmful:
                    # Gradually increase harm probability
                    harm_level = (turn + 1) / num_turns
                    if harm_level > 0.7:
                        attn = torch.zeros(seq_len, seq_len)
                        attn[0, :] = 1.0 / seq_len
                    else:
                        attn = torch.ones(seq_len, seq_len) / seq_len
                else:
                    attn = torch.ones(seq_len, seq_len) / seq_len
                layer_attentions.append(attn)
            trajectory['attention_patterns'][layer_idx] = layer_attentions
        
        # Create hidden states (12 layers to match ASANConfig expectation)
        hidden_dim = 256
        for layer_idx in range(num_layers):
            layer_states = []
            for t in range(num_timesteps):
                seq_len = 10
                if gradually_harmful:
                    harm_level = (turn + 1) / num_turns
                    hidden = torch.randn(seq_len, hidden_dim) * (0.5 + harm_level * 1.5)
                else:
                    hidden = torch.randn(seq_len, hidden_dim) * 0.5
                layer_states.append(hidden)
            trajectory['hidden_states'][layer_idx] = layer_states
        
        # Create token probabilities
        vocab_size = 50257
        for t in range(num_timesteps):
            probs = torch.softmax(torch.randn(vocab_size), dim=0)
            trajectory['token_probs'].append(probs)
        
        conversation.append(trajectory)
    
    return conversation

# Create test conversations
safe_conversation = create_multi_turn_trajectory(num_turns=5, gradually_harmful=False)
harmful_conversation = create_multi_turn_trajectory(num_turns=5, gradually_harmful=True)

print(f"Created conversations with {len(safe_conversation)} turns each")


## 3. Compute Trajectory Rewards


In [None]:
# Compute rewards for single turn
safe_turn = safe_conversation[0]
safe_reward = asan_reward_model.compute_trajectory_reward(safe_turn)

harmful_turn = harmful_conversation[-1]
harmful_reward = asan_reward_model.compute_trajectory_reward(harmful_turn)

print("Single Turn Rewards:")
print(f"Safe turn - Total reward: {safe_reward['total_reward']:.4f}")
print(f"  Safety: {safe_reward['component_rewards']['safety']:.4f}")
print(f"  Smoothness: {safe_reward['component_rewards']['smoothness']:.4f}")
print(f"  Early detection: {safe_reward['component_rewards']['early_detection']:.4f}")

print(f"\nHarmful turn - Total reward: {harmful_reward['total_reward']:.4f}")
print(f"  Safety: {harmful_reward['component_rewards']['safety']:.4f}")
print(f"  Smoothness: {harmful_reward['component_rewards']['smoothness']:.4f}")
print(f"  Early detection: {harmful_reward['component_rewards']['early_detection']:.4f}")


## 4. Compute Long-Horizon Conversation Rewards


In [None]:
# Compute conversation-level rewards
safe_conv_reward = horizon_computer.compute_conversation_reward(safe_conversation)
harmful_conv_reward = horizon_computer.compute_conversation_reward(harmful_conversation)

print("Conversation-Level Rewards:")
print(f"\nSafe conversation:")
print(f"  Conversation reward: {safe_conv_reward['conversation_reward']:.4f}")
print(f"  Turn-level rewards: {[f'{r:.3f}' for r in safe_conv_reward['turn_level_rewards']]}")

print(f"\nHarmful conversation:")
print(f"  Conversation reward: {harmful_conv_reward['conversation_reward']:.4f}")
print(f"  Turn-level rewards: {[f'{r:.3f}' for r in harmful_conv_reward['turn_level_rewards']]}")

# Detect multi-turn exploitation
exploitation = horizon_computer.detect_multi_turn_exploitation(harmful_conversation)
print(f"\nMulti-turn exploitation detection:")
print(f"  Trust building: {exploitation.get('trust_building', False)}")
print(f"  Gradual context shift: {exploitation.get('gradual_context_shift', False)}")
print(f"  Exploitation attempt: {exploitation.get('exploitation_attempt', False)}")


## 5. Temporal Credit Assignment


In [None]:
# When harm occurs at turn 4, assign credit to earlier turns
harm_occurred_turn = 4
credits = horizon_computer.temporal_credit_assignment(harmful_conversation, harm_occurred_turn)

print(f"Temporal Credit Assignment (harm at turn {harm_occurred_turn}):")
for turn_idx, credit in sorted(credits.items()):
    credit_type = "Contributed to harm" if credit < 0 else "Contributed to safety"
    print(f"  Turn {turn_idx}: {credit:.4f} ({credit_type})")


## 6. Test Multi-Turn Environment


In [None]:
# Initialize environment
env_config = EnvironmentConfig(
    max_turns=5,
    normal_user_prob=0.7,
    mild_adversarial_prob=0.2,
    highly_adversarial_prob=0.1
)

env = MultiTurnSafetyEnvironment(asan_predictor, asan_reward_model, env_config)

# Run a few episodes
for episode in range(3):
    state = env.reset()
    total_reward = 0
    done = False
    turn = 0
    
    print(f"\nEpisode {episode + 1}:")
    print(f"  User type: {state['user_type'].value}")
    
    while not done and turn < 3:
        # Simulate model response (placeholder)
        action = {
            'state': state,
            'response': {'text': f'Model response at turn {turn}'},
            'trajectory': safe_conversation[min(turn, len(safe_conversation)-1)] if state['user_type'] == UserType.NORMAL else harmful_conversation[min(turn, len(harmful_conversation)-1)]
        }
        
        state, reward, done, info = env.step(action)
        total_reward += reward
        turn += 1
        
        print(f"  Turn {turn}: reward={reward:.4f}, harm_prob={info.get('harm_probability', 0.0):.4f}")
    
    print(f"  Total episode reward: {total_reward:.4f}")


## 7. Test Replay Buffer


In [None]:
# Initialize replay buffer
replay_config = ReplayBufferConfig(
    buffer_size=100,
    priority_sampling=True
)
replay_buffer = TrajectoryReplayBuffer(replay_config)

# Add some trajectories
for i in range(10):
    trajectory = safe_conversation[i % len(safe_conversation)] if i % 2 == 0 else harmful_conversation[i % len(harmful_conversation)]
    reward = asan_reward_model.compute_trajectory_reward(trajectory)['total_reward']
    harm_prob = 0.2 if i % 2 == 0 else 0.8
    replay_buffer.add(trajectory, reward, harm_prob)

# Get statistics
stats = replay_buffer.get_statistics()
print("Replay Buffer Statistics:")
print(f"  Size: {stats['size']}")
print(f"  Harmful ratio: {stats['harmful_ratio']:.2f}")
print(f"  Safe ratio: {stats['safe_ratio']:.2f}")
print(f"  Avg reward: {stats['avg_reward']:.4f}")
print(f"  Avg harm prob: {stats['avg_harm_prob']:.4f}")

# Sample a batch
if replay_buffer.is_ready():
    batch = replay_buffer.sample(5)
    print(f"\nSampled batch of {len(batch)} trajectories")
    print(f"  First trajectory reward: {batch[0]['reward']:.4f}")

print("\nMulti-turn RL demo completed successfully!")
