In [None]:
from pydantic_ai import Agent
from utils import AgentConfig, NamedCallback

In [None]:

orchestrator_instructions = """ 

You are the orchestrator agent. 
Your task is to delegate tasks to the right agent:

- clarifier_agent : rewrite user query in three distinct ways
- search_agent : use the Huberman vector_search tool to answer questions from the vector store
- websearch_agent 
    - search_web: use the Brave API to search for relevant information based on the user interest and a list of selected website domains
    - web_page_content: fetch Markdown content of web pages

RULES:
- always use the clarifier_agent first
- always check the vector search results before invoking the websearch_agent
- only invoke the websearch_agent if the user asks for more information outside of what is contained in the vector store or explicitly requests to do a web search
"""

def create_orchestrator(config: AgentConfig = None):
    """Build and return the orchestrator Agent with configured instructions."""
    if config is None:
        config = AgentConfig()

    orchestrator = Agent(
        name = "orchestrator",
        instructions = orchestrator_instructions,
        model = config.model
    )
    return orchestrator

orchestrator = create_orchestrator()
orchestrator_callback = NamedCallback(orchestrator)


In [None]:
from typing import List
from pydantic import BaseModel
from pydantic_ai import RunContext


class RewriteResponse(BaseModel):
    rewrites: List[str]

clarifier_instructions = """ 

You assist the search_agent and websearch_agent.
You take a user's query and rewrite it 3 distinct ways using different phrasing, key terms, related subquestions.

"""

def create_clarifier_agent(config: AgentConfig = None) -> Agent:
    """Instantiate the clarifier agent that rewrites user queries."""

    if config is None: 
        config = AgentConfig()

    clarifier_agent = Agent(
        name="clarifier_agent",
        instructions=clarifier_instructions,
        model=config.model,
        output_type=RewriteResponse
    )
    return clarifier_agent

clarifier_agent = create_clarifier_agent()

@orchestrator.tool
async def rewrite_user_query(ctx: RunContext, query:str) -> str:
    """Use the clarifier agent to produce three rewritten queries."""
    callback = NamedCallback(clarifier_agent)
    results = await clarifier_agent.run(user_prompt=query, event_stream_handler=callback)
    return results.output


In [None]:

from utils import AgentConfig
from search_agent import create_search_agent, SearchResultResponse

search_agent = create_search_agent()

@orchestrator.tool
async def vector_search(ctx: RunContext, query:str, config:AgentConfig=None):
    """Run the domain search agent before falling back to web search."""
    if config is None:
        config = AgentConfig()

    prior_outputs = []
    for m in ctx.messages:
        for p in m.parts:
            if p.part_kind == "tool-return" and p.tool_name == "rewrite_user_query":
                prior_outputs.append(p.content)

    prior_text = "
".join(str(x) for x in prior_outputs)

    prompt = f"""
    User query:
    {query}

    Prior clarification:
    {prior_text}
    """.strip()

    callback = NamedCallback(search_agent)

    results = await search_agent.run(
        user_prompt=prompt, 
        event_stream_handler=callback, 
        output_type=SearchResultResponse

    )
    return results.output.format_response()


In [None]:
from websearch_agent import websearch_instructions, ResearchReport
import requests
import random
import os


def create_websearch_agent(config:AgentConfig = None):
    """Create the websearch agent that uses Brave API tools."""
    if config is None:
        config = AgentConfig()
    
    websearch_agent = Agent(
        name="websearch_agent",
        instructions=websearch_instructions,
        model=config.model
    )
    return websearch_agent

websearch_agent = create_websearch_agent()


@orchestrator.tool
async def search_web(ctx: RunContext, query:str):
    """Call the Brave API and delegate summarization to the web agent."""

    preferred_sites = [
    "brainfacts",
    "nimh",
    "nih"
    "alleninstitute",
    "mit",
    "stanford",
    "acsm",
    "nsca",
    "acefitness",
    "exerciseismedicine",
    "bjsm",
    "apa",
    "stanford",
    "motivationscience",
    "berkeley",
    "mayoclinic",
    "clevelandclinic",
    "harvard",
    "hopkinsmedicine",
    "cdc",
    "mit",
    "mpg",
    "yale",
    "scientificamerican",
    "psychologytoday",
    "nature",
    "science"
]
    
    urls = "\n".join(preferred_sites)

    url = f"https://api.search.brave.com/res/v1/web/search?q={query}"
    headers = {
        "Accept": "application/json",
        "X-Subscription-Token": os.getenv("BRAVE_API_KEY")
    }

    try:
        response = requests.get(url, headers=headers)
        results = response.json().get("web", {}).get("results", [])
        urls_all = [item.get("url") for item in results if item.get("url")]  
        urls_filtered = [u for u in urls_all if any(i in u for i in urls)]

        urls_filtered_5 = random.sample(urls_filtered, min(5, len(urls_filtered)))
    except (requests.exceptions.RequestException, UnicodeDecodeError) as e:
        print(f" Error fetching content for {query}: {e}")

    prior_outputs = []
    for m in ctx.messages:
        for p in m.parts:
            if p.part_kind == "tool-return" and p.tool_name == "rewrite_user_query":
                prior_outputs.append(p.content)

    prior_text = "\n".join(str(x) for x in prior_outputs)

    prompt = f"""
    User query:
    {query}
    
    Prior clarification:
    {prior_text}
    """.strip()

    callback = NamedCallback(websearch_agent)

    results = await websearch_agent.run(
        user_prompt=prompt, 
        event_stream_handler=callback, 

    )
    return results



In [None]:

@orchestrator.tool
async def web_page_content(ctx: RunContext, url:str, query:str, config: AgentConfig = None):
    """Fetch page content through the reader proxy and summarize it."""
    reader_url_prefix = "https://r.jina.ai/"
    reader_url = reader_url_prefix + url

    try:
        response = requests.get(reader_url, timeout=45)
        response.raise_for_status()  # raises for 4xx/5xx HTTP errors
        content = response.content.decode("utf-8")
    except (requests.exceptions.RequestException, UnicodeDecodeError) as e:
        # Optional: log or print the error for debugging
        print(f"Error fetching content from {url}: {e}")
        # return None

    prior_outputs = []
    for m in ctx.messages:
        for p in m.parts:
            if p.part_kind == "tool-return" and p.tool_name == "search_web":
                prior_outputs.append(p.content)

    prior_text = "\n".join(str(x) for x in prior_outputs)

    prompt = f"""
    User query:
    {query}
    
    Prior clarification:
    {prior_text}
    """.strip()

    callback = NamedCallback(websearch_agent)

    results = await websearch_agent.run(
        user_prompt=prompt, 
        event_stream_handler=callback, 
        output_type=ResearchReport
    )
    return results.output.format_response()

In [None]:
message_history = []
question = "alzheimer's and coffee"
orchestrator_results = await orchestrator.run(
    user_prompt=question,
    message_history=message_history,
    event_stream_handler=orchestrator_callback,
)

In [None]:
messages = orchestrator_results.new_messages()
message_history.extend(messages)
print_messages(messages)