# Constitutional AI for Education: Data Generation Tutorial

This tutorial demonstrates how to generate training data to implement the **Constitutional AI** fine-tuning technique. We'll use the Gemma 3 (4B) model to generate, critique, and refine AI responses that support student learning.

## Overview

Constitutional AI aims to align AI behavior with specific principles and values. In an educational context, this means ensuring the AI:
- Promotes deep understanding rather than providing direct answers
- Guides students through problem-solving processes
- Encourages critical thinking and independent learning
- Maintains appropriate academic standards

### Process Overview

Our data generation pipeline consists of three key stages:

1. **Initial Response Generation**
   - Take student queries as input
   - Generate initial AI responses using Gemma-3 (4B)
   - These responses may or may not align with our educational principles

2. **Constitutional Critique**
   - Ask the model to apply educational principles to evaluate the initial response
   - The model will identify areas where the response could better support learning based on our constitutional guidelines

3. **Response Revision**
   - Use the critique to generate an improved response
   - Create a dataset of aligned responses for future model training

> **Note:** This notebook demonstrates the process with a single example. For full dataset generation, use the [generate_dataset.py](./generate_dataset.py) script.

In [1]:
import os

if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl==0.15.2 triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf datasets huggingface_hub hf_transfer
    !pip install --no-deps unsloth

    # Data directory
    from google.colab import drive
    drive.mount('/content/drive')
    wrk_dir = '/content/drive/MyDrive/constitutional-ai-education'

Mounted at /content/drive


In [2]:
import json
import random
from dataclasses import dataclass
from collections import defaultdict
from typing import Dict, List, Tuple, Optional
import pandas as pd
import torch
import time
import os
import tqdm
from unsloth import FastModel
from unsloth.chat_templates import get_chat_template

from IPython.display import Markdown, display
def printmd(string):
    display(Markdown(string))

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
Unsloth: Failed to patch Gemma3ForConditionalGeneration.
🦥 Unsloth Zoo will now patch everything to make training faster!


## Load the LLM

This section demonstrates how to load the pretrained Gemma 3 4B model from Hugging Face and configure it for efficient inference.

In [3]:
@dataclass
class Config:
    """Configuration settings for dataset generation and model parameters."""
    max_samples: int = 5000
    max_new_tokens: int = 1200
    temperature: float = 1.0
    top_k: int = 64
    top_p: float = 0.95
    min_p: float = 0.0
    repetition_penalty: float = 1.0
    constitution_path: str = os.path.join(wrk_dir, "data/constitution_education.json")
    dataset_path: str = os.path.join(wrk_dir, "data/student_prompts.json")
    model_name: str = "unsloth/gemma-3-4b-it-unsloth-bnb-4bit"

# Create config instance
config = Config()

# Model initialization
print('Loading models...')
model, tokenizer = FastModel.from_pretrained(
    model_name=config.model_name,
    max_seq_length=2048,
    load_in_4bit=True,
    load_in_8bit=False,
    full_finetuning=False,
)

# Configure tokenizer with Gemma-3 chat template
tokenizer = get_chat_template(
    tokenizer,
    chat_template="gemma-3"
)

Loading models...
==((====))==  Unsloth 2025.3.19: Fast Gemma3 patching. Transformers: 4.51.1.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Using float16 precision for gemma3 won't work! Using float32.


model.safetensors:   0%|          | 0.00/4.56G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/210 [00:00<?, ?B/s]

processor_config.json:   0%|          | 0.00/70.0 [00:00<?, ?B/s]

chat_template.json:   0%|          | 0.00/1.61k [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/670 [00:00<?, ?B/s]

## Load Constitutional Principles and Examples

In this section, we load a dataset containing **constitutional principles**.

Each principle includes:
   - A **critic** section that evaluates the quality of a response (e.g., whether it promotes critical thinking or encourages independence).
   - A **revision** section that provides guidance for improving the response to better align with the principles.

In [4]:
# Load constitutional principles and example conversations
with open(config.constitution_path) as f:
    data = json.load(f)
    constitutions = data["constitutions"]

In [5]:
constitutions[:3]

[{'critic': 'Does the response inspire the student to think about how they approached the problem? For example, does it ask them to reflect on their reasoning process or consider alternative approaches?',
  'revision': 'Rewrite the response to include thoughtful questions that encourage the student to reflect on their reasoning and approach.'},
 {'critic': 'Think carefully about whether the response handed the solution to the student too easily. What is the strongest argument you can make that the response discouraged independent thinking?',
  'revision': 'Revise the response to focus on guiding the student toward the solution through prompts or hints, without explicitly solving the problem for them.'},
 {'critic': 'Did the response help the student develop problem-solving skills? If not, how could it have supported them in figuring out the solution themselves?',
  'revision': 'Please rewrite the response to include strategies or steps that guide the student in developing their problem

# Load Training and Test Prompts

Load a dataset of training, validation, and test prompts that mimic student queries.

In [6]:
# Load the dataset
with open(config.dataset_path, "r") as f:
    ds = json.load(f)
    for split in ds:
        ds[split] = ds[split][:config.max_samples]
printmd(f'#### Example prompt:\n\n{ds["train"][0]["prompt"]}')

#### Example prompt:

Solve this equation and show me the answer: 3x + 7 = 22.

### Response generation

This function is our core text generation engine. Here's what it does:

**Input Handling**
- Takes either a single question (like "How do I solve this equation?") or
- A full conversation history (previous exchanges between `user` and `assistant`)

**Processing Steps**
1. Formats the input into Gemma's preferred conversation style
2. Sends it to the model for processing
3. Cleans up the response by removing technical markers

> **Note**: This function will be used throughout our pipeline to generate initial responses, evaluate them, and create improved versions.

In [7]:
def generate_text(
    prompt: Optional[str] = None,
    message_history: Optional[List[Dict[str, str]]] = None,
    model: FastModel = None,
    tokenizer = None,
    config: Config = None
) -> str:
    """
    Generates text using the Gemma-3 model.

    Args:
        prompt: Single prompt for text generation
        message_history: List of message dictionaries containing conversation history
        model: The loaded FastModel instance
        tokenizer: The configured tokenizer
        config: Configuration settings

    Returns:
        str: Generated text response

    Raises:
        AssertionError: If inputs are invalid or missing required components
    """
    # Input validation
    assert not (prompt is None and message_history is None), \
        "Either prompt or message_history must be provided"
    assert not (prompt is not None and message_history is not None), \
        "Cannot provide both prompt and message_history"
    assert model is not None, "Model must be provided"
    assert tokenizer is not None, "Tokenizer must be provided"
    assert config is not None, "Config must be provided"

    if prompt is not None:
        assert isinstance(prompt, str), "prompt must be a string"

    if message_history is not None:
        assert isinstance(message_history, list), "message_history must be a list"
        assert len(message_history) > 0, "message_history cannot be empty"
        assert all(isinstance(msg, dict) for msg in message_history), \
            "all messages in message_history must be dictionaries"
        assert all("role" in msg and "content" in msg for msg in message_history), \
            "each message must contain 'role' and 'content' keys"
        assert all(isinstance(msg["role"], str) and isinstance(msg["content"], str)
                  for msg in message_history), \
            "message role and content must be strings"

    try:
        # Format messages for Gemma-3
        if prompt is not None:
            messages = [{
                "role": "user",
                "content": [{"type": "text", "text": prompt}]
            }]
        else:
            messages = [{
                "role": msg["role"],
                "content": [{"type": "text", "text": msg["content"]}]
            } for msg in message_history]

        # Apply chat template with generation prompt
        formatted_prompt = tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
        )

        # Tokenize and move to device
        inputs = tokenizer([formatted_prompt], return_tensors="pt").to(model.device)

        # Generate text with Gemma-3 recommended parameters
        outputs = model.generate(
            **inputs,
            max_new_tokens=config.max_new_tokens,
            temperature=config.temperature,
            top_k=config.top_k,
            top_p=config.top_p,
            min_p=config.min_p,
            repetition_penalty=config.repetition_penalty,
            do_sample=True,
        )

        # Decode and clean response
        response = tokenizer.batch_decode(outputs)[0]
        response = response.split('<start_of_turn>model\n')[-1].strip()

        if response.endswith('<end_of_turn>'):
            response = response[:-len('<end_of_turn>')].strip()

        return response

    except Exception as e:
        print(f"Error in generate_text: {e}")
        raise

In [8]:
response = generate_text(prompt=ds['train'][0]['prompt'],
                        model=model,
                        tokenizer=tokenizer,
                        config=config)

printmd(f'#### Example response:\n\n{response}')

#### Example response:

Okay, let's solve the equation 3x + 7 = 22:

1. **Subtract 7 from both sides:**
   3x + 7 - 7 = 22 - 7
   3x = 15

2. **Divide both sides by 3:**
   3x / 3 = 15 / 3
   x = 5

**Therefore, x = 5**

**To check the solution, substitute x = 5 back into the original equation:**

3(5) + 7 = 15 + 7 = 22
The equation holds true, so our solution is correct.

## Creating AI Responses with Constitutional Guidance

This function is the heart of our constitutional AI process. For each student question, it creates a three-stage conversation:

**Stage 1: Initial Response**
- Takes a student's question
- Generates an initial AI response
- This response might not perfectly align with our educational principles yet

**Stage 2: Self-Critique**
- Randomly selects one of our educational principles (or *constitutions*)
- Uses it to evaluate the initial response
- Identifies specific ways the response could better support learning
- For example, checking if we're guiding rather than giving answers

**Stage 3: Improved Response**
- Creates a new response that addresses the critique
- Aims to better align with our educational goals
- Maintains the helpful aspects while fixing identified issues

> **Note**: Each response is stored along with its critique and revision, creating a complete example for training.

In [9]:
def create_sample(
    split: str,
    i: int,
    task: str,
    model: FastModel,
    tokenizer,
    config: Config,
    constitutions: List[Dict[str, str]],
    system_chat: Optional[List[Dict[str, str]]] = None
) -> Tuple[str, int, Dict[str, str]]:
    """
    Process a single task with critique and revision using constitutional principles.

    Args:
        split: Dataset split (train/val/test)
        i: Index of the task
        task: The initial student query
        model: The loaded FastModel instance
        tokenizer: The configured tokenizer
        config: Configuration settings
        constitutions: List of constitutional principles
        system_chat: List of system messages

    Returns:
        tuple: (split, index, dictionary containing prompts and responses)

    Raises:
        AssertionError: If inputs are invalid or missing required components
    """
    # Input validation
    assert isinstance(split, str), "split must be a string"
    assert isinstance(i, int), "i must be an integer"
    assert isinstance(task, str), "task must be a string"
    assert task.strip(), "task cannot be empty"
    assert model is not None, "Model must be provided"
    assert tokenizer is not None, "Tokenizer must be provided"

    # Validate system_chat structure
    if system_chat is not None:
        assert isinstance(system_chat, list) and len(system_chat) > 0, \
            "system_chat must be a non-empty list"
        assert all(
            isinstance(msg, dict) and
            "role" in msg and
            "content" in msg and
            isinstance(msg["role"], str) and
            isinstance(msg["content"], str)
            for msg in system_chat
        ), "system_chat messages must be dictionaries with 'role' and 'content' string fields"

    # Initialize chat history with system messages
    chat_history = []
    if system_chat:
        chat_history.extend([
            {"role": msg["role"], "content": msg["content"]}
            for msg in system_chat
        ])

    # Select a random constitutional principle
    constitution = random.choice(constitutions)
    row = {}

    # Go through initial response, critique, and revision phases
    phases = [
        (task, "init_prompt", "init_response"),
        (constitution["critic"], "critic_prompt", "critic_response"),
        (constitution["revision"], "revision_prompt", "revision_response"),
    ]

    for prompt, prompt_key, response_key in phases:
        # Validate prompt for each phase
        assert isinstance(prompt, str) and prompt.strip(), \
            f"Invalid prompt for {prompt_key}"

        prompt_suffix = ''
        if 'revision' in prompt_key:
            prompt_suffix = '\n\n **Note:** Only include the revised response to the student\'s initial query - DO NOT INCLUDE ANY OTHER TEXT BEFORE OR AFTER THE REVISED RESPONSE. Write your response as if you are addressing the initial query and do not include text like "Here\'s the revised response: ...".'
        elif 'critic' in prompt_key:
            prompt_suffix = '\n\n **Note:** Only include the critique of the previous response - DO NOT INCLUDE ANY PROPOSED REVISIONS.'
        else:
            prompt_suffix = ''
        # Add the current prompt to chat history
        chat_history.append({
            "role": "user",
            "content": prompt + prompt_suffix
        })

        # Generate response using the full chat history
        completion = generate_text(
            message_history=chat_history,
            model=model,
            tokenizer=tokenizer,
            config=config
        )

        # Validate completion
        assert isinstance(completion, str) and completion.strip(), \
            f"Invalid completion received for {response_key}"

        # Add response to conversation history
        chat_history.append({
            "role": "assistant",
            "content": completion
        })

        # Store prompt and response
        row[prompt_key] = prompt
        row[response_key] = completion

    return split, i, row

## Test Data Generation

Testing with a single example

In [10]:
# Example usage
split, idx, result = create_sample(
    "train", 0, ds['train'][0]['prompt'],
    model, tokenizer, config,
    constitutions
)

# Display results
printmd(f'#### Example initial response:\n\n{result["init_response"]}')
printmd(f'#### Example critique:\n\n{result["critic_response"]}')
printmd(f'#### Example revised response:\n\n{result["revision_response"]}')

#### Example initial response:

Okay, let's solve the equation 3x + 7 = 22:

1. **Subtract 7 from both sides:**
   3x + 7 - 7 = 22 - 7
   3x = 15

2. **Divide both sides by 3:**
   3x / 3 = 15 / 3
   x = 5

**Therefore, the solution is x = 5**

**To check the solution:**
Substitute x=5 back into the original equation:
3(5) + 7 = 15 + 7 = 22.  The equation holds true.

#### Example critique:

The response primarily provided a step-by-step solution, which, while correct, didn’t actively engage the student in the problem-solving process. It felt more like presenting the answer than facilitating the student’s understanding of *how* to arrive at that answer.  It didn’t ask questions to gauge their existing knowledge, nor did it explicitly explain the reasoning behind each step in a way that would help them build confidence in their own abilities to solve similar problems. The "check the solution" section was a good addition, but it could have been integrated earlier to reinforce the process.

#### Example revised response:

Let’s solve this equation together and really understand how to get to the answer! The equation is 3x + 7 = 22.

1. **Our Goal:**  We want to find the value of ‘x’ that makes this equation true.  That means we want to isolate ‘x’ on one side of the equation by itself.

2. **Undo Addition/Subtraction First:**  Look closely at the equation.  Right now, we have "+ 7" on its own side. To get rid of it, we need to do the opposite operation – subtraction. Let’s subtract 7 from *both* sides of the equation. This keeps the equation balanced.
   3x + 7 - 7 = 22 - 7
   This simplifies to: 3x = 15

3. **Undo Multiplication/Division:** Now we have ‘3x’. This means 3 is being multiplied by ‘x’. To isolate ‘x’, we need to do the opposite – division.  Divide *both* sides of the equation by 3.
   3x / 3 = 15 / 3
   This simplifies to: x = 5

4. **Check Your Work:**  It’s always a good idea to check if you’ve solved it correctly! Let’s plug x = 5 back into the original equation:
   3(5) + 7 = 15 + 7 = 22.  See?  The equation is true!  This means we did it correctly.

**Let’s think about why this works:**  Every time we do something to one side of an equation, we *have* to do the same thing to the other side.  This keeps the equation balanced.  We're essentially reversing the operations that were performed on 'x' to isolate it.

Do you want to try another similar problem to practice this strategy?

## Create dataset

> Check out the [generate_dataset.py](./generate_dataset.py) script to run the procedure to produce the training, validation, and test sets!