# Data Generation

This tutorial is the first step in implementing **Constitutional AI** techniques in the context of education. The objective is to demonstrate how to generate a foundational dataset that aligns with a "constitution" of principles aimed at guiding AI behavior. These principles ensure the AI serves as a learning aid, promoting understanding without substituting for students' effort or giving direct answers.

The approach involves:
1. **Generating prompts and initial responses**: Using an AI model to simulate potential student interactions.
2. **Critiquing and revising**: Guiding the AI to critique and improve its responses according to a defined set of principles.
3. **Creating a refined dataset**: Compiling the revised responses into a dataset for fine-tuning an AI model that aligns with educational goals.

To run this procedure on GPU can be done running the [generate_dataset.py](./generate_dataset.py) script.

In [5]:
#@title Colab Extra Install { display-mode: "form" }
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm
    wrk_dir = ''
else:
    !pip install --no-deps unsloth vllm
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    # Skip restarting message in Colab
    import sys, re, requests; modules = list(sys.modules.keys())
    for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf datasets huggingface_hub hf_transfer

    # vLLM requirements - vLLM breaks Colab due to reinstalling numpy
    f = requests.get("https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt").content
    with open("vllm_requirements.txt", "wb") as file:
        file.write(re.sub(rb"(transformers|numpy|xformers)[^\n]{1,}\n", b"", f))
    !pip install -r vllm_requirements.txt

    # Data directory
    from google.colab import drive
    drive.mount('/content/drive')
    wrk_dir = '/content/drive/MyDrive/constitutional-ai-education'

In [6]:
import asyncio
import json
import contextlib
import random
from dataclasses import dataclass
from collections import defaultdict
from tqdm.asyncio import tqdm_asyncio
import pandas as pd
import torch
import time
from unsloth import FastModel
from unsloth.chat_templates import get_chat_template, train_on_responses_only
import time

## Load the LLM

This section demonstrates how to load the pretrained Gemma 3 4B model from Hugging Face and configure it for efficient inference.

In [7]:
@dataclass
class Config:
    """Configuration settings for dataset generation and model parameters."""
    max_samples: int = 5000
    max_new_tokens: int = 1200
    # Gemma-3 recommended inference settings
    temperature: float = 1.0
    top_k: int = 64
    top_p: float = 0.95
    min_p: float = 0.0
    repetition_penalty: float = 1.0  # 1.0 means disabled
    constitution_path: str = os.path.join(wrk_dir, "data/constitution_education.json")
    dataset_path: str = os.path.join(wrk_dir, "data/student_prompts.json")
    # Unsloth's optimized 4-bit version
    model_name: str = "unsloth/gemma-3-4b-it-unsloth-bnb-4bit"

# Load configuration
config = Config()

print('Loading models...')

# Initialize model and tokenizer with Unsloth
model, tokenizer = FastModel.from_pretrained(
    model_name=config.model_name,
    max_seq_length = 2048,
    load_in_4bit = True,
    load_in_8bit = False,
    full_finetuning = False,
)

# Configure tokenizer with Gemma-3 chat template
tokenizer = get_chat_template(
    tokenizer,
    chat_template="gemma-3"
)

Loading models...
==((====))==  Unsloth 2025.3.19: Fast Gemma3 patching. Transformers: 4.50.0. vLLM: 0.8.2.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Using float16 precision for gemma3 won't work! Using float32.


model.safetensors:   0%|          | 0.00/4.44G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/192 [00:00<?, ?B/s]

processor_config.json:   0%|          | 0.00/70.0 [00:00<?, ?B/s]

chat_template.json:   0%|          | 0.00/1.61k [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.50, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/670 [00:00<?, ?B/s]

## Load Constitutional Principles and Examples

In this section, we load a dataset containing **constitutional principles** and examples of **prompt-response-critique-correction** interactions. These elements are critical for aligning the model's outputs with the desired educational principles.

### Contents of the Data File (`constitution_education.json`)

1. **Constitutions**: A list of principles designed to guide AI responses. Each principle includes:
   - A **critic** section that evaluates the quality of a response (e.g., whether it promotes critical thinking or encourages independence).
   - A **revision** section that provides guidance for improving the response to better align with the principles.

2. **System Chat**: Examples of user queries, initial responses from the assistant, critiques, and revised responses. These interactions serve as templates for generating data that is both educationally valuable and aligned with the constitution.

The example below illustrates how to load and preprocess this data:

In [8]:
# Load constitutional principles and example conversations
with open(config.constitution_path) as f:
    data = json.load(f)
    constitutions = data["constitutions"]
    system_chat = [item for sublist in data["system_chat"] for item in sublist]

# Select subset for our purposes
system_chat = system_chat[:16]

In [9]:
constitutions[:3]

[{'critic': 'Does the response inspire the student to think about how they approached the problem? For example, does it ask them to reflect on their reasoning process or consider alternative approaches?',
  'revision': 'Rewrite the response to include thoughtful questions that encourage the student to reflect on their reasoning and approach.'},
 {'critic': 'Think carefully about whether the response handed the solution to the student too easily. What is the strongest argument you can make that the response discouraged independent thinking?',
  'revision': 'Revise the response to focus on guiding the student toward the solution through prompts or hints, without explicitly solving the problem for them.'},
 {'critic': 'Did the response help the student develop problem-solving skills? If not, how could it have supported them in figuring out the solution themselves?',
  'revision': 'Please rewrite the response to include strategies or steps that guide the student in developing their problem

In [10]:
system_chat[:6]

[{'role': 'user',
  'content': 'Can you explain why multiplying -2 × -3 = 6 without asking me questions?'},
 {'role': 'assistant',
  'content': 'Multiplying two negative numbers gives a positive because it follows the rule that reversing a reversal returns to the original direction, making -2 × -3 = 6.'},
 {'role': 'user',
  'content': 'Does the response inspire the student to think about how they approached the problem? For example, does it ask them to reflect on their reasoning process or consider alternative approaches?'},
 {'role': 'assistant',
  'content': 'The response provides the direct explanation but does not encourage the student to reflect on their reasoning process or consider alternative approaches.'},
 {'role': 'user',
  'content': 'Rewrite the response to include thoughtful questions that encourage the student to reflect on their reasoning and approach.'},
 {'role': 'assistant',
  'content': 'Multiplying two negative numbers results in a positive because it follows the 

# Load Training and Test Prompts

Load a dataset of training, validation, and test prompts.

In [11]:
# Load the dataset
with open(config.dataset_path, "r") as f:
    ds = json.load(f)
    for split in ds:
        ds[split] = ds[split][:config.max_samples]

# Print an example
print(ds['train'][0])

{'prompt': 'Provide a comprehensive analysis of the factors leading to the American Civil Rights Movement.'}


### Response generation

In [12]:
async def generate_text(prompt=None, message_history=None, semaphore=None):
    """
    Generates text using the Gemma-3 model asynchronously.

    Args:
        prompt (str, optional): Single prompt for text generation
        message_history (list, optional): List of message dictionaries containing conversation history
        semaphore (asyncio.Semaphore, optional): Semaphore for controlling concurrent requests

    Returns:
        str: Generated text response

    Raises:
        AssertionError: If neither or both prompt and message_history are provided,
                       or if inputs are of incorrect types
    """
    # Input validation
    assert not (prompt is None and message_history is None), \
        "Either prompt or message_history must be provided"
    assert not (prompt is not None and message_history is not None), \
        "Cannot provide both prompt and message_history"

    if prompt is not None:
        assert isinstance(prompt, str), "prompt must be a string"

    if message_history is not None:
        assert isinstance(message_history, list), "message_history must be a list"
        assert len(message_history) > 0, "message_history cannot be empty"
        assert all(isinstance(msg, dict) for msg in message_history), \
            "all messages in message_history must be dictionaries"
        assert all("role" in msg and "content" in msg for msg in message_history), \
            "each message must contain 'role' and 'content' keys"
        assert all(isinstance(msg["role"], str) and isinstance(msg["content"], str)
                  for msg in message_history), \
            "message role and content must be strings"

    async with contextlib.nullcontext() if semaphore is None else semaphore:
        try:
            # Format messages for Gemma-3
            if prompt is not None:
                messages = [{
                    "role": "user",
                    "content": [{
                        "type": "text",
                        "text": prompt,
                    }]
                }]
            else:
                messages = [{
                    "role": msg["role"],
                    "content": [{
                        "type": "text",
                        "text": msg["content"],
                    }]
                } for msg in message_history]

            # Apply chat template with generation prompt
            formatted_prompt = tokenizer.apply_chat_template(
                messages,
                add_generation_prompt=True,  # Required for generation
            )

            # Tokenize and move to device
            inputs = tokenizer([formatted_prompt], return_tensors="pt").to(model.device)

            # Generate text with Gemma-3 recommended parameters
            outputs = model.generate(
                **inputs,
                max_new_tokens=config.max_new_tokens,
                temperature=config.temperature,
                top_k=config.top_k,
                top_p=config.top_p,
                min_p=config.min_p,
                repetition_penalty=config.repetition_penalty,
                do_sample=True,
            )

            # Decode and clean response
            response = tokenizer.batch_decode(outputs)[0]
            # Extract only the model's response after the last model turn
            response = response.split('<start_of_turn>model\n')[-1].strip()

            # Remove <end_of_turn> if it appears at the end of the response
            if response.endswith('<end_of_turn>'):
                response = response[:-len('<end_of_turn>')].strip()

            return response
        except Exception as e:
            print(f"Error in generate_text: {e}")
            raise

In [13]:
# Example of using the LLM to generate text
async def run_generation(prompt):
    response = await generate_text(prompt)
    return response

result = await run_generation(prompt=ds['train'][0]['prompt'])
print(result)

## The Roots of the American Civil Rights Movement: A Comprehensive Analysis

The American Civil Rights Movement of the mid-20th century wasn't a spontaneous eruption, but rather the culmination of centuries of struggle, injustice, and evolving social consciousness. Understanding the movement requires examining a complex web of interconnected factors spanning slavery, Jim Crow, economic shifts, and burgeoning intellectual and political currents. Here’s a breakdown of the key elements:

**1. The Legacy of Slavery (1619-1865): The Foundational Wound**

* **Dehumanization & Brutality:**  The institution of slavery was the core foundational factor. It wasn’t merely an economic system, but a system predicated on the denial of basic human rights and the systematic dehumanization of Africans and their descendants.  Generations of violence, forced labor, and family separation instilled deep trauma and a sense of dispossession.
* **Economic Engine:** Slavery fueled the Southern economy, particu

## Generate Data Samples with User Prompts and Constitutional Principles

This section defines a function to create a single dataset sample based on a student task, a random constitutional principle, and an initial user-assistant interaction. The function incorporates critiques and revisions into the sample, making it ready for fine-tuning the AI model.

In [14]:
async def create_sample(split, i, task, semaphore=None):
    """
    Process a single task with critique and revision using constitutional principles.

    Args:
        split (str): Dataset split (train/val/test)
        i (int): Index of the task
        task (str): The initial student query
        semaphore (asyncio.Semaphore, optional): Semaphore for controlling concurrent requests

    Returns:
        tuple: (split, index, dictionary containing prompts and responses)

    Raises:
        AssertionError: If inputs are invalid or if system_chat or constitutions are not properly formatted
    """
    # Input validation
    assert isinstance(split, str), "split must be a string"
    assert isinstance(i, int), "i must be an integer"
    assert isinstance(task, str), "task must be a string"
    assert task.strip(), "task cannot be empty"
    assert semaphore is None or isinstance(semaphore, asyncio.Semaphore), \
        "semaphore must be None or an asyncio.Semaphore instance"

    # Validate system_chat structure
    assert isinstance(system_chat, list) and len(system_chat) > 0, \
        "system_chat must be a non-empty list"
    assert all(
        isinstance(msg, dict) and
        "role" in msg and
        "content" in msg and
        isinstance(msg["role"], str) and
        isinstance(msg["content"], str)
        for msg in system_chat
    ), "system_chat messages must be dictionaries with 'role' and 'content' string fields"

    # Validate constitutions structure
    assert isinstance(constitutions, list) and len(constitutions) > 0, \
        "constitutions must be a non-empty list"
    assert all(
        isinstance(const, dict) and
        "critic" in const and
        "revision" in const and
        isinstance(const["critic"], str) and
        isinstance(const["revision"], str)
        for const in constitutions
    ), "constitutions must contain dictionaries with 'critic' and 'revision' string fields"

    # Initialize chat history with system messages
    chat_history = [
        {"role": msg["role"], "content": msg["content"]}
        for msg in system_chat
    ]

    # Select a random constitutional principle for critique and revision
    constitution = random.choice(constitutions)

    # Initialize an empty dictionary to store the sample's components
    row = {}

    # Go through initial response, critique, and revision phases
    phases = [
        (task, "init_prompt", "init_response"),
        (constitution["critic"], "critic_prompt", "critic_response"),
        (constitution["revision"], "revision_prompt", "revision_response"),
    ]

    for prompt, prompt_key, response_key in phases:
        # Validate prompt for each phase
        assert isinstance(prompt, str) and prompt.strip(), \
            f"Invalid prompt for {prompt_key}"

        prompt_suffix = ''
        if 'revision' in prompt_key:
            prompt_suffix = ' Only include the revised response to the student - do not include any sort of reflection on the response or the improvements.'

        # Add the current prompt to chat history
        chat_history.append({
            "role": "user",
            "content": prompt+prompt_suffix
        })

        # Generate response using the full chat history
        completion = await generate_text(
            message_history=chat_history,
            semaphore=semaphore
        )

        # Validate completion
        assert isinstance(completion, str) and completion.strip(), \
            f"Invalid completion received for {response_key}"

        # Add response to conversation history
        chat_history.append({
            "role": "assistant",
            "content": completion
        })

        # Store prompt and response
        row[prompt_key] = prompt
        row[response_key] = completion

    # Validate final row structure
    expected_keys = {"init_prompt", "init_response", "critic_prompt",
                    "critic_response", "revision_prompt", "revision_response"}
    assert set(row.keys()) == expected_keys, \
        f"Missing keys in row. Expected: {expected_keys}, Got: {set(row.keys())}"
    assert all(isinstance(v, str) and v.strip() for v in row.values()), \
        "All values in row must be non-empty strings"

    return split, i, row

## Test Data Generation

Testing with a single example

In [15]:
# Example usage
async def run_sample(split='train', i=1):
    task = ds[split][i]['prompt']
    return await create_sample(split, i, task)

result = await run_sample()
result[2]

{'init_prompt': 'Provide a detailed explanation of the process of cellular respiration and its importance to living organisms.',
 'init_response': "Okay, let’s delve into cellular respiration – it’s a fundamental process for almost all living organisms! Here’s a detailed explanation:\n\n**What is Cellular Respiration?**\n\nCellular respiration is the process by which cells break down glucose (a sugar) to release energy in the form of ATP (adenosine triphosphate). Think of it like a controlled burning – it’s not a dramatic explosion, but it releases a lot of usable energy.\n\n**The Overall Equation:**\n\nGlucose (C6H12O6) + Oxygen (O2)  → Carbon Dioxide (CO2) + Water (H2O) + Energy (ATP)\n\n**The Stages of Cellular Respiration:**\n\nCellular respiration isn’t just one step; it’s a series of interconnected stages:\n\n1. **Glycolysis:** (Happens in the cytoplasm - the fluid part of the cell)\n   * Glucose is broken down into two molecules of pyruvate.\n   * A small amount of ATP (energy) 

## Create dataset

Use the function to process the entire dataset.

In [16]:
async def main():
    """Main function to process all tasks and generate the dataset."""
    try:
        # Configure concurrent processing based on available GPU
        semaphore = asyncio.Semaphore(torch.cuda.device_count() * 10 or 1)

        # Validate dataset structure
        assert isinstance(ds, dict) and len(ds) > 0, "Dataset must be a non-empty dictionary"
        for split, data in ds.items():
            assert isinstance(split, str), f"Split name must be string, got {type(split)}"
            assert isinstance(data, list), f"Split {split} must contain a list of samples"
            assert all(isinstance(row, dict) and "prompt" in row for row in data), \
                f"Each row in {split} must be a dictionary containing 'prompt'"

        # Create tasks for all samples
        tasks = [
            create_sample(split, idx, row["prompt"], semaphore)
            for split in ds
            for idx, row in enumerate(ds[split])
            if idx < 5 # Remove this line for actual run
        ]

        # Process all tasks with progress bar
        print(f"Processing {len(tasks)} tasks across {len(ds)} splits...")
        results = await tqdm_asyncio.gather(*tasks)

        # Organize results by dataset split
        all_ds = defaultdict(lambda: defaultdict(list))
        for split, i, row in results:
            for key, value in row.items():
                all_ds[split][key].append(value)

        # Validate and save results to CSV files
        for split, data in all_ds.items():
            df = pd.DataFrame(data)

            # Validate DataFrame structure
            expected_columns = {
                "init_prompt", "init_response",
                "critic_prompt", "critic_response",
                "revision_prompt", "revision_response"
            }
            assert set(df.columns) == expected_columns, \
                f"Missing columns in {split} dataset. Expected: {expected_columns}"

            # Ensure no empty values
            assert not df.isna().any().any(), f"Found null values in {split} dataset"

            # Save to CSV
            output_path = f"data/{split}_dataset.csv"
            df.to_csv(output_path, index=False)
            print(f"Saved {len(df)} samples to {output_path}")

    except Exception as e:
        print(f"Error in main: {e}")
        raise

In [17]:
import nest_asyncio
nest_asyncio.apply()

In [None]:
# Run the main process with timing
start_time = time.time()
asyncio.run(main())
end_time = time.time()
print(f"Data generation completed in {end_time - start_time:.2f} seconds.")

Processing 15 tasks across 3 splits...


  0%|          | 0/15 [00:00<?, ?it/s]