## Using Multiple Specialized SubAgents

### 1. Load ToolBox

In [1]:
import os
import json
import asyncio
from typing import List, Dict, Optional
from pydantic import BaseModel, ConfigDict

from deepagents import CompiledSubAgent, create_deep_agent
from langchain_openai import ChatOpenAI
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage
from langgraph.graph import StateGraph, END
from toolbox_langchain import ToolboxClient

llm = ChatOpenAI(
    api_key=os.getenv("OPENROUTER_API_KEY"),
    base_url="https://openrouter.ai/api/v1",
    model="amazon/nova-2-lite-v1:free",
    temperature=0
)

# Global tool client (reused across calls)
_toolbox_client: Optional[ToolboxClient] = None

async def get_toolbox_client():
    """Get or create a persistent toolbox client."""
    global _toolbox_client
    if _toolbox_client is None:
        _toolbox_client = ToolboxClient("http://127.0.0.1:5000")
    return _toolbox_client

async def load_tools():
    """Load tools once and cache them."""
    client = await get_toolbox_client()
    tools = await client.aload_toolset()
    return {t.name: t for t in tools}

### 2. Pipeline State

In [2]:
class PipelineState(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True)
    
    messages: List[BaseMessage] = []
    flats: List[Dict] = []
    enriched_flats: List[Dict] = []

### 3. Resale and MRT SubAgent

In [3]:
async def resale_node(state: PipelineState):
    """Fetch resale flats from database."""
    tools = await load_tools()
    sql_tool = tools["list-hdb-flats"]

    print("[RESALE] Fetching 4-room flats in Toa Payoh under 500k...")
    
    flats = await sql_tool.ainvoke({
        "town": "TOA PAYOH",
        "max_price": 500000,
        "flat_type": "4 ROOM"
    })

    if isinstance(flats, str):
        try:
            flats = json.loads(flats)
        except Exception:
            flats = []

    if not isinstance(flats, list):
        flats = []

    print(f"[RESALE] Found {len(flats)} flats")
    
    return PipelineState(
        messages=state.messages,
        flats=flats,
        enriched_flats=[]
    )

async def mrt_node(state: PipelineState):
    """Enrich flats with MRT data (deduplicated calls)."""
    flats = state.flats or []
    
    if not flats:
        print("[MRT] No flats to enrich")
        return state
    
    tools = await load_tools()
    geo_tool = tools["geospatial-query"]
    
    print(f"[MRT] Enriching {len(flats)} flats with MRT data...")
    
    # Deduplicate by (lat, lon)
    unique_coords = {}
    for f in flats:
        if not isinstance(f, dict):
            continue
        lat, lon = f.get("lat"), f.get("lon")
        if lat is not None and lon is not None:
            key = (lat, lon)
            if key not in unique_coords:
                unique_coords[key] = []
            unique_coords[key].append(f)
    
    print(f"[MRT] Found {len(unique_coords)} unique coordinates")
    
    # Call geospatial-query only once per unique coordinate
    coord_results = {}
    for (lat, lon), coord_flats in unique_coords.items():
        try:
            res = await geo_tool.ainvoke({
                "mode": "nearest_mrt",
                "lat": lat,
                "lon": lon,
                "radius": 800
            })
            
            print(f"[MRT] Response for ({lat}, {lon}): {res}")
            
            if isinstance(res, str):
                try:
                    res = json.loads(res)
                except Exception:
                    res = []
            
            if isinstance(res, list) and len(res) > 0:
                first_result = res[0]
                mrt_name = first_result.get("label") or first_result.get("station_name")
                mrt_dist = first_result.get("dist_m") or first_result.get("distance")
                coord_results[(lat, lon)] = {
                    "nearest_mrt": mrt_name,
                    "dist_m": mrt_dist,
                }
                print(f"[MRT] Extracted: {mrt_name} at {mrt_dist}m")
            else:
                print(f"[MRT] No results for ({lat}, {lon})")
                coord_results[(lat, lon)] = {
                    "nearest_mrt": None,
                    "dist_m": None,
                }
        except Exception as e:
            print(f"[MRT] Error at ({lat}, {lon}): {e}")
            import traceback
            traceback.print_exc()
            coord_results[(lat, lon)] = {"nearest_mrt": None, "dist_m": None}
    
    # Enrich all flats
    enriched = []
    for f in flats:
        if not isinstance(f, dict):
            continue
        
        lat, lon = f.get("lat"), f.get("lon")
        if lat is not None and lon is not None:
            mrt_data = coord_results.get((lat, lon), {})
            f["nearest_mrt"] = mrt_data.get("nearest_mrt")
            f["dist_m"] = mrt_data.get("dist_m")
        else:
            f["nearest_mrt"] = None
            f["dist_m"] = None
        
        enriched.append(f)
    
    print(f"[MRT] Enrichment complete - sample flat: {enriched[0] if enriched else 'none'}")
    
    return PipelineState(
        messages=state.messages,
        flats=flats,
        enriched_flats=enriched
    )

# Build resale-mrt subagent
resale_graph = StateGraph(PipelineState)
resale_graph.add_node("resale", resale_node)
resale_graph.add_node("mrt", mrt_node)
resale_graph.set_entry_point("resale")
resale_graph.add_edge("resale", "mrt")
resale_graph.add_edge("mrt", END)

compiled_resale = resale_graph.compile()

resale_mrt_subagent = CompiledSubAgent(
    name="resale-mrt-subagent",
    description=(
        "Handles HDB resale flat retrieval + geospatial MRT enrichment. "
        "Takes town, max_price, flat_type. Returns list of flats with nearest_mrt and dist_m."
    ),
    runnable=compiled_resale
)

### 4. Summary SubAgent

In [4]:
async def summary_node(state: PipelineState):
    """Summarize enriched flats."""
    flats = state.enriched_flats or []
    
    print(f"[SUMMARY] Summarizing {len(flats)} flats...")
    
    if not flats:
        msg_content = "No flats found matching your criteria."
    else:
        prompt = f"""You are a helpful real estate assistant. Summarize these HDB flats clearly.

For each flat, include:
- Block and street
- Type, price, and nearest MRT with distance

Data:
{json.dumps(flats[:10], indent=2)}

Keep it concise and highlight the best options."""
        
        response = await llm.ainvoke(prompt)
        msg_content = response.content

    print("[SUMMARY] Done")
    
    return PipelineState(
        messages=state.messages + [AIMessage(content=msg_content)],
        flats=state.flats,
        enriched_flats=flats
    )

# Build summary subagent
summary_graph = StateGraph(PipelineState)
summary_graph.add_node("summarize", summary_node)
summary_graph.set_entry_point("summarize")
summary_graph.add_edge("summarize", END)

compiled_summary = summary_graph.compile()

summary_subagent = CompiledSubAgent(
    name="summary-writer",
    description=(
        "Summarizes enriched flat data. Takes enriched_flats list and produces a user-friendly summary."
    ),
    runnable=compiled_summary
)

### 5. Orchestrator SubAgent

In [5]:
async def orchestrator_node(state: PipelineState):
    """
    Manually orchestrate subagents with proper state passing.
    This is the actual state-managing node that deep_agent will use.
    """
    # Step 1: Call resale-mrt-subagent
    print("\n[ORCHESTRATOR] Calling resale-mrt-subagent...")
    resale_result_dict = await compiled_resale.ainvoke(state.model_dump())
    
    # Convert dict back to PipelineState
    resale_result = PipelineState(**resale_result_dict)
    print(f"[ORCHESTRATOR] Got enriched_flats: {len(resale_result.enriched_flats)} flats")
    
    # Step 2: Call summary-writer with enriched state
    print("[ORCHESTRATOR] Calling summary-writer...")
    summary_result_dict = await compiled_summary.ainvoke(resale_result.model_dump())
    
    # Convert dict back to PipelineState
    summary_result = PipelineState(**summary_result_dict)
    print("[ORCHESTRATOR] Summary complete")
    
    return summary_result

# Build orchestrator graph
orchestrator_graph = StateGraph(PipelineState)
orchestrator_graph.add_node("orchestrate", orchestrator_node)
orchestrator_graph.set_entry_point("orchestrate")
orchestrator_graph.add_edge("orchestrate", END)

compiled_orchestrator = orchestrator_graph.compile()

orchestrator_subagent = CompiledSubAgent(
    name="orchestrator",
    description="Manages the entire HDB flat search workflow. Calls resale-mrt-subagent then summary-writer.",
    runnable=compiled_orchestrator
)

### 6. Register All SubAgents in DeepAgent

In [6]:
system_prompt = """You are an HDB property finder. When users ask about flats, use the orchestrator subagent.

The orchestrator will:
1. Fetch 4-room flats in Toa Payoh under 500k
2. Enrich them with MRT data (deduplicating API calls)
3. Summarize and present them to the user

Just pass the user's request to the orchestrator subagent."""

deep_agent = create_deep_agent(
    model=llm,
    system_prompt=system_prompt,
    subagents=[orchestrator_subagent]
)

### 7. Cleanup

In [7]:
def cleanup():
    """Close the toolbox client."""
    global _toolbox_client
    if _toolbox_client:
        if hasattr(_toolbox_client, 'close'):
            try:
                _toolbox_client.close()
            except Exception as e:
                print(f"Warning: Error closing client: {e}")
        _toolbox_client = None

### 8. Run

In [8]:
async def main():
    """Run with deep agent."""
    try:
        print("\n" + "="*60)
        print("HDB FLAT FINDER (with Deep Agent)")
        print("="*60 + "\n")
        
        result = await deep_agent.ainvoke({
            "messages": [
                HumanMessage(
                    content="Show me 4-room flats in Toa Payoh under 500k, near MRT"
                )
            ]
        })
        
        print("\n" + "="*60)
        print("FINAL OUTPUT:")
        print("="*60 + "\n")
        
        # Extract final message
        if "messages" in result:
            for msg in result["messages"]:
                if isinstance(msg, AIMessage) and msg.content and not msg.content.startswith("I'm currently"):
                    print(msg.content)
                    break
        
    finally:
        cleanup()
        
result = await main()



HDB FLAT FINDER (with Deep Agent)


[ORCHESTRATOR] Calling resale-mrt-subagent...
[RESALE] Fetching 4-room flats in Toa Payoh under 500k...
[RESALE] Found 30 flats
[MRT] Enriching 30 flats with MRT data...
[MRT] Found 13 unique coordinates
[MRT] Response for (1.33914880025231, 103.858779661838): [{"dist_m":0.011440705003769033,"info":"Exit: Exit C","label":"BRADDELL MRT STATION","metric":null,"qtype":"nearest_mrt"}]
[MRT] Extracted: BRADDELL MRT STATION at 0.011440705003769033m
[MRT] Response for (1.34152911615925, 103.849589127712): [{"dist_m":0.002312908908895664,"info":"Exit: Exit C","label":"BRADDELL MRT STATION","metric":null,"qtype":"nearest_mrt"}]
[MRT] Extracted: BRADDELL MRT STATION at 0.002312908908895664m
[MRT] Response for (1.34176590225727, 103.856041153786): [{"dist_m":0.008669591833450473,"info":"Exit: Exit C","label":"BRADDELL MRT STATION","metric":null,"qtype":"nearest_mrt"}]
[MRT] Extracted: BRADDELL MRT STATION at 0.008669591833450473m
[MRT] Response for (1.34167292