diff --git a/src/agentex/lib/core/clients/temporal/utils.py b/src/agentex/lib/core/clients/temporal/utils.py index ffdb04b6..9be7cf5c 100644 --- a/src/agentex/lib/core/clients/temporal/utils.py +++ b/src/agentex/lib/core/clients/temporal/utils.py @@ -3,6 +3,7 @@ from typing import Any from temporalio.client import Client, Plugin as ClientPlugin +from temporalio.worker import Interceptor from temporalio.runtime import Runtime, TelemetryConfig, OpenTelemetryConfig from temporalio.contrib.pydantic import pydantic_data_converter from temporalio.contrib.openai_agents import OpenAIAgentsPlugin @@ -61,6 +62,24 @@ def validate_client_plugins(plugins: list[Any]) -> None: ) +def validate_worker_interceptors(interceptors: list[Any]) -> None: + """ + Validate that all items in the interceptors list are valid Temporal worker interceptors. + + Args: + interceptors: List of interceptors to validate + + Raises: + TypeError: If any interceptor is not a valid Interceptor instance + """ + for i, interceptor in enumerate(interceptors): + if not isinstance(interceptor, Interceptor): + raise TypeError( + f"Interceptor at index {i} must be an instance of temporalio.worker.Interceptor, " + f"got {type(interceptor).__name__}" + ) + + async def get_temporal_client(temporal_address: str, metrics_url: str | None = None, plugins: list[Any] = []) -> Client: """ Create a Temporal client with plugin integration. diff --git a/src/agentex/lib/core/temporal/plugins/__init__.py b/src/agentex/lib/core/temporal/plugins/__init__.py new file mode 100644 index 00000000..52ab6eac --- /dev/null +++ b/src/agentex/lib/core/temporal/plugins/__init__.py @@ -0,0 +1,58 @@ +"""OpenAI Agents SDK Temporal Plugin with Streaming Support. + +This module provides streaming capabilities for the OpenAI Agents SDK in Temporal +using interceptors to thread task_id through workflows to activities. + +The streaming implementation works by: +1. Using Temporal interceptors to thread task_id through the execution +2. Streaming LLM responses to Redis in real-time from activities +3. Returning complete responses to maintain Temporal determinism + +Example: + >>> from agentex.lib.core.temporal.plugins.openai_agents import ( + ... TemporalStreamingModelProvider, + ... TemporalTracingModelProvider, + ... ContextInterceptor, + ... ) + >>> from temporalio.contrib.openai_agents import OpenAIAgentsPlugin, ModelActivityParameters + >>> from datetime import timedelta + >>> + >>> # Create streaming model provider + >>> model_provider = TemporalStreamingModelProvider() + >>> + >>> # Create STANDARD plugin with streaming model provider + >>> plugin = OpenAIAgentsPlugin( + ... model_params=ModelActivityParameters( + ... start_to_close_timeout=timedelta(seconds=120), + ... ), + ... model_provider=model_provider, + ... ) + >>> + >>> # Register interceptor with worker + >>> interceptor = ContextInterceptor() + >>> # Add interceptor to worker configuration +""" + +from agentex.lib.core.temporal.plugins.openai_agents import ( + ContextInterceptor, + TemporalStreamingHooks, + TemporalStreamingModel, + TemporalTracingModelProvider, + TemporalStreamingModelProvider, + streaming_task_id, + streaming_trace_id, + stream_lifecycle_content, + streaming_parent_span_id, +) + +__all__ = [ + "TemporalStreamingModel", + "TemporalStreamingModelProvider", + "TemporalTracingModelProvider", + "ContextInterceptor", + "streaming_task_id", + "streaming_trace_id", + "streaming_parent_span_id", + "TemporalStreamingHooks", + "stream_lifecycle_content", +] \ No newline at end of file diff --git a/src/agentex/lib/core/temporal/plugins/openai_agents/README.md b/src/agentex/lib/core/temporal/plugins/openai_agents/README.md new file mode 100644 index 00000000..5497c466 --- /dev/null +++ b/src/agentex/lib/core/temporal/plugins/openai_agents/README.md @@ -0,0 +1,750 @@ +# Temporal + OpenAI Agents SDK Streaming Implementation + +## TL;DR + +We use Temporal interceptors to add real-time streaming to Redis/UI while maintaining workflow determinism with the STANDARD OpenAI Agents plugin. The key challenge was threading `task_id` (only known at runtime) through a plugin system initialized at startup. We solved this using Temporal's interceptor pattern to inject task_id into activity headers, making it available via context variables in the model. + +**What we built**: Real-time streaming of LLM responses to users while preserving Temporal's durability guarantees. + +**How**: Interceptors thread task_id → Model reads from context → stream to Redis during activity → return complete response for determinism. + +**The win**: NO forked plugin needed - uses standard `temporalio.contrib.openai_agents.OpenAIAgentsPlugin`! + +## Table of Contents +1. [Background: How OpenAI Agents SDK Works](#background-how-openai-agents-sdk-works) +2. [How Temporal's OpenAI Plugin Works](#how-temporals-openai-plugin-works) +3. [The Streaming Challenge](#the-streaming-challenge) +4. [Our Streaming Solution](#our-streaming-solution) +5. [Implementation Details](#implementation-details) +6. [Usage](#usage) +7. [Drawbacks and Maintenance](#drawbacks-and-maintenance) + +--- + +## Background: How OpenAI Agents SDK Works + +Before diving into Temporal integration, let's understand the basic OpenAI Agents SDK flow: + +```python +# Standard OpenAI Agents SDK usage +agent = Agent( + name="Assistant", + model="gpt-4", + instructions="You are a helpful assistant" +) + +# Under the hood, this happens: +runner = AgentRunner() +result = await runner.run(agent, "Hello") +# ↓ +# runner.run() calls agent.model.get_response() +# ↓ +# model.get_response() makes the actual LLM API call to OpenAI +``` + +The key insight: **`model.get_response()`** is where the actual LLM call happens. + +--- + +## How Temporal's OpenAI Plugin Works + +The Temporal plugin intercepts this flow to make LLM calls durable by converting them into Temporal activities. Here's how: + +### 1. Plugin Setup and Runner Override + +When you create the Temporal plugin and pass it to the worker: + +```python +# In _temporal_openai_agents.py (lines ~72-112) +@contextmanager +def set_open_ai_agent_temporal_overrides(model_params): + # This is the critical line - replaces the default runner! + set_default_agent_runner(TemporalOpenAIRunner(model_params)) +``` + +### 2. Model Interception Chain + +Here's the clever interception that happens: + +``` +Original OpenAI SDK Flow: +┌─────────┐ ┌──────────────┐ ┌───────────────────┐ ┌────────────┐ +│ Agent │ --> │ Runner.run() │ --> │ Model.get_response│ --> │ OpenAI API │ +└─────────┘ └──────────────┘ └───────────────────┘ └────────────┘ + +Temporal Plugin Flow: +┌─────────┐ ┌────────────────────┐ ┌──────────────────────┐ +│ Agent │ --> │ TemporalRunner.run │ --> │ _TemporalModelStub │ +└─────────┘ └────────────────────┘ │ .get_response() │ + └──────────┬───────────┘ + ↓ + ┌──────────────────────┐ + │ Temporal Activity │ + │ "invoke_model_activity"│ + └──────────┬───────────┘ + ↓ + ┌──────────────────────┐ ┌────────────┐ + │ Model.get_response() │ --> │ OpenAI API │ + └──────────────────────┘ └────────────┘ +``` + +### 3. The Model Stub Trick + +The `TemporalOpenAIRunner` replaces the agent's model with `_TemporalModelStub`: + +```python +# In _openai_runner.py +def _convert_agent(agent): + # Replace the model with a stub + new_agent.model = _TemporalModelStub( + model_name=agent.model, + model_params=model_params + ) + return new_agent +``` + +### 4. Activity Creation + +The `_TemporalModelStub` doesn't call the LLM directly. Instead, it creates a Temporal activity: + +```python +# In _temporal_model_stub.py +class _TemporalModelStub: + async def get_response(self, ...): + # Instead of calling the LLM, create an activity! + return await workflow.execute_activity_method( + ModelActivity.invoke_model_activity, # ← This becomes visible in Temporal UI + activity_input, + ... + ) +``` + +### 5. Actual LLM Call in Activity + +Finally, inside the activity, the real LLM call happens: + +```python +# In _invoke_model_activity.py +class ModelActivity: + async def invoke_model_activity(self, input): + model = self._model_provider.get_model(input["model_name"]) + # NOW we actually call the LLM + return await model.get_response(...) # ← Real OpenAI API call +``` + +**Summary**: The plugin intercepts at TWO levels: +1. **Runner level**: Replaces default runner with TemporalRunner +2. **Model level**: Replaces agent.model with _TemporalModelStub that creates activities + +--- + +## The Streaming Challenge + +### Why Temporal Doesn't Support Streaming by Default + +Temporal's philosophy is that activities should be: +- **Idempotent**: Same input → same output +- **Retriable**: Can restart from beginning on failure +- **Deterministic**: Replays produce identical results + +Streaming breaks these guarantees: +- If streaming fails halfway, where do you restart? +- How do you replay a stream deterministically? +- Partial responses violate idempotency + +### Why We Need Streaming Anyway + +For Scale/AgentEx customers, **latency is critical**: +- Time to first token matters more than total generation time +- Users expect to see responses as they're generated +- 10-30 second waits for long responses are unacceptable + +Our pragmatic decision: **Accept the tradeoff**. If streaming fails midway, we restart from the beginning. This may cause a brief UX hiccup but enables the streaming experience users expect. + +--- + +## Our Streaming Solution + +### The Key Insight: Where We Can Hook In + +When we instantiate the OpenAI plugin for Temporal, we can pass in a **model provider**: + +```python +plugin = OpenAIAgentsPlugin( + model_provider=StreamingModelProvider() # ← This is our hook! +) +``` + +**IMPORTANT**: This model provider returns the ACTUAL model that makes the LLM call - this is the final layer, NOT the stub. This is where `model.get_response()` actually calls OpenAI's API. By providing our own model here, we can: + +1. Make the same OpenAI chat completion call with `stream=True` +2. Capture chunks as they arrive +3. Stream them to Redis +4. Still return the complete response for Temporal + +Our `StreamingModel` implementation: +1. **Streams to Redis** using XADD commands +2. **Returns complete response** to maintain Temporal determinism + +### The Task ID Problem + +Here's the critical issue we had to solve: + +``` +Timeline of Execution: +═══════════════════════════════════════════════════════════════════ +Time T0: Application Startup + plugin = CustomStreamingOpenAIAgentsPlugin( + model_provider=StreamingModelProvider() ← No task_id exists yet! + ) + +Time T1: Worker Creation + worker = Worker(plugins=[plugin]) ← Still no task_id! + +Time T2: Worker Starts + await worker.run() ← Still no task_id! + +Time T3: Workflow Receives Request + @workflow.defn + async def on_task_create(params): + task_id = params.task.id ← task_id CREATED HERE! 🎯 + +Time T4: Model Needs to Stream + StreamingModel.get_response(...?) ← Need task_id but how?! +═══════════════════════════════════════════════════════════════════ +``` + +**The problem**: The model provider is configured before we know the task_id, but streaming requires task_id to route to the correct Redis channel. + +### Our Solution: Temporal Interceptors + Context Variables + +Instead of forking the plugin, we use Temporal's interceptor pattern to thread task_id through the system. This elegant solution uses standard Temporal features and requires NO custom plugin components! + +Here's exactly how task_id flows through the interceptor chain: + +``` +┌──────────────────────────────────────────────────────────────────┐ +│ WORKFLOW EXECUTION │ +│ self._task_id = params.task.id <-- Store in instance variable │ +└────────────────────────────┬─────────────────────────────────────┘ + ↓ workflow.instance() +┌──────────────────────────────────────────────────────────────────┐ +│ StreamingWorkflowOutboundInterceptor │ +│ • Reads _task_id from workflow.instance() │ +│ • Injects into activity headers │ +└────────────────────────────┬─────────────────────────────────────┘ + ↓ headers["streaming-task-id"]="abc123" +┌──────────────────────────────────────────────────────────────────┐ +│ STANDARD Temporal Plugin │ +│ • Uses standard TemporalRunner (no fork!) │ +│ • Uses standard TemporalModelStub (no fork!) │ +│ • Creates standard invoke_model_activity │ +└────────────────────────────┬─────────────────────────────────────┘ + ↓ activity with headers +┌──────────────────────────────────────────────────────────────────┐ +│ StreamingActivityInboundInterceptor │ +│ • Extracts task_id from headers │ +│ • Sets streaming_task_id ContextVar │ +└────────────────────────────┬─────────────────────────────────────┘ + ↓ streaming_task_id.set("abc123") +┌──────────────────────────────────────────────────────────────────┐ +│ StreamingModel.get_response() │ +│ • Reads task_id from streaming_task_id.get() │ +│ • Streams chunks to Redis channel: "stream:abc123" │ +│ • Returns complete response for Temporal │ +└──────────────────────────────────────────────────────────────────┘ + ↓ +┌──────────────────────────────────────────────────────────────────┐ +│ REDIS │ +│ XADD stream:abc123 chunk1, chunk2, chunk3... │ +└────────────────────────────┬─────────────────────────────────────┘ + ↓ +┌──────────────────────────────────────────────────────────────────┐ +│ UI SUBSCRIBER │ +│ Reads from stream:abc123 and displays real-time updates │ +└──────────────────────────────────────────────────────────────────┘ +``` + +--- + +## Implementation Details + +### The Interceptor Approach - Clean and Maintainable + +Instead of forking components, we use Temporal's interceptor system. Here's what we built: + +### 1. StreamingInterceptor - The Main Component + +```python +# streaming_interceptor.py +class StreamingInterceptor(Interceptor): + """Main interceptor that enables task_id threading.""" + + def intercept_activity(self, next): + # Create activity interceptor to extract headers + return StreamingActivityInboundInterceptor(next, self._payload_converter) + + def workflow_interceptor_class(self, input): + # Return workflow interceptor class + return StreamingWorkflowInboundInterceptor +``` + +### 2. Task ID Flow - Using Standard Components + +Here's EXACTLY how task_id flows through the system without any forked components: + +#### Step 1: Workflow stores task_id in instance variable +```python +# workflow.py +self._task_id = params.task.id # Store in instance variable +result = await Runner.run(agent, input) # No context needed! +``` + +#### Step 2: Outbound Interceptor injects task_id into headers +```python +# StreamingWorkflowOutboundInterceptor +def start_activity(self, input): + workflow_instance = workflow.instance() + task_id = getattr(workflow_instance, '_task_id', None) + if task_id and "invoke_model_activity" in str(input.activity): + input.headers["streaming-task-id"] = self._payload_converter.to_payload(task_id) +``` + +#### Step 3: Inbound Interceptor extracts from headers and sets context +```python +# StreamingActivityInboundInterceptor +async def execute_activity(self, input): + if input.headers and "streaming-task-id" in input.headers: + task_id = self._payload_converter.from_payload(input.headers["streaming-task-id"], str) + streaming_task_id.set(task_id) # Set ContextVar! +``` + +#### Step 4: StreamingModel reads from context variable +```python +# StreamingModel.get_response() +from agentex.lib.core.temporal.plugins.openai_agents.streaming_interceptor import ( + streaming_task_id, + streaming_trace_id, + streaming_parent_span_id +) + +async def get_response(self, ...): + # Read from ContextVar - set by interceptor! + task_id = streaming_task_id.get() + trace_id = streaming_trace_id.get() + parent_span_id = streaming_parent_span_id.get() + + if task_id: + # Open streaming context to Redis + async with adk.streaming.streaming_task_message_context( + task_id=task_id, + ... + ) as streaming_context: + # Stream tokens as they arrive + ... +``` + +### 3. Worker Configuration - Simply Add the Interceptor + +```python +# run_worker.py +from temporalio.contrib.openai_agents import OpenAIAgentsPlugin # STANDARD! +from agentex.lib.core.temporal.plugins.openai_agents import ( + StreamingInterceptor, + StreamingModelProvider, +) + +# Create the interceptor +interceptor = StreamingInterceptor() + +# Use STANDARD plugin with streaming model provider +plugin = OpenAIAgentsPlugin( + model_provider=StreamingModelProvider(), + model_params=ModelActivityParameters(...) +) + +# Create worker with interceptor +worker = Worker( + client, + task_queue="example_tutorial_queue", + workflows=[ExampleTutorialWorkflow], + activities=[...], + interceptors=[interceptor], # Just add interceptor! +) +``` + +### 4. The Streaming Model - Where Magic Happens + +This is where the actual streaming happens. Our `StreamingModel` is what gets called inside the activity: + +```python +# streaming_model.py +class StreamingModel(Model): + async def get_response(self, ..., task_id=None): + # 1. Open Redis streaming context with task_id + async with adk.streaming.streaming_task_message_context( + task_id=task_id, # ← This creates Redis channel stream:abc123 + initial_content=TextContent(author="agent", content="") + ) as streaming_context: + + # 2. Make OpenAI call WITH STREAMING + stream = await self.client.chat.completions.create( + model=self.model_name, + messages=messages, + stream=True, # ← Enable streaming! + # ... other params ... + ) + + # 3. Process chunks as they arrive + full_content = "" + async for chunk in stream: + if chunk.choices and chunk.choices[0].delta.content: + content = chunk.choices[0].delta.content + full_content += content + + # 4. Stream to Redis (UI sees this immediately!) + delta = TextDelta(type="text", text_delta=content) + update = StreamTaskMessageDelta( + parent_task_message=streaming_context.task_message, + delta=delta, + type="delta" + ) + await streaming_context.stream_update(update) + + # 5. Handle tool calls (sent as complete messages, not streamed) + if tool_calls: + for tool_call_data in tool_calls.values(): + tool_request = ToolRequestContent( + author="agent", + tool_call_id=tool_call_data["id"], + name=tool_call_data["function"]["name"], + arguments=json.loads(tool_call_data["function"]["arguments"]) + ) + + # Tool calls use StreamTaskMessageFull (complete message) + async with adk.streaming.streaming_task_message_context( + task_id=task_id, + initial_content=tool_request + ) as tool_context: + await tool_context.stream_update( + StreamTaskMessageFull( + parent_task_message=tool_context.task_message, + content=tool_request, + type="full" + ) + ) + + # 6. Handle reasoning tokens (o1 models) + if reasoning_content: # For o1 models + reasoning = ReasoningContent( + author="agent", + summary=[reasoning_content], + type="reasoning" + ) + # Stream reasoning as complete message + await stream_reasoning_update(reasoning) + + # 7. Context auto-closes and saves to DB + # The streaming_task_message_context: + # - Accumulates all chunks + # - Saves complete message to database + # - Sends DONE signal to Redis + + # 8. Return complete response for Temporal determinism + return ModelResponse( + output=output_items, # Complete response + usage=usage, + response_id=completion_id + ) +``` + +### 5. Redis and AgentEx Streaming Infrastructure + +Here's what happens under the hood with AgentEx's streaming system: + +#### Redis Implementation Details + +1. **Channel Creation**: `stream:{task_id}` - Each task gets its own Redis stream +2. **XADD Commands**: Each chunk is appended using Redis XADD +3. **Message Types**: + - `StreamTaskMessageDelta`: For text chunks (token by token) + - `StreamTaskMessageFull`: For complete messages (tool calls, reasoning) +4. **Auto-accumulation**: The streaming context accumulates all chunks +5. **Database Persistence**: Complete message saved to DB when context closes +6. **DONE Signal**: Sent to Redis when streaming completes + +#### What Gets Streamed + +```python +# Text content - streamed token by token +await streaming_context.stream_update( + StreamTaskMessageDelta(delta=TextDelta(text_delta=chunk)) +) + +# Tool calls - sent as complete messages +await streaming_context.stream_update( + StreamTaskMessageFull(content=ToolRequestContent(...)) +) + +# Reasoning (o1 models) - sent as complete +await streaming_context.stream_update( + StreamTaskMessageFull(content=ReasoningContent(...)) +) + +# Guardrails - sent as complete +await streaming_context.stream_update( + StreamTaskMessageFull(content=GuardrailContent(...)) +) +``` + +#### UI Subscription + +The frontend subscribes to `stream:{task_id}` and receives: +1. Real-time text chunks as they're generated +2. Complete tool calls when they're ready +3. Reasoning summaries for o1 models +4. DONE signal when complete + +This decoupling means we can stream anything we want through Redis! + +### 6. Workflow Integration + +```python +# workflow.py +@workflow.defn +class ExampleWorkflow: + async def on_task_event_send(self, params): + # Pass task_id through context + context = {"task_id": params.task.id} # ← Critical line! + + runner = get_default_agent_runner() # Gets our StreamingTemporalRunner + result = await runner.run(agent, input, context=context) +``` + +--- + +## Usage + +### Installation + +This plugin is included in the agentex-python package. No additional installation needed. + +### Basic Setup + +```python +from agentex.lib.core.temporal.plugins.openai_agents import ( + CustomStreamingOpenAIAgentsPlugin, + StreamingModelProvider, +) +from temporalio.contrib.openai_agents import ModelActivityParameters +from temporalio.client import Client +from temporalio.worker import Worker +from datetime import timedelta + +# Create streaming model provider +model_provider = StreamingModelProvider() + +# Create plugin with streaming support +plugin = CustomStreamingOpenAIAgentsPlugin( + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=120), + ), + model_provider=model_provider, +) + +# Use with Temporal client +client = await Client.connect( + "localhost:7233", + plugins=[plugin] +) + +# Create worker with the plugin +worker = Worker( + client, + task_queue="my-task-queue", + workflows=[MyWorkflow], +) +``` + +### In Your Workflow + +```python +from agents import Agent +from agents.run import get_default_agent_runner + +@workflow.defn +class MyWorkflow: + @workflow.run + async def run(self, params): + # Create an agent + agent = Agent( + name="Assistant", + instructions="You are a helpful assistant", + model="gpt-4o", + ) + + # Pass task_id through context for streaming + context = {"task_id": params.task.id} + + # Run the agent - streaming happens automatically! + runner = get_default_agent_runner() + result = await runner.run( + agent, + params.event.content, + context=context # task_id enables streaming + ) + + return result.final_output +``` + +### Comparison with Original Temporal Plugin + +| Feature | Original Plugin | Streaming Plugin | +|---------|----------------|------------------| +| **Response Time** | Complete response only (10-30s wait) | Real-time streaming (immediate feedback) | +| **User Experience** | No feedback during generation | See response as it's generated | +| **Task ID Support** | Not supported | Runtime extraction and threading | +| **Activity Name** | `invoke_model_activity` | `invoke_model_activity_streaming` | +| **Model Stub** | `_TemporalModelStub` | `StreamingTemporalModelStub` | +| **Runner** | `TemporalOpenAIRunner` | `StreamingTemporalRunner` | +| **Redis Integration** | None | Full streaming via AgentEx ADK | +| **Temporal Determinism** | ✅ Yes | ✅ Yes (returns complete response) | +| **Replay Safety** | ✅ Yes | ✅ Yes (streaming is side-effect only) | + +--- + +## Benefits of the Interceptor Approach + +### Major Advantages Over Forking + +1. **No Code Duplication**: Uses standard `temporalio.contrib.openai_agents` plugin + - Automatic compatibility with Temporal updates + - No risk of divergence from upstream features + - Zero maintenance of forked code + +2. **Clean Architecture**: + - Interceptors are Temporal's official extension mechanism + - Clear separation between streaming logic and core plugin + - Easy to enable/disable streaming by adding/removing interceptor + +3. **Simplicity**: + - Single interceptor handles all task_id threading + - Uses Python's ContextVar for thread-safe async state + - No need to understand Temporal plugin internals + +### Minimal Limitations + +1. **Streaming Semantics** (unchanged): + - On failure, streaming restarts from beginning (may show duplicate partial content) + - This is acceptable for user experience + +2. **Worker Configuration**: + - Must register interceptor with worker + - Workflow must store task_id in instance variable + +### Future Improvements + +1. **Contribute Back**: + - This pattern could be contributed to Temporal as an example + - Shows how to extend plugins without forking + +2. **Enhanced Features**: + - Could add request/response interceptors for other use cases + - Pattern works for any runtime context threading need + +### Alternative Approaches Considered + +1. **Workflow-level streaming**: Stream directly from workflow (violates determinism) +2. **Separate streaming service**: Additional infrastructure complexity +3. **Polling pattern**: Poor latency characteristics +4. **WebSockets**: Doesn't integrate with existing AgentEx infrastructure + +--- + +## Key Innovation + +The most important innovation is **using interceptors for runtime context threading**. Instead of forking the plugin to pass task_id through custom components, we use Temporal's interceptor system with Python's ContextVar. This allows: + +- One plugin instance for all workflows (standard plugin!) +- Dynamic streaming channels per execution +- Clean separation of concerns +- No forked components to maintain +- Thread-safe async context propagation +- Compatible with all Temporal updates + +--- + +## Troubleshooting + +**No streaming visible in UI:** +- Ensure task_id is passed in the context: `context = {"task_id": params.task.id}` +- Verify Redis is running and accessible +- Check that the UI is subscribed to the correct task channel + +**Import errors:** +- Make sure agentex-python/src is in your Python path +- Install required dependencies: `uv add agentex-sdk openai-agents temporalio` + +**Activity not found:** +- Ensure the plugin is registered with both client and worker +- Check that `invoke_model_activity_streaming` is registered + +--- + +## Testing + +### Running Tests + +The streaming model implementation has comprehensive tests in `tests/test_streaming_model.py` that verify all configurations, tool types, and edge cases. + +#### From Repository Root + +```bash +# Run all tests +rye run pytest src/agentex/lib/core/temporal/plugins/openai_agents/tests/test_streaming_model.py -v + +# Run without parallel execution (more stable) +rye run pytest src/agentex/lib/core/temporal/plugins/openai_agents/tests/test_streaming_model.py -v -n0 + +# Run specific test +rye run pytest src/agentex/lib/core/temporal/plugins/openai_agents/tests/test_streaming_model.py::TestStreamingModelSettings::test_temperature_setting -v +``` + +#### From Test Directory + +```bash +cd src/agentex/lib/core/temporal/plugins/openai_agents/tests + +# Run all tests +rye run pytest test_streaming_model.py -v + +# Run without parallel execution (recommended) +rye run pytest test_streaming_model.py -v -n0 + +# Run specific test class +rye run pytest test_streaming_model.py::TestStreamingModelSettings -v +``` + +#### Test Coverage + +The test suite covers: +- **ModelSettings**: All configuration parameters (temperature, reasoning, truncation, etc.) +- **Tool Types**: Function tools, web search, file search, computer tools, MCP tools, etc. +- **Streaming**: Redis context creation, task ID threading, error handling +- **Edge Cases**: Missing task IDs, multiple computer tools, handoffs + +**Note**: Tests run faster without parallel execution (`-n0` flag) and avoid potential state pollution between test workers. All 29 tests pass individually; parallel execution may show 4-6 intermittent failures due to shared mock state. + +--- + +## Conclusion + +This implementation uses Temporal interceptors to thread task_id through the standard OpenAI plugin to enable real-time streaming while maintaining workflow determinism. The key innovation is using interceptors with Python's ContextVar to propagate runtime context without forking any Temporal components. + +This approach provides the optimal user experience with: +- **Zero code duplication** - uses standard Temporal plugin +- **Minimal maintenance** - only interceptor and streaming model to maintain +- **Clean architecture** - leverages Temporal's official extension mechanism +- **Full compatibility** - works with all Temporal and OpenAI SDK updates + +The interceptor pattern demonstrates how to extend Temporal plugins without forking, setting a precedent for future enhancements. \ No newline at end of file diff --git a/src/agentex/lib/core/temporal/plugins/openai_agents/__init__.py b/src/agentex/lib/core/temporal/plugins/openai_agents/__init__.py new file mode 100644 index 00000000..def67c9a --- /dev/null +++ b/src/agentex/lib/core/temporal/plugins/openai_agents/__init__.py @@ -0,0 +1,84 @@ +"""OpenAI Agents SDK Temporal Plugin with Streaming Support. + +This module provides streaming capabilities for the OpenAI Agents SDK in Temporal +using interceptors to thread task_id through workflows to activities. + +The streaming implementation works by: +1. Using Temporal interceptors to thread task_id through the execution +2. Streaming LLM responses to Redis in real-time from activities +3. Streaming lifecycle events (tool calls, handoffs) via hooks and activities +4. Returning complete responses to maintain Temporal determinism + +Example - Complete Setup: + >>> from agentex.lib.core.temporal.plugins.openai_agents import ( + ... StreamingModelProvider, + ... TemporalStreamingHooks, + ... ContextInterceptor, + ... ) + >>> from temporalio.contrib.openai_agents import OpenAIAgentsPlugin, ModelActivityParameters + >>> from datetime import timedelta + >>> from agents import Agent, Runner + >>> + >>> # 1. Create streaming model provider + >>> model_provider = StreamingModelProvider() + >>> + >>> # 2. Create STANDARD plugin with streaming model provider + >>> plugin = OpenAIAgentsPlugin( + ... model_params=ModelActivityParameters( + ... start_to_close_timeout=timedelta(seconds=120), + ... ), + ... model_provider=model_provider, + ... ) + >>> + >>> # 3. Register interceptor with worker + >>> interceptor = ContextInterceptor() + >>> # Add interceptor to worker configuration + >>> + >>> # 4. In workflow, store task_id in instance variable + >>> self._task_id = params.task.id + >>> + >>> # 5. Create hooks for streaming lifecycle events + >>> hooks = TemporalStreamingHooks(task_id="your-task-id") + >>> + >>> # 6. Run agent - interceptor handles task_id threading automatically + >>> result = await Runner.run(agent, input, hooks=hooks) + +This gives you: +- Real-time streaming of LLM responses (via StreamingModel + interceptors) +- Real-time streaming of tool calls (via TemporalStreamingHooks) +- Real-time streaming of agent handoffs (via TemporalStreamingHooks) +- Full Temporal durability and observability +- No forked plugin required - uses standard OpenAIAgentsPlugin +""" + +from agentex.lib.core.temporal.plugins.openai_agents.hooks.hooks import ( + TemporalStreamingHooks, +) +from agentex.lib.core.temporal.plugins.openai_agents.hooks.activities import ( + stream_lifecycle_content, +) +from agentex.lib.core.temporal.plugins.openai_agents.models.temporal_tracing_model import ( + TemporalTracingModelProvider, +) +from agentex.lib.core.temporal.plugins.openai_agents.models.temporal_streaming_model import ( + TemporalStreamingModel, + TemporalStreamingModelProvider, +) +from agentex.lib.core.temporal.plugins.openai_agents.interceptors.context_interceptor import ( + ContextInterceptor, + streaming_task_id, + streaming_trace_id, + streaming_parent_span_id, +) + +__all__ = [ + "TemporalStreamingModel", + "TemporalStreamingModelProvider", + "TemporalTracingModelProvider", + "ContextInterceptor", + "streaming_task_id", + "streaming_trace_id", + "streaming_parent_span_id", + "TemporalStreamingHooks", + "stream_lifecycle_content", +] \ No newline at end of file diff --git a/src/agentex/lib/core/temporal/plugins/openai_agents/hooks/__init__.py b/src/agentex/lib/core/temporal/plugins/openai_agents/hooks/__init__.py new file mode 100644 index 00000000..7a01e3f5 --- /dev/null +++ b/src/agentex/lib/core/temporal/plugins/openai_agents/hooks/__init__.py @@ -0,0 +1,17 @@ +"""Temporal streaming hooks and activities for OpenAI Agents SDK. + +This module provides hooks for streaming agent lifecycle events and +activities for streaming content to the AgentEx UI. +""" + +from agentex.lib.core.temporal.plugins.openai_agents.hooks.hooks import ( + TemporalStreamingHooks, +) +from agentex.lib.core.temporal.plugins.openai_agents.hooks.activities import ( + stream_lifecycle_content, +) + +__all__ = [ + "TemporalStreamingHooks", + "stream_lifecycle_content", +] \ No newline at end of file diff --git a/src/agentex/lib/core/temporal/plugins/openai_agents/hooks/activities.py b/src/agentex/lib/core/temporal/plugins/openai_agents/hooks/activities.py new file mode 100644 index 00000000..bcd82385 --- /dev/null +++ b/src/agentex/lib/core/temporal/plugins/openai_agents/hooks/activities.py @@ -0,0 +1,78 @@ +"""Temporal activities for streaming agent lifecycle events. + +This module provides reusable Temporal activities for streaming content +to the AgentEx UI, designed to work with TemporalStreamingHooks. +""" + +from typing import Union + +from temporalio import activity + +from agentex.lib import adk +from agentex.types.text_content import TextContent +from agentex.types.task_message_update import StreamTaskMessageFull +from agentex.types.task_message_content import ( + TaskMessageContent, + ToolRequestContent, + ToolResponseContent, +) + + +@activity.defn(name="stream_lifecycle_content") +async def stream_lifecycle_content( + task_id: str, + content: Union[TextContent, ToolRequestContent, ToolResponseContent, TaskMessageContent], +) -> None: + """Stream agent lifecycle content to the AgentEx UI. + + This is a universal streaming activity that can handle any type of agent + lifecycle content (text messages, tool requests, tool responses, etc.). + It uses the AgentEx streaming context to send updates to the UI in real-time. + + Designed to work seamlessly with TemporalStreamingHooks. The hooks class + will call this activity automatically when lifecycle events occur. + + Args: + task_id: The AgentEx task ID for routing the content to the correct UI session + content: The content to stream - can be any of: + - TextContent: Plain text messages (e.g., handoff notifications) + - ToolRequestContent: Tool invocation requests with call_id and name + - ToolResponseContent: Tool execution results with call_id and output + - TaskMessageContent: Generic task message content + + Example: + Register this activity with your Temporal worker:: + + from agentex.lib.core.temporal.plugins.openai_agents import ( + TemporalStreamingHooks, + stream_lifecycle_content, + ) + + # In your workflow + hooks = TemporalStreamingHooks( + task_id=params.task.id, + stream_activity=stream_lifecycle_content + ) + result = await Runner.run(agent, input, hooks=hooks) + + Note: + This activity is non-blocking and will not throw exceptions to the workflow. + Any streaming errors are logged but do not fail the activity. This ensures + that streaming failures don't break the agent execution. + """ + try: + async with adk.streaming.streaming_task_message_context( + task_id=task_id, + initial_content=content, + ) as streaming_context: + # Send the content as a full message update + await streaming_context.stream_update( + StreamTaskMessageFull( + parent_task_message=streaming_context.task_message, + content=content, + type="full", + ) + ) + except Exception as e: + # Log error but don't fail the activity - streaming failures shouldn't break execution + activity.logger.warning(f"Failed to stream content to task {task_id}: {e}") diff --git a/src/agentex/lib/core/temporal/plugins/openai_agents/hooks/hooks.py b/src/agentex/lib/core/temporal/plugins/openai_agents/hooks/hooks.py new file mode 100644 index 00000000..2d2765fb --- /dev/null +++ b/src/agentex/lib/core/temporal/plugins/openai_agents/hooks/hooks.py @@ -0,0 +1,201 @@ +"""Temporal streaming hooks for OpenAI Agents SDK lifecycle events. + +This module provides a convenience class for streaming agent lifecycle events +to the AgentEx UI via Temporal activities. +""" + +import logging +from typing import Any, override +from datetime import timedelta + +from agents import Tool, Agent, RunHooks, RunContextWrapper +from temporalio import workflow +from agents.tool_context import ToolContext + +from agentex.types.text_content import TextContent +from agentex.types.task_message_content import ToolRequestContent, ToolResponseContent +from agentex.lib.core.temporal.plugins.openai_agents.hooks.activities import stream_lifecycle_content + +logger = logging.getLogger(__name__) + + +class TemporalStreamingHooks(RunHooks): + """Convenience hooks class for streaming OpenAI Agent lifecycle events to the AgentEx UI. + + This class automatically streams agent lifecycle events (tool calls, handoffs) to the + AgentEx UI via Temporal activities. It subclasses the OpenAI Agents SDK's RunHooks + to intercept lifecycle events and forward them for real-time UI updates. + + Lifecycle events streamed: + - Tool requests (on_tool_start): Streams when a tool is about to be invoked + - Tool responses (on_tool_end): Streams the tool's execution result + - Agent handoffs (on_handoff): Streams when control transfers between agents + + Usage: + Basic usage - streams all lifecycle events:: + + from agentex.lib.core.temporal.plugins.openai_agents import TemporalStreamingHooks + + hooks = TemporalStreamingHooks(task_id="abc123") + result = await Runner.run(agent, input, hooks=hooks) + + Advanced - subclass for custom behavior:: + + class MyCustomHooks(TemporalStreamingHooks): + async def on_tool_start(self, context, agent, tool): + # Add custom logic before streaming + await self.my_custom_logging(tool) + # Call parent to stream to UI + await super().on_tool_start(context, agent, tool) + + async def on_agent_start(self, context, agent): + # Override empty methods for additional tracking + print(f"Agent {agent.name} started") + + Power users can ignore this class and subclass agents.RunHooks directly for full control. + + Note: + Tool arguments are not available in hooks due to OpenAI SDK architecture. + The SDK's hook signature doesn't include tool arguments - they're only passed + to the actual tool function. This is why arguments={} in ToolRequestContent. + + Attributes: + task_id: The AgentEx task ID for routing streamed events + timeout: Timeout for streaming activity calls (default: 10 seconds) + """ + + def __init__( + self, + task_id: str, + timeout: timedelta = timedelta(seconds=10), + ): + """Initialize the streaming hooks. + + Args: + task_id: AgentEx task ID for routing streamed events to the correct UI session + timeout: Timeout for streaming activity invocations (default: 10 seconds) + """ + super().__init__() + self.task_id = task_id + self.timeout = timeout + + @override + async def on_agent_start(self, context: RunContextWrapper, agent: Agent) -> None: # noqa: ARG002 + """Called when an agent starts execution. + + Default implementation logs the event. Override to add custom behavior. + + Args: + context: The run context wrapper + agent: The agent that is starting + """ + logger.debug(f"[TemporalStreamingHooks] Agent '{agent.name}' started execution") + + @override + async def on_agent_end(self, context: RunContextWrapper, agent: Agent, output: Any) -> None: # noqa: ARG002 + """Called when an agent completes execution. + + Default implementation logs the event. Override to add custom behavior. + + Args: + context: The run context wrapper + agent: The agent that completed + output: The agent's output + """ + logger.debug(f"[TemporalStreamingHooks] Agent '{agent.name}' completed execution with output type: {type(output).__name__}") + + @override + async def on_tool_start(self, context: RunContextWrapper, agent: Agent, tool: Tool) -> None: # noqa: ARG002 + """Stream tool request when a tool starts execution. + + Extracts the tool_call_id from the context and streams a ToolRequestContent + message to the UI showing that the tool is about to execute. + + Note: Tool arguments are not available in the hook context due to OpenAI SDK + design. The hook signature doesn't include tool arguments - they're passed + directly to the tool function instead. We send an empty dict as a placeholder. + + Args: + context: The run context wrapper (will be a ToolContext with tool_call_id) + agent: The agent executing the tool + tool: The tool being executed + """ + tool_context = context if isinstance(context, ToolContext) else None + tool_call_id = tool_context.tool_call_id if tool_context else f"call_{id(tool)}" + + await workflow.execute_activity_method( + stream_lifecycle_content, + args=[ + self.task_id, + ToolRequestContent( + author="agent", + tool_call_id=tool_call_id, + name=tool.name, + arguments={}, # Not available in hook context - SDK limitation + ), + ], + start_to_close_timeout=self.timeout, + ) + + @override + async def on_tool_end( + self, context: RunContextWrapper, agent: Agent, tool: Tool, result: str # noqa: ARG002 + ) -> None: + """Stream tool response when a tool completes execution. + + Extracts the tool_call_id and streams a ToolResponseContent message to the UI + showing the tool's execution result. + + Args: + context: The run context wrapper (will be a ToolContext with tool_call_id) + agent: The agent that executed the tool + tool: The tool that was executed + result: The tool's execution result + """ + tool_context = context if isinstance(context, ToolContext) else None + tool_call_id = ( + getattr(tool_context, "tool_call_id", f"call_{id(tool)}") + if tool_context + else f"call_{id(tool)}" + ) + + await workflow.execute_activity_method( + stream_lifecycle_content, + args=[ + self.task_id, + ToolResponseContent( + author="agent", + tool_call_id=tool_call_id, + name=tool.name, + content=result, + ), + ], + start_to_close_timeout=self.timeout, + ) + + @override + async def on_handoff( + self, context: RunContextWrapper, from_agent: Agent, to_agent: Agent # noqa: ARG002 + ) -> None: + """Stream handoff message when control transfers between agents. + + Sends a text message to the UI indicating that one agent is handing off + to another agent. + + Args: + context: The run context wrapper + from_agent: The agent transferring control + to_agent: The agent receiving control + """ + await workflow.execute_activity_method( + stream_lifecycle_content, + args=[ + self.task_id, + TextContent( + author="agent", + content=f"Handoff from {from_agent.name} to {to_agent.name}", + type="text", + ), + ], + start_to_close_timeout=self.timeout, + ) diff --git a/src/agentex/lib/core/temporal/plugins/openai_agents/interceptors/__init__.py b/src/agentex/lib/core/temporal/plugins/openai_agents/interceptors/__init__.py new file mode 100644 index 00000000..47290ea4 --- /dev/null +++ b/src/agentex/lib/core/temporal/plugins/openai_agents/interceptors/__init__.py @@ -0,0 +1,19 @@ +"""Temporal interceptors for OpenAI Agents SDK integration. + +This module provides interceptors for threading context (task_id, trace_id, parent_span_id) +from workflows to activities in Temporal. +""" + +from agentex.lib.core.temporal.plugins.openai_agents.interceptors.context_interceptor import ( + ContextInterceptor, + streaming_task_id, + streaming_trace_id, + streaming_parent_span_id, +) + +__all__ = [ + "ContextInterceptor", + "streaming_task_id", + "streaming_trace_id", + "streaming_parent_span_id", +] \ No newline at end of file diff --git a/src/agentex/lib/core/temporal/plugins/openai_agents/interceptors/context_interceptor.py b/src/agentex/lib/core/temporal/plugins/openai_agents/interceptors/context_interceptor.py new file mode 100644 index 00000000..8e551fc2 --- /dev/null +++ b/src/agentex/lib/core/temporal/plugins/openai_agents/interceptors/context_interceptor.py @@ -0,0 +1,160 @@ +""" +Temporal context interceptors for threading runtime context through workflows and activities. + +This module provides interceptors that pass task_id, trace_id, and parent_span_id from +workflows to activities via headers, making them available via ContextVars for models +to use for streaming, tracing, or other purposes. +""" + +import logging +from typing import Any, Type, Optional, override +from contextvars import ContextVar + +from temporalio import workflow +from temporalio.worker import ( + Interceptor, + StartActivityInput, + ExecuteActivityInput, + ExecuteWorkflowInput, + ActivityInboundInterceptor, + WorkflowInboundInterceptor, + WorkflowOutboundInterceptor, +) +from temporalio.converter import default + +# Set up logging +logger = logging.getLogger("context.interceptor") + +# Global context variables that models can read +# These are thread-safe and work across async boundaries +streaming_task_id: ContextVar[Optional[str]] = ContextVar('streaming_task_id', default=None) +streaming_trace_id: ContextVar[Optional[str]] = ContextVar('streaming_trace_id', default=None) +streaming_parent_span_id: ContextVar[Optional[str]] = ContextVar('streaming_parent_span_id', default=None) + +# Header keys for passing context +TASK_ID_HEADER = "context-task-id" +TRACE_ID_HEADER = "context-trace-id" +PARENT_SPAN_ID_HEADER = "context-parent-span-id" + +class ContextInterceptor(Interceptor): + """Main interceptor that enables context threading through Temporal.""" + + def __init__(self): + self._payload_converter = default().payload_converter + logger.info("[ContextInterceptor] Initialized") + + @override + def intercept_activity(self, next: ActivityInboundInterceptor) -> ActivityInboundInterceptor: + """Create activity interceptor to read context from headers.""" + return ContextActivityInboundInterceptor(next, self._payload_converter) + + @override + def workflow_interceptor_class(self, _input: Any) -> Optional[Type[WorkflowInboundInterceptor]]: + """Return workflow interceptor class.""" + return ContextWorkflowInboundInterceptor + + +class ContextWorkflowInboundInterceptor(WorkflowInboundInterceptor): + """Workflow interceptor that creates the outbound interceptor.""" + + def __init__(self, next: WorkflowInboundInterceptor): + super().__init__(next) + self._payload_converter = default().payload_converter + + @override + async def execute_workflow(self, input: ExecuteWorkflowInput) -> Any: + """Execute workflow - just pass through.""" + return await self.next.execute_workflow(input) + + @override + def init(self, outbound: WorkflowOutboundInterceptor) -> None: + """Initialize with our custom outbound interceptor.""" + self.next.init(ContextWorkflowOutboundInterceptor( + outbound, self._payload_converter + )) + + +class ContextWorkflowOutboundInterceptor(WorkflowOutboundInterceptor): + """Outbound interceptor that adds task_id to activity headers.""" + + def __init__(self, next, payload_converter): + super().__init__(next) + self._payload_converter = payload_converter + + @override + def start_activity(self, input: StartActivityInput) -> workflow.ActivityHandle: + """Add task_id, trace_id, and parent_span_id to headers when starting model activities.""" + + # Only add headers for invoke_model_activity calls + activity_name = str(input.activity) if hasattr(input, 'activity') else "" + + if "invoke_model_activity" in activity_name or "invoke-model-activity" in activity_name: + # Get task_id, trace_id, and parent_span_id from workflow instance instead of inbound interceptor + try: + workflow_instance = workflow.instance() + task_id = getattr(workflow_instance, '_task_id', None) + trace_id = getattr(workflow_instance, '_trace_id', None) + parent_span_id = getattr(workflow_instance, '_parent_span_id', None) + + if task_id and trace_id and parent_span_id: + # Initialize headers if needed + if not input.headers: + input.headers = {} + + # Add task_id to headers + input.headers[TASK_ID_HEADER] = self._payload_converter.to_payload(task_id) # type: ignore[index] + input.headers[TRACE_ID_HEADER] = self._payload_converter.to_payload(trace_id) # type: ignore[index] + input.headers[PARENT_SPAN_ID_HEADER] = self._payload_converter.to_payload(parent_span_id) # type: ignore[index] + logger.debug(f"[OutboundInterceptor] Added task_id, trace_id, and parent_span_id to activity headers: {task_id}, {trace_id}, {parent_span_id}") + else: + logger.warning("[OutboundInterceptor] No _task_id, _trace_id, or _parent_span_id found in workflow instance") + except Exception as e: + logger.error(f"[OutboundInterceptor] Failed to get task_id, trace_id, or parent_span_id from workflow instance: {e}") + + return self.next.start_activity(input) + + +class ContextActivityInboundInterceptor(ActivityInboundInterceptor): + """Activity interceptor that extracts task_id, trace_id, and parent_span_id from headers and sets context variables.""" + + def __init__(self, next, payload_converter): + super().__init__(next) + self._payload_converter = payload_converter + + @override + async def execute_activity(self, input: ExecuteActivityInput) -> Any: + """Extract task_id, trace_id, and parent_span_id from headers and set context variables.""" + + # Extract task_id from headers if present + if input.headers and TASK_ID_HEADER in input.headers: + task_id_value = self._payload_converter.from_payload( + input.headers[TASK_ID_HEADER], str + ) + trace_id_value = self._payload_converter.from_payload( + input.headers[TRACE_ID_HEADER], str + ) + parent_span_id_value = self._payload_converter.from_payload( + input.headers[PARENT_SPAN_ID_HEADER], str + ) + + # P THIS IS THE KEY PART - Set the context variable! + # This makes task_id available to TemporalStreamingModel.get_response() + streaming_task_id.set(task_id_value) + streaming_trace_id.set(trace_id_value) + streaming_parent_span_id.set(parent_span_id_value) + logger.info(f"[ActivityInterceptor] Set task_id, trace_id, and parent_span_id in context: {task_id_value}, {trace_id_value}, {parent_span_id_value}") + else: + logger.debug("[ActivityInterceptor] No task_id, trace_id, or parent_span_id in headers") + + try: + # Execute the activity + # The TemporalStreamingModel can now read streaming_task_id.get() + result = await self.next.execute_activity(input) + return result + finally: + # Clean up context after activity + streaming_task_id.set(None) + streaming_trace_id.set(None) + streaming_parent_span_id.set(None) + logger.debug("[ActivityInterceptor] Cleared task_id, trace_id, and parent_span_id from context") + diff --git a/src/agentex/lib/core/temporal/plugins/openai_agents/models/__init__.py b/src/agentex/lib/core/temporal/plugins/openai_agents/models/__init__.py new file mode 100644 index 00000000..bb5dc97e --- /dev/null +++ b/src/agentex/lib/core/temporal/plugins/openai_agents/models/__init__.py @@ -0,0 +1,23 @@ +"""Model providers for Temporal OpenAI Agents SDK integration. + +This module provides model implementations that add streaming and tracing +capabilities to standard OpenAI models when running in Temporal workflows/activities. +""" + +from agentex.lib.core.temporal.plugins.openai_agents.models.temporal_tracing_model import ( + TemporalTracingModelProvider, + TemporalTracingResponsesModel, + TemporalTracingChatCompletionsModel, +) +from agentex.lib.core.temporal.plugins.openai_agents.models.temporal_streaming_model import ( + TemporalStreamingModel, + TemporalStreamingModelProvider, +) + +__all__ = [ + "TemporalStreamingModel", + "TemporalStreamingModelProvider", + "TemporalTracingModelProvider", + "TemporalTracingResponsesModel", + "TemporalTracingChatCompletionsModel", +] \ No newline at end of file diff --git a/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py b/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py new file mode 100644 index 00000000..cde08606 --- /dev/null +++ b/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py @@ -0,0 +1,786 @@ +"""Custom Temporal Model Provider with streaming support for OpenAI agents.""" +from __future__ import annotations + +import uuid +import logging +from typing import Any, List, Union, Optional, override + +from agents import ( + Tool, + Model, + Handoff, + FunctionTool, + ModelTracing, + ModelProvider, + ModelResponse, + ModelSettings, + TResponseInputItem, + AgentOutputSchemaBase, +) +from openai import NOT_GIVEN, AsyncOpenAI +from agents.tool import ( + ComputerTool, + HostedMCPTool, + WebSearchTool, + FileSearchTool, + LocalShellTool, + CodeInterpreterTool, + ImageGenerationTool, +) +from agents.usage import Usage, InputTokensDetails, OutputTokensDetails # type: ignore[attr-defined] +from agents.model_settings import MCPToolChoice +from openai.types.responses import ( + ResponseOutputText, + ResponseOutputMessage, + ResponseCompletedEvent, + ResponseTextDeltaEvent, + ResponseFunctionToolCall, + ResponseOutputItemDoneEvent, + # Event types for proper type checking + ResponseOutputItemAddedEvent, + ResponseReasoningTextDeltaEvent, + ResponseReasoningSummaryPartDoneEvent, + ResponseFunctionCallArgumentsDoneEvent, + ResponseReasoningSummaryPartAddedEvent, + ResponseReasoningSummaryTextDeltaEvent, + ResponseFunctionCallArgumentsDeltaEvent, +) + +# AgentEx SDK imports +from agentex.lib import adk +from agentex.lib.core.tracing.tracer import AsyncTracer +from agentex.types.task_message_delta import TextDelta, ReasoningContentDelta, ReasoningSummaryDelta +from agentex.types.task_message_update import StreamTaskMessageFull, StreamTaskMessageDelta +from agentex.types.task_message_content import TextContent, ReasoningContent +from agentex.lib.adk.utils._modules.client import create_async_agentex_client +from agentex.lib.core.temporal.plugins.openai_agents.interceptors.context_interceptor import ( + streaming_task_id, + streaming_trace_id, + streaming_parent_span_id, +) + +# Create logger for this module +logger = logging.getLogger("agentex.temporal.streaming") + +class TemporalStreamingModel(Model): + """Custom model implementation with streaming support.""" + + def __init__(self, model_name: str = "gpt-4o", _use_responses_api: bool = True): + """Initialize the streaming model with OpenAI client and model name.""" + # Match the default behavior with no retries (Temporal handles retries) + self.client = AsyncOpenAI(max_retries=0) + self.model_name = model_name + # Always use Responses API for all models + self.use_responses_api = True + + # Initialize tracer as a class variable + agentex_client = create_async_agentex_client() + self.tracer = AsyncTracer(agentex_client) + + logger.info(f"[TemporalStreamingModel] Initialized model={self.model_name}, use_responses_api={self.use_responses_api}, tracer=initialized") + + def _non_null_or_not_given(self, value: Any) -> Any: + """Convert None to NOT_GIVEN sentinel, matching OpenAI SDK pattern.""" + return value if value is not None else NOT_GIVEN + + def _prepare_response_input(self, input: Union[str, list[TResponseInputItem]]) -> List[dict]: + """Convert input to Responses API format. + + Args: + input: Either a string prompt or list of ResponseInputItem messages + + Returns: + List of input items in Responses API format + """ + response_input = [] + + if isinstance(input, list): + # Process list of ResponseInputItem objects + for _idx, item in enumerate(input): + # Convert to dict if needed + if isinstance(item, dict): + item_dict = item + else: + item_dict = item.model_dump() if hasattr(item, 'model_dump') else item + + item_type = item_dict.get("type") + + if item_type == "message": + # ResponseOutputMessage format + role = item_dict.get("role", "assistant") + content_list = item_dict.get("content", []) + + # Build content array + content_array = [] + for content_item in content_list: + if isinstance(content_item, dict): + if content_item.get("type") == "output_text": + # For assistant messages, keep as output_text + # For user messages, convert to input_text + if role == "user": + content_array.append({ + "type": "input_text", + "text": content_item.get("text", "") + }) + else: + content_array.append({ + "type": "output_text", + "text": content_item.get("text", "") + }) + else: + content_array.append(content_item) + + response_input.append({ + "type": "message", + "role": role, + "content": content_array + }) + + elif item_type == "function_call": + # Function call from previous response + logger.debug(f"[Responses API] function_call item keys: {list(item_dict.keys())}") + call_id = item_dict.get("call_id") or item_dict.get("id") + if not call_id: + logger.debug(f"[Responses API] WARNING: No call_id found in function_call item!") + logger.debug(f"[Responses API] Full item: {item_dict}") + # Generate a fallback ID if missing + call_id = f"call_{uuid.uuid4().hex[:8]}" + logger.debug(f"[Responses API] Generated fallback call_id: {call_id}") + logger.debug(f"[Responses API] Adding function_call with call_id={call_id}, name={item_dict.get('name')}") + response_input.append({ + "type": "function_call", + "call_id": call_id, # API expects 'call_id' not 'id' + "name": item_dict.get("name", ""), + "arguments": item_dict.get("arguments", "{}"), + }) + + elif item_type == "function_call_output": + # Function output/response + call_id = item_dict.get("call_id") + if not call_id: + logger.debug(f"[Responses API] WARNING: No call_id in function_call_output!") + # Try to find it from id field + call_id = item_dict.get("id") + response_input.append({ + "type": "function_call_output", + "call_id": call_id or "", + "output": item_dict.get("output", "") + }) + + elif item_dict.get("role") == "user": + # Simple user message + response_input.append({ + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": item_dict.get("content", "")}] + }) + + elif item_dict.get("role") == "tool": + # Tool message + response_input.append({ + "type": "function_call_output", + "call_id": item_dict.get("tool_call_id"), + "output": item_dict.get("content") + }) + else: + logger.debug(f"[Responses API] Skipping unhandled item type: {item_type}, role: {item_dict.get('role')}") + + elif isinstance(input, str): + # Simple string input + response_input.append({ + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": input}] + }) + + return response_input + + def _convert_tools(self, tools: list[Tool], handoffs: list[Handoff]) -> tuple[List[dict], List[str]]: + """Convert tools and handoffs to Responses API format. + + Args: + tools: List of Tool objects + handoffs: List of Handoff objects + + Returns: + Tuple of (converted_tools, include_list) where include_list contains + additional response data to request + """ + response_tools = [] + tool_includes = [] + + # Check for multiple computer tools (only one allowed) + computer_tools = [tool for tool in tools if isinstance(tool, ComputerTool)] + if len(computer_tools) > 1: + raise ValueError(f"You can only provide one computer tool. Got {len(computer_tools)}") + + # Convert each tool based on its type + for tool in tools: + if isinstance(tool, FunctionTool): + response_tools.append({ + "type": "function", + "name": tool.name, + "description": tool.description or "", + "parameters": tool.params_json_schema if tool.params_json_schema else {}, + "strict": tool.strict_json_schema, + }) + + elif isinstance(tool, WebSearchTool): + tool_config = { + "type": "web_search", + } + # filters attribute was removed from WebSearchTool API + if hasattr(tool, 'user_location') and tool.user_location is not None: + tool_config["user_location"] = tool.user_location + if hasattr(tool, 'search_context_size') and tool.search_context_size is not None: + tool_config["search_context_size"] = tool.search_context_size + response_tools.append(tool_config) + + elif isinstance(tool, FileSearchTool): + tool_config = { + "type": "file_search", + "vector_store_ids": tool.vector_store_ids, + } + if tool.max_num_results: + tool_config["max_num_results"] = tool.max_num_results + if tool.ranking_options: + tool_config["ranking_options"] = tool.ranking_options + if tool.filters: + tool_config["filters"] = tool.filters + response_tools.append(tool_config) + + # Add include for file search results if needed + if tool.include_search_results: + tool_includes.append("file_search_call.results") + + elif isinstance(tool, ComputerTool): + response_tools.append({ + "type": "computer_use_preview", + "environment": tool.computer.environment, + "display_width": tool.computer.dimensions[0], + "display_height": tool.computer.dimensions[1], + }) + + elif isinstance(tool, HostedMCPTool): + response_tools.append(tool.tool_config) + + elif isinstance(tool, ImageGenerationTool): + response_tools.append(tool.tool_config) + + elif isinstance(tool, CodeInterpreterTool): + response_tools.append(tool.tool_config) + + elif isinstance(tool, LocalShellTool): + # LocalShellTool API changed - no longer has working_directory + # The executor handles execution details internally + response_tools.append({ + "type": "local_shell", + }) + + else: + logger.warning(f"Unknown tool type: {type(tool).__name__}, skipping") + + # Convert handoffs (always function tools) + for handoff in handoffs: + response_tools.append({ + "type": "function", + "name": handoff.tool_name, + "description": handoff.tool_description or f"Transfer to {handoff.agent_name}", + "parameters": handoff.input_json_schema if handoff.input_json_schema else {}, + }) + + return response_tools, tool_includes + + def _build_reasoning_param(self, model_settings: ModelSettings) -> Any: + """Build reasoning parameter from model settings. + + Args: + model_settings: Model configuration settings + + Returns: + Reasoning parameter dict or NOT_GIVEN + """ + if not model_settings.reasoning: + return NOT_GIVEN + + if hasattr(model_settings.reasoning, 'effort') and model_settings.reasoning.effort: + # For Responses API, reasoning is an object + reasoning_param = { + "effort": model_settings.reasoning.effort, + } + # Add generate_summary if specified and not None + if hasattr(model_settings.reasoning, 'generate_summary') and model_settings.reasoning.generate_summary is not None: + reasoning_param["summary"] = model_settings.reasoning.generate_summary + logger.debug(f"[TemporalStreamingModel] Using reasoning param: {reasoning_param}") + return reasoning_param + + return NOT_GIVEN + + def _convert_tool_choice(self, tool_choice: Any) -> Any: + """Convert tool_choice to Responses API format. + + Args: + tool_choice: Tool choice from model settings + + Returns: + Converted tool choice or NOT_GIVEN + """ + if tool_choice is None: + return NOT_GIVEN + + if isinstance(tool_choice, MCPToolChoice): + # MCP tool choice with server label + return { + "server_label": tool_choice.server_label, + "type": "mcp", + "name": tool_choice.name, + } + elif tool_choice == "required": + return "required" + elif tool_choice == "auto": + return "auto" + elif tool_choice == "none": + return "none" + elif tool_choice == "file_search": + return {"type": "file_search"} + elif tool_choice == "web_search": + return {"type": "web_search"} + elif tool_choice == "web_search_preview": + return {"type": "web_search_preview"} + elif tool_choice == "computer_use_preview": + return {"type": "computer_use_preview"} + elif tool_choice == "image_generation": + return {"type": "image_generation"} + elif tool_choice == "code_interpreter": + return {"type": "code_interpreter"} + elif tool_choice == "mcp": + # Generic MCP without specific tool + return {"type": "mcp"} + elif isinstance(tool_choice, str): + # Specific function tool by name + return { + "type": "function", + "name": tool_choice, + } + else: + # Pass through as-is for other types + return tool_choice + + @override + async def get_response( + self, + system_instructions: Optional[str], + input: Union[str, list[TResponseInputItem]], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: Optional[AgentOutputSchemaBase], + handoffs: list[Handoff], + tracing: ModelTracing, # noqa: ARG002 + **kwargs, # noqa: ARG002 + ) -> ModelResponse: + """Get a non-streaming response from the model with streaming to Redis. + + This method is used by Temporal activities and needs to return a complete + response, but we stream the response to Redis while generating it. + """ + + task_id = streaming_task_id.get() + trace_id = streaming_trace_id.get() + parent_span_id = streaming_parent_span_id.get() + + if not task_id or not trace_id or not parent_span_id: + raise ValueError("task_id, trace_id, and parent_span_id are required for streaming with Responses API") + + trace = self.tracer.trace(trace_id) + + async with trace.span( + parent_id=parent_span_id, + name="streaming_model_get_response", + input={ + "model": self.model_name, + "has_system_instructions": system_instructions is not None, + "input_type": type(input).__name__, + "tools_count": len(tools) if tools else 0, + "handoffs_count": len(handoffs) if handoffs else 0, + }, + ) as span: + # Always use Responses API for streaming + if not task_id: + # If no task_id, we can't use streaming - this shouldn't happen normally + raise ValueError("task_id is required for streaming with Responses API") + + logger.info(f"[TemporalStreamingModel] Using Responses API for {self.model_name}") + + try: + # Prepare input using helper method + response_input = self._prepare_response_input(input) + + # Convert tools and handoffs using helper method + response_tools, tool_includes = self._convert_tools(tools, handoffs) + openai_tools = response_tools if response_tools else None + + # Build reasoning parameter using helper method + reasoning_param = self._build_reasoning_param(model_settings) + + # Convert tool_choice using helper method + tool_choice = self._convert_tool_choice(model_settings.tool_choice) + + # Build include list for response data + include_list = [] + # Add tool-specific includes + if tool_includes: + include_list.extend(tool_includes) + # Add user-specified includes + if model_settings.response_include: + include_list.extend(model_settings.response_include) + # Add logprobs include if top_logprobs is set + if model_settings.top_logprobs is not None: + include_list.append("message.output_text.logprobs") + # Build response format for verbosity and structured output + response_format = NOT_GIVEN + if output_schema is not None: + # Handle structured output schema + # This would need conversion logic similar to Converter.get_response_format + pass # TODO: Implement output_schema conversion + elif model_settings.verbosity is not None: + response_format = {"verbosity": model_settings.verbosity} + + # Build extra_args dict for additional parameters + extra_args = dict(model_settings.extra_args or {}) + if model_settings.top_logprobs is not None: + extra_args["top_logprobs"] = model_settings.top_logprobs + + # Create the response stream using Responses API + logger.debug(f"[TemporalStreamingModel] Creating response stream with Responses API") + stream = await self.client.responses.create( # type: ignore[call-overload] + + model=self.model_name, + input=response_input, + instructions=system_instructions, + tools=openai_tools or NOT_GIVEN, + stream=True, + # Temperature and sampling parameters + temperature=self._non_null_or_not_given(model_settings.temperature), + max_output_tokens=self._non_null_or_not_given(model_settings.max_tokens), + top_p=self._non_null_or_not_given(model_settings.top_p), + # Note: frequency_penalty and presence_penalty are not supported by Responses API + # Tool and reasoning parameters + reasoning=reasoning_param, + tool_choice=tool_choice, + parallel_tool_calls=self._non_null_or_not_given(model_settings.parallel_tool_calls), + # Context and truncation + truncation=self._non_null_or_not_given(model_settings.truncation), + # Response configuration + text=response_format, + include=include_list if include_list else NOT_GIVEN, + # Metadata and storage + metadata=self._non_null_or_not_given(model_settings.metadata), + store=self._non_null_or_not_given(model_settings.store), + # Extra customization + extra_headers=model_settings.extra_headers, + extra_query=model_settings.extra_query, + extra_body=model_settings.extra_body, + # Any additional parameters from extra_args + **extra_args, + ) + + # Process the stream of events from Responses API + output_items = [] + current_text = "" + reasoning_context = None + reasoning_summaries = [] + reasoning_contents = [] + current_reasoning_summary = "" + event_count = 0 + + # We expect task_id to always be provided for streaming + if not task_id: + raise ValueError("[TemporalStreamingModel] task_id is required for streaming model") + + # Use proper async with context manager for streaming to Redis + async with adk.streaming.streaming_task_message_context( + task_id=task_id, + initial_content=TextContent( + author="agent", + content="", + format="markdown", + ), + ) as streaming_context: + # Process events from the Responses API stream + function_calls_in_progress = {} # Track function calls being streamed + + async for event in stream: + event_count += 1 + + # Log event type + logger.debug(f"[TemporalStreamingModel] Event {event_count}: {type(event).__name__}") + + # Handle different event types using isinstance for type safety + if isinstance(event, ResponseOutputItemAddedEvent): + # New output item (reasoning, function call, or message) + item = getattr(event, 'item', None) + output_index = getattr(event, 'output_index', 0) + + if item and getattr(item, 'type', None) == 'reasoning': + logger.debug(f"[TemporalStreamingModel] Starting reasoning item") + if not reasoning_context: + # Start a reasoning context for streaming reasoning to UI + reasoning_context = await adk.streaming.streaming_task_message_context( + task_id=task_id, + initial_content=ReasoningContent( + author="agent", + summary=[], + content=[], + type="reasoning", + style="active", + ), + ).__aenter__() + elif item and getattr(item, 'type', None) == 'function_call': + # Track the function call being streamed + function_calls_in_progress[output_index] = { + 'id': getattr(item, 'id', ''), + 'call_id': getattr(item, 'call_id', ''), + 'name': getattr(item, 'name', ''), + 'arguments': getattr(item, 'arguments', ''), + } + logger.debug(f"[TemporalStreamingModel] Starting function call: {item.name}") + + elif isinstance(event, ResponseFunctionCallArgumentsDeltaEvent): + # Stream function call arguments + output_index = getattr(event, 'output_index', 0) + delta = getattr(event, 'delta', '') + + if output_index in function_calls_in_progress: + function_calls_in_progress[output_index]['arguments'] += delta + logger.debug(f"[TemporalStreamingModel] Function call args delta: {delta[:50]}...") + + elif isinstance(event, ResponseFunctionCallArgumentsDoneEvent): + # Function call arguments complete + output_index = getattr(event, 'output_index', 0) + arguments = getattr(event, 'arguments', '') + + if output_index in function_calls_in_progress: + function_calls_in_progress[output_index]['arguments'] = arguments + logger.debug(f"[TemporalStreamingModel] Function call args done") + + elif isinstance(event, (ResponseReasoningTextDeltaEvent, ResponseReasoningSummaryTextDeltaEvent, ResponseTextDeltaEvent)): + # Handle text streaming + delta = getattr(event, 'delta', '') + + if isinstance(event, ResponseReasoningSummaryTextDeltaEvent) and reasoning_context: + # Stream reasoning summary deltas - these are the actual reasoning tokens! + try: + # Use ReasoningSummaryDelta for reasoning summaries + summary_index = getattr(event, 'summary_index', 0) + delta_obj = ReasoningSummaryDelta( + summary_index=summary_index, + summary_delta=delta, + type="reasoning_summary", + ) + update = StreamTaskMessageDelta( + parent_task_message=reasoning_context.task_message, + delta=delta_obj, + type="delta", + ) + await reasoning_context.stream_update(update) + # Accumulate the reasoning summary + if len(reasoning_summaries) <= summary_index: + reasoning_summaries.extend([""] * (summary_index + 1 - len(reasoning_summaries))) + reasoning_summaries[summary_index] += delta + logger.debug(f"[TemporalStreamingModel] Streamed reasoning summary: {delta[:30]}..." if len(delta) > 30 else f"[TemporalStreamingModel] Streamed reasoning summary: {delta}") + except Exception as e: + logger.warning(f"Failed to send reasoning delta: {e}") + elif isinstance(event, ResponseReasoningTextDeltaEvent) and reasoning_context: + # Regular reasoning delta (if these ever appear) + try: + delta_obj = ReasoningContentDelta( + content_index=0, + content_delta=delta, + type="reasoning_content", + ) + update = StreamTaskMessageDelta( + parent_task_message=reasoning_context.task_message, + delta=delta_obj, + type="delta", + ) + await reasoning_context.stream_update(update) + reasoning_contents.append(delta) + except Exception as e: + logger.warning(f"Failed to send reasoning delta: {e}") + elif isinstance(event, ResponseTextDeltaEvent): + # Stream regular text output + current_text += delta + try: + delta_obj = TextDelta( + type="text", + text_delta=delta, + ) + update = StreamTaskMessageDelta( + parent_task_message=streaming_context.task_message, + delta=delta_obj, + type="delta", + ) + await streaming_context.stream_update(update) + except Exception as e: + logger.warning(f"Failed to send text delta: {e}") + + elif isinstance(event, ResponseOutputItemDoneEvent): + # Output item completed + item = getattr(event, 'item', None) + output_index = getattr(event, 'output_index', 0) + + if item and getattr(item, 'type', None) == 'reasoning': + logger.debug(f"[TemporalStreamingModel] Reasoning item completed") + # Don't close the context here - let it stay open for more reasoning events + # It will be closed when we send the final update or at the end + elif item and getattr(item, 'type', None) == 'function_call': + # Function call completed - add to output + if output_index in function_calls_in_progress: + call_data = function_calls_in_progress[output_index] + logger.debug(f"[TemporalStreamingModel] Function call completed: {call_data['name']}") + + # Create proper function call object + tool_call = ResponseFunctionToolCall( + id=call_data['id'], + call_id=call_data['call_id'], + type="function_call", + name=call_data['name'], + arguments=call_data['arguments'], + ) + output_items.append(tool_call) + + elif isinstance(event, ResponseReasoningSummaryPartAddedEvent): + # New reasoning part/summary started - reset accumulator + part = getattr(event, 'part', None) + if part: + part_type = getattr(part, 'type', 'unknown') + logger.debug(f"[TemporalStreamingModel] New reasoning part: type={part_type}") + # Reset the current reasoning summary for this new part + current_reasoning_summary = "" + + elif isinstance(event, ResponseReasoningSummaryPartDoneEvent): + # Reasoning part completed - send final update and close if this is the last part + if reasoning_context and reasoning_summaries: + logger.debug(f"[TemporalStreamingModel] Reasoning part completed, sending final update") + try: + # Send a full message update with the complete reasoning content + complete_reasoning_content = ReasoningContent( + author="agent", + summary=reasoning_summaries, # Use accumulated summaries + content=reasoning_contents if reasoning_contents else [], + type="reasoning", + style="static", + ) + + await reasoning_context.stream_update( + update=StreamTaskMessageFull( + parent_task_message=reasoning_context.task_message, + content=complete_reasoning_content, + type="full", + ), + ) + + # Close the reasoning context after sending the final update + # This matches the reference implementation pattern + await reasoning_context.close() + reasoning_context = None + logger.debug(f"[TemporalStreamingModel] Closed reasoning context after final update") + except Exception as e: + logger.warning(f"Failed to send reasoning part done update: {e}") + + elif isinstance(event, ResponseCompletedEvent): + # Response completed + logger.debug(f"[TemporalStreamingModel] Response completed") + response = getattr(event, 'response', None) + if response and hasattr(response, 'output'): + # Use the final output from the response + output_items = response.output + logger.debug(f"[TemporalStreamingModel] Found {len(output_items)} output items in final response") + + # End of event processing loop - close any open contexts + if reasoning_context: + await reasoning_context.close() + reasoning_context = None + + # Build the response from output items collected during streaming + # Create output from the items we collected + response_output = [] + + # Process output items from the response + if output_items: + for item in output_items: + if isinstance(item, ResponseFunctionToolCall): + response_output.append(item) + elif isinstance(item, ResponseOutputMessage): + response_output.append(item) + else: + response_output.append(item) + else: + # No output items - create empty message + message = ResponseOutputMessage( + id=f"msg_{uuid.uuid4().hex[:8]}", + type="message", + status="completed", + role="assistant", + content=[ResponseOutputText( + type="output_text", + text=current_text if current_text else "", + annotations=[] + )] + ) + response_output.append(message) + + # Create usage object + usage = Usage( + input_tokens=0, + output_tokens=0, + total_tokens=0, + input_tokens_details=InputTokensDetails(cached_tokens=0), + output_tokens_details=OutputTokensDetails(reasoning_tokens=len(''.join(reasoning_contents)) // 4), # Approximate + ) + + # Return the response + return ModelResponse( + output=response_output, + usage=usage, + response_id=f"resp_{uuid.uuid4().hex[:8]}", + ) + + except Exception as e: + logger.error(f"Error using Responses API: {e}") + raise + + # The _get_response_with_responses_api method has been merged into get_response above + # All Responses API logic is now integrated directly in get_response() method + + @override + def stream_response(self, *args, **kwargs): + """Streaming is not implemented as we use the async get_response method. + This method is included for compatibility with the Model interface but should not be used. + All streaming is handled through the async get_response method with the Responses API.""" + raise NotImplementedError("stream_response is not used in Temporal activities - use get_response instead") + + +class TemporalStreamingModelProvider(ModelProvider): + """Custom model provider that returns a streaming-capable model.""" + + def __init__(self): + """Initialize the provider.""" + super().__init__() + logger.info("[TemporalStreamingModelProvider] Initialized") + + @override + def get_model(self, model_name: Union[str, None]) -> Model: + """Get a model instance with streaming capabilities. + + Args: + model_name: The name of the model to retrieve + + Returns: + A Model instance with streaming support. + """ + # Use the provided model_name or default to gpt-4o + actual_model = model_name if model_name else "gpt-4o" + logger.info(f"[TemporalStreamingModelProvider] Creating TemporalStreamingModel for model_name: {actual_model}") + model = TemporalStreamingModel(model_name=actual_model) + return model diff --git a/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_tracing_model.py b/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_tracing_model.py new file mode 100644 index 00000000..ae0b79e5 --- /dev/null +++ b/src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_tracing_model.py @@ -0,0 +1,316 @@ +"""Temporal-aware tracing model provider. + +This module provides model implementations that add AgentEx tracing to standard OpenAI models +when running in Temporal workflows/activities. It uses context variables set by the Temporal +context interceptor to access task_id, trace_id, and parent_span_id. + +The key innovation is that these are thin wrappers around the standard OpenAI models, +avoiding code duplication while adding tracing capabilities. +""" + +import logging +from typing import List, Union, Optional, override + +from agents import ( + Tool, + Model, + Handoff, + ModelTracing, + ModelResponse, + ModelSettings, + OpenAIProvider, + TResponseInputItem, + AgentOutputSchemaBase, +) +from openai.types.responses import ResponsePromptParam +from agents.models.openai_responses import OpenAIResponsesModel +from agents.models.openai_chatcompletions import OpenAIChatCompletionsModel + +from agentex.lib.core.tracing.tracer import AsyncTracer + +# Import AgentEx components +from agentex.lib.adk.utils._modules.client import create_async_agentex_client + +# Import context variables from the interceptor +from agentex.lib.core.temporal.plugins.openai_agents.interceptors.context_interceptor import ( + streaming_task_id, + streaming_trace_id, + streaming_parent_span_id, +) + +logger = logging.getLogger("agentex.temporal.tracing") + + +class TemporalTracingModelProvider(OpenAIProvider): + """Model provider that returns OpenAI models wrapped with AgentEx tracing. + + This provider extends the standard OpenAIProvider to return models that add + tracing spans around model calls when running in Temporal activities with + the context interceptor enabled. + """ + + def __init__(self, *args, **kwargs): + """Initialize the tracing model provider. + + Accepts all the same arguments as OpenAIProvider. + """ + super().__init__(*args, **kwargs) + + # Initialize tracer for all models + agentex_client = create_async_agentex_client() + self._tracer = AsyncTracer(agentex_client) + logger.info("[TemporalTracingModelProvider] Initialized with AgentEx tracer") + + @override + def get_model(self, model_name: Optional[str]) -> Model: + """Get a model wrapped with tracing capabilities. + + Args: + model_name: The name of the model to use + + Returns: + A model instance wrapped with tracing + """ + # Get the base model from the parent provider + base_model = super().get_model(model_name) + + # Wrap with appropriate tracing wrapper based on model type + if isinstance(base_model, OpenAIResponsesModel): + logger.info(f"[TemporalTracingModelProvider] Wrapping OpenAIResponsesModel '{model_name}' with tracing") + return TemporalTracingResponsesModel(base_model, self._tracer) # type: ignore[abstract] + elif isinstance(base_model, OpenAIChatCompletionsModel): + logger.info(f"[TemporalTracingModelProvider] Wrapping OpenAIChatCompletionsModel '{model_name}' with tracing") + return TemporalTracingChatCompletionsModel(base_model, self._tracer) # type: ignore[abstract] + else: + logger.warning(f"[TemporalTracingModelProvider] Unknown model type, returning without tracing: {type(base_model)}") + return base_model + + +class TemporalTracingResponsesModel(Model): + """Wrapper for OpenAIResponsesModel that adds AgentEx tracing. + + This is a thin wrapper that adds tracing spans around the base model's + get_response() method. It reads tracing context from ContextVars set by + the Temporal context interceptor. + """ + + def __init__(self, base_model: OpenAIResponsesModel, tracer: AsyncTracer): + """Initialize the tracing wrapper. + + Args: + base_model: The OpenAI Responses model to wrap + tracer: The AgentEx tracer to use + """ + self._base_model = base_model + self._tracer = tracer + # Expose the model name for compatibility + self.model = base_model.model + + @override + async def get_response( + self, + system_instructions: Optional[str], + input: Union[str, List[TResponseInputItem]], + model_settings: ModelSettings, + tools: List[Tool], + output_schema: Optional[AgentOutputSchemaBase], + handoffs: List[Handoff], + tracing: ModelTracing, + previous_response_id: Optional[str] = None, + conversation_id: Optional[str] = None, + prompt: Optional[ResponsePromptParam] = None, + **kwargs, + ) -> ModelResponse: + """Get a response from the model with optional tracing. + + If tracing context is available from the interceptor, this wraps the + model call in a tracing span. Otherwise, it passes through to the + base model without tracing. + """ + # Try to get tracing context from ContextVars + task_id = streaming_task_id.get() + trace_id = streaming_trace_id.get() + parent_span_id = streaming_parent_span_id.get() + + # If we have tracing context, wrap with span + if trace_id and parent_span_id: + logger.debug(f"[TemporalTracingResponsesModel] Adding tracing span for task_id={task_id}, trace_id={trace_id}") + + trace = self._tracer.trace(trace_id) + + async with trace.span( + parent_id=parent_span_id, + name="model_get_response", + input={ + "model": str(self.model), + "has_system_instructions": system_instructions is not None, + "input_type": type(input).__name__, + "tools_count": len(tools) if tools else 0, + "handoffs_count": len(handoffs) if handoffs else 0, + "has_output_schema": output_schema is not None, + "model_settings": { + "temperature": model_settings.temperature, + "max_tokens": model_settings.max_tokens, + "reasoning": model_settings.reasoning, + } if model_settings else None, + }, + ) as span: + try: + # Call the base model + response = await self._base_model.get_response( + system_instructions=system_instructions, + input=input, + model_settings=model_settings, + tools=tools, + output_schema=output_schema, + handoffs=handoffs, + tracing=tracing, + previous_response_id=previous_response_id, + conversation_id=conversation_id, # type: ignore[call-arg] + prompt=prompt, + **kwargs, + ) + + # Add response info to span output + span.output = { # type: ignore[attr-defined] + "response_id": getattr(response, "id", None), + "model_used": getattr(response, "model", None), + "usage": { + "input_tokens": response.usage.input_tokens if response.usage else None, + "output_tokens": response.usage.output_tokens if response.usage else None, + "total_tokens": response.usage.total_tokens if response.usage else None, + } if response.usage else None, + } + + return response + + except Exception as e: + # Record error in span + span.error = str(e) # type: ignore[attr-defined] + raise + else: + # No tracing context, just pass through + logger.debug("[TemporalTracingResponsesModel] No tracing context available, calling base model directly") + return await self._base_model.get_response( + system_instructions=system_instructions, + input=input, + model_settings=model_settings, + tools=tools, + output_schema=output_schema, + handoffs=handoffs, + tracing=tracing, + previous_response_id=previous_response_id, + conversation_id=conversation_id, # type: ignore[call-arg] + prompt=prompt, + **kwargs, + ) + + +class TemporalTracingChatCompletionsModel(Model): + """Wrapper for OpenAIChatCompletionsModel that adds AgentEx tracing. + + This is a thin wrapper that adds tracing spans around the base model's + get_response() method. It reads tracing context from ContextVars set by + the Temporal context interceptor. + """ + + def __init__(self, base_model: OpenAIChatCompletionsModel, tracer: AsyncTracer): + """Initialize the tracing wrapper. + + Args: + base_model: The OpenAI ChatCompletions model to wrap + tracer: The AgentEx tracer to use + """ + self._base_model = base_model + self._tracer = tracer + # Expose the model name for compatibility + self.model = base_model.model + + @override + async def get_response( + self, + system_instructions: Optional[str], + input: Union[str, List[TResponseInputItem]], + model_settings: ModelSettings, + tools: List[Tool], + output_schema: Optional[AgentOutputSchemaBase], + handoffs: List[Handoff], + tracing: ModelTracing, + **kwargs, + ) -> ModelResponse: + """Get a response from the model with optional tracing. + + If tracing context is available from the interceptor, this wraps the + model call in a tracing span. Otherwise, it passes through to the + base model without tracing. + """ + # Try to get tracing context from ContextVars + task_id = streaming_task_id.get() + trace_id = streaming_trace_id.get() + parent_span_id = streaming_parent_span_id.get() + + # If we have tracing context, wrap with span + if trace_id and parent_span_id: + logger.debug(f"[TemporalTracingChatCompletionsModel] Adding tracing span for task_id={task_id}, trace_id={trace_id}") + + trace = self._tracer.trace(trace_id) + + async with trace.span( + parent_id=parent_span_id, + name="model_get_response", + input={ + "model": str(self.model), + "has_system_instructions": system_instructions is not None, + "input_type": type(input).__name__, + "tools_count": len(tools) if tools else 0, + "handoffs_count": len(handoffs) if handoffs else 0, + "has_output_schema": output_schema is not None, + "model_settings": { + "temperature": model_settings.temperature, + "max_tokens": model_settings.max_tokens, + } if model_settings else None, + }, + ) as span: + try: + # Call the base model + response = await self._base_model.get_response( + system_instructions=system_instructions, + input=input, + model_settings=model_settings, + tools=tools, + output_schema=output_schema, + handoffs=handoffs, + tracing=tracing, + **kwargs, + ) + + # Add response info to span output + span.output = { # type: ignore[attr-defined] + "response_id": getattr(response, "id", None), + "model_used": getattr(response, "model", None), + "usage": { + "input_tokens": response.usage.input_tokens if response.usage else None, + "output_tokens": response.usage.output_tokens if response.usage else None, + "total_tokens": response.usage.total_tokens if response.usage else None, + } if response.usage else None, + } + + return response + + except Exception as e: + # Record error in span + span.error = str(e) # type: ignore[attr-defined] + raise + else: + # No tracing context, just pass through + logger.debug("[TemporalTracingChatCompletionsModel] No tracing context available, calling base model directly") + return await self._base_model.get_response( + system_instructions=system_instructions, + input=input, + model_settings=model_settings, + tools=tools, + output_schema=output_schema, + handoffs=handoffs, + tracing=tracing, + **kwargs, + ) \ No newline at end of file diff --git a/src/agentex/lib/core/temporal/plugins/openai_agents/tests/__init__.py b/src/agentex/lib/core/temporal/plugins/openai_agents/tests/__init__.py new file mode 100644 index 00000000..0c635833 --- /dev/null +++ b/src/agentex/lib/core/temporal/plugins/openai_agents/tests/__init__.py @@ -0,0 +1,3 @@ +""" +Tests for the StreamingModel implementation in the OpenAI Agents plugin. +""" \ No newline at end of file diff --git a/src/agentex/lib/core/temporal/plugins/openai_agents/tests/conftest.py b/src/agentex/lib/core/temporal/plugins/openai_agents/tests/conftest.py new file mode 100644 index 00000000..599cb1e3 --- /dev/null +++ b/src/agentex/lib/core/temporal/plugins/openai_agents/tests/conftest.py @@ -0,0 +1,297 @@ +""" +Pytest configuration and fixtures for StreamingModel tests. +""" + +import uuid +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +import pytest_asyncio +from agents import ( + Handoff, + FunctionTool, + ModelSettings, +) +from agents.tool import ( + ComputerTool, + HostedMCPTool, + WebSearchTool, + FileSearchTool, + LocalShellTool, + CodeInterpreterTool, + ImageGenerationTool, +) +from agents.model_settings import Reasoning # type: ignore[attr-defined] +from openai.types.responses import ( + ResponseCompletedEvent, + ResponseTextDeltaEvent, + ResponseOutputItemAddedEvent, + ResponseReasoningSummaryTextDeltaEvent, +) + +# Configure pytest-asyncio +pytest_plugins = ("pytest_asyncio",) + + +@pytest.fixture +def mock_openai_client(): + """Mock AsyncOpenAI client""" + client = MagicMock() + client.responses = MagicMock() + return client + + +@pytest.fixture +def sample_task_id(): + """Generate a sample task ID""" + return f"task_{uuid.uuid4().hex[:8]}" + + +@pytest.fixture +def mock_streaming_context(): + """Mock streaming context for testing""" + context = AsyncMock() + context.task_message = MagicMock() + context.stream_update = AsyncMock() + context.close = AsyncMock() + context.__aenter__ = AsyncMock(return_value=context) + context.__aexit__ = AsyncMock() + return context + + +@pytest.fixture(autouse=True) +def mock_adk_streaming(): + """Mock the ADK streaming module""" + with patch('agentex.lib.adk.streaming') as mock_streaming: + mock_context = AsyncMock() + mock_context.task_message = MagicMock() + mock_context.stream_update = AsyncMock() + mock_context.close = AsyncMock() + mock_context.__aenter__ = AsyncMock(return_value=mock_context) + mock_context.__aexit__ = AsyncMock() + + mock_streaming.streaming_task_message_context.return_value = mock_context + yield mock_streaming + + +@pytest.fixture +def sample_function_tool(): + """Sample FunctionTool for testing""" + async def mock_tool_handler(_context, _args): + return {"temperature": "72F", "condition": "sunny"} + + return FunctionTool( + name="get_weather", + description="Get the current weather", + params_json_schema={ + "type": "object", + "properties": { + "location": {"type": "string"} + } + }, + on_invoke_tool=mock_tool_handler, + strict_json_schema=False + ) + + +@pytest.fixture +def sample_web_search_tool(): + """Sample WebSearchTool for testing""" + return WebSearchTool( + user_location=None, + search_context_size="medium" + ) + + +@pytest.fixture +def sample_file_search_tool(): + """Sample FileSearchTool for testing""" + return FileSearchTool( + vector_store_ids=["vs_123"], + max_num_results=10, + include_search_results=True + ) + + +@pytest.fixture +def sample_computer_tool(): + """Sample ComputerTool for testing""" + computer = MagicMock() + computer.environment = "desktop" + computer.dimensions = [1920, 1080] + return ComputerTool(computer=computer) + + +@pytest.fixture +def sample_hosted_mcp_tool(): + """Sample HostedMCPTool for testing""" + tool = MagicMock(spec=HostedMCPTool) + tool.tool_config = { + "type": "mcp", + "server_label": "test_server", + "name": "test_tool" + } + return tool + + +@pytest.fixture +def sample_image_generation_tool(): + """Sample ImageGenerationTool for testing""" + tool = MagicMock(spec=ImageGenerationTool) + tool.tool_config = { + "type": "image_generation", + "model": "dall-e-3" + } + return tool + + +@pytest.fixture +def sample_code_interpreter_tool(): + """Sample CodeInterpreterTool for testing""" + tool = MagicMock(spec=CodeInterpreterTool) + tool.tool_config = { + "type": "code_interpreter" + } + return tool + + +@pytest.fixture +def sample_local_shell_tool(): + """Sample LocalShellTool for testing""" + from agents import LocalShellExecutor + executor = MagicMock(spec=LocalShellExecutor) + return LocalShellTool(executor=executor) + + +@pytest.fixture +def sample_handoff(): + """Sample Handoff for testing""" + from agents import Agent + + async def mock_handoff_handler(_context, _args): + # Return a mock agent + return MagicMock(spec=Agent) + + return Handoff( + agent_name="support_agent", + tool_name="transfer_to_support", + tool_description="Transfer to support agent", + input_json_schema={"type": "object"}, + on_invoke_handoff=mock_handoff_handler + ) + + +@pytest.fixture +def basic_model_settings(): + """Basic ModelSettings for testing""" + return ModelSettings( + temperature=0.7, + max_tokens=1000, + top_p=0.9 + ) + + +@pytest.fixture +def reasoning_model_settings(): + """ModelSettings with reasoning enabled""" + return ModelSettings( + reasoning=Reasoning( + effort="medium", + generate_summary="auto" + ) + ) + + +@pytest.fixture +def mock_response_stream(): + """Mock a response stream with basic events""" + async def stream_generator(): + # Yield some basic events + yield ResponseOutputItemAddedEvent( # type: ignore[call-arg] + type="response.output_item.added", + output_index=0, + item=MagicMock(type="message") + ) + + yield ResponseTextDeltaEvent( # type: ignore[call-arg] + type="response.text.delta", + delta="Hello ", + output_index=0 + ) + + yield ResponseTextDeltaEvent( # type: ignore[call-arg] + type="response.text.delta", + delta="world!", + output_index=0 + ) + + yield ResponseCompletedEvent( # type: ignore[call-arg] + type="response.completed", + response=MagicMock( + output=[], + usage=MagicMock() + ) + ) + + return stream_generator() + + +@pytest.fixture +def mock_reasoning_stream(): + """Mock a response stream with reasoning events""" + async def stream_generator(): + # Start reasoning + yield ResponseOutputItemAddedEvent( # type: ignore[call-arg] + type="response.output_item.added", + output_index=0, + item=MagicMock(type="reasoning") + ) + + # Reasoning deltas + yield ResponseReasoningSummaryTextDeltaEvent( # type: ignore[call-arg] + type="response.reasoning_summary_text.delta", + delta="Let me think about this...", + summary_index=0 + ) + + # Complete + yield ResponseCompletedEvent( # type: ignore[call-arg] + type="response.completed", + response=MagicMock( + output=[], + usage=MagicMock() + ) + ) + + return stream_generator() + + +@pytest_asyncio.fixture(scope="function") +async def streaming_model(): + """Create a TemporalStreamingModel instance for testing""" + from ..models.temporal_streaming_model import TemporalStreamingModel + + model = TemporalStreamingModel(model_name="gpt-4o") + # Mock the OpenAI client with fresh mocks for each test + model.client = AsyncMock() + model.client.responses = AsyncMock() + + yield model + + # Cleanup after each test + if hasattr(model.client, 'close'): + await model.client.close() + + +# Mock environment variables for testing +@pytest.fixture(autouse=True) +def mock_env_vars(): + """Mock environment variables""" + env_vars = { + "OPENAI_API_KEY": "test-key-123", + "AGENT_NAME": "test-agent", + "ACP_URL": "http://localhost:8000", + } + + with patch.dict("os.environ", env_vars): + yield env_vars \ No newline at end of file diff --git a/src/agentex/lib/core/temporal/plugins/openai_agents/tests/test_streaming_model.py b/src/agentex/lib/core/temporal/plugins/openai_agents/tests/test_streaming_model.py new file mode 100644 index 00000000..457ec954 --- /dev/null +++ b/src/agentex/lib/core/temporal/plugins/openai_agents/tests/test_streaming_model.py @@ -0,0 +1,848 @@ +""" +Comprehensive tests for StreamingModel with all configurations and tool types. +""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest +from agents import ModelSettings +from openai import NOT_GIVEN +from agents.model_settings import Reasoning, MCPToolChoice # type: ignore[attr-defined] + + +class TestStreamingModelSettings: + """Test that all ModelSettings parameters work with Responses API""" + + @pytest.mark.asyncio + async def test_temperature_setting(self, streaming_model, _mock_adk_streaming, sample_task_id): + """Test that temperature parameter is properly passed to Responses API""" + streaming_model.client.responses.create = AsyncMock() + + # Mock the response stream + mock_stream = AsyncMock() + mock_stream.__aiter__.return_value = iter([ + MagicMock(type="response.completed", response=MagicMock(output=[])) + ]) + streaming_model.client.responses.create.return_value = mock_stream + + # Test with various temperature values + for temp in [0.0, 0.7, 1.5, 2.0]: + settings = ModelSettings(temperature=temp) + + await streaming_model.get_response( + system_instructions="Test", + input="Hello", + model_settings=settings, + tools=[], + output_schema=None, + handoffs=[], + tracing=None, + task_id=sample_task_id + ) + + # Verify temperature was passed correctly + create_call = streaming_model.client.responses.create.call_args + assert create_call.kwargs['temperature'] == temp + + @pytest.mark.asyncio + async def test_top_p_setting(self, streaming_model, _mock_adk_streaming, sample_task_id): + """Test that top_p parameter is properly passed to Responses API""" + streaming_model.client.responses.create = AsyncMock() + + mock_stream = AsyncMock() + mock_stream.__aiter__.return_value = iter([ + MagicMock(type="response.completed", response=MagicMock(output=[])) + ]) + streaming_model.client.responses.create.return_value = mock_stream + + # Test with various top_p values + for top_p in [0.1, 0.5, 0.9, None]: + settings = ModelSettings(top_p=top_p) + + await streaming_model.get_response( + system_instructions="Test", + input="Hello", + model_settings=settings, + tools=[], + output_schema=None, + handoffs=[], + tracing=None, + task_id=sample_task_id + ) + + create_call = streaming_model.client.responses.create.call_args + expected = top_p if top_p is not None else NOT_GIVEN + assert create_call.kwargs['top_p'] == expected + + @pytest.mark.asyncio + async def test_max_tokens_setting(self, streaming_model, _mock_adk_streaming, sample_task_id): + """Test that max_tokens is properly mapped to max_output_tokens""" + streaming_model.client.responses.create = AsyncMock() + + mock_stream = AsyncMock() + mock_stream.__aiter__.return_value = iter([ + MagicMock(type="response.completed", response=MagicMock(output=[])) + ]) + streaming_model.client.responses.create.return_value = mock_stream + + settings = ModelSettings(max_tokens=2000) + + await streaming_model.get_response( + system_instructions="Test", + input="Hello", + model_settings=settings, + tools=[], + output_schema=None, + handoffs=[], + tracing=None, + task_id=sample_task_id + ) + + create_call = streaming_model.client.responses.create.call_args + assert create_call.kwargs['max_output_tokens'] == 2000 + + @pytest.mark.asyncio + async def test_reasoning_effort_settings(self, streaming_model, _mock_adk_streaming, sample_task_id): + """Test reasoning effort levels (low/medium/high)""" + streaming_model.client.responses.create = AsyncMock() + + mock_stream = AsyncMock() + mock_stream.__aiter__.return_value = iter([ + MagicMock(type="response.completed", response=MagicMock(output=[])) + ]) + streaming_model.client.responses.create.return_value = mock_stream + + for effort in ["low", "medium", "high"]: + settings = ModelSettings( + reasoning=Reasoning(effort=effort) + ) + + await streaming_model.get_response( + system_instructions="Test", + input="Hello", + model_settings=settings, + tools=[], + output_schema=None, + handoffs=[], + tracing=None, + task_id=sample_task_id + ) + + create_call = streaming_model.client.responses.create.call_args + assert create_call.kwargs['reasoning'] == {"effort": effort} + + @pytest.mark.asyncio + async def test_reasoning_summary_settings(self, streaming_model, _mock_adk_streaming, sample_task_id): + """Test reasoning summary settings (auto/none)""" + streaming_model.client.responses.create = AsyncMock() + + mock_stream = AsyncMock() + mock_stream.__aiter__.return_value = iter([ + MagicMock(type="response.completed", response=MagicMock(output=[])) + ]) + streaming_model.client.responses.create.return_value = mock_stream + + for summary in ["auto", "concise", "detailed"]: + settings = ModelSettings( + reasoning=Reasoning(effort="medium", generate_summary=summary) + ) + + await streaming_model.get_response( + system_instructions="Test", + input="Hello", + model_settings=settings, + tools=[], + output_schema=None, + handoffs=[], + tracing=None, + task_id=sample_task_id + ) + + create_call = streaming_model.client.responses.create.call_args + assert create_call.kwargs['reasoning'] == {"effort": "medium", "summary": summary} + + @pytest.mark.asyncio + async def test_tool_choice_variations(self, streaming_model, _mock_adk_streaming, sample_task_id, sample_function_tool): + """Test various tool_choice settings""" + streaming_model.client.responses.create = AsyncMock() + + mock_stream = AsyncMock() + mock_stream.__aiter__.return_value = iter([ + MagicMock(type="response.completed", response=MagicMock(output=[])) + ]) + streaming_model.client.responses.create.return_value = mock_stream + + # Test different tool_choice options + test_cases = [ + ("auto", "auto"), + ("required", "required"), + ("none", "none"), + ("get_weather", {"type": "function", "name": "get_weather"}), + ("web_search", {"type": "web_search"}), + (MCPToolChoice(server_label="test", name="tool"), {"server_label": "test", "type": "mcp", "name": "tool"}) + ] + + for tool_choice, expected in test_cases: + settings = ModelSettings(tool_choice=tool_choice) + + await streaming_model.get_response( + system_instructions="Test", + input="Hello", + model_settings=settings, + tools=[sample_function_tool], + output_schema=None, + handoffs=[], + tracing=None, + task_id=sample_task_id + ) + + create_call = streaming_model.client.responses.create.call_args + assert create_call.kwargs['tool_choice'] == expected + + @pytest.mark.asyncio + async def test_parallel_tool_calls(self, streaming_model, _mock_adk_streaming, sample_task_id, sample_function_tool): + """Test parallel tool calls setting""" + streaming_model.client.responses.create = AsyncMock() + + mock_stream = AsyncMock() + mock_stream.__aiter__.return_value = iter([ + MagicMock(type="response.completed", response=MagicMock(output=[])) + ]) + streaming_model.client.responses.create.return_value = mock_stream + + for parallel in [True, False]: + settings = ModelSettings(parallel_tool_calls=parallel) + + await streaming_model.get_response( + system_instructions="Test", + input="Hello", + model_settings=settings, + tools=[sample_function_tool], + output_schema=None, + handoffs=[], + tracing=None, + task_id=sample_task_id + ) + + create_call = streaming_model.client.responses.create.call_args + assert create_call.kwargs['parallel_tool_calls'] == parallel + + @pytest.mark.asyncio + async def test_truncation_strategy(self, streaming_model, _mock_adk_streaming, sample_task_id): + """Test truncation parameter""" + streaming_model.client.responses.create = AsyncMock() + + mock_stream = AsyncMock() + mock_stream.__aiter__.return_value = iter([ + MagicMock(type="response.completed", response=MagicMock(output=[])) + ]) + streaming_model.client.responses.create.return_value = mock_stream + + # truncation now accepts 'auto' or 'disabled' string literals + settings = ModelSettings(truncation="auto") + + await streaming_model.get_response( + system_instructions="Test", + input="Hello", + model_settings=settings, + tools=[], + output_schema=None, + handoffs=[], + tracing=None, + task_id=sample_task_id + ) + + create_call = streaming_model.client.responses.create.call_args + assert create_call.kwargs['truncation'] == "auto" + + @pytest.mark.asyncio + async def test_response_include(self, streaming_model, _mock_adk_streaming, sample_task_id, sample_file_search_tool): + """Test response include parameter""" + streaming_model.client.responses.create = AsyncMock() + + mock_stream = AsyncMock() + mock_stream.__aiter__.return_value = iter([ + MagicMock(type="response.completed", response=MagicMock(output=[])) + ]) + streaming_model.client.responses.create.return_value = mock_stream + + settings = ModelSettings( + response_include=["reasoning.encrypted_content", "message.output_text.logprobs"] + ) + + await streaming_model.get_response( + system_instructions="Test", + input="Hello", + model_settings=settings, + tools=[sample_file_search_tool], # This adds file_search_call.results + output_schema=None, + handoffs=[], + tracing=None, + task_id=sample_task_id + ) + + create_call = streaming_model.client.responses.create.call_args + include_list = create_call.kwargs['include'] + assert "reasoning.encrypted_content" in include_list + assert "message.output_text.logprobs" in include_list + assert "file_search_call.results" in include_list # Added by file search tool + + @pytest.mark.asyncio + async def test_verbosity(self, streaming_model, _mock_adk_streaming, sample_task_id): + """Test verbosity settings""" + streaming_model.client.responses.create = AsyncMock() + + mock_stream = AsyncMock() + mock_stream.__aiter__.return_value = iter([ + MagicMock(type="response.completed", response=MagicMock(output=[])) + ]) + streaming_model.client.responses.create.return_value = mock_stream + + settings = ModelSettings(verbosity="high") + + await streaming_model.get_response( + system_instructions="Test", + input="Hello", + model_settings=settings, + tools=[], + output_schema=None, + handoffs=[], + tracing=None, + task_id=sample_task_id + ) + + create_call = streaming_model.client.responses.create.call_args + assert create_call.kwargs['text'] == {"verbosity": "high"} + + @pytest.mark.asyncio + async def test_metadata_and_store(self, streaming_model, _mock_adk_streaming, sample_task_id): + """Test metadata and store parameters""" + streaming_model.client.responses.create = AsyncMock() + + mock_stream = AsyncMock() + mock_stream.__aiter__.return_value = iter([ + MagicMock(type="response.completed", response=MagicMock(output=[])) + ]) + streaming_model.client.responses.create.return_value = mock_stream + + metadata = {"user_id": "123", "session": "abc"} + store = True + + settings = ModelSettings( + metadata=metadata, + store=store + ) + + await streaming_model.get_response( + system_instructions="Test", + input="Hello", + model_settings=settings, + tools=[], + output_schema=None, + handoffs=[], + tracing=None, + task_id=sample_task_id + ) + + create_call = streaming_model.client.responses.create.call_args + assert create_call.kwargs['metadata'] == metadata + assert create_call.kwargs['store'] == store + + @pytest.mark.asyncio + async def test_extra_headers_and_body(self, streaming_model, _mock_adk_streaming, sample_task_id): + """Test extra customization parameters""" + streaming_model.client.responses.create = AsyncMock() + + mock_stream = AsyncMock() + mock_stream.__aiter__.return_value = iter([ + MagicMock(type="response.completed", response=MagicMock(output=[])) + ]) + streaming_model.client.responses.create.return_value = mock_stream + + extra_headers = {"X-Custom": "header"} + extra_body = {"custom_field": "value"} + extra_query = {"param": "value"} + + settings = ModelSettings( + extra_headers=extra_headers, + extra_body=extra_body, + extra_query=extra_query + ) + + await streaming_model.get_response( + system_instructions="Test", + input="Hello", + model_settings=settings, + tools=[], + output_schema=None, + handoffs=[], + tracing=None, + task_id=sample_task_id + ) + + create_call = streaming_model.client.responses.create.call_args + assert create_call.kwargs['extra_headers'] == extra_headers + assert create_call.kwargs['extra_body'] == extra_body + assert create_call.kwargs['extra_query'] == extra_query + + @pytest.mark.asyncio + async def test_top_logprobs(self, streaming_model, _mock_adk_streaming, sample_task_id): + """Test top_logprobs parameter""" + streaming_model.client.responses.create = AsyncMock() + + mock_stream = AsyncMock() + mock_stream.__aiter__.return_value = iter([ + MagicMock(type="response.completed", response=MagicMock(output=[])) + ]) + streaming_model.client.responses.create.return_value = mock_stream + + settings = ModelSettings(top_logprobs=5) + + await streaming_model.get_response( + system_instructions="Test", + input="Hello", + model_settings=settings, + tools=[], + output_schema=None, + handoffs=[], + tracing=None, + task_id=sample_task_id + ) + + create_call = streaming_model.client.responses.create.call_args + # top_logprobs goes into extra_args + assert "top_logprobs" in create_call.kwargs + assert create_call.kwargs['top_logprobs'] == 5 + # Also should add to include list + assert "message.output_text.logprobs" in create_call.kwargs['include'] + + +class TestStreamingModelTools: + """Test that all tool types work with streaming""" + + @pytest.mark.asyncio + async def test_function_tool(self, streaming_model, _mock_adk_streaming, sample_task_id, sample_function_tool): + """Test FunctionTool conversion and streaming""" + streaming_model.client.responses.create = AsyncMock() + + mock_stream = AsyncMock() + mock_stream.__aiter__.return_value = iter([ + MagicMock(type="response.completed", response=MagicMock(output=[])) + ]) + streaming_model.client.responses.create.return_value = mock_stream + + await streaming_model.get_response( + system_instructions="Test", + input="Hello", + model_settings=ModelSettings(), + tools=[sample_function_tool], + output_schema=None, + handoffs=[], + tracing=None, + task_id=sample_task_id + ) + + create_call = streaming_model.client.responses.create.call_args + tools = create_call.kwargs['tools'] + assert len(tools) == 1 + assert tools[0]['type'] == 'function' + assert tools[0]['name'] == 'get_weather' + assert tools[0]['description'] == 'Get the current weather' + assert 'parameters' in tools[0] + + @pytest.mark.asyncio + async def test_web_search_tool(self, streaming_model, _mock_adk_streaming, sample_task_id, sample_web_search_tool): + """Test WebSearchTool conversion""" + streaming_model.client.responses.create = AsyncMock() + + mock_stream = AsyncMock() + mock_stream.__aiter__.return_value = iter([ + MagicMock(type="response.completed", response=MagicMock(output=[])) + ]) + streaming_model.client.responses.create.return_value = mock_stream + + await streaming_model.get_response( + system_instructions="Test", + input="Hello", + model_settings=ModelSettings(), + tools=[sample_web_search_tool], + output_schema=None, + handoffs=[], + tracing=None, + task_id=sample_task_id + ) + + create_call = streaming_model.client.responses.create.call_args + tools = create_call.kwargs['tools'] + assert len(tools) == 1 + assert tools[0]['type'] == 'web_search' + + @pytest.mark.asyncio + async def test_file_search_tool(self, streaming_model, _mock_adk_streaming, sample_task_id, sample_file_search_tool): + """Test FileSearchTool conversion""" + streaming_model.client.responses.create = AsyncMock() + + mock_stream = AsyncMock() + mock_stream.__aiter__.return_value = iter([ + MagicMock(type="response.completed", response=MagicMock(output=[])) + ]) + streaming_model.client.responses.create.return_value = mock_stream + + await streaming_model.get_response( + system_instructions="Test", + input="Hello", + model_settings=ModelSettings(), + tools=[sample_file_search_tool], + output_schema=None, + handoffs=[], + tracing=None, + task_id=sample_task_id + ) + + create_call = streaming_model.client.responses.create.call_args + tools = create_call.kwargs['tools'] + assert len(tools) == 1 + assert tools[0]['type'] == 'file_search' + assert tools[0]['vector_store_ids'] == ['vs_123'] + assert tools[0]['max_num_results'] == 10 + + @pytest.mark.asyncio + async def test_computer_tool(self, streaming_model, _mock_adk_streaming, sample_task_id, sample_computer_tool): + """Test ComputerTool conversion""" + streaming_model.client.responses.create = AsyncMock() + + mock_stream = AsyncMock() + mock_stream.__aiter__.return_value = iter([ + MagicMock(type="response.completed", response=MagicMock(output=[])) + ]) + streaming_model.client.responses.create.return_value = mock_stream + + await streaming_model.get_response( + system_instructions="Test", + input="Hello", + model_settings=ModelSettings(), + tools=[sample_computer_tool], + output_schema=None, + handoffs=[], + tracing=None, + task_id=sample_task_id + ) + + create_call = streaming_model.client.responses.create.call_args + tools = create_call.kwargs['tools'] + assert len(tools) == 1 + assert tools[0]['type'] == 'computer_use_preview' + assert tools[0]['environment'] == 'desktop' + assert tools[0]['display_width'] == 1920 + assert tools[0]['display_height'] == 1080 + + @pytest.mark.asyncio + async def test_multiple_computer_tools_error(self, streaming_model, _mock_adk_streaming, sample_task_id, sample_computer_tool): + """Test that multiple computer tools raise an error""" + streaming_model.client.responses.create = AsyncMock() + + # Create two computer tools + computer2 = MagicMock() + computer2.environment = "mobile" + computer2.dimensions = [375, 812] + from agents.tool import ComputerTool + second_computer_tool = ComputerTool(computer=computer2) + + with pytest.raises(ValueError, match="You can only provide one computer tool"): + await streaming_model.get_response( + system_instructions="Test", + input="Hello", + model_settings=ModelSettings(), + tools=[sample_computer_tool, second_computer_tool], + output_schema=None, + handoffs=[], + tracing=None, + task_id=sample_task_id + ) + + @pytest.mark.asyncio + async def test_hosted_mcp_tool(self, streaming_model, _mock_adk_streaming, sample_task_id, sample_hosted_mcp_tool): + """Test HostedMCPTool conversion""" + streaming_model.client.responses.create = AsyncMock() + + mock_stream = AsyncMock() + mock_stream.__aiter__.return_value = iter([ + MagicMock(type="response.completed", response=MagicMock(output=[])) + ]) + streaming_model.client.responses.create.return_value = mock_stream + + await streaming_model.get_response( + system_instructions="Test", + input="Hello", + model_settings=ModelSettings(), + tools=[sample_hosted_mcp_tool], + output_schema=None, + handoffs=[], + tracing=None, + task_id=sample_task_id + ) + + create_call = streaming_model.client.responses.create.call_args + tools = create_call.kwargs['tools'] + assert len(tools) == 1 + assert tools[0]['type'] == 'mcp' + assert tools[0]['server_label'] == 'test_server' + + @pytest.mark.asyncio + async def test_image_generation_tool(self, streaming_model, _mock_adk_streaming, sample_task_id, sample_image_generation_tool): + """Test ImageGenerationTool conversion""" + streaming_model.client.responses.create = AsyncMock() + + mock_stream = AsyncMock() + mock_stream.__aiter__.return_value = iter([ + MagicMock(type="response.completed", response=MagicMock(output=[])) + ]) + streaming_model.client.responses.create.return_value = mock_stream + + await streaming_model.get_response( + system_instructions="Test", + input="Hello", + model_settings=ModelSettings(), + tools=[sample_image_generation_tool], + output_schema=None, + handoffs=[], + tracing=None, + task_id=sample_task_id + ) + + create_call = streaming_model.client.responses.create.call_args + tools = create_call.kwargs['tools'] + assert len(tools) == 1 + assert tools[0]['type'] == 'image_generation' + + @pytest.mark.asyncio + async def test_code_interpreter_tool(self, streaming_model, _mock_adk_streaming, sample_task_id, sample_code_interpreter_tool): + """Test CodeInterpreterTool conversion""" + streaming_model.client.responses.create = AsyncMock() + + mock_stream = AsyncMock() + mock_stream.__aiter__.return_value = iter([ + MagicMock(type="response.completed", response=MagicMock(output=[])) + ]) + streaming_model.client.responses.create.return_value = mock_stream + + await streaming_model.get_response( + system_instructions="Test", + input="Hello", + model_settings=ModelSettings(), + tools=[sample_code_interpreter_tool], + output_schema=None, + handoffs=[], + tracing=None, + task_id=sample_task_id + ) + + create_call = streaming_model.client.responses.create.call_args + tools = create_call.kwargs['tools'] + assert len(tools) == 1 + assert tools[0]['type'] == 'code_interpreter' + + @pytest.mark.asyncio + async def test_local_shell_tool(self, streaming_model, _mock_adk_streaming, sample_task_id, sample_local_shell_tool): + """Test LocalShellTool conversion""" + streaming_model.client.responses.create = AsyncMock() + + mock_stream = AsyncMock() + mock_stream.__aiter__.return_value = iter([ + MagicMock(type="response.completed", response=MagicMock(output=[])) + ]) + streaming_model.client.responses.create.return_value = mock_stream + + await streaming_model.get_response( + system_instructions="Test", + input="Hello", + model_settings=ModelSettings(), + tools=[sample_local_shell_tool], + output_schema=None, + handoffs=[], + tracing=None, + task_id=sample_task_id + ) + + create_call = streaming_model.client.responses.create.call_args + tools = create_call.kwargs['tools'] + assert len(tools) == 1 + assert tools[0]['type'] == 'local_shell' + # working_directory no longer in API - LocalShellTool uses executor internally + + @pytest.mark.asyncio + async def test_handoffs(self, streaming_model, _mock_adk_streaming, sample_task_id, sample_handoff): + """Test Handoff conversion to function tools""" + streaming_model.client.responses.create = AsyncMock() + + mock_stream = AsyncMock() + mock_stream.__aiter__.return_value = iter([ + MagicMock(type="response.completed", response=MagicMock(output=[])) + ]) + streaming_model.client.responses.create.return_value = mock_stream + + await streaming_model.get_response( + system_instructions="Test", + input="Hello", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[sample_handoff], + tracing=None, + task_id=sample_task_id + ) + + create_call = streaming_model.client.responses.create.call_args + tools = create_call.kwargs['tools'] + assert len(tools) == 1 + assert tools[0]['type'] == 'function' + assert tools[0]['name'] == 'transfer_to_support' + assert tools[0]['description'] == 'Transfer to support agent' + + @pytest.mark.asyncio + async def test_mixed_tools(self, streaming_model, _mock_adk_streaming, sample_task_id, + sample_function_tool, sample_web_search_tool, sample_handoff): + """Test multiple tools together""" + streaming_model.client.responses.create = AsyncMock() + + mock_stream = AsyncMock() + mock_stream.__aiter__.return_value = iter([ + MagicMock(type="response.completed", response=MagicMock(output=[])) + ]) + streaming_model.client.responses.create.return_value = mock_stream + + await streaming_model.get_response( + system_instructions="Test", + input="Hello", + model_settings=ModelSettings(), + tools=[sample_function_tool, sample_web_search_tool], + output_schema=None, + handoffs=[sample_handoff], + tracing=None, + task_id=sample_task_id + ) + + create_call = streaming_model.client.responses.create.call_args + tools = create_call.kwargs['tools'] + assert len(tools) == 3 # 2 tools + 1 handoff + + # Check each tool type is present + tool_types = [t['type'] for t in tools] + assert 'function' in tool_types # function tool and handoff + assert 'web_search' in tool_types + + +class TestStreamingModelBasics: + """Test core streaming functionality""" + + @pytest.mark.asyncio + async def test_responses_api_streaming(self, streaming_model, mock_adk_streaming, sample_task_id): + """Test basic Responses API streaming flow""" + streaming_model.client.responses.create = AsyncMock() + + # Create a mock stream with text deltas + mock_stream = AsyncMock() + events = [ + MagicMock(type="response.output_item.added", item=MagicMock(type="message")), + MagicMock(type="response.text.delta", delta="Hello "), + MagicMock(type="response.text.delta", delta="world!"), + MagicMock(type="response.completed", response=MagicMock(output=[])) + ] + mock_stream.__aiter__.return_value = iter(events) + streaming_model.client.responses.create.return_value = mock_stream + + result = await streaming_model.get_response( + system_instructions="Test", + input="Hello", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=None, + task_id=sample_task_id + ) + + # Verify streaming context was created + mock_adk_streaming.streaming_task_message_context.assert_called_with( + task_id=sample_task_id, + initial_content=mock_adk_streaming.streaming_task_message_context.call_args.kwargs['initial_content'] + ) + + # Verify result is returned as ModelResponse + from agents import ModelResponse + assert isinstance(result, ModelResponse) + + @pytest.mark.asyncio + async def test_task_id_threading(self, streaming_model, mock_adk_streaming): + """Test that task_id is properly threaded through to streaming context""" + streaming_model.client.responses.create = AsyncMock() + + mock_stream = AsyncMock() + mock_stream.__aiter__.return_value = iter([ + MagicMock(type="response.completed", response=MagicMock(output=[])) + ]) + streaming_model.client.responses.create.return_value = mock_stream + + task_id = "test_task_12345" + + await streaming_model.get_response( + system_instructions="Test", + input="Hello", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=None, + task_id=task_id + ) + + # Verify task_id was passed to streaming context + mock_adk_streaming.streaming_task_message_context.assert_called() + call_args = mock_adk_streaming.streaming_task_message_context.call_args + assert call_args.kwargs['task_id'] == task_id + + @pytest.mark.asyncio + async def test_redis_context_creation(self, streaming_model, mock_adk_streaming, sample_task_id): + """Test that Redis streaming contexts are created properly""" + streaming_model.client.responses.create = AsyncMock() + + # Mock stream with reasoning + mock_stream = AsyncMock() + events = [ + MagicMock(type="response.output_item.added", item=MagicMock(type="reasoning")), + MagicMock(type="response.reasoning_summary_text.delta", delta="Thinking...", summary_index=0), + MagicMock(type="response.completed", response=MagicMock(output=[])) + ] + mock_stream.__aiter__.return_value = iter(events) + streaming_model.client.responses.create.return_value = mock_stream + + await streaming_model.get_response( + system_instructions="Test", + input="Hello", + model_settings=ModelSettings(reasoning=Reasoning(effort="medium")), + tools=[], + output_schema=None, + handoffs=[], + tracing=None, + task_id=sample_task_id + ) + + # Should create at least one context for reasoning + assert mock_adk_streaming.streaming_task_message_context.call_count >= 1 + + @pytest.mark.asyncio + async def test_missing_task_id_error(self, streaming_model): + """Test that missing task_id raises appropriate error""" + streaming_model.client.responses.create = AsyncMock() + + with pytest.raises(ValueError, match="task_id is required"): + await streaming_model.get_response( + system_instructions="Test", + input="Hello", + model_settings=ModelSettings(), + tools=[], + output_schema=None, + handoffs=[], + tracing=None, + task_id=None # Missing task_id + ) \ No newline at end of file diff --git a/src/agentex/lib/core/temporal/workers/worker.py b/src/agentex/lib/core/temporal/workers/worker.py index 6a92d305..28cab2e1 100644 --- a/src/agentex/lib/core/temporal/workers/worker.py +++ b/src/agentex/lib/core/temporal/workers/worker.py @@ -13,6 +13,7 @@ from temporalio.worker import ( Plugin as WorkerPlugin, Worker, + Interceptor, UnsandboxedWorkflowRunner, ) from temporalio.runtime import Runtime, TelemetryConfig, OpenTelemetryConfig @@ -80,6 +81,16 @@ def _validate_plugins(plugins: list) -> None: ) +def _validate_interceptors(interceptors: list) -> None: + """Validate that all items in the interceptors list are valid Temporal interceptors.""" + for i, interceptor in enumerate(interceptors): + if not isinstance(interceptor, Interceptor): + raise TypeError( + f"Interceptor at index {i} must be an instance of temporalio.worker.Interceptor, " + f"got {type(interceptor).__name__}" + ) + + async def get_temporal_client(temporal_address: str, metrics_url: str | None = None, plugins: list = []) -> Client: if plugins != []: # We don't need to validate the plugins if they are empty _validate_plugins(plugins) @@ -116,6 +127,7 @@ def __init__( max_concurrent_activities: int = 10, health_check_port: int | None = None, plugins: list = [], + interceptors: list = [], ): self.task_queue = task_queue self.activity_handles = [] @@ -125,6 +137,7 @@ def __init__( self.healthy = False self.health_check_port = health_check_port if health_check_port is not None else EnvironmentVariables.refresh().HEALTH_CHECK_PORT self.plugins = plugins + self.interceptors = interceptors @overload async def run( @@ -151,6 +164,11 @@ async def run( ): await self.start_health_check_server() await self._register_agent() + + # Validate interceptors if any are provided + if self.interceptors: + _validate_interceptors(self.interceptors) + temporal_client = await get_temporal_client( temporal_address=os.environ.get("TEMPORAL_ADDRESS", "localhost:7233"), plugins=self.plugins, @@ -174,6 +192,7 @@ async def run( max_concurrent_activities=self.max_concurrent_activities, build_id=str(uuid.uuid4()), debug_mode=debug_enabled, # Disable deadlock detection in debug mode + interceptors=self.interceptors, # Pass interceptors to Worker ) logger.info(f"Starting workers for task queue: {self.task_queue}") diff --git a/src/agentex/lib/sdk/fastacp/fastacp.py b/src/agentex/lib/sdk/fastacp/fastacp.py index fd943382..a1b2c1a3 100644 --- a/src/agentex/lib/sdk/fastacp/fastacp.py +++ b/src/agentex/lib/sdk/fastacp/fastacp.py @@ -52,12 +52,14 @@ def create_agentic_acp(config: AgenticACPConfig, **kwargs) -> BaseACPServer: implementation_class = AGENTIC_ACP_IMPLEMENTATIONS[config.type] # Handle temporal-specific configuration if config.type == "temporal": - # Extract temporal_address and plugins from config if it's a TemporalACPConfig + # Extract temporal_address, plugins, and interceptors from config if it's a TemporalACPConfig temporal_config = kwargs.copy() if hasattr(config, "temporal_address"): temporal_config["temporal_address"] = config.temporal_address # type: ignore[attr-defined] if hasattr(config, "plugins"): temporal_config["plugins"] = config.plugins # type: ignore[attr-defined] + if hasattr(config, "interceptors"): + temporal_config["interceptors"] = config.interceptors # type: ignore[attr-defined] return implementation_class.create(**temporal_config) else: return implementation_class.create(**kwargs) diff --git a/src/agentex/lib/sdk/fastacp/impl/temporal_acp.py b/src/agentex/lib/sdk/fastacp/impl/temporal_acp.py index 9a8ebb19..750707c4 100644 --- a/src/agentex/lib/sdk/fastacp/impl/temporal_acp.py +++ b/src/agentex/lib/sdk/fastacp/impl/temporal_acp.py @@ -30,19 +30,21 @@ def __init__( temporal_address: str, temporal_task_service: TemporalTaskService | None = None, plugins: list[Any] | None = None, + interceptors: list[Any] | None = None, ): super().__init__() self._temporal_task_service = temporal_task_service self._temporal_address = temporal_address self._plugins = plugins or [] + self._interceptors = interceptors or [] @classmethod @override - def create(cls, temporal_address: str, plugins: list[Any] | None = None) -> "TemporalACP": + def create(cls, temporal_address: str, plugins: list[Any] | None = None, interceptors: list[Any] | None = None) -> "TemporalACP": logger.info("Initializing TemporalACP instance") # Create instance without temporal client initially - temporal_acp = cls(temporal_address=temporal_address, plugins=plugins) + temporal_acp = cls(temporal_address=temporal_address, plugins=plugins, interceptors=interceptors) temporal_acp._setup_handlers() logger.info("TemporalACP instance initialized now") return temporal_acp diff --git a/src/agentex/lib/types/fastacp.py b/src/agentex/lib/types/fastacp.py index 5a2428ad..5b9b6c59 100644 --- a/src/agentex/lib/types/fastacp.py +++ b/src/agentex/lib/types/fastacp.py @@ -4,7 +4,7 @@ from pydantic import Field, BaseModel, field_validator -from agentex.lib.core.clients.temporal.utils import validate_client_plugins +from agentex.lib.core.clients.temporal.utils import validate_client_plugins, validate_worker_interceptors class BaseACPConfig(BaseModel): @@ -48,11 +48,13 @@ class TemporalACPConfig(AgenticACPConfig): type: The type of ACP implementation temporal_address: The address of the temporal server plugins: List of Temporal client plugins + interceptors: List of Temporal worker interceptors """ type: Literal["temporal"] = Field(default="temporal", frozen=True) temporal_address: str = Field(default="temporal-frontend.temporal.svc.cluster.local:7233", frozen=True) plugins: list[Any] = Field(default=[], frozen=True) + interceptors: list[Any] = Field(default=[], frozen=True) @field_validator("plugins") @classmethod @@ -61,6 +63,13 @@ def validate_plugins(cls, v: list[Any]) -> list[Any]: validate_client_plugins(v) return v + @field_validator("interceptors") + @classmethod + def validate_interceptors(cls, v: list[Any]) -> list[Any]: + """Validate that all interceptors are valid Temporal worker interceptors.""" + validate_worker_interceptors(v) + return v + class AgenticBaseACPConfig(AgenticACPConfig): """Configuration for AgenticBaseACP implementation