From 7fcd840eb22518ac370c074ab7448f8e2827b526 Mon Sep 17 00:00:00 2001 From: Rishav C Date: Mon, 18 Aug 2025 17:38:38 -0400 Subject: [PATCH 1/2] fix: enable FunctionTool serialization for Temporal worker nodes - RunAgent*Params objects must be serializable for over-the-wire transmission to Temporal workers/backend. Previous implementation failed when users specified FunctionTool with callable on_invoke_tool params. - This commit adds cloudpickle-based serialization support to resolve serialization errors - During testing, also had to pin OpenAI to v1.99.9 to avoid LiteLLM incompatibility issue ([#13711](https://github.com/BerriAI/litellm/issues/13711)) --- pyproject.toml | 3 +- requirements-dev.lock | 5 +- requirements.lock | 5 +- .../adk/providers/openai_activities.py | 94 ++++++- tests/test_function_tool.py | 251 ++++++++++++++++++ uv.lock | 15 +- 6 files changed, 364 insertions(+), 9 deletions(-) create mode 100644 tests/test_function_tool.py diff --git a/pyproject.toml b/pyproject.toml index 3571cf00..9ca24d62 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,8 @@ dependencies = [ "pytest-asyncio>=1.0.0", "scale-gp-beta==0.1.0a20", "ipykernel>=6.29.5", - "openai>=1.99.9", + "openai==1.99.9", # anything higher than 1.99.9 breaks litellm - https://github.com/BerriAI/litellm/issues/13711 + "cloudpickle>=3.1.1", ] requires-python = ">= 3.12,<4" classifiers = [ diff --git a/requirements-dev.lock b/requirements-dev.lock index 909aa446..c2b6e799 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -53,6 +53,8 @@ click==8.2.1 # via litellm # via typer # via uvicorn +cloudpickle==3.1.1 + # via agentex-sdk colorama==0.4.6 # via griffe colorlog==6.7.0 @@ -179,9 +181,10 @@ oauthlib==3.3.1 # via kubernetes # via requests-oauthlib openai==1.99.9 + # via agentex-sdk # via litellm # via openai-agents -openai-agents==0.2.6 +openai-agents==0.2.7 # via agentex-sdk packaging==23.2 # via huggingface-hub diff --git a/requirements.lock b/requirements.lock index 29120c52..58bc3e38 100644 --- a/requirements.lock +++ b/requirements.lock @@ -51,6 +51,8 @@ click==8.2.1 # via litellm # via typer # via uvicorn +cloudpickle==3.1.1 + # via agentex-sdk colorama==0.4.6 # via griffe comm==0.2.3 @@ -162,9 +164,10 @@ oauthlib==3.3.1 # via kubernetes # via requests-oauthlib openai==1.99.9 + # via agentex-sdk # via litellm # via openai-agents -openai-agents==0.2.6 +openai-agents==0.2.7 # via agentex-sdk packaging==25.0 # via huggingface-hub diff --git a/src/agentex/lib/core/temporal/activities/adk/providers/openai_activities.py b/src/agentex/lib/core/temporal/activities/adk/providers/openai_activities.py index e5004556..a53c6ecf 100644 --- a/src/agentex/lib/core/temporal/activities/adk/providers/openai_activities.py +++ b/src/agentex/lib/core/temporal/activities/adk/providers/openai_activities.py @@ -1,9 +1,13 @@ # Standard library imports +import base64 from collections.abc import Callable from contextlib import AsyncExitStack, asynccontextmanager from enum import Enum -from typing import Any, Literal +from typing import Any, Literal, Optional, override +from pydantic import Field, PrivateAttr + +import cloudpickle from agents import RunContextWrapper, RunResult, RunResultStreaming from agents.mcp import MCPServerStdio, MCPServerStdioParams from agents.model_settings import ModelSettings as OAIModelSettings @@ -41,12 +45,92 @@ class FunctionTool(BaseModelWithTraceParams): name: str description: str params_json_schema: dict[str, Any] - on_invoke_tool: Callable[[RunContextWrapper, str], Any] + strict_json_schema: bool = True is_enabled: bool = True + _on_invoke_tool: Callable[[RunContextWrapper, str], Any] = PrivateAttr() + on_invoke_tool_serialized: str = Field( + default="", + description=( + "Normally will be set automatically during initialization and" + " doesn't need to be passed. " + "Instead, pass `on_invoke_tool` to the constructor. " + "See the __init__ method for details." + ), + ) + + def __init__( + self, + *, + on_invoke_tool: Optional[Callable[[RunContextWrapper, str], Any]] = None, + **data, + ): + """ + Initialize a FunctionTool with hacks to support serialization of the + on_invoke_tool callable arg. This is required to facilitate over-the-wire + communication of this object to/from temporal services/workers. + + Args: + on_invoke_tool: The callable to invoke when the tool is called. + **data: Additional data to initialize the FunctionTool. + """ + super().__init__(**data) + if not on_invoke_tool: + if not self.on_invoke_tool_serialized: + raise ValueError( + "One of `on_invoke_tool` or `on_invoke_tool_serialized` should be set" + ) + else: + on_invoke_tool = self._deserialize_callable( + self.on_invoke_tool_serialized + ) + else: + self.on_invoke_tool_serialized = self._serialize_callable(on_invoke_tool) + + self._on_invoke_tool = on_invoke_tool + + @classmethod + def _deserialize_callable( + cls, serialized: str + ) -> Callable[[RunContextWrapper, str], Any]: + encoded = serialized.encode() + serialized_bytes = base64.b64decode(encoded) + return cloudpickle.loads(serialized_bytes) + + @classmethod + def _serialize_callable(cls, func: Callable) -> str: + serialized_bytes = cloudpickle.dumps(func) + encoded = base64.b64encode(serialized_bytes) + return encoded.decode() + + @property + def on_invoke_tool(self) -> Callable[[RunContextWrapper, str], Any]: + if self._on_invoke_tool is None and self.on_invoke_tool_serialized: + self._on_invoke_tool = self._deserialize_callable( + self.on_invoke_tool_serialized + ) + return self._on_invoke_tool + + @on_invoke_tool.setter + def on_invoke_tool(self, value: Callable[[RunContextWrapper, str], Any]): + self.on_invoke_tool_serialized = self._serialize_callable(value) + self._on_invoke_tool = value + def to_oai_function_tool(self) -> OAIFunctionTool: - return OAIFunctionTool(**self.model_dump(exclude=["trace_id", "parent_span_id"])) + """Convert to OpenAI function tool, excluding serialization fields.""" + # Create a dictionary with only the fields OAIFunctionTool expects + data = self.model_dump( + exclude={ + "trace_id", + "parent_span_id", + "_on_invoke_tool", + "on_invoke_tool_serialized", + } + ) + # Add the callable for OAI tool since properties are not serialized + data["on_invoke_tool"] = self.on_invoke_tool + return OAIFunctionTool(**data) class ModelSettings(BaseModelWithTraceParams): @@ -68,7 +152,9 @@ class ModelSettings(BaseModelWithTraceParams): extra_args: dict[str, Any] | None = None def to_oai_model_settings(self) -> OAIModelSettings: - return OAIModelSettings(**self.model_dump(exclude=["trace_id", "parent_span_id"])) + return OAIModelSettings( + **self.model_dump(exclude=["trace_id", "parent_span_id"]) + ) class RunAgentParams(BaseModelWithTraceParams): diff --git a/tests/test_function_tool.py b/tests/test_function_tool.py new file mode 100644 index 00000000..6f100408 --- /dev/null +++ b/tests/test_function_tool.py @@ -0,0 +1,251 @@ +import json +import pytest +from typing import Any + +from src.agentex.lib.core.temporal.activities.adk.providers.openai_activities import ( + FunctionTool, +) + + +def sample_handler(context, args: str) -> str: + """Sample handler function for testing.""" + return f"Processed: {args}" + + +def complex_handler(context, args: str) -> dict[str, Any]: + """More complex handler that returns structured data.""" + parsed_args = json.loads(args) if args else {} + return { + "status": "success", + "input": parsed_args, + "context_info": str(type(context)), + } + + +class TestFunctionTool: + """Test cases for FunctionTool serialization with JSON.""" + + def test_basic_serialization_with_json(self): + """Test that FunctionTool can be serialized and deserialized with JSON.""" + # Create a FunctionTool with a callable + tool = FunctionTool( + name="test_tool", + description="A test tool", + params_json_schema={"type": "string"}, + strict_json_schema=True, + is_enabled=True, + on_invoke_tool=sample_handler, + ) + + # Serialize to JSON (this is what the caller will do) + json_data = json.dumps(tool.model_dump()) + + # Deserialize from JSON + data = json.loads(json_data) + new_tool = FunctionTool.model_validate(data) + + # Test that the callable is restored + assert new_tool.on_invoke_tool is not None + assert callable(new_tool.on_invoke_tool) + + # Test that the callable works as expected + result = new_tool.on_invoke_tool(None, "test_input") + assert result == "Processed: test_input" + + def test_complex_function_serialization(self): + """Test serialization of more complex functions.""" + tool = FunctionTool( + name="complex_tool", + description="A complex test tool", + params_json_schema={ + "type": "object", + "properties": {"key": {"type": "string"}}, + }, + on_invoke_tool=complex_handler, + ) + + # Serialize and deserialize via JSON + json_data = json.dumps(tool.model_dump()) + data = json.loads(json_data) + new_tool = FunctionTool.model_validate(data) + + # Test the complex function + test_input = '{"test": "value"}' + result = new_tool.on_invoke_tool(None, test_input) + + assert result["status"] == "success" + assert result["input"] == {"test": "value"} + + def test_none_callable_handling(self): + """Test that passing None for callable raises an error.""" + # Test that None callable raises ValueError + with pytest.raises( + ValueError, + match="One of `on_invoke_tool` or `on_invoke_tool_serialized` should be set", + ): + FunctionTool( + name="empty_tool", + description="Tool with no callable", + params_json_schema={"type": "string"}, + on_invoke_tool=None, + ) + + # Test with valid function - this should work + tool_func = FunctionTool( + name="func_tool", + description="Tool with function", + params_json_schema={"type": "string"}, + on_invoke_tool=sample_handler, + ) + assert tool_func.on_invoke_tool is not None + + def test_lambda_function_serialization(self): + """Test that lambda functions can be serialized.""" + # Set a lambda function + tool = FunctionTool( + name="lambda_tool", + description="Tool with lambda", + params_json_schema={"type": "string"}, + on_invoke_tool=lambda ctx, args: f"Lambda result: {args}", + ) + + # Serialize and deserialize via JSON + json_data = json.dumps(tool.model_dump()) + data = json.loads(json_data) + new_tool = FunctionTool.model_validate(data) + + # Test that the lambda works + result = new_tool.on_invoke_tool(None, "test") + assert result == "Lambda result: test" + + def test_closure_serialization(self): + """Test that closures can be serialized.""" + + def create_handler(prefix: str): + def handler(context, args: str) -> str: + return f"{prefix}: {args}" + + return handler + + # Set a closure + tool = FunctionTool( + name="closure_tool", + description="Tool with closure", + params_json_schema={"type": "string"}, + on_invoke_tool=create_handler("PREFIX"), + ) + + # Serialize and deserialize via JSON + json_data = json.dumps(tool.model_dump()) + data = json.loads(json_data) + new_tool = FunctionTool.model_validate(data) + + # Test that the closure works with captured variable + result = new_tool.on_invoke_tool(None, "test") + assert result == "PREFIX: test" + + def test_function_tool_with_none_handler_raises_error(self): + """Test that trying to create tool with None handler raises error.""" + # Test that None callable raises ValueError + with pytest.raises( + ValueError, + match="One of `on_invoke_tool` or `on_invoke_tool_serialized` should be set", + ): + FunctionTool( + name="none_handler_test", + description="Test tool with None handler", + params_json_schema={"type": "string"}, + on_invoke_tool=None, + ) + + def test_to_oai_function_tool_with_valid_handler(self): + """Test that to_oai_function_tool works with valid function.""" + tool = FunctionTool( + name="valid_handler_test", + description="Test tool with valid handler", + params_json_schema={"type": "string"}, + on_invoke_tool=sample_handler, + ) + + # This should work when on_invoke_tool is set + oai_tool = tool.to_oai_function_tool() + + # Verify the OAI tool was created successfully + assert oai_tool is not None + assert oai_tool.name == "valid_handler_test" + assert oai_tool.description == "Test tool with valid handler" + assert oai_tool.on_invoke_tool is not None + assert callable(oai_tool.on_invoke_tool) + + # Test that the handler works through the OAI tool + result = oai_tool.on_invoke_tool(None, "test_input") + assert result == "Processed: test_input" + + def test_serialization_error_handling(self): + """Test error handling when serialization fails.""" + + # Try to create a FunctionTool with an unserializable callable + class UnserializableCallable: + def __call__(self, context, args): + return "test" + + def __getstate__(self): + raise Exception("Cannot serialize this object") + + unserializable = UnserializableCallable() + + # This should raise an Exception during construction (from the unserializable object) + with pytest.raises(Exception, match="Cannot serialize this object"): + FunctionTool( + name="error_test_with_unserializable", + description="Test error handling with unserializable", + params_json_schema={"type": "string"}, + on_invoke_tool=unserializable, + ) + + def test_deserialization_error_handling(self): + """Test error handling when deserialization fails.""" + + # Create a tool and manually corrupt its serialized data to test deserialization error + # First, create a valid tool + valid_tool = FunctionTool( + name="valid_tool", + description="Valid tool for corruption", + params_json_schema={"type": "string"}, + on_invoke_tool=sample_handler, + ) + + # Serialize it + serialized_data = valid_tool.model_dump() + + # Corrupt the serialized callable data with invalid base64 + serialized_data["on_invoke_tool_serialized"] = ( + "invalid_base64_data!" # Add invalid character + ) + + # This should raise an error during model validation due to invalid base64 + with pytest.raises(Exception): # Could be ValidationError or ValueError + FunctionTool.model_validate(serialized_data) + + def test_full_roundtrip_with_serialization(self): + """Test a full roundtrip with a single tool.""" + tool = FunctionTool( + name="test_tool", + description="Test tool for roundtrip", + params_json_schema={"type": "string"}, + on_invoke_tool=lambda ctx, args: f"Tool result: {args}", + ) + + # Serialize tool to JSON + json_data = json.dumps(tool.model_dump()) + + # Deserialize from JSON + data = json.loads(json_data) + new_tool = FunctionTool.model_validate(data) + + # Test the tool + result = new_tool.on_invoke_tool(None, "test") + assert "Tool result: test" == result + + result = new_tool.to_oai_function_tool().on_invoke_tool(None, "test") + assert "Tool result: test" == result diff --git a/uv.lock b/uv.lock index 0b7b7472..8c6f4c13 100644 --- a/uv.lock +++ b/uv.lock @@ -1,14 +1,15 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.12, <4" [[package]] name = "agentex-sdk" -version = "0.4.0" +version = "0.4.4" source = { editable = "." } dependencies = [ { name = "aiohttp" }, { name = "anyio" }, + { name = "cloudpickle" }, { name = "distro" }, { name = "fastapi" }, { name = "httpx" }, @@ -62,6 +63,7 @@ requires-dist = [ { name = "aiohttp", specifier = ">=3.10.10,<4" }, { name = "aiohttp", marker = "extra == 'aiohttp'" }, { name = "anyio", specifier = ">=3.5.0,<5" }, + { name = "cloudpickle", specifier = ">=3.1.1" }, { name = "distro", specifier = ">=1.7.0,<2" }, { name = "fastapi", specifier = ">=0.115.0,<0.116" }, { name = "httpx", specifier = ">=0.27.2,<0.28" }, @@ -332,6 +334,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/85/32/10bb5764d90a8eee674e9dc6f4db6a0ab47c8c4d0d83c27f7c39ac415a4d/click-8.2.1-py3-none-any.whl", hash = "sha256:61a3265b914e850b85317d0b3109c7f8cd35a670f963866005d6ef1d5175a12b", size = 102215, upload-time = "2025-05-20T23:19:47.796Z" }, ] +[[package]] +name = "cloudpickle" +version = "3.1.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/52/39/069100b84d7418bc358d81669d5748efb14b9cceacd2f9c75f550424132f/cloudpickle-3.1.1.tar.gz", hash = "sha256:b216fa8ae4019d5482a8ac3c95d8f6346115d8835911fd4aefd1a445e4242c64", size = 22113, upload-time = "2025-01-14T17:02:05.085Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7e/e8/64c37fadfc2816a7701fa8a6ed8d87327c7d54eacfbfb6edab14a2f2be75/cloudpickle-3.1.1-py3-none-any.whl", hash = "sha256:c8c5a44295039331ee9dad40ba100a9c7297b6f988e50e87ccdf3765a668350e", size = 20992, upload-time = "2025-01-14T17:02:02.417Z" }, +] + [[package]] name = "colorama" version = "0.4.6" From d65379219226475d5a2c068055d6af87f458dc0a Mon Sep 17 00:00:00 2001 From: Rishav C Date: Tue, 19 Aug 2025 12:59:30 -0400 Subject: [PATCH 2/2] chore: demonstrate FunctionTool use in a (temporal) tutorial --- .../010_agent_chat/project/workflow.py | 177 +++++++++++++++--- 1 file changed, 149 insertions(+), 28 deletions(-) diff --git a/examples/tutorials/10_agentic/10_temporal/010_agent_chat/project/workflow.py b/examples/tutorials/10_agentic/10_temporal/010_agent_chat/project/workflow.py index a4444475..d4fe91d6 100644 --- a/examples/tutorials/10_agentic/10_temporal/010_agent_chat/project/workflow.py +++ b/examples/tutorials/10_agentic/10_temporal/010_agent_chat/project/workflow.py @@ -1,12 +1,12 @@ import os -from typing import Dict, List, override +import json +from typing import Dict, List, override, Any from dotenv import load_dotenv -from dotenv import load_dotenv from agentex.lib.utils.model_utils import BaseModel from mcp import StdioServerParameters from temporalio import workflow -from agents import ModelSettings +from agents import ModelSettings, RunContextWrapper from openai.types.shared import Reasoning from agentex.lib import adk @@ -14,30 +14,37 @@ from agentex.lib.core.temporal.workflows.workflow import BaseWorkflow from agentex.lib.core.temporal.types.workflow import SignalName from agentex.lib.utils.logging import make_logger -from agentex.lib.core.tracing.tracing_processor_manager import add_tracing_processor_config +from agentex.lib.core.tracing.tracing_processor_manager import ( + add_tracing_processor_config, +) from agentex.lib.types.tracing import SGPTracingProcessorConfig from agentex.lib.environment_variables import EnvironmentVariables from agentex.types.text_content import TextContent +from agentex.lib.core.temporal.activities.adk.providers.openai_activities import ( + FunctionTool, +) environment_variables = EnvironmentVariables.refresh() load_dotenv(dotenv_path=".env") -add_tracing_processor_config(SGPTracingProcessorConfig( - sgp_api_key=os.environ.get("SCALE_GP_API_KEY", ""), - sgp_account_id=os.environ.get("SCALE_GP_ACCOUNT_ID", ""), -)) +add_tracing_processor_config( + SGPTracingProcessorConfig( + sgp_api_key=os.environ.get("SCALE_GP_API_KEY", ""), + sgp_account_id=os.environ.get("SCALE_GP_ACCOUNT_ID", ""), + ) +) -if environment_variables.WORKFLOW_NAME is None: +if not environment_variables.WORKFLOW_NAME: raise ValueError("Environment variable WORKFLOW_NAME is not set") -if environment_variables.AGENT_NAME is None: +if not environment_variables.AGENT_NAME: raise ValueError("Environment variable AGENT_NAME is not set") logger = make_logger(__name__) class StateModel(BaseModel): - input_list: List[Dict] + input_list: List[Dict[str, Any]] turn_number: int @@ -49,44 +56,139 @@ class StateModel(BaseModel): StdioServerParameters( command="uvx", args=["openai-websearch-mcp"], - env={ - "OPENAI_API_KEY": os.environ.get("OPENAI_API_KEY", "") - } + env={"OPENAI_API_KEY": os.environ.get("OPENAI_API_KEY", "")}, ), ] + +async def calculator(context: RunContextWrapper, args: str) -> str: + """ + Simple calculator that can perform basic arithmetic operations. + + Args: + context: The run context wrapper + args: JSON string containing the operation and operands + + Returns: + String representation of the calculation result + """ + try: + # Parse the JSON arguments + parsed_args = json.loads(args) + operation = parsed_args.get("operation") + a = parsed_args.get("a") + b = parsed_args.get("b") + + if operation is None or a is None or b is None: + return ( + "Error: Missing required parameters. " + "Please provide 'operation', 'a', and 'b'." + ) + + # Convert to numbers + try: + a = float(a) + b = float(b) + except (ValueError, TypeError): + return "Error: 'a' and 'b' must be valid numbers." + + # Perform the calculation + if operation == "add": + result = a + b + elif operation == "subtract": + result = a - b + elif operation == "multiply": + result = a * b + elif operation == "divide": + if b == 0: + return "Error: Division by zero is not allowed." + result = a / b + else: + supported_ops = "add, subtract, multiply, divide" + return ( + f"Error: Unknown operation '{operation}'. " + f"Supported operations: {supported_ops}." + ) + + # Format the result nicely + if result == int(result): + return f"The result of {a} {operation} {b} is {int(result)}" + else: + formatted = f"{result:.6f}".rstrip("0").rstrip(".") + return f"The result of {a} {operation} {b} is {formatted}" + + except json.JSONDecodeError: + return "Error: Invalid JSON format in arguments." + except Exception as e: + return f"Error: An unexpected error occurred: {str(e)}" + + +# Create the calculator tool +CALCULATOR_TOOL = FunctionTool( + name="calculator", + description=( + "Performs basic arithmetic operations (add, subtract, multiply, divide) " + "on two numbers." + ), + params_json_schema={ + "type": "object", + "properties": { + "operation": { + "type": "string", + "enum": ["add", "subtract", "multiply", "divide"], + "description": "The arithmetic operation to perform", + }, + "a": {"type": "number", "description": "The first number"}, + "b": {"type": "number", "description": "The second number"}, + }, + "required": ["operation", "a", "b"], + "additionalProperties": False, + }, + strict_json_schema=True, + on_invoke_tool=calculator, +) + + @workflow.defn(name=environment_variables.WORKFLOW_NAME) class At010AgentChatWorkflow(BaseWorkflow): """ Minimal async workflow template for AgentEx Temporal agents. """ + def __init__(self): super().__init__(display_name=environment_variables.AGENT_NAME) self._complete_task = False - self._state = None + self._state: StateModel | None = None @workflow.signal(name=SignalName.RECEIVE_EVENT) @override async def on_task_event_send(self, params: SendEventParams) -> None: logger.info(f"Received task message instruction: {params}") - + if not params.event.content: return if params.event.content.type != "text": raise ValueError(f"Expected text message, got {params.event.content.type}") if params.event.content.author != "user": - raise ValueError(f"Expected user message, got {params.event.content.author}") - + raise ValueError( + f"Expected user message, got {params.event.content.author}" + ) + + if self._state is None: + raise ValueError("State is not initialized") + # Increment the turn number self._state.turn_number += 1 # Add the new user message to the message history - self._state.input_list.append({"role": "user", "content": params.event.content.content}) + self._state.input_list.append( + {"role": "user", "content": params.event.content.content} + ) async with adk.tracing.span( trace_id=params.task.id, name=f"Turn {self._state.turn_number}", - input=self._state + input=self._state, ) as span: # Echo back the user's message so it shows up in the UI. This is not done by default so the agent developer has full control over what is shown to the user. await adk.messages.create( @@ -102,7 +204,15 @@ async def on_task_event_send(self, params: SendEventParams) -> None: trace_id=params.task.id, content=TextContent( author="agent", - content="Hey, sorry I'm unable to respond to your message because you're running this example without an OpenAI API key. Please set the OPENAI_API_KEY environment variable to run this example. Do this by either by adding a .env file to the project/ directory or by setting the environment variable in your terminal.", + content=( + "Hey, sorry I'm unable to respond to your message " + "because you're running this example without an " + "OpenAI API key. Please set the OPENAI_API_KEY " + "environment variable to run this example. Do this " + "by either by adding a .env file to the project/ " + "directory or by setting the environment variable " + "in your terminal." + ), ), parent_span_id=span.id if span else None, ) @@ -115,9 +225,14 @@ async def on_task_event_send(self, params: SendEventParams) -> None: input_list=self._state.input_list, mcp_server_params=MCP_SERVERS, agent_name="Tool-Enabled Assistant", - agent_instructions="""You are a helpful assistant that can answer questions using various tools. - You have access to sequential thinking and web search capabilities through MCP servers. - Use these tools when appropriate to provide accurate and well-reasoned responses.""", + agent_instructions=( + "You are a helpful assistant that can answer questions " + "using various tools. You have access to sequential " + "thinking and web search capabilities through MCP servers, " + "as well as a calculator tool for performing basic " + "arithmetic operations. Use these tools when appropriate " + "to provide accurate and well-reasoned responses." + ), parent_span_id=span.id if span else None, model="o4-mini", model_settings=ModelSettings( @@ -125,12 +240,18 @@ async def on_task_event_send(self, params: SendEventParams) -> None: # response_include=["reasoning.encrypted_content"], # Ask the model to include a short reasoning summary reasoning=Reasoning(effort="medium", summary="auto"), - ) + ), + tools=[CALCULATOR_TOOL], ) - self._state.input_list = run_result.final_input_list + if self._state: + # Update the state with the final input list if available + final_list = getattr(run_result, "final_input_list", None) + if final_list is not None: + self._state.input_list = final_list # Set the span output to the state for the next turn - span.output = self._state + if span and self._state: + span.output = self._state.model_dump() @workflow.run @override @@ -151,5 +272,5 @@ async def on_task_create(self, params: CreateTaskParams) -> None: await workflow.wait_condition( lambda: self._complete_task, - timeout=None, # Set a timeout if you want to prevent the task from running indefinitely. Generally this is not needed. Temporal can run hundreds of millions of workflows in parallel and more. Only do this if you have a specific reason to do so. + timeout=None, # Set a timeout if you want to prevent the task from running indefinitely. Generally this is not needed. Temporal can run hundreds of millions of workflows in parallel and more. Only do this if you have a specific reason to do so. )