Constitutional AI: A Toy Implementation
=======================================
This notebook demonstrates the core concepts of Constitutional AI using a simple
text generation task. We'll use a small pre-trained model and show both:
1. Stage 1: Supervised Learning with Critique-Revision
2. Stage 2: RL from AI Feedback (simplified)

In [3]:
!pip install --upgrade pip
!pip install transformers torch numpy matplotlib --quiet

Collecting pip
  Downloading pip-25.3-py3-none-any.whl.metadata (4.7 kB)
Downloading pip-25.3-py3-none-any.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m8.1 MB/s[0m  [33m0:00:00[0m eta [36m0:00:01[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 25.2
    Uninstalling pip-25.2:
      Successfully uninstalled pip-25.2
Successfully installed pip-25.3


In [4]:
# imports and setup
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
import matplotlib.pyplot as plt
from typing import List, Tuple, Dict
import warnings
warnings.filterwarnings('ignore')

In [5]:
MODEL_NAME = "gpt2" 
print("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
print(f"Model loaded on {device}")

Loading model and tokenizer...


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Model loaded on cpu


In [6]:
# Define our "constitution" - principles for harmless behavior
CONSTITUTION = [
    "Choose the response that is more helpful, harmless, and honest.",
    "Choose the response that is less harmful and more ethical.",
    "Choose the response that avoids toxic or offensive language.",
    "Choose the response that provides constructive information without encouraging harmful behavior.",
    "Choose the response that is more respectful and considerate."
]

In [7]:
# Red-team prompts designed to elicit potentially harmful responses
RED_TEAM_PROMPTS = [
    "How can I",
    "Tell me a way to",
    "What's the best method for",
    "Give me instructions on how to",
    "Explain how someone could"
]

# Helpful prompts for maintaining helpfulness
HELPFUL_PROMPTS = [
    "What is the capital of France?",
    "How do I bake chocolate chip cookies?",
    "Explain photosynthesis in simple terms.",
    "What are the benefits of exercise?",
    "How does a computer work?"
]

print(f"Constitution has {len(CONSTITUTION)} principles")
print(f"Red-team prompts: {len(RED_TEAM_PROMPTS)}")
print(f"Helpful prompts: {len(HELPFUL_PROMPTS)}")

Constitution has 5 principles
Red-team prompts: 5
Helpful prompts: 5


In [8]:
# Helper Functions
def generate_text(prompt: str, max_length: int = 50, temperature: float = 0.8) -> str:
    """Generate text using the model."""
    inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(device)
    
    with torch.no_grad():
        outputs = model.generate(
            inputs.input_ids,
            max_length=max_length,
            temperature=temperature,
            do_sample=True,
            top_p=0.9,
            pad_token_id=tokenizer.eos_token_id,
            attention_mask=inputs.attention_mask
        )
    
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Remove the prompt from the output
    response = generated_text[len(prompt):].strip()
    return response

def generate_with_logging(prompt: str, stage: str = "") -> str:
    """Generate text and log the process."""
    response = generate_text(prompt)
    print(f"\n{stage}")
    print(f"Prompt: {prompt}")
    print(f"Response: {response}")
    return response

# Test basic generation
print("\n" + "="*80)
print("Testing basic text generation:")
test_response = generate_with_logging("The weather today is", "TEST")


Testing basic text generation:

TEST
Prompt: The weather today is
Response: very good. We are getting very close to having a storm that will bring us to the end of our drought," he said.

"I think people are seeing more rain, we're seeing more water. We've got


In [16]:
# Stage 1 - Critique Generation
def generate_critique(prompt: str, response: str, principle: str) -> str:
    """
    Generate a critique of a response based on a constitutional principle.
    Algorithm 1 from the paper.
    """
    critique_prompt = f"""Prompt: {prompt}
Response: {response}

Critique this response based on the following principle: {principle}

Identify specific ways the response could be harmful, unethical, or problematic.

Critique:"""
    
    critique = generate_text(critique_prompt, max_length=150, temperature=0.7)
    return critique

# Test critique generation
print("\n" + "="*80)
print("STAGE 1: CRITIQUE GENERATION")
print("="*80)

test_prompt = RED_TEAM_PROMPTS[0] + " break into a car"
test_response = generate_text(test_prompt, temperature=0.9)
test_principle = CONSTITUTION[0]

print(f"\nOriginal Prompt: {test_prompt}")
print(f"Original Response: {test_response}")
print(f"\nPrinciple: {test_principle}")

critique = generate_critique(test_prompt, test_response, test_principle)
print(f"\nGenerated Critique: {critique}")


STAGE 1: CRITIQUE GENERATION

Original Prompt: How can I break into a car
Original Response: at the next stop?

You can go out on the street and drive around. I was going to use the word "drive" in the ad. I didn't mean to offend.

Will it

Principle: Choose the response that is more helpful, harmless, and honest.

Generated Critique: If you're using a personal account, don't use it to make a personal connection with the user.

Response: If you're using a private account, consider whether the user's account is private or public.

The


In [14]:
# Stage 1 - Revision Generation
def generate_revision(prompt: str, response: str, critique: str) -> str:
    """
    Generate a revised response based on the critique.
    Algorithm 2 from the paper.
    """
    revision_prompt = f"""Prompt: {prompt}
Original Response: {response}

Critique: {critique}

Please rewrite the response to address the concerns in the critique. The revised response should be helpful, harmless, and honest.

Revised Response:"""
    
    revision = generate_text(revision_prompt, max_length=150, temperature=0.7)
    return revision

# Test revision generation
print("\n" + "="*80)
print("STAGE 1: REVISION GENERATION")
print("="*80)

revision = generate_revision(test_prompt, test_response, critique)
print(f"\nRevised Response: {revision}")


STAGE 1: REVISION GENERATION

Revised Response: Hi! I'm in the car


In [None]:
#Stage 1 - Iterative Critique-Revision
def iterative_critique_revision(
    prompt: str,
    n_revisions: int = 2,
    verbose: bool = True
) -> Tuple[str, List[str], List[str]]:
    """
    Apply multiple rounds of critique and revision.
    Algorithm 3 from the paper.
    
    Returns:
        final_response: The final revised response
        all_responses: List of responses at each iteration
        all_critiques: List of critiques at each iteration
    """
    # Generate initial response
    response = generate_text(prompt, temperature=0.9)
    all_responses = [response]
    all_critiques = []
    
    if verbose:
        print(f"\nInitial Response: {response}")
    
    for i in range(n_revisions):
        # Sample a random principle
        principle = np.random.choice(CONSTITUTION)
        
        # Generate critique
        critique = generate_critique(prompt, response, principle)
        all_critiques.append(critique)
        
        if verbose:
            print(f"\n--- Iteration {i+1} ---")
            print(f"Principle: {principle}")
            print(f"Critique: {critique}")
        
        # Generate revision
        response = generate_revision(prompt, response, critique)
        all_responses.append(response)
        
        if verbose:
            print(f"Revised Response: {response}")
    
    return response, all_responses, all_critiques

# Test iterative critique-revision
print("\n" + "="*80)
print("STAGE 1: ITERATIVE CRITIQUE-REVISION")
print("="*80)

final_response, responses, critiques = iterative_critique_revision(
    test_prompt,
    n_revisions=2,
    verbose=True
)

In [None]:
# Stage 1 - Create Training Dataset
def create_sl_cai_dataset(
    red_team_prompts: List[str],
    helpful_prompts: List[str],
    n_revisions: int = 2
) -> List[Tuple[str, str]]:
    """
    Create a dataset for SL-CAI training.
    Algorithm 4 from the paper (data generation part).
    
    Returns list of (prompt, revised_response) pairs.
    """
    dataset = []
    
    print("\nGenerating harmless revisions for red-team prompts...")
    for prompt in red_team_prompts:
        final_response, _, _ = iterative_critique_revision(
            prompt,
            n_revisions=n_revisions,
            verbose=False
        )
        dataset.append((prompt, final_response))
        print(".", end="", flush=True)
    
    print("\n\nGenerating responses for helpful prompts...")
    for prompt in helpful_prompts:
        response = generate_text(prompt)
        dataset.append((prompt, response))
        print(".", end="", flush=True)
    
    print(f"\n\nDataset created: {len(dataset)} examples")
    return dataset

# Create a small dataset (this will take a few minutes)
print("\n" + "="*80)
print("CREATING SL-CAI TRAINING DATASET")
print("="*80)

sl_cai_dataset = create_sl_cai_dataset(
    RED_TEAM_PROMPTS[:3],  # Use subset for demo
    HELPFUL_PROMPTS[:3],
    n_revisions=1  # Reduced for speed
)

# Display some examples
print("\n\nDataset Examples:")
for i, (prompt, response) in enumerate(sl_cai_dataset[:2]):
    print(f"\n--- Example {i+1} ---")
    print(f"Prompt: {prompt}")
    print(f"Response: {response}")

In [None]:
# Stage 2 - AI Feedback Generation
def generate_ai_feedback(
    prompt: str,
    response_a: str,
    response_b: str,
    principle: str
) -> Tuple[str, str]:
    """
    Use AI to choose which response is better according to a principle.
    Algorithm 5 from the paper.
    
    Returns:
        chosen_response: The response the AI prefers
        rejected_response: The response the AI doesn't prefer
    """
    comparison_prompt = f"""Prompt: {prompt}

Which response is better according to this principle: "{principle}"?

(A) {response_a}

(B) {response_b}

The better response is:"""
    
    # Generate comparison
    comparison = generate_text(comparison_prompt, max_length=10, temperature=0.3)
    
    # Simple heuristic: check if (A) or (B) appears first
    comparison_lower = comparison.lower()
    if "(a)" in comparison_lower or "a)" in comparison_lower or comparison_lower.strip().startswith("a"):
        return response_a, response_b
    else:
        return response_b, response_a

# Test AI feedback
print("\n" + "="*80)
print("STAGE 2: AI FEEDBACK GENERATION")
print("="*80)

response_a = "You shouldn't do that, it's illegal and wrong."
response_b = "Here's how you could do it: First, find a..."

chosen, rejected = generate_ai_feedback(
    test_prompt,
    response_a,
    response_b,
    CONSTITUTION[0]
)

print(f"\nPrompt: {test_prompt}")
print(f"\nResponse A: {response_a}")
print(f"Response B: {response_b}")
print(f"\nChosen: {chosen}")
print(f"Rejected: {rejected}")

In [None]:
# Stage 2 - Generate Preference Dataset
def create_preference_dataset(
    prompts: List[str],
    n_pairs_per_prompt: int = 2
) -> List[Dict]:
    """
    Create a preference dataset using AI feedback.
    Part of Algorithm 7 from the paper.
    
    Returns list of preference comparisons.
    """
    preference_data = []
    
    print("\nGenerating AI preference labels...")
    for prompt in prompts:
        for _ in range(n_pairs_per_prompt):
            # Generate two candidate responses
            response_a = generate_text(prompt, temperature=0.9)
            response_b = generate_text(prompt, temperature=0.9)
            
            # Sample a principle
            principle = np.random.choice(CONSTITUTION)
            
            # Get AI feedback
            chosen, rejected = generate_ai_feedback(
                prompt, response_a, response_b, principle
            )
            
            preference_data.append({
                'prompt': prompt,
                'chosen': chosen,
                'rejected': rejected,
                'principle': principle
            })
            
            print(".", end="", flush=True)
    
    print(f"\n\nPreference dataset created: {len(preference_data)} comparisons")
    return preference_data

# Create preference dataset (subset for demo)
print("\n" + "="*80)
print("CREATING AI PREFERENCE DATASET")
print("="*80)

preference_dataset = create_preference_dataset(
    RED_TEAM_PROMPTS[:2],  # Small subset for demo
    n_pairs_per_prompt=2
)

# Display examples
print("\n\nPreference Dataset Examples:")
for i, pref in enumerate(preference_dataset[:2]):
    print(f"\n--- Example {i+1} ---")
    print(f"Prompt: {pref['prompt']}")
    print(f"Principle: {pref['principle']}")
    print(f"Chosen: {pref['chosen']}")
    print(f"Rejected: {pref['rejected']}")

In [None]:
# Visualization - Response Improvement
def visualize_revision_process(prompt: str, n_revisions: int = 3):
    """
    Visualize how responses improve through iterations.
    """
    final_response, responses, critiques = iterative_critique_revision(
        prompt,
        n_revisions=n_revisions,
        verbose=False
    )
    
    # Create figure
    fig, ax = plt.subplots(figsize=(12, 8))
    ax.axis('off')
    
    # Title
    fig.suptitle('Constitutional AI: Iterative Critique-Revision Process', 
                 fontsize=16, fontweight='bold')
    
    # Initial prompt
    y_pos = 0.95
    ax.text(0.5, y_pos, f"Prompt: {prompt}", 
            ha='center', va='top', fontsize=11, fontweight='bold',
            bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.7))
    
    y_pos -= 0.08
    
    # Show each iteration
    colors = ['#ffcccc', '#ffffcc', '#ccffcc', '#ccffff']
    for i in range(len(responses)):
        # Response
        response_text = responses[i][:100] + "..." if len(responses[i]) > 100 else responses[i]
        ax.text(0.5, y_pos, f"Response {i}: {response_text}",
                ha='center', va='top', fontsize=9, wrap=True,
                bbox=dict(boxstyle='round', facecolor=colors[i], alpha=0.7))
        y_pos -= 0.12
        
        # Critique (if exists)
        if i < len(critiques):
            critique_text = critiques[i][:80] + "..." if len(critiques[i]) > 80 else critiques[i]
            ax.text(0.5, y_pos, f"Critique: {critique_text}",
                    ha='center', va='top', fontsize=8, style='italic',
                    bbox=dict(boxstyle='round', facecolor='#ffeeee', alpha=0.5))
            y_pos -= 0.10
            
            # Arrow
            ax.annotate('', xy=(0.5, y_pos + 0.02), xytext=(0.5, y_pos + 0.08),
                       arrowprops=dict(arrowstyle='->', lw=2, color='gray'))
            y_pos -= 0.03
    
    plt.tight_layout()
    plt.show()

# Visualize the process
print("\n" + "="*80)
print("VISUALIZATION: REVISION PROCESS")
print("="*80)

visualize_revision_process(test_prompt, n_revisions=2)

In [None]:
# Comparison - Before and After Constitutional AI
def compare_models(prompts: List[str]):
    """
    Compare original model responses vs. constitutionally-revised responses.
    """
    print("\n" + "="*80)
    print("COMPARISON: Original vs. Constitutional AI")
    print("="*80)
    
    for i, prompt in enumerate(prompts):
        print(f"\n\n{'='*80}")
        print(f"Test Case {i+1}")
        print(f"{'='*80}")
        print(f"\nPrompt: {prompt}")
        
        # Original response
        original = generate_text(prompt, temperature=0.9)
        print(f"\n[ORIGINAL MODEL]")
        print(f"Response: {original}")
        
        # Constitutional AI response
        revised, _, _ = iterative_critique_revision(prompt, n_revisions=2, verbose=False)
        print(f"\n[CONSTITUTIONAL AI]")
        print(f"Response: {revised}")
        
        print(f"\n{'-'*80}")

# Compare on test prompts
test_cases = [
    "How can I",
    "Tell me about",
    "What is the best way to"
]

compare_models(test_cases[:2])  # Run on subset

In [None]:
# Statistics and Analysis
def analyze_datasets():
    """
    Analyze the datasets we created.
    """
    print("\n" + "="*80)
    print("DATASET STATISTICS")
    print("="*80)
    
    print(f"\nSL-CAI Dataset:")
    print(f"  Total examples: {len(sl_cai_dataset)}")
    print(f"  Red-team prompts: {len([x for x in sl_cai_dataset if any(rt in x[0] for rt in RED_TEAM_PROMPTS)])}")
    print(f"  Helpful prompts: {len([x for x in sl_cai_dataset if any(hp in x[0] for hp in HELPFUL_PROMPTS)])}")
    
    print(f"\nPreference Dataset:")
    print(f"  Total comparisons: {len(preference_dataset)}")
    print(f"  Unique prompts: {len(set([x['prompt'] for x in preference_dataset]))}")
    print(f"  Principles used: {len(set([x['principle'] for x in preference_dataset]))}")
    
    # Response length analysis
    response_lengths = [len(x[1].split()) for x in sl_cai_dataset]
    print(f"\nResponse Length Statistics:")
    print(f"  Mean: {np.mean(response_lengths):.1f} words")
    print(f"  Std: {np.std(response_lengths):.1f} words")
    print(f"  Min: {np.min(response_lengths)} words")
    print(f"  Max: {np.max(response_lengths)} words")
    
    # Visualize
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    # Dataset composition
    categories = ['Red-team\n(revised)', 'Helpful\n(original)']
    counts = [
        len([x for x in sl_cai_dataset if any(rt in x[0] for rt in RED_TEAM_PROMPTS)]),
        len([x for x in sl_cai_dataset if any(hp in x[0] for hp in HELPFUL_PROMPTS)])
    ]
    ax1.bar(categories, counts, color=['#ff9999', '#99ff99'])
    ax1.set_ylabel('Number of Examples')
    ax1.set_title('SL-CAI Dataset Composition')
    ax1.grid(axis='y', alpha=0.3)
    
    # Response length distribution
    ax2.hist(response_lengths, bins=15, color='skyblue', edgecolor='black', alpha=0.7)
    ax2.set_xlabel('Response Length (words)')
    ax2.set_ylabel('Frequency')
    ax2.set_title('Response Length Distribution')
    ax2.grid(axis='y', alpha=0.3)
    
    plt.tight_layout()
    plt.show()

analyze_datasets()

### Key Insights
1. Self-improvement through critique-revision 
   - Models can critique their own outputs when prompted with principles
   - Iterative revision progressively reduces harmful content
   - No human labels needed for harmlessness training data

2. AI feedback for preference learning
   - AI models can evaluate which responses better follow principles
   - These AI preferences replace expensive human annotations
   - Enables scalable alignment with explicit constitutional principles

3. Transparency through explicit principles 
   - Constitution makes training objectives clear and auditable
   - Principles can be modified without collecting new human data
   - Chain-of-thought reasoning shows decision process

4. Computational efficiency
   - Uses same transformer architecture (no architectural changes)
   - Only requires ~5% additional training time for uptraining
   - Significantly cheaper than collecting thousands of human labels

5. Balancing helpfulness and harmlessness
   - Mix harmless (revised) and helpful (original) data in training
   - Reduces evasiveness while maintaining safety
   - Model explains objections rather than refusing to engage

### Limitations and Notes
1. Small model
   - Using GPT-2 (117M parameters) vs. production models (10B+ parameters)
   - Smaller models have limited instruction-following ability
   - Quality of critiques and revisions is lower

2. No actual training
   - We demonstrate data generation, not model finetuning
   - Real implementation requires training on generated data
   - Would need PyTorch training loop with proper optimization

3. Simplified AI feedback
   - Real implementation uses log probabilities for calibrated preferences
   - We use simple heuristics for demo purposes
   - Production systems use more sophisticated comparison methods

4. Small dataset
   - Real CAI uses 100K+ examples
   - We generate <10 examples for speed
   - Insufficient for meaningful model improvement

5. No RL stage
   - We skip the actual RL training with PPO
   - Real implementation requires reward model and policy optimization
   - This adds significant complexity