From ff07afd944597336d09f207ccd6e9171e3613f75 Mon Sep 17 00:00:00 2001 From: Giselle van Dongen Date: Fri, 23 May 2025 12:56:03 +0200 Subject: [PATCH 1/5] Improve error handling of agent session and middleware --- common/a2a_middleware.py | 2 + common/agent_session.py | 88 ++++++++++++++-------------------------- uv.lock | 12 ------ 3 files changed, 32 insertions(+), 70 deletions(-) diff --git a/common/a2a_middleware.py b/common/a2a_middleware.py index 0656671..2ddf213 100644 --- a/common/a2a_middleware.py +++ b/common/a2a_middleware.py @@ -314,6 +314,8 @@ async def process_request( try: json_rpc_request = A2ARequest.validate_python(req.model_dump()) except Exception as e: + if isinstance(e, restate.vm.SuspendedException): + raise e logger.error("Error validating request: %s", e) return JSONRPCResponse( id=req.id, diff --git a/common/agent_session.py b/common/agent_session.py index 0b1914e..1367e60 100644 --- a/common/agent_session.py +++ b/common/agent_session.py @@ -34,7 +34,6 @@ from pydantic import BaseModel, ConfigDict, Field from restate import TerminalError from restate.handler import handler_from_callable -from restate.serde import PydanticJsonSerde from typing_extensions import Generic from .models import ( @@ -47,7 +46,6 @@ SendTaskResponse, JSONRPCRequest, A2AClientHTTPError, - A2AClientJSONError, Message, TextPart, TaskState, @@ -232,7 +230,6 @@ class AgentInput(BaseModel): starting_agent: Agent agents: list[Agent] message: str - force_starting_agent: bool = False class AgentResponse(BaseModel): @@ -348,12 +345,8 @@ async def run_agent(ctx: restate.ObjectContext, req: AgentInput) -> AgentRespons logging.info(f"{logging_prefix} Starting agent session") session_state = SessionState(input_items=await ctx.get("agent_state")) session_state.add_user_message(ctx, req.message) - - if req.force_starting_agent: - # We ignore the current agent, and use the starting agent in the message - agent_name = req.starting_agent.formatted_name - else: - agent_name = await ctx.get("agent_name") or req.starting_agent.formatted_name + + agent_name = await ctx.get("agent_name") or req.starting_agent.formatted_name ctx.set("agent_name", agent_name) logging.info(f"{logging_prefix} Current agent is {agent_name}") @@ -405,7 +398,7 @@ async def run_agent(ctx: restate.ObjectContext, req: AgentInput) -> AgentRespons parallel_tool_calls=True, stream=False, ), - serde=PydanticJsonSerde(Response), + type_hint=Response, ) # Register the output in the session state @@ -419,13 +412,8 @@ async def run_agent(ctx: restate.ObjectContext, req: AgentInput) -> AgentRespons agents_dict, response.output, tools ) except Exception as e: - logger.info( - f"""{logging_prefix} Output of LLM response parsing: - Tool calls: {tool_calls} - Run handoffs: {run_handoffs} - Output messages: {output_messages} - """ - ) + if isinstance(e, restate.vm.SuspendedException): + raise e logger.warning(f"{logging_prefix} Failed to parse LLM response: {str(e)}") session_state.add_system_message( ctx, f"Failed to parse LLM response: {str(e)}" @@ -437,7 +425,6 @@ async def run_agent(ctx: restate.ObjectContext, req: AgentInput) -> AgentRespons parallel_tools = [] for tool_call in tool_calls: logger.info(f"{logging_prefix} Executing tool {tool_call.name}") - # session_state.add_system_message(ctx, f"Executing tool {tool_call.name}.") try: if tool_call.delay_in_millis is None: handle = ctx.generic_call( @@ -460,17 +447,17 @@ async def run_agent(ctx: restate.ObjectContext, req: AgentInput) -> AgentRespons ctx, f"Task {tool_call.name} was scheduled" ) except Exception as e: - if not isinstance(e, restate.vm.SuspendedException): - logger.warning( - f"{logging_prefix} Failed to execute tool {tool_call.name}: {str(e)}" - ) - session_state.add_system_message( - ctx, - f"Failed to execute tool {tool_call.name}: {str(e)}", - ) - # We add it to the session_state to feed it back into the next LLM call - else: + if isinstance(e, restate.vm.SuspendedException): raise e + # We add it to the session_state to feed it back into the next LLM call + logger.warning( + f"{logging_prefix} Failed to execute tool {tool_call.name}: {str(e)}" + ) + session_state.add_system_message( + ctx, + f"Failed to execute tool {tool_call.name}: {str(e)}", + ) + if len(parallel_tools) > 0: results_done = await restate.gather(*parallel_tools) results = [(await result).decode() for result in results_done] @@ -513,7 +500,7 @@ async def run_agent(ctx: restate.ObjectContext, req: AgentInput) -> AgentRespons try: remote_agent_to_call = agents_dict.get(handoff_command.name) if remote_agent_to_call is None: - raise ValueError( + raise TerminalError( f"Agent {handoff_command.name} not found in the list of agents." ) remote_agent_output = await call_remote_agent( @@ -525,21 +512,15 @@ async def run_agent(ctx: restate.ObjectContext, req: AgentInput) -> AgentRespons ctx, f"Agent response: {remote_agent_output}", ) - except Exception as e: - # if isinstance(e, restate.vm.SuspendedException) or isinstance(e, httpx.ReadTimeout): - # Surface suspensions - # And surface read timeouts --> leads to a retry with idempotency key (=attach) - raise e - # TODO Are there some errors that we don't want to retry but feed back into the model? Think about the split here - # else: - # logger.warning( - # f"{logging_prefix} Failed to call remote agent {handoff_command.model_dump_json()}: {str(e)}" - # ) - # session_state.add_system_message( - # ctx, - # f"Failed to call remote agent {handoff_command.model_dump_json()}: {str(e)}", - # ) - # We add it to the session_state to feed it back into the next LLM call + except TerminalError as e: + logger.warning( + f"{logging_prefix} Failed to call remote agent {handoff_command.model_dump_json()}: {str(e)}" + ) + session_state.add_system_message( + ctx, + f"Failed to call remote agent {handoff_command.model_dump_json()}: {str(e)}", + ) + # We add it to the session_state to feed it back into the next LLM call else: # Start a new agent loop with the new agent agent = next_agent @@ -607,7 +588,6 @@ async def parse_llm_response( # UTILS - def format_name(name: str) -> str: return name.replace(" ", "_").lower() @@ -699,13 +679,11 @@ async def call_remote_agent( logger.info( f"Sending request to {card.name} at {card.url} with request payload: {request.model_dump()}" ) - response_json = await ctx.run("Call Agent", send_request, args=(card.url, request)) - response = SendTaskResponse(**response_json) + response = await ctx.run("Call Agent", send_request, args=(card.url, request)) logger.info( f"Received response from {card.name}: {response.result.model_dump_json()}" ) - # TODO Check with the protocol to see if it is expected behavior to find these fields for these states match response.result.status.state: case TaskState.INPUT_REQUIRED: final_output = f"MISSING_INFO: {response.result.status.message.parts}" @@ -722,13 +700,10 @@ async def call_remote_agent( case _: final_output = "Task status unknown" - # if task_callback: - # task_callback(final_output) - return final_output -async def send_request(url: str, request: JSONRPCRequest) -> dict[str, Any]: +async def send_request(url: str, request: JSONRPCRequest) -> SendTaskResponse: async with httpx.AsyncClient() as client: try: # Image generation could take time, increasing timeout @@ -736,11 +711,8 @@ async def send_request(url: str, request: JSONRPCRequest) -> dict[str, Any]: resp = await client.post( url, json=request.model_dump(), headers=headers, timeout=300 ) - # TODO handle httpx.ReadTimeout --> for long-running tasks what is the behavior we want here? Use push notifications? # Right now we are just catching these and letting it retry resp.raise_for_status() - return resp.json() - except httpx.HTTPStatusError as e: - raise A2AClientHTTPError(e.response.status_code, str(e)) from e - except json.JSONDecodeError as e: - raise A2AClientJSONError(str(e)) from e + return SendTaskResponse(**resp.json()) + except (json.JSONDecodeError, TypeError) as e: + raise TerminalError(str(e)) from e diff --git a/uv.lock b/uv.lock index 96cf003..0c7c015 100644 --- a/uv.lock +++ b/uv.lock @@ -11,7 +11,6 @@ resolution-markers = [ members = [ "a2a", "agents", - "credit-workflows", "diy-patterns", "insurance-claims", "restate-ai-examples", @@ -241,17 +240,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335 }, ] -[[package]] -name = "credit-workflows" -version = "0.1.0" -source = { editable = "end-to-end-applications/credit-workflows" } -dependencies = [ - { name = "restate-ai-examples" }, -] - -[package.metadata] -requires-dist = [{ name = "restate-ai-examples", editable = "." }] - [[package]] name = "cryptography" version = "45.0.2" From d75c371bb8f9b4c9f68863192c9abc6e56cd06e2 Mon Sep 17 00:00:00 2001 From: Giselle van Dongen Date: Mon, 26 May 2025 14:14:47 +0200 Subject: [PATCH 2/5] Improve error handling --- agents/native_restate/__main__.py | 6 +++++ agents/openai_sdk/__main__.py | 5 ++++ common/agent_session.py | 31 +++++++++--------------- pyproject.toml | 2 +- uv.lock | 40 +++++++++++++++---------------- 5 files changed, 43 insertions(+), 41 deletions(-) diff --git a/agents/native_restate/__main__.py b/agents/native_restate/__main__.py index 1ed86d2..1ddfbef 100644 --- a/agents/native_restate/__main__.py +++ b/agents/native_restate/__main__.py @@ -1,10 +1,16 @@ import hypercorn import asyncio import restate +import logging from agent import agent from account import account +logging.basicConfig( + level=logging.INFO, + format="[%(asctime)s] [%(process)d] [%(levelname)s] - %(message)s", +) + def main(): app = restate.app( diff --git a/agents/openai_sdk/__main__.py b/agents/openai_sdk/__main__.py index a9b2d2f..4728d92 100644 --- a/agents/openai_sdk/__main__.py +++ b/agents/openai_sdk/__main__.py @@ -1,9 +1,14 @@ import hypercorn import asyncio import restate +import logging from agent import agent +logging.basicConfig( + level=logging.INFO, + format="[%(asctime)s] [%(process)d] [%(levelname)s] - %(message)s", +) def main(): app = restate.app(services=[agent]) diff --git a/common/agent_session.py b/common/agent_session.py index 1367e60..16259d3 100644 --- a/common/agent_session.py +++ b/common/agent_session.py @@ -28,7 +28,6 @@ ResponseFunctionWebSearch, ResponseReasoningItem, ResponseComputerToolCall, - ResponseOutputText, ResponseOutputItem, ) from pydantic import BaseModel, ConfigDict, Field @@ -45,7 +44,6 @@ SendTaskRequest, SendTaskResponse, JSONRPCRequest, - A2AClientHTTPError, Message, TextPart, TaskState, @@ -350,7 +348,7 @@ async def run_agent(ctx: restate.ObjectContext, req: AgentInput) -> AgentRespons ctx.set("agent_name", agent_name) logging.info(f"{logging_prefix} Current agent is {agent_name}") - agents_dict = {agent.formatted_name: agent for agent in req.agents} + agents_dict = {ag.formatted_name: ag for ag in req.agents} agent = agents_dict.get(agent_name) if agent is None: @@ -369,12 +367,12 @@ async def run_agent(ctx: restate.ObjectContext, req: AgentInput) -> AgentRespons tools = {tool.formatted_name: tool for tool in agent.tools} tool_and_handoffs_list = [tool.tool_schema for tool in agent.tools] logger.info( - f"{logging_prefix} Starting iteration of agent loop with agent: {agent.name} and tools/handoffs: {[tool.formatted_name for tool in agent.tools]}" + f"{logging_prefix} Starting iteration of agent loop with agent: {agent.name} and tools/handoffs: {[t.formatted_name for t in agent.tools]}" ) for handoff_agent_name in agent.handoffs: handoff_agent = agents_dict.get(format_name(handoff_agent_name)) - # If the agent is not found, we ignore only use the other handoff agents. + # If the agent is not found, we ignore it if handoff_agent is None: logger.warning( f"Agent {handoff_agent_name} not found in the list of agents. Ignoring this agent. Available agents: {list(agents_dict.keys())}" @@ -446,9 +444,7 @@ async def run_agent(ctx: restate.ObjectContext, req: AgentInput) -> AgentRespons session_state.add_system_message( ctx, f"Task {tool_call.name} was scheduled" ) - except Exception as e: - if isinstance(e, restate.vm.SuspendedException): - raise e + except TerminalError as e: # We add it to the session_state to feed it back into the next LLM call logger.warning( f"{logging_prefix} Failed to execute tool {tool_call.name}: {str(e)}" @@ -532,18 +528,14 @@ async def run_agent(ctx: restate.ObjectContext, req: AgentInput) -> AgentRespons # Handle output messages # If there are no output messages, then we just continue the loop - if output_messages: - last_content = ( - output_messages[-1].content[-1] if output_messages[-1].content else None + final_output = response.output_text + if final_output != "": + logger.info(f"{logging_prefix} Final output message: {final_output}") + return AgentResponse( + agent=agent.name, + messages=session_state.get_new_items(), + final_output=final_output, ) - if isinstance(last_content, ResponseOutputText): - logger.info(f"{logging_prefix} Final output message: {last_content}") - return AgentResponse( - agent=agent.name, - messages=session_state.get_new_items(), - final_output=last_content.text, - ) - async def parse_llm_response( agents_dict: Dict[str, Agent], @@ -586,7 +578,6 @@ async def parse_llm_response( return output_messages, run_handoffs, tool_calls -# UTILS def format_name(name: str) -> str: return name.replace(" ", "_").lower() diff --git a/pyproject.toml b/pyproject.toml index 2b4a641..2cc3466 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ requires-python = ">=3.12" dependencies = [ "httpx>=0.28.1", "hypercorn", - "restate_sdk[serde]>=0.7.2", + "restate_sdk[serde]>=0.7.3", "pydantic>=2.10.6", "openai", ] diff --git a/uv.lock b/uv.lock index 0c7c015..8e5f8fd 100644 --- a/uv.lock +++ b/uv.lock @@ -1768,31 +1768,31 @@ requires-dist = [ { name = "hypercorn" }, { name = "openai" }, { name = "pydantic", specifier = ">=2.10.6" }, - { name = "restate-sdk", extras = ["serde"], specifier = ">=0.7.2" }, + { name = "restate-sdk", extras = ["serde"], specifier = ">=0.7.3" }, ] provides-extras = ["dev"] [[package]] name = "restate-sdk" -version = "0.7.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/cb/2b/66da94e16898a943b9500e6a9556bbcb41c9b635fc9ade0a3a95b5dbed8f/restate_sdk-0.7.2.tar.gz", hash = "sha256:82995f9d4c673c7a3799c2ef214b6966236fc5f6d2acd1f4840ae828cdd6df69", size = 59250 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/7b/4b/4ff54e1fe2f88dd2c6802a464ed59735ea3f09ddf5c1df93ce2dfae98fae/restate_sdk-0.7.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:b5c262ceaf41cf98e649d48e9ddf5290c5ed870b4b159bb64e65bac65c26f680", size = 1793410 }, - { url = "https://files.pythonhosted.org/packages/1e/c7/b5b7445addfd26c3ffb423631a78d64617f4ee106f20b0df8e9450ab0eb6/restate_sdk-0.7.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a2bce3fe1b2c0d11184b8a7118bc83e0c35c93fda01078dc6dd5223f05dc6dbd", size = 1716616 }, - { url = "https://files.pythonhosted.org/packages/c2/eb/a460425261d8fa0c61160161252fb72e88591a59034557968ecdd7008c7c/restate_sdk-0.7.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65fe05134c52872ac80640b745d5a0f0988e3600dea002a6e81463f547cfa395", size = 1981138 }, - { url = "https://files.pythonhosted.org/packages/56/10/868068c19fb46bf41f54a9d3d95f485b9a538e649e0e9935f477a64da42e/restate_sdk-0.7.2-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:926e2b23b5c4b2dfe00e40407dfbaf43c3051cbf35fcd5cec6004969e400741b", size = 1942515 }, - { url = "https://files.pythonhosted.org/packages/f0/a2/80546b3833c8749baafbfd283efba5332ceadfc6e48d30bdb1a2c78c2ad8/restate_sdk-0.7.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:f35395a224ee272bbfc855e8bf5eeb50f387a2178104d9130b4f94e6b6fa1f9e", size = 2129698 }, - { url = "https://files.pythonhosted.org/packages/87/90/af2742fc8b8d93a43b48d6dbd95c14a40d9626681e48cb0f2991d56e3322/restate_sdk-0.7.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:23acb3b61823693a86bcc6ff48b4ab83bd19ef727f5adf8cff88403ce7f2becc", size = 2151279 }, - { url = "https://files.pythonhosted.org/packages/e6/6d/e810a5b3dc55f8d2e32e9c20fcc7d692bacb3f3e205468936a06487b60d2/restate_sdk-0.7.2-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:ecedf8b53c19ca6ade5618ee0f9a0b130266181c6bbc9ffc82fcae50a7fe0112", size = 1793611 }, - { url = "https://files.pythonhosted.org/packages/8c/c9/1267483feb761edc92ffe36c21b11c968181b9964ba46a5850f111ac0d72/restate_sdk-0.7.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:3a3df3b6c57c2ec3d909848c9730ccbaec213618a5236d6c8cfbdaa9637714cc", size = 1716649 }, - { url = "https://files.pythonhosted.org/packages/f7/59/514a299824091807575a153b080bf336ca16d6bddc929a176ed222907d8f/restate_sdk-0.7.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d429ac4d2c9b61a16215d6bb9665699f312185c07aa76b93dd87cf3206c3951f", size = 1980940 }, - { url = "https://files.pythonhosted.org/packages/cb/97/e5f828b27f1b8104b08b6a61257cc0b430b0ccf0b8423be7fae668040138/restate_sdk-0.7.2-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:16022c0eda79018afbf9d6eb85c8a5df4bdc72f50dc9a6b2722b936dcbbc7fbc", size = 1942548 }, - { url = "https://files.pythonhosted.org/packages/ee/5d/b25f46a4bad0b22a9b2c9fe901e01144866df180566b03b2b4094a509f92/restate_sdk-0.7.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ac392b6c14d778e7560b9c5fca54e09f4cf059b9ecf76375797f332408f93b16", size = 2129613 }, - { url = "https://files.pythonhosted.org/packages/6f/73/c55b5a3d5fc4c5c0ca8be14597b98b8dfd19bd49926c3b1e2ea46a6353a0/restate_sdk-0.7.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:2eb82ae0e80820d6ee95748d94b39d64bab811106b3dcaf683d182227eb7e161", size = 2151387 }, - { url = "https://files.pythonhosted.org/packages/62/e6/75786eea31ca69df124dc6e2d92cce54bbfb733d8e0c16aa1dfd6d877617/restate_sdk-0.7.2-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:ce02ef38e05ab796896703f8672b5dc12b0db3ddfdccbd5d0b478fd32cec81f0", size = 1940035 }, - { url = "https://files.pythonhosted.org/packages/72/bc/fcba0d58f38ca27618816921b45dd56714028aeee9de952f5a2fe74c4d84/restate_sdk-0.7.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:ab57c51081c1462ccbba1b5a34526dd04aa31b60d1165685505ff0397fccd0d8", size = 2130234 }, - { url = "https://files.pythonhosted.org/packages/72/1e/4fcbb51b61ccf43d016c41f6f8fb45d61c3b45b051b67d91eb5e663924cc/restate_sdk-0.7.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:d24f8887e7efb642a0c4167f7c91712f026b25d46f6c54451b9132e297dbc3b7", size = 2153920 }, +version = "0.7.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c7/12/17b40ffa24324436de45a07d9048383a5cb5587c37039c1001b61b88abce/restate_sdk-0.7.3.tar.gz", hash = "sha256:102a60285a37b4fee66b399ef7593b0645eb04fa28ee6874beaa834c6642bb73", size = 59513 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d9/72/cdb2cedcd959c96bf0bb8976b9153942ab688fb8cccb6083e4bb24a73a93/restate_sdk-0.7.3-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:6c87eea37caaeece474e67405013484c0cd280616ffdc68a888792408e4be30e", size = 1793791 }, + { url = "https://files.pythonhosted.org/packages/99/06/5d6f7bb5819d96894aceb9cbb9317ebc0814217d8eece0d40e6319674fea/restate_sdk-0.7.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4986a1de827436a7f2f2547dbc123ce881cff88ca6dfc628367da0821174ef04", size = 1716880 }, + { url = "https://files.pythonhosted.org/packages/60/7c/b01fc618f45402ea43c9f7cbf7431a145b28e7c7781d35d23beea382d9ac/restate_sdk-0.7.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ffdf2fe01b60380c9771adfc42931923770f593123f2ad2bcb846b339812d76", size = 1981139 }, + { url = "https://files.pythonhosted.org/packages/32/47/1b19d82cbfa82063bcd183ff0620ec132ed190305147efebe2c88e08bb27/restate_sdk-0.7.3-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:201303283fed514fe1a46402360e77464e3788e04302ac5da439e600b139c0c2", size = 1942679 }, + { url = "https://files.pythonhosted.org/packages/0e/e4/8a2becd2eba796f2a1312b5ff6cd6040564c4232f5c12e5198e5f2a671fe/restate_sdk-0.7.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:af94fb31126a0458ede7e4723a4dea3397c6266604113bb1b74e78bbac9f7250", size = 2129488 }, + { url = "https://files.pythonhosted.org/packages/0d/94/e88fff6eae81753ca3c76a6b90431fe2ef862f6176fe1a23bf0530ee7754/restate_sdk-0.7.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:23543262158242b3fcca7d77062c922dfd3823f204ad2eabe349217b9b0c589e", size = 2151334 }, + { url = "https://files.pythonhosted.org/packages/0d/0c/e15d745fb279ec18a476d25900ce1b503b7136bfc8f32103d4feba00362f/restate_sdk-0.7.3-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:0f94de8727bc18bbb99d407304bb1cad257deffe6615a22c56e33ebf7fab9576", size = 1793922 }, + { url = "https://files.pythonhosted.org/packages/30/2f/eb7198a4edfc1a1b8957dc2ac3a2f44daacfbd1e0c16a281b58705ef8be5/restate_sdk-0.7.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a0bd3d8106a454b04cd2d7184d9382dfe7fb93e72425fd3f693c05ad8fed0739", size = 1717043 }, + { url = "https://files.pythonhosted.org/packages/6d/a5/20cf6321baf0f11218478edcfb75d5cc61766c7e6c9338a5ef81e9bf993b/restate_sdk-0.7.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:229763952f38a706ac3f443ec8329eb7228376557a9187ffd08ed7916a696821", size = 1980981 }, + { url = "https://files.pythonhosted.org/packages/50/2d/26ab75a2173f7454534a0fdacd6943bfe0c770f9c8da116a3aea7d993dec/restate_sdk-0.7.3-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:f6715078bef2699c4f197fbc0fcf5af1c3283bcd3434ee3f60aa51de235e4fdf", size = 1942676 }, + { url = "https://files.pythonhosted.org/packages/b7/68/2cab951fbaf9cecf1d4a5c5f80ced510c57e09f9ebc3c2974ae123360698/restate_sdk-0.7.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:be6f7cf1fd24301140a727245d3a00b9f5b7d121ae1a4054ce192f5757888656", size = 2129594 }, + { url = "https://files.pythonhosted.org/packages/f7/67/838c5f0a625ac21a8ff52d68588ef535924556b86dfbf117a627bd829f12/restate_sdk-0.7.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5e5034d87dc170296bed0c1b5684818b32216ecc937d363abb223aadb12e571a", size = 2151426 }, + { url = "https://files.pythonhosted.org/packages/8e/01/abe53980b67a116874b4102970d7d01b9bc767bfe0c2422313176eaad6cd/restate_sdk-0.7.3-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:69262fb2061f96265399522fe18e4d1acb9d6d4c33cf8224a8964b71bb0ef0e5", size = 1940202 }, + { url = "https://files.pythonhosted.org/packages/7d/92/18137566a69bcad414f78a423154248745a07aeb759c34dc87ec970efa09/restate_sdk-0.7.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:690b1fc7f221cdb0516c54475b64a04d8220b8649b5eb3b86a282417e3c04b78", size = 2130299 }, + { url = "https://files.pythonhosted.org/packages/43/b0/1dac29dbc180682d564e09fcd00f5d9cccac70e81ab21f4e58ded9f9b1e9/restate_sdk-0.7.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:2d9c42e4851dc872932490853d43c38edd2648d0fb1b837cb350d45b5c95ef71", size = 2154280 }, ] [package.optional-dependencies] From 878dc95bd73d896d7d8e3e289728d4d03cc4ceab Mon Sep 17 00:00:00 2001 From: Giselle van Dongen Date: Mon, 26 May 2025 16:34:53 +0200 Subject: [PATCH 3/5] Improve error handling --- agents/openai_sdk/__main__.py | 1 + common/a2a_middleware.py | 6 ++-- common/agent_session.py | 58 +++++++++++++++++++++++------------ 3 files changed, 41 insertions(+), 24 deletions(-) diff --git a/agents/openai_sdk/__main__.py b/agents/openai_sdk/__main__.py index 4728d92..8f5ea1c 100644 --- a/agents/openai_sdk/__main__.py +++ b/agents/openai_sdk/__main__.py @@ -10,6 +10,7 @@ format="[%(asctime)s] [%(process)d] [%(levelname)s] - %(message)s", ) + def main(): app = restate.app(services=[agent]) diff --git a/common/a2a_middleware.py b/common/a2a_middleware.py index 2ddf213..4c566c0 100644 --- a/common/a2a_middleware.py +++ b/common/a2a_middleware.py @@ -5,7 +5,7 @@ from collections.abc import AsyncIterable, Iterable from datetime import datetime -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError from restate.serde import PydanticJsonSerde from .a2a_agent import GenericRestateAgent @@ -313,9 +313,7 @@ async def process_request( try: json_rpc_request = A2ARequest.validate_python(req.model_dump()) - except Exception as e: - if isinstance(e, restate.vm.SuspendedException): - raise e + except ValidationError as e: logger.error("Error validating request: %s", e) return JSONRPCResponse( id=req.id, diff --git a/common/agent_session.py b/common/agent_session.py index 16259d3..53d256a 100644 --- a/common/agent_session.py +++ b/common/agent_session.py @@ -101,6 +101,16 @@ class Empty(BaseModel): model_config = ConfigDict(extra="forbid") +class AgentError(Exception): + """ + Errors that should be fed back into the next agent loop. + """ + + def __init__(self, message: str): + self.message = message + super().__init__(f"Agent Error: {message}") + + class RestateRequest(BaseModel, Generic[I]): """ Represents a request to a Restate service. @@ -343,7 +353,7 @@ async def run_agent(ctx: restate.ObjectContext, req: AgentInput) -> AgentRespons logging.info(f"{logging_prefix} Starting agent session") session_state = SessionState(input_items=await ctx.get("agent_state")) session_state.add_user_message(ctx, req.message) - + agent_name = await ctx.get("agent_name") or req.starting_agent.formatted_name ctx.set("agent_name", agent_name) logging.info(f"{logging_prefix} Current agent is {agent_name}") @@ -397,6 +407,7 @@ async def run_agent(ctx: restate.ObjectContext, req: AgentInput) -> AgentRespons stream=False, ), type_hint=Response, + max_attempts=3, # To using too many credits on infinite retries during development ) # Register the output in the session state @@ -409,9 +420,7 @@ async def run_agent(ctx: restate.ObjectContext, req: AgentInput) -> AgentRespons output_messages, run_handoffs, tool_calls = await parse_llm_response( agents_dict, response.output, tools ) - except Exception as e: - if isinstance(e, restate.vm.SuspendedException): - raise e + except AgentError as e: logger.warning(f"{logging_prefix} Failed to parse LLM response: {str(e)}") session_state.add_system_message( ctx, f"Failed to parse LLM response: {str(e)}" @@ -496,7 +505,7 @@ async def run_agent(ctx: restate.ObjectContext, req: AgentInput) -> AgentRespons try: remote_agent_to_call = agents_dict.get(handoff_command.name) if remote_agent_to_call is None: - raise TerminalError( + raise AgentError( f"Agent {handoff_command.name} not found in the list of agents." ) remote_agent_output = await call_remote_agent( @@ -508,7 +517,7 @@ async def run_agent(ctx: restate.ObjectContext, req: AgentInput) -> AgentRespons ctx, f"Agent response: {remote_agent_output}", ) - except TerminalError as e: + except AgentError as e: logger.warning( f"{logging_prefix} Failed to call remote agent {handoff_command.model_dump_json()}: {str(e)}" ) @@ -537,6 +546,7 @@ async def run_agent(ctx: restate.ObjectContext, req: AgentInput) -> AgentRespons final_output=final_output, ) + async def parse_llm_response( agents_dict: Dict[str, Agent], output: List[ResponseOutputItem], @@ -554,7 +564,8 @@ async def parse_llm_response( or isinstance(item, ResponseReasoningItem) or isinstance(item, ResponseComputerToolCall) ): - raise ValueError( + # feed error message back to LLM + raise AgentError( "This implementation does not support file search, web search, computer tools, or reasoning yet." ) @@ -565,20 +576,21 @@ async def parse_llm_response( else: # Tool calls if item.name not in tools.keys(): - raise ValueError( + # feed error message back to LLM + raise AgentError( f"This agent does not have access to this tool: {item.name}. Use another tool or handoff." ) tool = tools[item.name] tool_calls.append(to_tool_call(tool, item)) else: - raise ValueError( + # feed error message back to LLM + raise AgentError( f"This agent cannot handle this output type {type(item)}. Use another tool or handoff.", ) return output_messages, run_handoffs, tool_calls - def format_name(name: str) -> str: return name.replace(" ", "_").lower() @@ -599,7 +611,7 @@ def restate_tool(tool_call: Callable[[Any, I], Awaitable[O]]) -> RestateTool: case "service": description = target_handler.description case _: - raise TerminalError(f"Unknown service type {service_type}") + raise TerminalError(f"Unknown service type {service_type}. Is this tool a Restate handler?") return RestateTool( service_name=target_handler.service_tag.name, @@ -638,7 +650,8 @@ def to_tool_call(tool: RestateTool, item: ResponseFunctionToolCall) -> ToolCall: if tool.service_type in {"workflow", "object"}: key = tool_request.get("key") if key is None: - raise ValueError( + # feed error message back to LLM + raise AgentError( f"Service key is required for {tool.service_type} ${tool.service_name} but not provided in the request." ) @@ -696,14 +709,19 @@ async def call_remote_agent( async def send_request(url: str, request: JSONRPCRequest) -> SendTaskResponse: async with httpx.AsyncClient() as client: + # retry any errors that come out of this + resp = await client.post( + url, + json=request.model_dump(), + headers={"idempotency-key": request.id}, + timeout=300, + ) + resp.raise_for_status() + try: - # Image generation could take time, increasing timeout - headers = {"idempotency-key": request.id} - resp = await client.post( - url, json=request.model_dump(), headers=headers, timeout=300 - ) - # Right now we are just catching these and letting it retry - resp.raise_for_status() return SendTaskResponse(**resp.json()) except (json.JSONDecodeError, TypeError) as e: - raise TerminalError(str(e)) from e + # feed error message back to LLM + raise AgentError( + f"Response was not in A2A SendTaskResponse format. Error: {str(e)}" + ) from e From c5231d833796b5a1cf82aa8408d278167e5867ea Mon Sep 17 00:00:00 2001 From: Giselle van Dongen Date: Tue, 27 May 2025 08:44:13 +0200 Subject: [PATCH 4/5] Cleanup agent session --- common/agent_session.py | 308 +++++++----------- .../interruptible-agent/app/__main__.py | 3 +- .../interruptible-agent/app/agent.py | 40 ++- .../interruptible-agent/app/chat.py | 7 +- 4 files changed, 152 insertions(+), 206 deletions(-) diff --git a/common/agent_session.py b/common/agent_session.py index 0be6895..d073fd3 100644 --- a/common/agent_session.py +++ b/common/agent_session.py @@ -24,10 +24,6 @@ ResponseFunctionToolCall, Response, ResponseOutputMessage, - ResponseFileSearchToolCall, - ResponseFunctionWebSearch, - ResponseReasoningItem, - ResponseComputerToolCall, ResponseOutputItem, ) from pydantic import BaseModel, ConfigDict, Field @@ -191,8 +187,6 @@ class Agent(BaseModel): push_notifications_support: bool = False def to_tool_schema(self) -> dict[str, Any]: - # If the agent is a remote agent, we want to attach extra context/message to the request - # TODO it might be better to use the A2A TaskSendRequest format here? schema = to_strict_json_schema(RemoteAgentMessage if self.remote_url else Empty) return { "type": "function", @@ -298,18 +292,13 @@ def add_user_messages(self, ctx: restate.ObjectContext, items: List[str]): ctx.set(self.state_name, self._input_items) self._new_items.extend(user_messages) - def add_system_message( - self, ctx: restate.ObjectContext, item: str, flush: bool = True - ): + def add_system_message(self, ctx: restate.ObjectContext, item: str): system_message = SessionItem(content=item, role="system") self._input_items.append(system_message) self._new_items.append(system_message) - if flush: - ctx.set(self.state_name, self._input_items) + ctx.set(self.state_name, self._input_items) - def add_system_messages( - self, ctx: restate.ObjectContext, items: List[str], flush: bool = False - ): + def add_system_messages(self, ctx: restate.ObjectContext, items: List[str]): system_messages = [SessionItem(content=item, role="system") for item in items] self._input_items.extend(system_messages) ctx.set(self.state_name, self._input_items) @@ -347,203 +336,164 @@ async def run_agent(ctx: restate.ObjectContext, req: AgentInput) -> AgentRespons Args: req (AgentInput): The input for the agent """ - logging_prefix = f"Agent session {ctx.key()} - " + log_prefix = f"{ctx.request().id} - agent-session {ctx.key()}- - " # === 1. initialize the agent session === - logging.info(f"{logging_prefix} Starting agent session") + logging.info(f"{log_prefix} Starting agent session") session_state = SessionState(input_items=await ctx.get("agent_state")) session_state.add_user_message(ctx, req.message) agent_name = await ctx.get("agent_name") or req.starting_agent.formatted_name ctx.set("agent_name", agent_name) - logging.info(f"{logging_prefix} Current agent is {agent_name}") - agents_dict = {ag.formatted_name: ag for ag in req.agents} + agents_dict = {a.formatted_name: a for a in req.agents} agent = agents_dict.get(agent_name) if agent is None: - # Don't retry this. It's a configuration error - session_state.add_system_message( - ctx, - f"Current/starting agent not found in the list of agents {agent_name}. Available agents: {list(agents_dict.keys())}", - ) raise TerminalError( - f"Agent {agent_name} not found in the list of agents. Available agents: {list(agents_dict.keys())}" + f"Agent {agent_name} not found in the list of agents: {list(agents_dict.keys())}" ) # === 2. Run the agent loop === while True: # Get the tools in the right format for the LLM - tools = {tool.formatted_name: tool for tool in agent.tools} - tool_and_handoffs_list = [tool.tool_schema for tool in agent.tools] - logger.info( - f"{logging_prefix} Starting iteration of agent loop with agent: {agent.name} and tools/handoffs: {[t.formatted_name for t in agent.tools]}" - ) + try: + tools = {tool.formatted_name: tool for tool in agent.tools} + logger.info( + f"{log_prefix} Starting iteration of agent: {agent.name} with tools/handoffs: {list(tools.keys())}" + ) - for handoff_agent_name in agent.handoffs: - handoff_agent = agents_dict.get(format_name(handoff_agent_name)) - # If the agent is not found, we ignore it - if handoff_agent is None: - logger.warning( - f"Agent {handoff_agent_name} not found in the list of agents. Ignoring this agent. Available agents: {list(agents_dict.keys())}" - ) - session_state.add_system_message( - ctx, - f"Agent {handoff_agent_name} not found in the list of agents. Available agents: {list(agents_dict.keys())}", - ) - else: - tool_and_handoffs_list.append(handoff_agent.to_tool_schema()) - - # Call the LLM - OpenAPI Responses API - logger.info(f"{logging_prefix} Calling LLM") - response: Response = await ctx.run( - "Call LLM", - lambda: client.responses.create( - model="gpt-4o", - instructions=agent.instructions, - input=session_state.get_input_items(), - tools=tool_and_handoffs_list, - parallel_tool_calls=True, - stream=False, - ), - type_hint=Response, - max_attempts=3, # To using too many credits on infinite retries during development - ) + tool_schemas = [tool.tool_schema for tool in agent.tools] + for handoff_agent_name in agent.handoffs: + handoff_agent = agents_dict.get(format_name(handoff_agent_name)) + if handoff_agent is None: + logger.warning( + f"Agent {handoff_agent_name} not found in the list of agents. Ignoring this handoff agent." + ) + tool_schemas.append(handoff_agent.to_tool_schema()) + + # Call the LLM - OpenAPI Responses API + logger.info(f"{log_prefix} Calling LLM") + response: Response = await ctx.run( + "Call LLM", + lambda: client.responses.create( + model="gpt-4o", + instructions=agent.instructions, + input=session_state.get_input_items(), + tools=tool_schemas, + parallel_tool_calls=True, + stream=False, + ), + type_hint=Response, + max_attempts=3, # To using too many credits on infinite retries during development + ) - # Register the output in the session state - session_state.add_system_messages( - ctx, [item.model_dump_json() for item in response.output], flush=True - ) + # Register the output in the session state + session_state.add_system_messages( + ctx, [item.model_dump_json() for item in response.output] + ) - # Parse LLM response - try: + # Parse LLM response output_messages, run_handoffs, tool_calls = await parse_llm_response( agents_dict, response.output, tools ) - except AgentError as e: - logger.warning(f"{logging_prefix} Failed to parse LLM response: {str(e)}") - session_state.add_system_message( - ctx, f"Failed to parse LLM response: {str(e)}" - ) - # Let the LLM evaluate what it did wrong and correct - continue - - # Execute (parallel) tool calls - parallel_tools = [] - for tool_call in tool_calls: - logger.info(f"{logging_prefix} Executing tool {tool_call.name}") - try: - if tool_call.delay_in_millis is None: - handle = ctx.generic_call( - service=tool_call.tool.service_name, - handler=tool_call.tool.name, - arg=tool_call.input_bytes, - key=tool_call.key, + + # Execute (parallel) tool calls + parallel_tools = [] + for tool_call in tool_calls: + logger.info(f"{log_prefix} Executing tool {tool_call.name}") + try: + if tool_call.delay_in_millis is None: + handle = ctx.generic_call( + service=tool_call.tool.service_name, + handler=tool_call.tool.name, + arg=tool_call.input_bytes, + key=tool_call.key, + ) + parallel_tools.append(handle) + else: + # Used for scheduling tasks in the future or long-running tasks like workflows + ctx.generic_send( + service=tool_call.tool.service_name, + handler=tool_call.tool.name, + arg=tool_call.input_bytes, + key=tool_call.key, + send_delay=timedelta( + milliseconds=tool_call.delay_in_millis + ), + ) + session_state.add_system_message( + ctx, f"Task {tool_call.name} was scheduled" ) - parallel_tools.append(handle) - else: - # Used for scheduling tasks in the future or long-running tasks like workflows - ctx.generic_send( - service=tool_call.tool.service_name, - handler=tool_call.tool.name, - arg=tool_call.input_bytes, - key=tool_call.key, - send_delay=timedelta(milliseconds=tool_call.delay_in_millis), + except TerminalError as e: + # We add it to the session_state to feed it back into the next LLM call + # Let the other parallel tool executions continue + session_state.add_system_message( + ctx, + f"Failed to execute tool {tool_call.name}: {str(e)}", ) - session_state.add_system_message( - ctx, f"Task {tool_call.name} was scheduled" - ) - except TerminalError as e: - # We add it to the session_state to feed it back into the next LLM call - logger.warning( - f"{logging_prefix} Failed to execute tool {tool_call.name}: {str(e)}" - ) - session_state.add_system_message( - ctx, - f"Failed to execute tool {tool_call.name}: {str(e)}", - ) - if len(parallel_tools) > 0: - results_done = await restate.gather(*parallel_tools) - results = [(await result).decode() for result in results_done] - logger.info(f"{logging_prefix} Gathered tool execution results: {results}") - session_state.add_system_messages(ctx, results) - - # Handle handoffs - if run_handoffs: - # Only one agent can be in charge of the conversation at a time. - # So if there are multiple handoffs in the response, only run the first one. - # For the others, we add a tool response that we will not handle them. - if len(run_handoffs) > 1: - for handoff in run_handoffs[1:]: + if len(parallel_tools) > 0: + results_done = await restate.gather(*parallel_tools) + results = [(await result).decode() for result in results_done] + session_state.add_system_messages(ctx, results) + + # Handle handoffs + if run_handoffs: + # Only one agent can be in charge of the conversation at a time. + # So if there are multiple handoffs in the response, only run the first one. + if len(run_handoffs) > 1: logger.info( - f"{logging_prefix} Multiple handoffs detected, ignoring this one: {handoff.name} with arguments {handoff.arguments}." + f"{log_prefix} Multiple handoffs detected. Ignoring: {[h.name for h in run_handoffs[1:]]}" ) - handoff_command = run_handoffs[0] - - # Determine the new agent in charge - next_agent = agents_dict.get(handoff_command.name) - if next_agent is None: - session_state.add_system_message( - ctx, - f"Agent {handoff_command.name} not found in the list of agents.", - ) - logger.info( - f"{logging_prefix} Agent {handoff_command.name} not found in the list of agents." - ) - continue + handoff_command = run_handoffs[0] + next_agent = agents_dict.get(handoff_command.name) + if next_agent is None: + raise AgentError( + f"Agent {handoff_command.name} not found in the list of agents." + ) - if next_agent.remote_url not in {None, ""}: - logger.info( - f"{logging_prefix} Calling Remote Agent over A2A {handoff_command.name}" - ) - session_state.add_system_message( - ctx, - f"Calling Remote Agent: {handoff_command.name} with arguments {handoff_command.arguments}.", - ) - try: + if next_agent.remote_url not in {None, ""}: remote_agent_to_call = agents_dict.get(handoff_command.name) if remote_agent_to_call is None: raise AgentError( f"Agent {handoff_command.name} not found in the list of agents." ) + + logger.info( + f"{log_prefix} Calling Remote A2A Agent {handoff_command.name}" + ) remote_agent_output = await call_remote_agent( ctx, remote_agent_to_call.as_agent_card(), handoff_command.arguments, ) session_state.add_system_message( - ctx, - f"Agent response: {remote_agent_output}", - ) - except AgentError as e: - logger.warning( - f"{logging_prefix} Failed to call remote agent {handoff_command.model_dump_json()}: {str(e)}" - ) - session_state.add_system_message( - ctx, - f"Failed to call remote agent {handoff_command.model_dump_json()}: {str(e)}", + ctx, f"{handoff_command.name} response: {remote_agent_output}" ) - # We add it to the session_state to feed it back into the next LLM call - else: - # Start a new agent loop with the new agent - agent = next_agent - logger.info(f"{logging_prefix} Handing off to agent {agent.name}") - session_state.add_system_message(ctx, f"Transferred to {agent.name}.") - ctx.set("agent_name", format_name(agent.name)) - - continue - - # Handle output messages - # If there are no output messages, then we just continue the loop - final_output = response.output_text - if final_output != "": - logger.info(f"{logging_prefix} Final output message: {final_output}") - return AgentResponse( - agent=agent.name, - messages=session_state.get_new_items(), - final_output=final_output, + else: + # Start a new agent loop with the new agent + agent = next_agent + ctx.set("agent_name", format_name(agent.name)) + continue + + # Handle output messages + # If there are no output messages, then we just continue the loop + final_output = response.output_text + if final_output != "": + logger.info(f"{log_prefix} Final output message generated.") + return AgentResponse( + agent=agent.name, + messages=session_state.get_new_items(), + final_output=final_output, + ) + except AgentError as e: + logger.warning( + f"{log_prefix} Iteration of agent run failed. Updating state and feeding back error to LLM: {str(e)}" + ) + session_state.add_system_message( + ctx, f"Failed iteration of agent run: {str(e)}" ) @@ -558,16 +508,6 @@ async def parse_llm_response( for item in output: if isinstance(item, ResponseOutputMessage): output_messages.append(item) - elif ( - isinstance(item, ResponseFileSearchToolCall) - or isinstance(item, ResponseFunctionWebSearch) - or isinstance(item, ResponseReasoningItem) - or isinstance(item, ResponseComputerToolCall) - ): - # feed error message back to LLM - raise AgentError( - "This implementation does not support file search, web search, computer tools, or reasoning yet." - ) elif isinstance(item, ResponseFunctionToolCall): if item.name in agents_dict.keys(): @@ -578,14 +518,14 @@ async def parse_llm_response( if item.name not in tools.keys(): # feed error message back to LLM raise AgentError( - f"This agent does not have access to this tool: {item.name}. Use another tool or handoff." + f"Error while parsing LLM response: This agent does not have access to this tool: {item.name}. Use another tool or handoff." ) tool = tools[item.name] tool_calls.append(to_tool_call(tool, item)) else: # feed error message back to LLM raise AgentError( - f"This agent cannot handle this output type {type(item)}. Use another tool or handoff.", + f"Error while parsing LLM response: This agent cannot handle this output type {type(item)}. Use another tool or handoff.", ) return output_messages, run_handoffs, tool_calls @@ -611,7 +551,9 @@ def restate_tool(tool_call: Callable[[Any, I], Awaitable[O]]) -> RestateTool: case "service": description = target_handler.description case _: - raise TerminalError(f"Unknown service type {service_type}. Is this tool a Restate handler?") + raise TerminalError( + f"Unknown service type {service_type}. Is this tool a Restate handler?" + ) return RestateTool( service_name=target_handler.service_tag.name, @@ -703,7 +645,7 @@ async def call_remote_agent( final_output = "Task is in progress" case _: final_output = "Task status unknown" - + return final_output diff --git a/end-to-end-applications/interruptible-agent/app/__main__.py b/end-to-end-applications/interruptible-agent/app/__main__.py index dd4e445..528540a 100644 --- a/end-to-end-applications/interruptible-agent/app/__main__.py +++ b/end-to-end-applications/interruptible-agent/app/__main__.py @@ -8,9 +8,10 @@ logging.basicConfig( level=logging.INFO, - format='[%(asctime)s] [%(levelname)s] %(name)s: %(message)s', + format="[%(asctime)s] [%(levelname)s] %(name)s: %(message)s", ) + def main(): app = restate.app( services=[ diff --git a/end-to-end-applications/interruptible-agent/app/agent.py b/end-to-end-applications/interruptible-agent/app/agent.py index 58dcb27..aed7086 100644 --- a/end-to-end-applications/interruptible-agent/app/agent.py +++ b/end-to-end-applications/interruptible-agent/app/agent.py @@ -4,11 +4,7 @@ from datetime import timedelta from openai import OpenAI -from openai.types.responses import ( - Response, - ResponseFunctionToolCall, - FunctionToolParam -) +from openai.types.responses import Response, ResponseFunctionToolCall, FunctionToolParam from pydantic import BaseModel from utils.models import AgentResponse, AgentInput @@ -17,6 +13,7 @@ client = OpenAI() + class Task(BaseModel): """ The task that needs to get executed @@ -24,6 +21,7 @@ class Task(BaseModel): Args: description (str): The description of the task to execute. """ + description: str class Config: @@ -57,9 +55,10 @@ async def run(ctx: restate.ObjectContext, req: AgentInput): input_items.append({"role": "system", "content": task_result}) # peek for new input - match await restate.select(new_input_promise=new_input_promise, - timeout=ctx.sleep(timedelta(seconds=0))): - case ['new_input_promise', new_input]: + match await restate.select( + new_input_promise=new_input_promise, timeout=ctx.sleep(timedelta(seconds=0)) + ): + case ["new_input_promise", new_input]: logger.info(f"Incorporating new input for {ctx.key()}: {new_input}") id, new_input_promise = ctx.awakeable() ctx.set(NEW_INPUT_PROMISE, id) @@ -74,12 +73,11 @@ async def run(ctx: restate.ObjectContext, req: AgentInput): if response.output_text != "": logger.info(f"Final output message: {response.output_text}") from chat import process_agent_response + ctx.object_send( process_agent_response, key=ctx.key(), - arg=AgentResponse( - final_output=response.output_text - ), + arg=AgentResponse(final_output=response.output_text), ) ctx.clear(NEW_INPUT_PROMISE) return @@ -90,16 +88,21 @@ async def incorporate_new_input(ctx: restate.ObjectSharedContext, req: str) -> b id = await ctx.get(NEW_INPUT_PROMISE) if id is None: - logger.warning(f"No awakeable ID found. Maybe invocation finished in the meantime. Cannot incorporate new input for {ctx.key()}.") + logger.warning( + f"No awakeable ID found. Maybe invocation finished in the meantime. Cannot incorporate new input for {ctx.key()}." + ) return False ctx.resolve_awakeable(id, req) - logger.info(f"Resolved awakeable with ID {id} with new input for {ctx.key()}: {req}") + logger.info( + f"Resolved awakeable with ID {id} with new input for {ctx.key()}: {req}" + ) return True # UTILS + async def execute_task(ctx: restate.ObjectContext, req: Task) -> str: """ Executes tasks, based on the description provided. @@ -113,10 +116,13 @@ async def execute_task(ctx: restate.ObjectContext, req: Task) -> str: return f"Task executed successfully: {req.description}" -execute_task_tool=FunctionToolParam( +execute_task_tool = FunctionToolParam( name=execute_task.__name__, description=execute_task.__doc__, - parameters=Task.model_json_schema(), strict=True, type="function") + parameters=Task.model_json_schema(), + strict=True, + type="function", +) async def call_task_agent(input_items) -> Response: @@ -129,5 +135,5 @@ async def call_task_agent(input_items) -> Response: input=input_items, tools=[execute_task_tool], parallel_tool_calls=False, - stream=False - ) \ No newline at end of file + stream=False, + ) diff --git a/end-to-end-applications/interruptible-agent/app/chat.py b/end-to-end-applications/interruptible-agent/app/chat.py index bf5e07c..9db3bd3 100644 --- a/end-to-end-applications/interruptible-agent/app/chat.py +++ b/end-to-end-applications/interruptible-agent/app/chat.py @@ -62,15 +62,12 @@ async def process_user_message(ctx: restate.ObjectContext, req: ChatMessage): # Queue the new agent run await send_message_to_agent(ctx, history) + async def send_message_to_agent(ctx, history): handle = ctx.object_send( agent_session_run, key=ctx.key(), - arg=AgentInput( - message_history=[ - entry.content for entry in history.entries - ] - ), + arg=AgentInput(message_history=[entry.content for entry in history.entries]), ) ctx.set(ACTIVE_AGENT_INVOCATION_ID, await handle.invocation_id()) From a60647376d15ebbfd7e1922a028bde6c3ff082e4 Mon Sep 17 00:00:00 2001 From: Giselle van Dongen Date: Tue, 27 May 2025 09:11:29 +0200 Subject: [PATCH 5/5] Cleanup agent session --- common/agent_session.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/common/agent_session.py b/common/agent_session.py index d073fd3..44d199f 100644 --- a/common/agent_session.py +++ b/common/agent_session.py @@ -363,14 +363,7 @@ async def run_agent(ctx: restate.ObjectContext, req: AgentInput) -> AgentRespons f"{log_prefix} Starting iteration of agent: {agent.name} with tools/handoffs: {list(tools.keys())}" ) - tool_schemas = [tool.tool_schema for tool in agent.tools] - for handoff_agent_name in agent.handoffs: - handoff_agent = agents_dict.get(format_name(handoff_agent_name)) - if handoff_agent is None: - logger.warning( - f"Agent {handoff_agent_name} not found in the list of agents. Ignoring this handoff agent." - ) - tool_schemas.append(handoff_agent.to_tool_schema()) + tool_schemas = await generate_tool_schemas(agent, agents_dict) # Call the LLM - OpenAPI Responses API logger.info(f"{log_prefix} Calling LLM") @@ -428,6 +421,7 @@ async def run_agent(ctx: restate.ObjectContext, req: AgentInput) -> AgentRespons except TerminalError as e: # We add it to the session_state to feed it back into the next LLM call # Let the other parallel tool executions continue + logger.warning(f"Failed to execute tool {tool_call.name}: {str(e)}") session_state.add_system_message( ctx, f"Failed to execute tool {tool_call.name}: {str(e)}", @@ -436,6 +430,7 @@ async def run_agent(ctx: restate.ObjectContext, req: AgentInput) -> AgentRespons if len(parallel_tools) > 0: results_done = await restate.gather(*parallel_tools) results = [(await result).decode() for result in results_done] + logger.info(f"{log_prefix} Gathered tool results.") session_state.add_system_messages(ctx, results) # Handle handoffs @@ -497,6 +492,18 @@ async def run_agent(ctx: restate.ObjectContext, req: AgentInput) -> AgentRespons ) +async def generate_tool_schemas(agent, agents_dict) -> list[dict[str, Any]]: + tool_schemas = [tool.tool_schema for tool in agent.tools] + for handoff_agent_name in agent.handoffs: + handoff_agent = agents_dict.get(format_name(handoff_agent_name)) + if handoff_agent is None: + logger.warning( + f"Agent {handoff_agent_name} not found in the list of agents. Ignoring this handoff agent." + ) + tool_schemas.append(handoff_agent.to_tool_schema()) + return tool_schemas + + async def parse_llm_response( agents_dict: Dict[str, Agent], output: List[ResponseOutputItem],