# 🤖 ART•E - Train an Email Agent that Beats o3!

Imagine you're drowning in emails and need an AI assistant that can actually find the information that you're looking for. Today, we're going to build and train one using **reinforcement learning** - the same technique used to train o3 and DeepSeek-R1. Specifically, we'll be using GRPO (group relative policy optimization) to teach Qwen 2.5 7B to retrieve information from a database of emails with a higher degree of accuracy than o3.


## 🧭 The Journey Ahead

1. **🔧 Setup** - Get our tools ready
2. **📚 Data** - Create a realistic email database
3. **🛠️ Tools** - Build search capabilities
4. **🎮 Training** - Teach our agent through trial and error
5. **🚀 Deploy** - Watch it work its magic!

If you come across any issues or have questions while following along, please join the Discord and ask away! For feature requests or to leave a star, visit our [GitHub](https://github.com/openpipe/art).


<div class="align-center">
<a href="https://github.com/openpipe/art"><img src="https://github.com/openpipe/art/raw/main/assets/ART_pill.png" height="50"></a>
<a href="https://discord.gg/zbBHRUpwf4"><img src="https://github.com/openpipe/art/raw/main/assets/Discord_pill.png" height="50"></a>
<a href="https://art.openpipe.ai"><img src="https://github.com/openpipe/art/raw/main/assets/Documentation_pill.png" height="50"></a>

</div>

 Now let's dive in!

---

## Step 1: Installing Packages 📦

We'll start by installing the packages we need to build our email search agent:


In [2]:
print("📦 Installing packages...")

# Install packages using pip (uv can be used via command line if preferred)
!uv pip install --quiet openpipe-art datasets litellm pydantic python-dotenv langchain-core tenacity weave rich tqdm nbformat

print("✅ Successfully installed all packages!")
print("\n🎉 Package installation complete!")

📦 Installing packages...
✅ Successfully installed all packages!

🎉 Package installation complete!


## 📥 Data: Thank you Enron! 🫡

Next, we need some realistic email inboxes for our research agent to navigate during training. Fortunately, when notorious energy trader Enron was sued for massive accounting fraud in 2001, 500K of their emails were made public in the litigation. (Pro tip: if you're engaging in massive accounting fraud, maybe don't save all your emails.)

These emails are perfect for our project - they're real business correspondence with the kind of questions people actually ask: meeting times, flight confirmations, document locations, and more.

In this step, we transform this raw email dump into a searchable database that our agent can learn to navigate. We'll do the following:

1. **Download the emails** - Use the `datasets` library to download the emails from the [Enron dataset](https://huggingface.co/datasets/corbt/enron-emails).
2. **Create database schema** - Set up SQLite tables for emails and recipients with proper indexes.
3. **Build full-text search** - Create SQLite FTS5 indexes and triggers for optimized keyword searches.
4. **Optimize for performance** - Add database indexes and configure search capabilities for real-time queries.


When we're done, our agent will be able to practice searching through thousands of these real emails to find the exact information users are looking for.

*The code below might take a few minutes to run while the email environment is being created. Feel free to expand the cell to see implementation details, or skip and continue on to the next section while it runs.*


In [9]:
# @title Email Database Initialization {display-mode:"form"}
# Setup SQLite database with email data from Hugging Face

import sqlite3
import os
from datasets import load_dataset, Features, Value, Sequence
from tqdm import tqdm

# Configuration
BASE_DIR = os.getcwd()
DEFAULT_DB_PATH = os.path.join(BASE_DIR, "data", "enron_emails.db")
DEFAULT_REPO_ID = "corbt/enron-emails"

# Database schema
SQL_CREATE_TABLES = """
DROP TABLE IF EXISTS recipients;
DROP TABLE IF EXISTS emails_fts;
DROP TABLE IF EXISTS emails;

CREATE TABLE emails (
    id INTEGER PRIMARY KEY AUTOINCREMENT,
    message_id TEXT UNIQUE,
    subject TEXT,
    from_address TEXT,
    date TEXT,
    body TEXT,
    file_name TEXT
);

CREATE TABLE recipients (
    email_id INTEGER,
    recipient_address TEXT,
    recipient_type TEXT,
    FOREIGN KEY(email_id) REFERENCES emails(id) ON DELETE CASCADE
);
"""

SQL_CREATE_INDEXES_TRIGGERS = """
CREATE INDEX idx_emails_from ON emails(from_address);
CREATE INDEX idx_emails_date ON emails(date);
CREATE INDEX idx_emails_message_id ON emails(message_id);
CREATE INDEX idx_recipients_address ON recipients(recipient_address);
CREATE INDEX idx_recipients_type ON recipients(recipient_type);
CREATE INDEX idx_recipients_email_id ON recipients(email_id);
CREATE INDEX idx_recipients_address_email ON recipients(recipient_address, email_id);

CREATE VIRTUAL TABLE emails_fts USING fts5(
    subject,
    body,
    content='emails',
    content_rowid='id'
);

CREATE TRIGGER emails_ai AFTER INSERT ON emails BEGIN
    INSERT INTO emails_fts (rowid, subject, body)
    VALUES (new.id, new.subject, new.body);
END;

CREATE TRIGGER emails_ad AFTER DELETE ON emails BEGIN
    DELETE FROM emails_fts WHERE rowid=old.id;
END;

CREATE TRIGGER emails_au AFTER UPDATE ON emails BEGIN
    UPDATE emails_fts SET subject=new.subject, body=new.body WHERE rowid=old.id;
END;

INSERT INTO emails_fts (rowid, subject, body) SELECT id, subject, body FROM emails;
"""


def generate_database(overwrite: bool = False):
    """Generate the email database from Hugging Face dataset."""
    if os.path.exists(DEFAULT_DB_PATH) and not overwrite:
        print(f"Database already exists at {DEFAULT_DB_PATH}")
        return

    os.makedirs(os.path.dirname(DEFAULT_DB_PATH), exist_ok=True)

    print("📥 Downloading dataset from Hugging Face...")
    # Download dataset
    expected_features = Features(
        {
            "message_id": Value("string"),
            "subject": Value("string"),
            "from": Value("string"),
            "to": Sequence(Value("string")),
            "cc": Sequence(Value("string")),
            "bcc": Sequence(Value("string")),
            "date": Value("timestamp[us]"),
            "body": Value("string"),
            "file_name": Value("string"),
        }
    )
    dataset = load_dataset(DEFAULT_REPO_ID, features=expected_features, split="train")

    print("🗄️ Creating database and tables...")
    # Create database
    conn = sqlite3.connect(DEFAULT_DB_PATH)
    cursor = conn.cursor()
    cursor.executescript(SQL_CREATE_TABLES)
    conn.commit()

    print("📝 Populating database with email data...")
    # Populate database
    conn.execute("PRAGMA synchronous = OFF;")
    conn.execute("PRAGMA journal_mode = MEMORY;")
    conn.execute("BEGIN TRANSACTION;")

    processed_emails = set()

    for email_data in tqdm(dataset, desc="Inserting emails"):
        message_id = email_data["message_id"]
        subject = email_data["subject"]
        from_address = email_data["from"]
        date_obj = email_data["date"]
        body = email_data["body"]
        file_name = email_data["file_name"]

        # Filter long emails and high recipient counts
        if len(body) > 5000:
            continue

        to_list = [str(addr) for addr in email_data["to"] if addr]
        cc_list = [str(addr) for addr in email_data["cc"] if addr]
        bcc_list = [str(addr) for addr in email_data["bcc"] if addr]

        total_recipients = len(to_list) + len(cc_list) + len(bcc_list)
        if total_recipients > 30:
            continue

        # Deduplicate
        email_key = (subject, body, from_address)
        if email_key in processed_emails:
            continue
        processed_emails.add(email_key)

        date_str = date_obj.strftime("%Y-%m-%d %H:%M:%S")

        cursor.execute(
            """
            INSERT INTO emails (message_id, subject, from_address, date, body, file_name)
            VALUES (?, ?, ?, ?, ?, ?)
        """,
            (message_id, subject, from_address, date_str, body, file_name),
        )

        email_pk_id = cursor.lastrowid

        # Insert recipients
        recipient_data = []
        for addr in to_list:
            recipient_data.append((email_pk_id, addr, "to"))
        for addr in cc_list:
            recipient_data.append((email_pk_id, addr, "cc"))
        for addr in bcc_list:
            recipient_data.append((email_pk_id, addr, "bcc"))

        if recipient_data:
            cursor.executemany(
                """
                INSERT INTO recipients (email_id, recipient_address, recipient_type)
                VALUES (?, ?, ?)
            """,
                recipient_data,
            )

    conn.execute("COMMIT;")
    print("🔍 Creating search indexes and triggers...")
    cursor.executescript(SQL_CREATE_INDEXES_TRIGGERS)
    conn.commit()
    conn.close()

    print(f"✅ Database successfully created at {DEFAULT_DB_PATH}")


# Initialize the database
generate_database()


# print first email from database
# Load and print first email from database
def get_sample_email():
    """Get a sample email from the database"""
    conn = sqlite3.connect(f"file:{DEFAULT_DB_PATH}?mode=ro", uri=True)
    cursor = conn.cursor()

    # Get a sample email with some content
    cursor.execute("""
        SELECT message_id, subject, from_address, date, body
        FROM emails 
        WHERE subject IS NOT NULL 
        AND body IS NOT NULL 
        AND length(body) > 100
        AND length(body) < 1000
        ORDER BY date DESC
        LIMIT 1
    """)

    result = cursor.fetchone()
    conn.close()

    if result:
        message_id, subject, from_addr, date, body = result
        return {
            "message_id": message_id,
            "subject": subject,
            "from_address": from_addr,
            "date": date,
            "body": body,
        }
    return None


# Load and display sample email
sample_email = get_sample_email()
if sample_email:
    print("📧 Sample Email from Database:")
    print("=" * 50)
    print(f"📅 Date: {sample_email['date']}")
    print(f"👤 From: {sample_email['from_address']}")
    print(f"📝 Subject: {sample_email['subject']}")
    print(f"🆔 Message ID: {sample_email['message_id']}")
    print("\n📄 Body:")
    print("-" * 30)
    print(sample_email["body"])
    print("=" * 50)
else:
    print("❌ No sample email found in database")

Database already exists at /root/sky_workdir/examples/art-e/data/enron_emails.db
📧 Sample Email from Database:
📅 Date: 2044-01-04 22:48:58
👤 From: cramer@cadvision.com
📝 Subject: trades
🆔 Message ID: <21511287.1075842027020.JavaMail.evans@thyme>

📄 Body:
------------------------------

Howdy, 
bom went out 35 at 35.5 
Feb traded 32.75 and 33 
Mar 33 ,(away) , 33.5, 33.75 , and  34.25. 
Q2 was lifted 33 
and Q4 closed at 39 for 25 MW 
What day are we going for lunch  next week ? 
Have good weekend 
Erik 
  


## Generating Training Scenarios 🧞

In addition to our email database, we another form of data: realistic questions that people might ask about their emails. Since a good dataset of real user queries doesn't exist, we need to use a synthetically generated one.

To save time, we'll use a [dataset](https://huggingface.co/datasets/corbt/enron_emails_sample_questions) of training scenarios that we've already generated.

For each inbox, we iterated through 1000 emails in batches of 20. Then for every email within each batch, gpt-4.1 generated synthetic question-answer pairs. Based on these emails, the model generated questions like:
- \"What time is the Astros group game against the Cubs?\"
- \"What is my confirmation number for my Continental Airlines flight?\"
- \"When is the Tuesday afternoon meeting with EES?\"

We linked each question to the correct answer and the source email ID. We even asked gpt-4.1 to rate how realistic each question was - surprisingly effective at filtering out questions no real person would ask!

In ART parlance, we refer to these synthetic question-answer pairs as [scenarios](https://art.openpipe.ai/resources/glossary#training-scenarios). Each scenario represents a situation that the agent will encounter during training and learn to handle before it's deployed.

*Feel free to expand the cell below if you want to go deeper into the details of how we're loading the scenarios. Otherwise, let's move on to the next step!*

In [4]:
# @title Training Scenario Utilities {display-mode:"form"}
# Data types, dataset loading, and trajectory reporting

import art
from typing import List, Optional, Literal, cast
from dataclasses import dataclass
from pydantic import BaseModel
import json
import weave
from weave.trace.autopatch import AutopatchSettings
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

# Configuration constants
HF_REPO_ID = "corbt/enron_emails_sample_questions"
bad_queries = [49, 101, 129, 171, 208, 266, 327]

# ==================== DATA TYPES ====================


class TrainingConfig(BaseModel):
    trajectories_per_group: int = 6
    groups_per_step: int = 1
    learning_rate: float = 1.2e-5
    eval_steps: int = 30
    val_set_size: int = 100
    training_dataset_size: int = 4000
    num_epochs: int = 4
    group_judge_model: str = "openai/gpt-4.1"
    minimum_reward_std_dev: float = 0.0
    training_dataset_seed: int | None = None


class ProjectPolicyConfig(BaseModel):
    max_turns: int = 10
    max_tokens: int = 2048
    log_to_openpipe: bool = False
    litellm_model_name: str | None = None
    stupid_simple_reward_fn: bool = False
    training_config: TrainingConfig | None = None


class SyntheticQuery(BaseModel):
    id: int
    question: str
    answer: str
    message_ids: List[str]
    how_realistic: float
    inbox_address: str
    query_date: str
    split: Literal["train", "test"]


class Email(BaseModel):
    message_id: str
    date: str
    subject: Optional[str] = None
    from_address: Optional[str] = None
    to_addresses: List[str] = []
    cc_addresses: List[str] = []
    bcc_addresses: List[str] = []
    body: Optional[str] = None
    file_name: Optional[str] = None


@dataclass
class SearchResult:
    message_id: str
    snippet: str


# ==================== DATASET LOADING ====================


def load_synthetic_queries(
    split: Literal["train", "test"] = "train",
    limit: Optional[int] = None,
    max_messages: Optional[int] = 1,
    shuffle: bool = False,
    seed: Optional[int] = None,
    exclude_known_bad_queries: bool = True,
) -> List[SyntheticQuery]:
    """Load synthetic query dataset."""
    dataset = load_dataset(HF_REPO_ID, split=split)

    if max_messages is not None:
        dataset = dataset.filter(lambda x: len(x["message_ids"]) <= max_messages)

    if exclude_known_bad_queries:
        dataset = dataset.filter(lambda x: x["id"] not in bad_queries)

    if shuffle or seed is not None:
        if seed is not None:
            dataset = dataset.shuffle(seed=seed)
        else:
            dataset = dataset.shuffle()

    queries = [SyntheticQuery(**row, split=split) for row in dataset]

    if max_messages is not None:
        queries = [query for query in queries if len(query.message_ids) <= max_messages]

    if limit is not None:
        return queries[:limit]
    else:
        return queries


synthetic_queries = load_synthetic_queries(split="train", limit=2)
print("number of scenarios: ", len(synthetic_queries))
print("--------------------------------")
print("first scenario:")
print("question: ", synthetic_queries[0].question)
print("answer: ", synthetic_queries[0].answer)
print("message_ids: ", synthetic_queries[0].message_ids)
print("--------------------------------")
print("second scenario:")
print("question: ", synthetic_queries[1].question)
print("answer: ", synthetic_queries[1].answer)
print("message_ids: ", synthetic_queries[1].message_ids)

# ==================== TRAJECTORY REPORTING ====================


class ProjectTrajectory(art.Trajectory):
    scenario: SyntheticQuery
    generated_answer: str | None = None


def report_trajectory(
    model: art.Model,
    trajectory: ProjectTrajectory,
    step: int = 0,
):
    """Report trajectory to Weave for logging."""
    client = weave.init(
        model.project, autopatch_settings=AutopatchSettings(disable_autopatch=True)
    )

    inputs = {
        "model": model.name,
        "scenario": trajectory.scenario,
        "step": step,
    }

    if isinstance(model, art.TrainableModel):
        inputs["base_model"] = model.base_model

    call = client.create_call("trajectory", inputs=inputs)
    client.finish_call(call, output={"tr": trajectory})


print("✅ General utilities loaded successfully!")

number of scenarios:  2
--------------------------------
first scenario:
question:  Were there any variances detected for hour 6 on 3/9/01?
answer:  Yes, variances were detected in both Generation and Energy Import/Export schedules for hour 6 on 3/9/01.
message_ids:  ['<17407857.1075840601283.JavaMail.evans@thyme>']
--------------------------------
second scenario:
question:  What changes are happening to the ISDA Master Agreements due to the NationsBank and Bank of America merger?
answer:  On July 5, 1999, NationsBank changed its name to Bank of America, N.A. On July 23, 1999, Bank of America National Trust and Savings Association merged with Bank of America, N.A. The ISDA Master Agreement between Enron Corp. and NationsBank, N.A. will remain in place, with Bank of America, National Association as counterparty. The ISDA Master Agreement between Enron Corp. and Bank of America National Savings and Trust Association is terminated, and existing trades will be transferred to the other agr

## Giving the Agent Tools 🛠️

Every agent operates in an environment, which just means \"the tools you give it and the information it has access to.\" In our case, the environment consists of a database of email inboxes and the tools it can use to search through them.

Our agent will have access to 3 tools:

1. **`search_emails(keywords, sent_after, sent_before)`** - finds up to 10 emails matching given keywords with date filters applied, returns message IDs and matching snippets
2. **`read_email(message_id)`** - returns the full email body for a given message ID  
3. **`return_final_answer(answer, sources)`** - returns the final answer to the user's question with the list of message IDs that supported the answer

Our agent will have to learn to use these tools wisely to find the right information.


In [5]:
# Email Search Tools
# Database connection and email search/retrieval functionality

from typing import List, Optional

# Global database connection
conn = None


def get_conn():
    """Get database connection (singleton pattern)"""
    global conn
    if conn is None:
        conn = sqlite3.connect(
            f"file:{DEFAULT_DB_PATH}?mode=ro", uri=True, check_same_thread=False
        )
    return conn


def search_emails(
    inbox: str,
    keywords: List[str],
    from_addr: Optional[str] = None,
    to_addr: Optional[str] = None,
    sent_after: Optional[str] = None,
    sent_before: Optional[str] = None,
    max_results: int = 10,
) -> List[SearchResult]:
    """Search the email database based on keywords and filters."""
    if not keywords:
        raise ValueError("No keywords provided for search.")

    if max_results > 10:
        raise ValueError("max_results must be less than or equal to 10.")

    cursor = get_conn().cursor()
    where_clauses = []
    params = []

    # Keywords (FTS) - Fixed f-string issue
    fts_query = " ".join('"' + k.replace('"', '""') + '"' for k in keywords)
    where_clauses.append("fts.emails_fts MATCH ?")
    params.append(fts_query)

    # Inbox filter
    where_clauses.append("""
        (e.from_address = ? OR EXISTS (
            SELECT 1 FROM recipients r_inbox
            WHERE r_inbox.recipient_address = ? AND r_inbox.email_id = e.id
        ))
    """)
    params.extend([inbox, inbox])

    # Optional filters
    if from_addr:
        where_clauses.append("e.from_address = ?")
        params.append(from_addr)

    if to_addr:
        where_clauses.append("""
            EXISTS (
                SELECT 1 FROM recipients r_to
                WHERE r_to.recipient_address = ? AND r_to.email_id = e.id
            )
        """)
        params.append(to_addr)

    if sent_after:
        where_clauses.append("e.date >= ?")
        params.append(f"{sent_after} 00:00:00")

    if sent_before:
        where_clauses.append("e.date < ?")
        params.append(f"{sent_before} 00:00:00")

    sql = f"""
        SELECT
            e.message_id,
            snippet(emails_fts, -1, '<b>', '</b>', ' ... ', 15) as snippet
        FROM
            emails e JOIN emails_fts fts ON e.id = fts.rowid
        WHERE
            {" AND ".join(where_clauses)}
        ORDER BY
            e.date DESC
        LIMIT ?;
    """
    params.append(max_results)

    cursor.execute(sql, params)
    results = cursor.fetchall()

    return [SearchResult(message_id=row[0], snippet=row[1]) for row in results]


def read_email(message_id: str) -> Optional[Email]:
    """Retrieve a single email by its message_id."""
    cursor = get_conn().cursor()

    # Get email details
    email_sql = """
        SELECT message_id, date, subject, from_address, body, file_name
        FROM emails
        WHERE message_id = ?;
    """
    cursor.execute(email_sql, (message_id,))
    email_row = cursor.fetchone()

    if not email_row:
        return None

    msg_id, date, subject, from_addr, body, file_name = email_row

    # Get recipients
    recipients_sql = """
        SELECT recipient_address, recipient_type
        FROM recipients
        WHERE email_id = ?;
    """
    cursor.execute(recipients_sql, (message_id,))
    recipient_rows = cursor.fetchall()

    to_addresses = []
    cc_addresses = []
    bcc_addresses = []

    for addr, type in recipient_rows:
        type_lower = type.lower()
        if type_lower == "to":
            to_addresses.append(addr)
        elif type_lower == "cc":
            cc_addresses.append(addr)
        elif type_lower == "bcc":
            bcc_addresses.append(addr)

    return Email(
        message_id=msg_id,
        date=date,
        subject=subject,
        from_address=from_addr,
        to_addresses=to_addresses,
        cc_addresses=cc_addresses,
        bcc_addresses=bcc_addresses,
        body=body,
        file_name=file_name,
    )


# Test the search functionality
print("🔍 Testing email search functionality...")
try:
    # Test with some sample keywords
    test_results = search_emails(
        inbox="jeff.dasovich@enron.com", keywords=["meeting", "schedule"], max_results=3
    )
    print(f"✅ Found {len(test_results)} emails in test search")

    if test_results:
        # Test reading an email
        email = read_email(test_results[0].message_id)
        if email:
            print(f"✅ Successfully read email: {email.subject[:50]}...")
        else:
            print("❌ Failed to read email")

    print("✅ Email search tools working correctly!")

except Exception as e:
    print(f"❌ Error testing email search: {e}")
    print("This might be expected if the database hasn't been created yet.")

🔍 Testing email search functionality...
✅ Found 3 emails in test search
✅ Successfully read email: RE: Meeting w/PG&E this Friday, Right after DJ's M...
✅ Email search tools working correctly!


# Rubric and Correctness

As our model learns to answer user queries, we need a way to track its performance over time. We'll define a rubric for storing metrics to be reported to wandb at the end of each training run.

### Rubric

* `num_turns`: the number of turns the agent took to answer the question
* `answer_correct`: Whether the LLM judge determines the answer is correct
* `sources_correct`: Whether the agent cited the right source message ID
* `num_turns`: Number of conversation turns taken
* `attempted_answer`: Whether the agent actually tried to answer
* `ever_found_right_email`: Whether the correct email appeared in search results
* `ever_read_right_email`: Whether the agent read the correct email
* `ever_tried_to_read_invalid_email`: Whether the agent tried to read non-existent emails
* `num_sources`: Number of sources cited
* `cant_parse_tool_call`: JSON parsing errors in tool calls
* `bad_tool_call_name`: Calling non-existent tools
* `bad_tool_call_args`: Invalid tool arguments
* `ran_out_of_turns`: Exceeded maximum turns
* `returned_i_dont_know`: Explicitly said "I don't know"
* `prompt_tokens`: Total prompt tokens used
* `completion_tokens`: Total completion tokens generated

Importantly, we'll also judge the `correctness` of the agent's answers. We'll do this by comparing the agent's answer to the real answer for each scenario using an LLM judge and determining whether the agent's answer matches the real answer or not. We'll use `correctness` as a proxy of agent performance to track whether the model is getting better over time or running off the rails.


*Expand the cell below to see the full rubric and the correctness judge's prompt.*

In [None]:
# Judging Correctness
# Trajectory execution, reward calculation, and correctness judging

import textwrap
from dataclasses import dataclass, asdict
from pydantic import BaseModel, Field, validate_call, ValidationError
from litellm import acompletion
import litellm
from langchain_core.utils.function_calling import convert_to_openai_tool
from litellm.caching.caching import LiteLLMCacheType, Cache
from art.utils.litellm import convert_litellm_choice_to_openai
from tenacity import retry, stop_after_attempt
import logging

# Setup LiteLLM
litellm.cache = Cache(type=LiteLLMCacheType.DISK)
litellm.drop_params = True
logging.getLogger("weave.trace.op").setLevel(logging.WARNING)

# ==================== RUBRIC ====================


@dataclass
class FinalRubric:
    answer_correct: bool = False
    sources_correct: bool = False
    num_turns: int = 0
    attempted_answer: bool = False
    ever_found_right_email: bool = False
    ever_read_right_email: bool = False
    cant_parse_tool_call: bool = False
    bad_tool_call_name: bool = False
    bad_tool_call_args: bool = False
    ran_out_of_turns: bool = False
    returned_i_dont_know: bool = False
    num_sources: int = 0
    ever_tried_to_read_invalid_email: bool = False
    prompt_tokens: int = 0
    completion_tokens: int = 0

    def to_metrics(self) -> dict[str, float | int]:
        metrics = {k: int(v) for k, v in asdict(self).items()}
        metrics["failed_format_validation"] = int(
            self.bad_tool_call_name
            or self.bad_tool_call_args
            or self.cant_parse_tool_call
        )
        return metrics


# ==================== CORRECTNESS JUDGING ====================


class CorrectnessJudgeResponse(BaseModel):
    thinking: str = Field(description="Explanation of the reasoning process.")
    accept: bool = Field(description="Whether the AI answer should be accepted.")


@retry(stop=stop_after_attempt(3))
async def judge_correctness(
    answer: str, query: SyntheticQuery
) -> CorrectnessJudgeResponse:
    """Use an LLM to judge whether answer matches the expected answer."""
    system_prompt = textwrap.dedent("""
        You are given a question, the reference answer, and an AI-generated answer.

        Follow these steps:
        1. Identify EXACTLY what information the question is asking for.
        2. Extract ONLY the essential facts from the reference answer.
        3. Verify that every essential fact appears in the AI answer.
        4. If any essential fact is missing or contradicted, set accept to false.

        Return pure JSON with this schema:
        {
          "thinking": string,
          "accept": boolean
        }
    """)

    messages = [
        {"role": "system", "content": system_prompt},
        {
            "role": "user",
            "content": (
                f"Question: {query.question}\n"
                f"Reference answer: {query.answer}\n"
                f"AI answer: {answer}"
            ),
        },
    ]

    response = await acompletion(
        model="openai/gpt-4.1",
        messages=messages,
        caching=True,
        response_format=CorrectnessJudgeResponse,
    )

    first_choice = response.choices[0]
    raw_content = first_choice.message.content or "{}"

    try:
        return CorrectnessJudgeResponse.model_validate_json(raw_content)
    except Exception as e:
        return CorrectnessJudgeResponse(
            thinking=f"Parse error: {e}\nRaw: {raw_content}", accept=False
        )

# The Rollout Function

*\"Show me the incentive and I'll show you the outcome.\"* — Charlie Munger

(Charlie would have made a great RL researcher!)

A robust reward function is imperative to successful RL training. The purpose of the reward function is to let the model know when it did well and when it did poorly, so that its weights can be updated to make the desirable behavior more likely and avoid the undesirable behavior altogether. Historically, reward functions have been tricky to get right. Fortunately, a new technique called GRO (Group Reward Optimization) makes it much easier.

Rather than judge the agent's performance individually, GRO groups agent attempts by scenario and sends a history of them to an LLM judge to be ranked. In practice, it's much easier to judge which of two attempts is better than to judge what score each individual attempt should get. This approach also takes advantage of GRPO's ability to learn from relative differences in rewards, rather than absolute scores.




In [6]:
# Rollout Function
# Trajectory execution, reward calculation, and correctness judging

import asyncio
from pydantic import BaseModel, Field
import litellm
from litellm.caching.caching import LiteLLMCacheType, Cache
from tenacity import retry, stop_after_attempt
import logging

# Setup LiteLLM
litellm.cache = Cache(type=LiteLLMCacheType.DISK)
litellm.drop_params = True
logging.getLogger("weave.trace.op").setLevel(logging.WARNING)


@retry(stop=stop_after_attempt(3))
async def rollout(
    model: art.Model,
    scenario: SyntheticQuery,
) -> ProjectTrajectory:
    """Execute a single trajectory rollout."""
    rubric = FinalRubric()
    traj = ProjectTrajectory(
        messages_and_choices=[],
        reward=0,
        metadata={"email_inbox": scenario.inbox_address, "scenario_id": scenario.id},
        scenario=scenario,
    )

    system_prompt = textwrap.dedent(f"""\
        You are an email search agent. You are given a user query and tools to search emails.
        Use the tools to find the answer to the user's query. You may take up to {model.config.max_turns} turns.

        User's email address is {scenario.inbox_address}
        Today's date is {scenario.query_date}
    """)

    async def search_emails_tool(keywords: list[str]) -> list[dict]:
        """Search the user's email inbox for emails matching keywords."""
        resp = search_emails(
            inbox=scenario.inbox_address,
            sent_before=scenario.query_date,
            keywords=keywords,
        )

        for r in resp:
            if r.message_id == scenario.message_ids[0]:
                rubric.ever_found_right_email = True
        return [asdict(r) for r in resp]

    async def read_email_tool(message_id: str) -> Email | dict:
        """Read the content of an email."""
        email_content = read_email(message_id)

        if message_id == scenario.message_ids[0]:
            rubric.ever_read_right_email = True
        if email_content is None:
            return {"error": "Email not found"}
        else:
            return email_content.model_dump()

    async def return_final_answer(answer: str, sources: list[str]):
        """Return the final answer with sources."""
        rubric.attempted_answer = True
        traj.generated_answer = answer

        if answer == "I don't know":
            rubric.returned_i_dont_know = True
        else:
            async with traj.track_duration("determine_if_answer_is_correct"):
                judge_response = await judge_correctness(answer, scenario)
                traj.logs.append(f"Correctness judge response: {judge_response}")
                rubric.answer_correct = judge_response.accept
            rubric.sources_correct = scenario.message_ids[0] in sources

    tools = [search_emails_tool, read_email_tool, return_final_answer]
    traj.tools = [convert_to_openai_tool(t) for t in tools]

    traj.messages_and_choices = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": scenario.question},
    ]

    while not rubric.attempted_answer:
        rubric.num_turns += 1

        if rubric.num_turns > model.config.max_turns:
            rubric.ran_out_of_turns = True
            break

        litellm_model_name = model.config.litellm_model_name
        if litellm_model_name is None:
            litellm_model_name = f"hosted_vllm/{model.name}"

        async with traj.track_duration("llm_completion"):
            llm_response = await acompletion(
                model=litellm_model_name,
                base_url=model.inference_base_url,
                messages=traj.messages(),
                caching=not model.trainable,
                api_key=model.inference_api_key,
                max_completion_tokens=model.config.max_tokens,
                tools=traj.tools,
            )

        rubric.prompt_tokens += llm_response.usage.prompt_tokens
        rubric.completion_tokens += llm_response.usage.completion_tokens
        choice = llm_response.choices[0]

        # Handle only one tool call at a time
        if choice.message.tool_calls is not None and len(choice.message.tool_calls) > 1:
            choice.message.tool_calls = choice.message.tool_calls[:1]
        traj.messages_and_choices.append(convert_litellm_choice_to_openai(choice))

        if choice.message.tool_calls is None:
            rubric.bad_tool_call_name = True
            break

        for tool_call in choice.message.tool_calls:
            if tool_call is None:
                rubric.bad_tool_call_args = True
                break

            try:
                tool_args = json.loads(tool_call.function.arguments)
            except Exception:
                rubric.bad_tool_call_args = True
                break

            for tool_fn in tools:
                if tool_fn.__name__ == tool_call.function.name:
                    try:
                        validated_fn = validate_call(tool_fn)
                        result = await validated_fn(**tool_args)
                        traj.messages_and_choices.append(
                            {
                                "role": "tool",
                                "tool_call_id": tool_call.id,
                                "content": json.dumps(result),
                            }
                        )
                    except ValidationError as e:
                        rubric.bad_tool_call_args = True
                        traj.logs.append(
                            f"Invalid args for {tool_call.function.name}: {e}"
                        )
                        break
                    break
            else:
                rubric.bad_tool_call_name = True
                break

        if rubric.bad_tool_call_name or rubric.bad_tool_call_args:
            break

    traj.metrics = rubric.to_metrics()

    traj.finish()
    return traj


print("✅ Rollout function loaded successfully!")

✅ Rollout function loaded successfully!


## The Secret Sauce: LLM-as-Judge Rewards 🎓

Instead of using a basic reward function that judges the agent's performance individually, we're bringing in an AI teacher (GPT-4) to give nuanced feedback on our agent's performance.

**Why this matters:** The AI judge can see subtle differences that our simple reward function might miss. It can tell the difference between \"Agent A found the answer but took 8 searches\" and \"Agent B found it in just 3 searches.\" This nuanced feedback makes our training exponentially smarter.

**The Group Judge Process:**
1. Run our agent 4 times on the same question
2. Compare all 4 attempts side-by-side
3. Score each one based on efficiency, accuracy, and strategy
4. Use these scores to teach our model which approaches work best

This is a great way to teach elegant problem-solving. The model learns not just to find answers, but to find them efficiently and reliably.

We'll ask our LLM judge to consider the following:

**🎯 Answer Correctness** - The big one. Getting the right answer gets the highest reward.

**⚡ Turn Efficiency** - We give small \"extra credit\" for taking fewer turns. This is a proxy for latency—fewer round trips means faster answers.

**🚫 Hallucination Penalty** - If the agent can't find the right answer, saying \"I don't know\" is much better than making something up. We add a significant penalty to incorrect answers.

**🏆 Partial Credit** - Small rewards for finding the right email in search results, actually reading it, and getting the source right even if the answer is wrong.

*This is the difference between training a model that works and training one that excels.*


In [7]:
# Group Judgement
# LLM-based scoring of trajectory groups

from typing import List, Literal
from pydantic import BaseModel, Field

# ==================== JUDGE DATA TYPES ====================


class Issue(BaseModel):
    label: str = Field(description="A short label for the issue.")
    explanation: str = Field(description="A human-readable explanation of the issue.")
    severity: Literal["minor", "major", "fatal"] = Field(
        description="The severity of the issue."
    )


class RolloutScore(BaseModel):
    rollout_id: str = Field(description="The id of the rollout being scored.")
    explanation: str = Field(
        description="A short explanation of why you gave this score."
    )
    score: float = Field(description="A score between 0 and 1.")
    issues: List[str] = Field(
        description="The list of labels for each issue identified."
    )


class JudgeGroupResponse(BaseModel):
    new_issues: List[Issue] = Field(
        description="Any new issues identified on the rollouts."
    )
    scores: List[RolloutScore] = Field(description="The scores for each rollout.")


# ==================== GROUP JUDGE CLASS ====================

DEFAULT_RUBRIC = """
- A rollout that achieves its goal should always get a significantly higher score than a rollout that does not achieve its goal.
- A rollout that achieves its goal more efficiently should get a higher score than a rollout that achieves its goal less efficiently.
- If one rollout is only slightly better than another, the difference in scores should be small. If it is significantly better, the difference in scores should be large.
- You may give some partial credit for a rollout that makes progress towards its goal but does not complete it.
"""


class GroupJudge:
    """LLM-based judge for groups of rollouts."""

    def __init__(
        self,
        project: str,
        judge_model: str = "openai/gpt-4.1",
        rubric: str = DEFAULT_RUBRIC,
        initial_issues: List[Issue] = None,
    ):
        self.project = project
        self.judge_model = judge_model
        self.rubric = rubric
        self.all_issues = initial_issues or [
            Issue(
                label="looping",
                explanation="The assistant repeats itself unnecessarily but is able to recover.",
                severity="minor",
            ),
            Issue(
                label="fatal_looping",
                explanation="The assistant began repeating itself and is unable to recover.",
                severity="fatal",
            ),
        ]

    async def judge(self, rollouts: list[ProjectTrajectory]) -> list[ProjectTrajectory]:
        """Score every trajectory in rollouts and write the score to traj.reward."""
        if not rollouts:
            return rollouts

        # Determine common prefix to save tokens
        message_lists = [traj.messages() for traj in rollouts]
        common_prefix_len = 0
        for i, msg in enumerate(message_lists[0]):
            if all(msg_list[i] == msg for msg_list in message_lists):
                common_prefix_len += 1
            else:
                break

        user_text = ""
        if common_prefix_len > 0:
            common_prefix_messages = message_lists[0][:common_prefix_len]
            user_text += (
                "<context>\n" + json.dumps(common_prefix_messages) + "\n</context>\n\n"
            )

        # Serialize rollouts without common prefix
        serialized_rollouts = []
        for idx, (traj, full_messages) in enumerate(
            zip(rollouts, message_lists), start=1
        ):
            traj.metrics["independent_reward"] = traj.reward
            trimmed_messages = full_messages[common_prefix_len:]
            serialized_rollouts.append(
                f'<rollout id="{idx}">\n'
                + json.dumps(trimmed_messages)
                + "\n</rollout>"
            )

        user_text += "Rollouts:\n\n" + "\n\n".join(serialized_rollouts)

        judge_prompt = f"""
All of the rollouts below have been given the same goal. Your job is to consider each of them and give them a score between 0 and 1.

Grading standards:
{self.rubric}

Existing issues:
{json.dumps([issue.model_dump() for issue in self.all_issues], indent=2)}
"""

        messages = [
            {"role": "system", "content": judge_prompt},
            {"role": "user", "content": user_text},
        ]

        response = await acompletion(
            model=self.judge_model,
            messages=messages,
            response_format=JudgeGroupResponse,
            caching=True,
        )

        first_choice = response.choices[0]
        content = first_choice.message.content or "{}"
        parsed = JudgeGroupResponse.model_validate_json(content)

        # Merge new issues
        if parsed.new_issues:
            existing_labels = {fm.label for fm in self.all_issues}
            for fm in parsed.new_issues:
                if fm.label not in existing_labels:
                    self.all_issues.append(fm)
                    existing_labels.add(fm.label)

        # Apply scores
        for traj, score in zip(rollouts, parsed.scores):
            traj.metrics["group_judge_score"] = score.score
            traj.reward = (
                score.score
                if traj.metrics.get("failed_format_validation", 0) == 0
                else 0
            )
            traj.logs.append(f"Judge group explanation: {score.explanation}")

            # Record issue metrics
            for issue in self.all_issues:
                metric_key = f"issues/{issue.severity}/{issue.label}"
                traj.metrics[metric_key] = issue.label in score.issues

        return rollouts


print("✅ Group judgment functionality loaded successfully!")

✅ Group judgment functionality loaded successfully!


## Let's Train! 🚀

Ok, now that we have our dataset, environment, and reward function, we can train our model!

We're using our open source ART library for all the training. ART (Agent Reinforcement Trainer) is purpose-built to make it easy to train real-world multi-turn agents using Group Relative Policy Optimization (GRPO).

**The GRPO training loop is beautifully simple:**

1. **Load a batch** of 12 questions (and their correct answers) from our dataset
2. **Generate trajectories** - For each question, run the agent 4 times using our rollout function
3. **Score everything** - Use our reward function + LLM judge to score all 4 trajectories  
4. **Learn from the best** - Use all 48 trajectories and their rewards to calculate loss and update the model
5. **Repeat** until the model stops improving

**What you'll see:** Your agent literally getting smarter in real-time. It will start by making random searches, then gradually learn which keywords work, how to read emails efficiently, and when to say \"I don't know\" instead of hallucinating.

This is reinforcement learning in action - learning from experience, not just memorizing patterns. Let's watch it happen!


In [None]:
# Main Training Function
# Complete training orchestration for email search agent

from art.local import LocalBackend
from art.utils import iterate_dataset
import statistics
import os

# Ensure database is ready (will skip if already exists)
generate_database()

config = ProjectPolicyConfig(
    max_turns=5,
    max_tokens=1024,
    stupid_simple_reward_fn=True,  # Use simple reward for testing
    training_config=TrainingConfig(
        trajectories_per_group=4,  # Small numbers for testing
        groups_per_step=24,
        learning_rate=1e-5,
        eval_steps=5,
        val_set_size=10,
        training_dataset_size=20,
        num_epochs=4,
        minimum_reward_std_dev=0.0,
    ),
)
model = art.TrainableModel(
    name="agent-001",
    project="art-e-tutorial",
    config=config,
    base_model="Qwen/Qwen2.5-7B-Instruct",  # Use smaller model for testing
)

if model.config.training_config is None:
    raise ValueError("Training config is not set")

group_judge = GroupJudge(
    project=model.project,
    judge_model=model.config.training_config.group_judge_model,
)

with LocalBackend() as backend:
    print(
        f"🔄 Pulling from S3 bucket: `{os.environ.get('BACKUP_BUCKET', 'default-bucket')}`"
    )

    # Try to pull from S3 if bucket is configured
    if "BACKUP_BUCKET" in os.environ:
        await backend._experimental_pull_from_s3(
            model,
            s3_bucket=os.environ["BACKUP_BUCKET"],
            verbose=True,
        )
    else:
        print("⚠️  No S3 bucket configured, skipping pull")

    await model.register(backend)

    print("📚 Loading training data...")
    tc = model.config.training_config
    seed = tc.training_dataset_seed if tc is not None else None
    train_scenarios = load_synthetic_queries(
        split="train",
        limit=tc.training_dataset_size if tc is not None else None,
        seed=seed,
    )
    print("📚 Loading validation data...")
    val_scenarios = load_synthetic_queries(
        split="test", limit=model.config.training_config.val_set_size
    )

    print(f"📊 Training data size: {len(train_scenarios)}")
    print(f"📊 Validation data size: {len(val_scenarios)}")

    train_iterator = iterate_dataset(
        train_scenarios,
        groups_per_step=model.config.training_config.groups_per_step,
        num_epochs=model.config.training_config.num_epochs,
        initial_step=await model.get_step(),
    )

    for batch, epoch, global_step, epoch_step in train_iterator:
        if global_step % model.config.training_config.eval_steps == 0:
            print(f"\n🔍 Evaluating at Iteration {global_step}")
            # Note: Evaluation/benchmarking code removed as requested
            await model.delete_checkpoints()

            if "BACKUP_BUCKET" in os.environ:
                await backend._experimental_push_to_s3(
                    model,
                    s3_bucket=os.environ["BACKUP_BUCKET"],
                )

        print(f"🎯 Generating trajectories for step {global_step}...")
        groups = await art.gather_trajectory_groups(
            (
                art.TrajectoryGroup(
                    (
                        rollout(model, scenario)
                        for _ in range(
                            model.config.training_config.trajectories_per_group
                        )
                    )
                )
                for scenario in batch
            )
        )

        # Apply group judge if configured
        training_cfg = model.config.training_config
        print("⚖️  Applying group judge")
        judge_tasks = [
            group_judge.judge(cast(list[ProjectTrajectory], g.trajectories))
            for g in groups
        ]

        results = await asyncio.gather(*judge_tasks, return_exceptions=True)

        successful_groups = []
        for grp_idx, (g, res) in enumerate(zip(groups, results)):
            if isinstance(res, Exception):
                print(
                    f"⚠️  WARNING: Judge group failed for group {grp_idx} at step {global_step}: {res!r}"
                )
            else:
                successful_groups.append(g)

        groups = successful_groups

        for g in groups:
            for t in g.trajectories:
                report_trajectory(model, t, global_step)

        if not groups:
            print(
                f"⚠️  WARNING: All judge groups failed at step {global_step}; skipping training step"
            )
            continue

        # Filter groups by reward standard deviation
        if (
            training_cfg.minimum_reward_std_dev is not None
            and training_cfg.minimum_reward_std_dev > 0
        ):
            print(
                f"📊 Filtering groups by reward std dev (min: {training_cfg.minimum_reward_std_dev})..."
            )
            filtered_groups = []
            for grp_idx, g in enumerate(groups):
                rewards = [t.reward for t in g.trajectories]
                if len(rewards) < 2:
                    std_dev = 0.0
                else:
                    std_dev = statistics.pstdev(rewards)
                if std_dev < training_cfg.minimum_reward_std_dev:
                    print(
                        f"⚠️  Dropping group {grp_idx} at step {global_step} (std dev: {std_dev:.4f})"
                    )
                    continue
                filtered_groups.append(g)

            groups = filtered_groups

            if not groups:
                print(
                    f"⚠️  WARNING: All groups dropped due to low std dev at step {global_step}; skipping training step"
                )
                continue

        print(f"🎓 Training model with {len(groups)} groups...")
        await model.train(
            groups,
            config=art.TrainConfig(
                learning_rate=model.config.training_config.learning_rate
            ),
        )
        print(f"✅ Completed training step {global_step}")

    # Final backup
    print("💾 Final model backup...")
    if "BACKUP_BUCKET" in os.environ:
        await backend._experimental_push_to_s3(
            model,
            s3_bucket=os.environ["BACKUP_BUCKET"],
        )
    print("🎉 Training finished successfully!")

## 🎉 You Just Built ART·E!

**Congratulations!** You just built an email research agent that can beat o3 on this task!

### What You Just Accomplished:
- 🏗️ **Built a realistic environment** using 500K real business emails
- 🧠 **Created synthetic training data** with thousands of realistic question-answer pairs
- 🛠️ **Designed a minimal but powerful toolset** (search, read, answer)
- 🎯 **Implemented a multi-objective reward function** that optimizes for accuracy AND efficiency
- 🎓 **Used LLM-as-judge** to provide nuanced feedback during training
- 🚀 **Trained with GRPO** to create an agent that learns from experience

**This is the future of AI.** You've just built something that:
- Is **faster** than o3 (fewer turns to find answers)
- Is **cheaper** than o3 (smaller model, more efficient)
- Is **more accurate** than o3 (better at finding the right emails)
- **Learns from experience** instead of just memorizing patterns

**The possibilities are endless!** You now have the blueprint to build RL agents for any task. Email search was just the beginning - what will you build next? 🌟

*P.S. - Your agent is now equipped to handle the \"How do I RSVP for my daughter's classroom party?\" and \"What time is my brother's flight?\" questions that started this whole journey. So 2025! 😉*
