# Lesson Introduction

Welcome! In this lesson, we’ll see how to implement a custom tool for **Retrieval-Augmented Generation (RAG) retrieval** in an AI agent. Previously, you learned how RAG agents use a knowledge base to answer questions. But what if you want your agent to use a specialized retrieval method or expose this as a reusable tool? That’s where custom tools come in.

Our goal: learn to design, implement, and integrate a custom retrieval tool that fetches the most relevant knowledge chunks for any user query. By the end, you’ll know how to build such a tool, validate its inputs, and connect it to your agent for context-aware responses.

## Why Custom Tools for RAG Retrieval?

Why do we need a custom tool for RAG retrieval? Imagine you’re building a personal learning assistant. You want it to answer questions using your own notes or learning plans, not just generic knowledge.

In AI agents, a “tool” is a callable function or service for a specific task. For RAG, this means a tool that searches your knowledge base and returns the most relevant information. For example, if a user asks, “What are my learning plans for SQL?” the agent should use the retrieval tool to find something like “Review different types of SQL joins — especially LEFT and FULL OUTER joins.”

A custom tool gives you control over how retrieval works, what data is returned, and how it’s formatted. This flexibility is key to building agents tailored to your needs.

## Tool Parameters and Input Validation

Tools should validate their inputs. This ensures the tool gets the right data and handles errors well. **Pydantic** is a popular library for this in Python.

Define a Pydantic model for the tool’s arguments:
```python
from pydantic import BaseModel

class FunctionArgs(BaseModel):
    user_query: str
```

This model enforces that `user_query` is a string. If the input is missing or incorrect, Pydantic raises an error.


## Using Input Validation in the Tool

Use this model in the tool’s function:

```python
async def run_function(ctx, args: str) -> str:
    parsed = FunctionArgs.model_validate_json(args)
    chunks = retrieve_top_chunks(parsed.user_query)
    return "\n".join([c["chunk"] for c in chunks])
# Output (for a relevant query):
# Review different types of SQL joins — especially LEFT and FULL OUTER joins.
```

`model_validate_json` parses and validates the input. This step is crucial for robust tools.

## Wrapping Retrieval Logic in a FunctionTool

With retrieval logic and input validation ready, let’s wrap it in a **FunctionTool**. This class describes the tool’s interface and connects it to the agent.

```python
from agents import FunctionTool

rag_retrieval_tool = FunctionTool(
    name="rag_retrieval_tool",
    description="A tool to retrieve context from RAG based on user query",
    params_json_schema={
        "type": "object",
        "properties": {
            "user_query": {"type": "string"},
        },
        "required": ["user_query"],
        "additionalProperties": False
    },
    on_invoke_tool=run_function
)
```

Key parameters:

- `name`: The tool’s unique identifier.
- `description`: Explains what the tool does.
- `params_json_schema`: Describes the expected input.
- `on_invoke_tool`: The function to call (our validated retrieval function).

This makes the tool discoverable and usable by any agent that supports tool invocation.

## Integrating the Tool with the Agent: part 1

Now, connect your custom tool to the agent. This lets the agent use the tool to retrieve information from the knowledge base.

```python
from agents import Agent

AGENT = Agent(
    name="Learning Assistant",
    instructions=(
        "You are a personal learning assistant with access to rag tool. "
        "Whenever asked a question about learning plans, use the RAG retrieval tool to get context from the DB and answer user questions."
    ),
    tools=[rag_retrieval_tool]
)
```

## Integrating the Tool with the Agent: part 2

Ask the agent a question, and it will use the retrieval tool as needed:

```python
from agents import Runner

def ask_agent(prompt):
    result = Runner.run_sync(AGENT, prompt)
    return result.final_output

# Example usage
response = ask_agent("What are my learning plans for SQL?")
print(response)  # Review different types of SQL joins — especially LEFT and FULL OUTER joins.
```

The agent automatically invokes the retrieval tool, fetches relevant knowledge, and uses it to answer the user’s question.

## Lesson Summary

You learned how to implement a custom tool for RAG retrieval and integrate it with an AI agent. We covered:

- Why custom tools are important for flexible, context-aware agents
- How to design a retrieval function that queries your knowledge base
- Using Pydantic for input validation
- Wrapping your logic in a `FunctionTool`
- Connecting the tool to an agent for relevant, knowledge-based answers

You now have the foundation to build and extend your own retrieval tools for any use case.

Now it’s your turn! Next, you’ll get hands-on experience building and testing your own custom RAG retrieval tool. You’ll practice designing retrieval logic, validating inputs, and integrating your tool with an agent. This will help solidify your understanding and prepare you for real-world applications.


In [None]:
from typing import Any
from pydantic import BaseModel
from agents import Agent, Runner, FunctionTool
from rag_builder import load_collection

def retrieve_top_chunks(user_query, top_k=3):
    collection = load_collection()
    # TODO: Call the collection's query method to retrieve top_k chunks based on user_query
    results = ______
    retrieved_chunks = []
    for i in range(len(results['documents'][0])):
        retrieved_chunks.append({
            "chunk": results['documents'][0][i],
            "id": results['metadatas'][0][i]['id'],
            "distance": results['distances'][0][i]
        })
    return retrieved_chunks

class FunctionArgs(BaseModel):
    # TODO: Define the function arguments schema for user query as a string

async def run_function(ctx: Any, args: str) -> str:
    parsed = FunctionArgs.model_validate_json(args)
    # TODO: Retrieve top chunks based on user query
    chunks = ______
    return "\n".join([f"{c['chunk']} (Score: {c['distance']})" for c in chunks])

rag_retrieval_tool = FunctionTool(
    name="rag_retrieval_tool",
    description="A tool to retrieve context from RAG based on user query",
    params_json_schema={
        "type": "object",
        "properties": {
            "user_query": {"type": "string"},
        },
        "required": ["user_query"],
        "additionalProperties": False
    },
    on_invoke_tool=________  # TODO: Add the function to be invoked here
)

AGENT = Agent(
    name="Learning Assistant",
    instructions=(
        "You are a personal learning assistant with access to rag tool. "
        "Whenever asked a question about learning plans, use the RAG retrieval tool to get context from the DB and answer user questions."
    ),
    tools=[rag_retrieval_tool]
)

def ask_agent(prompt):
    result = Runner.run_sync(AGENT, prompt)
    return result.final_output

In [None]:
from typing import Any
from pydantic import BaseModel
from agents import Agent, Runner, FunctionTool
from rag_builder import load_collection

def retrieve_top_chunks(user_query, top_k=3):
    collection = load_collection()
    # TODO: Call the collection's query method to retrieve top_k chunks based on user_query
    results = collection.query(
        query_texts=[query],
        n_results=top_k
    )
    retrieved_chunks = []
    for i in range(len(results['documents'][0])):
        retrieved_chunks.append({
            "chunk": results['documents'][0][i],
            "id": results['metadatas'][0][i]['id'],
            "distance": results['distances'][0][i]
        })
    return retrieved_chunks

class FunctionArgs(BaseModel):
    # TODO: Define the function arguments schema for user query as a string
    user_query: str

async def run_function(ctx: Any, args: str) -> str:
    parsed = FunctionArgs.model_validate_json(args)
    # TODO: Retrieve top chunks based on user query
    chunks = retrieve_top_chunks(parsed.user_query)
    return "\n".join([f"{c['chunk']} (Score: {c['distance']})" for c in chunks])

rag_retrieval_tool = FunctionTool(
    name="rag_retrieval_tool",
    description="A tool to retrieve context from RAG based on user query",
    params_json_schema={
        "type": "object",
        "properties": {
            "user_query": {"type": "string"},
        },
        "required": ["user_query"],
        "additionalProperties": False
    },
    on_invoke_tool=run_function  # TODO: Add the function to be invoked here
)

AGENT = Agent(
    name="Learning Assistant",
    instructions=(
        "You are a personal learning assistant with access to rag tool. "
        "Whenever asked a question about learning plans, use the RAG retrieval tool to get context from the DB and answer user questions."
    ),
    tools=[rag_retrieval_tool]
)

def ask_agent(prompt):
    result = Runner.run_sync(AGENT, prompt)
    return result.final_output

You've come a long way! Now, let's put your skills to the test by implementing a custom tool for RAG retrieval from scratch. This tool will enable an agent to fetch the most relevant knowledge chunks based on a user query.

You will need to:

Implement the retrieval function that queries the knowledge base for relevant chunks.
Define a Pydantic model for input validation.
Implement the function that validates input and calls the retrieval logic.
Create a FunctionTool to wrap the retrieval logic.
Integrate the tool with an agent so it can use the tool to answer user questions.

In [None]:
# Final exercise
from typing import Any
from pydantic import BaseModel
from agents import Agent, Runner, FunctionTool
from rag_builder import load_collection

def retrieve_top_chunks(user_query, top_k=1):
    # TODO: Load the collection using the load_collection function
    
    collection = load_collection()
    results = collection.query(
        query_texts=[user_query],
        n_results=top_k
    )
    
    
    retrieved_chunks = []
    for i in range(len(results['documents'][0])):
        retrieved_chunks.append({
            "chunk": results['documents'][0][i],
            "id": results['metadatas'][0][i]['id'],
            "distance": results['distances'][0][i]
        })
    return retrieved_chunks

# TODO: Define a Pydantic model for input validation with user query as a string
class FunctionArgs(BaseModel):
    user_query: str

# TODO: Implement the run_function to use the retrieval function and return results
# - This function should accept a context and args, parse the args, and return the retrieved chunks as a string.
async def run_function(ctx, args: str) -> str:
    parsed = FunctionArgs.model_validate_json(args)
    chunks = retrieve_top_chunks(parsed.user_query)
    return "\n".join([c["chunk"] for c in chunks])
# Output (for a relevant query):
# Review different types of SQL joins — especially LEFT and FULL OUTER joins.

# TODO: Create a FunctionTool to wrap the retrieval logic
# - The tool should have a name, description, params_json_schema, and on_invoke_tool function.
# - The params_json_schema should define the expected input structure:
#   - with type "object", properties "user_query" of type "string", and required "user_query"
# - The on_invoke_tool should call the run function with the provided context and args.
rag_retrieval_tool = FunctionTool(
    name="rag_retrieval_tool",
    description="A tool to retrieve context from RAG based on user query",
    params_json_schema={
        "type": "object",
        "properties": {
            "user_query": {"type": "string"},
        },
        "required": ["user_query"],
        "additionalProperties": False
    },
    on_invoke_tool=run_function
)


# TODO: Create an Agent and integrate the FunctionTool
# - The agent should have a name, e.g. "Learning Assistant",
# - instructions that describe its role, e.g. "You are a personal learning assistant with access to rag tool. Whenever asked a question about learning plans, use the RAG retrieval tool to get context from the DB and answer user questions.",
# - and the tools list containing the rag_retrieval_tool.
AGENT = Agent(
    name="Learning Assistant",
    instructions=(
        "You are a personal learning assistant with access to rag tool. "
        "Whenever asked a question about learning plans, use the RAG retrieval tool to get context from the DB and answer user questions."
    ),
    tools=[rag_retrieval_tool]
)

def ask_agent(prompt):
    result = Runner.run_sync(AGENT, prompt)
    return result.final_output