From 25a98ca962895bbe8c7195137ecc75e6f2a5796b Mon Sep 17 00:00:00 2001 From: Stas Moreinis Date: Wed, 29 Apr 2026 18:10:48 -0700 Subject: [PATCH] fix(tests): repair test_streaming_model so all 28 tests run and pass MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Four pre-existing bugs left this entire test file unrunnable on main (4 failures + 24 errors); fixing them here so the suite actually exercises TemporalStreamingModel and protects against regressions. Bug 1 (24 errors): `conftest.py` defines fixture `mock_adk_streaming` (no underscore) but every test in TestStreamingModelSettings and TestStreamingModelTools requested it as `_mock_adk_streaming`, so pytest failed to resolve the fixture before the body ever ran. The fixture is ``autouse=True`` and the param value was never used in any test body, so the parameter was vestigial — replaced with `_streaming_context_vars`, which provides the ContextVar setup these tests now actually need. Bug 2 (4 failures): `TemporalStreamingModel.get_response()` reads `task_id`, `trace_id`, and `parent_span_id` from ContextVars populated by `ContextInterceptor` from request headers in real Temporal flows. Tests had been passing `task_id=...` as a kwarg, which is silently swallowed by `**kwargs` and ignored, so all three ContextVars stayed at their defaults and the validation at the top of `get_response` raised before any work happened. New `_streaming_context_vars` fixture in conftest sets all three vars (and resets them on teardown), simulating what `ContextInterceptor` does in production. Bug 3 (test_computer_tool): A recent commit narrowed `ComputerTool` serialization to require an actual `Computer`/`AsyncComputer` instance, but `sample_computer_tool` still built a bare `MagicMock`. Switched to `MagicMock(spec=Computer)` so the production isinstance check passes. Bug 4 (3 streaming-context tests): The 3 tests in TestStreamingModelBasics that assert on `streaming_task_message_context` calls built event sequences with raw `MagicMock(type="...")`. Production dispatches via `isinstance(event, ResponseOutputItemAddedEvent)` etc., which `MagicMock` without `spec` never satisfies, so dispatch was silently skipped and the assertions failed. Switched to `MagicMock(spec=...)` for each event type — passes isinstance without triggering pydantic validation on the event's required fields. Also fixed `test_task_id_threading` which had been asserting against a hardcoded `task_id="test_task_12345"` that was never actually threaded anywhere (the kwarg was ignored, just like in Bug 2); it now asserts against the value yielded by the fixture, which is the value production reads from the ContextVar. After all four fixes: 28/28 pass, ruff clean, pyright clean. --- .../plugins/openai_agents/tests/conftest.py | 38 ++++- .../tests/test_streaming_model.py | 132 +++++++++++------- 2 files changed, 115 insertions(+), 55 deletions(-) diff --git a/src/agentex/lib/core/temporal/plugins/openai_agents/tests/conftest.py b/src/agentex/lib/core/temporal/plugins/openai_agents/tests/conftest.py index 599cb1e36..aa7ad7b04 100644 --- a/src/agentex/lib/core/temporal/plugins/openai_agents/tests/conftest.py +++ b/src/agentex/lib/core/temporal/plugins/openai_agents/tests/conftest.py @@ -21,6 +21,7 @@ CodeInterpreterTool, ImageGenerationTool, ) +from agents.computer import Computer from agents.model_settings import Reasoning # type: ignore[attr-defined] from openai.types.responses import ( ResponseCompletedEvent, @@ -47,6 +48,34 @@ def sample_task_id(): return f"task_{uuid.uuid4().hex[:8]}" +@pytest.fixture +def _streaming_context_vars(sample_task_id): + """Populate the streaming ContextVars that ContextInterceptor sets from + request headers in real Temporal flows. TemporalStreamingModel.get_response() + validates that all three are set before doing any work, so any test that + calls get_response() must request this fixture. + + Named with a leading underscore so tests can request it purely for its + setup/teardown side effects without ruff flagging it as an unused argument + (ARG002). The yielded value is the task_id set on the ContextVar, available + for tests that need to assert against it. + """ + from agentex.lib.core.temporal.plugins.openai_agents.interceptors.context_interceptor import ( + streaming_task_id, + streaming_trace_id, + streaming_parent_span_id, + ) + task_token = streaming_task_id.set(sample_task_id) + trace_token = streaming_trace_id.set("test-trace-id") + span_token = streaming_parent_span_id.set("test-parent-span-id") + try: + yield sample_task_id + finally: + streaming_task_id.reset(task_token) + streaming_trace_id.reset(trace_token) + streaming_parent_span_id.reset(span_token) + + @pytest.fixture def mock_streaming_context(): """Mock streaming context for testing""" @@ -115,8 +144,13 @@ def sample_file_search_tool(): @pytest.fixture def sample_computer_tool(): - """Sample ComputerTool for testing""" - computer = MagicMock() + """Sample ComputerTool for testing. + + Production validates ``isinstance(computer, (Computer, AsyncComputer))`` for + Responses API serialization, so the mock must be ``spec``-bound to + ``Computer`` for the isinstance check to pass. + """ + computer = MagicMock(spec=Computer) computer.environment = "desktop" computer.dimensions = [1920, 1080] return ComputerTool(computer=computer) diff --git a/src/agentex/lib/core/temporal/plugins/openai_agents/tests/test_streaming_model.py b/src/agentex/lib/core/temporal/plugins/openai_agents/tests/test_streaming_model.py index 457ec9547..817e5e5b7 100644 --- a/src/agentex/lib/core/temporal/plugins/openai_agents/tests/test_streaming_model.py +++ b/src/agentex/lib/core/temporal/plugins/openai_agents/tests/test_streaming_model.py @@ -8,13 +8,19 @@ from agents import ModelSettings from openai import NOT_GIVEN from agents.model_settings import Reasoning, MCPToolChoice # type: ignore[attr-defined] +from openai.types.responses import ( + ResponseCompletedEvent, + ResponseTextDeltaEvent, + ResponseOutputItemAddedEvent, + ResponseReasoningSummaryTextDeltaEvent, +) class TestStreamingModelSettings: """Test that all ModelSettings parameters work with Responses API""" @pytest.mark.asyncio - async def test_temperature_setting(self, streaming_model, _mock_adk_streaming, sample_task_id): + async def test_temperature_setting(self, streaming_model, _streaming_context_vars, sample_task_id): """Test that temperature parameter is properly passed to Responses API""" streaming_model.client.responses.create = AsyncMock() @@ -45,7 +51,7 @@ async def test_temperature_setting(self, streaming_model, _mock_adk_streaming, s assert create_call.kwargs['temperature'] == temp @pytest.mark.asyncio - async def test_top_p_setting(self, streaming_model, _mock_adk_streaming, sample_task_id): + async def test_top_p_setting(self, streaming_model, _streaming_context_vars, sample_task_id): """Test that top_p parameter is properly passed to Responses API""" streaming_model.client.responses.create = AsyncMock() @@ -75,7 +81,7 @@ async def test_top_p_setting(self, streaming_model, _mock_adk_streaming, sample_ assert create_call.kwargs['top_p'] == expected @pytest.mark.asyncio - async def test_max_tokens_setting(self, streaming_model, _mock_adk_streaming, sample_task_id): + async def test_max_tokens_setting(self, streaming_model, _streaming_context_vars, sample_task_id): """Test that max_tokens is properly mapped to max_output_tokens""" streaming_model.client.responses.create = AsyncMock() @@ -102,7 +108,7 @@ async def test_max_tokens_setting(self, streaming_model, _mock_adk_streaming, sa assert create_call.kwargs['max_output_tokens'] == 2000 @pytest.mark.asyncio - async def test_reasoning_effort_settings(self, streaming_model, _mock_adk_streaming, sample_task_id): + async def test_reasoning_effort_settings(self, streaming_model, _streaming_context_vars, sample_task_id): """Test reasoning effort levels (low/medium/high)""" streaming_model.client.responses.create = AsyncMock() @@ -132,7 +138,7 @@ async def test_reasoning_effort_settings(self, streaming_model, _mock_adk_stream assert create_call.kwargs['reasoning'] == {"effort": effort} @pytest.mark.asyncio - async def test_reasoning_summary_settings(self, streaming_model, _mock_adk_streaming, sample_task_id): + async def test_reasoning_summary_settings(self, streaming_model, _streaming_context_vars, sample_task_id): """Test reasoning summary settings (auto/none)""" streaming_model.client.responses.create = AsyncMock() @@ -162,7 +168,7 @@ async def test_reasoning_summary_settings(self, streaming_model, _mock_adk_strea assert create_call.kwargs['reasoning'] == {"effort": "medium", "summary": summary} @pytest.mark.asyncio - async def test_tool_choice_variations(self, streaming_model, _mock_adk_streaming, sample_task_id, sample_function_tool): + async def test_tool_choice_variations(self, streaming_model, _streaming_context_vars, sample_task_id, sample_function_tool): """Test various tool_choice settings""" streaming_model.client.responses.create = AsyncMock() @@ -200,7 +206,7 @@ async def test_tool_choice_variations(self, streaming_model, _mock_adk_streaming assert create_call.kwargs['tool_choice'] == expected @pytest.mark.asyncio - async def test_parallel_tool_calls(self, streaming_model, _mock_adk_streaming, sample_task_id, sample_function_tool): + async def test_parallel_tool_calls(self, streaming_model, _streaming_context_vars, sample_task_id, sample_function_tool): """Test parallel tool calls setting""" streaming_model.client.responses.create = AsyncMock() @@ -228,7 +234,7 @@ async def test_parallel_tool_calls(self, streaming_model, _mock_adk_streaming, s assert create_call.kwargs['parallel_tool_calls'] == parallel @pytest.mark.asyncio - async def test_truncation_strategy(self, streaming_model, _mock_adk_streaming, sample_task_id): + async def test_truncation_strategy(self, streaming_model, _streaming_context_vars, sample_task_id): """Test truncation parameter""" streaming_model.client.responses.create = AsyncMock() @@ -256,7 +262,7 @@ async def test_truncation_strategy(self, streaming_model, _mock_adk_streaming, s assert create_call.kwargs['truncation'] == "auto" @pytest.mark.asyncio - async def test_response_include(self, streaming_model, _mock_adk_streaming, sample_task_id, sample_file_search_tool): + async def test_response_include(self, streaming_model, _streaming_context_vars, sample_task_id, sample_file_search_tool): """Test response include parameter""" streaming_model.client.responses.create = AsyncMock() @@ -288,7 +294,7 @@ async def test_response_include(self, streaming_model, _mock_adk_streaming, samp assert "file_search_call.results" in include_list # Added by file search tool @pytest.mark.asyncio - async def test_verbosity(self, streaming_model, _mock_adk_streaming, sample_task_id): + async def test_verbosity(self, streaming_model, _streaming_context_vars, sample_task_id): """Test verbosity settings""" streaming_model.client.responses.create = AsyncMock() @@ -315,7 +321,7 @@ async def test_verbosity(self, streaming_model, _mock_adk_streaming, sample_task assert create_call.kwargs['text'] == {"verbosity": "high"} @pytest.mark.asyncio - async def test_metadata_and_store(self, streaming_model, _mock_adk_streaming, sample_task_id): + async def test_metadata_and_store(self, streaming_model, _streaming_context_vars, sample_task_id): """Test metadata and store parameters""" streaming_model.client.responses.create = AsyncMock() @@ -349,7 +355,7 @@ async def test_metadata_and_store(self, streaming_model, _mock_adk_streaming, sa assert create_call.kwargs['store'] == store @pytest.mark.asyncio - async def test_extra_headers_and_body(self, streaming_model, _mock_adk_streaming, sample_task_id): + async def test_extra_headers_and_body(self, streaming_model, _streaming_context_vars, sample_task_id): """Test extra customization parameters""" streaming_model.client.responses.create = AsyncMock() @@ -386,7 +392,7 @@ async def test_extra_headers_and_body(self, streaming_model, _mock_adk_streaming assert create_call.kwargs['extra_query'] == extra_query @pytest.mark.asyncio - async def test_top_logprobs(self, streaming_model, _mock_adk_streaming, sample_task_id): + async def test_top_logprobs(self, streaming_model, _streaming_context_vars, sample_task_id): """Test top_logprobs parameter""" streaming_model.client.responses.create = AsyncMock() @@ -421,7 +427,7 @@ class TestStreamingModelTools: """Test that all tool types work with streaming""" @pytest.mark.asyncio - async def test_function_tool(self, streaming_model, _mock_adk_streaming, sample_task_id, sample_function_tool): + async def test_function_tool(self, streaming_model, _streaming_context_vars, sample_task_id, sample_function_tool): """Test FunctionTool conversion and streaming""" streaming_model.client.responses.create = AsyncMock() @@ -451,7 +457,7 @@ async def test_function_tool(self, streaming_model, _mock_adk_streaming, sample_ assert 'parameters' in tools[0] @pytest.mark.asyncio - async def test_web_search_tool(self, streaming_model, _mock_adk_streaming, sample_task_id, sample_web_search_tool): + async def test_web_search_tool(self, streaming_model, _streaming_context_vars, sample_task_id, sample_web_search_tool): """Test WebSearchTool conversion""" streaming_model.client.responses.create = AsyncMock() @@ -478,7 +484,7 @@ async def test_web_search_tool(self, streaming_model, _mock_adk_streaming, sampl assert tools[0]['type'] == 'web_search' @pytest.mark.asyncio - async def test_file_search_tool(self, streaming_model, _mock_adk_streaming, sample_task_id, sample_file_search_tool): + async def test_file_search_tool(self, streaming_model, _streaming_context_vars, sample_task_id, sample_file_search_tool): """Test FileSearchTool conversion""" streaming_model.client.responses.create = AsyncMock() @@ -507,7 +513,7 @@ async def test_file_search_tool(self, streaming_model, _mock_adk_streaming, samp assert tools[0]['max_num_results'] == 10 @pytest.mark.asyncio - async def test_computer_tool(self, streaming_model, _mock_adk_streaming, sample_task_id, sample_computer_tool): + async def test_computer_tool(self, streaming_model, _streaming_context_vars, sample_task_id, sample_computer_tool): """Test ComputerTool conversion""" streaming_model.client.responses.create = AsyncMock() @@ -537,7 +543,7 @@ async def test_computer_tool(self, streaming_model, _mock_adk_streaming, sample_ assert tools[0]['display_height'] == 1080 @pytest.mark.asyncio - async def test_multiple_computer_tools_error(self, streaming_model, _mock_adk_streaming, sample_task_id, sample_computer_tool): + async def test_multiple_computer_tools_error(self, streaming_model, _streaming_context_vars, sample_task_id, sample_computer_tool): """Test that multiple computer tools raise an error""" streaming_model.client.responses.create = AsyncMock() @@ -561,7 +567,7 @@ async def test_multiple_computer_tools_error(self, streaming_model, _mock_adk_st ) @pytest.mark.asyncio - async def test_hosted_mcp_tool(self, streaming_model, _mock_adk_streaming, sample_task_id, sample_hosted_mcp_tool): + async def test_hosted_mcp_tool(self, streaming_model, _streaming_context_vars, sample_task_id, sample_hosted_mcp_tool): """Test HostedMCPTool conversion""" streaming_model.client.responses.create = AsyncMock() @@ -589,7 +595,7 @@ async def test_hosted_mcp_tool(self, streaming_model, _mock_adk_streaming, sampl assert tools[0]['server_label'] == 'test_server' @pytest.mark.asyncio - async def test_image_generation_tool(self, streaming_model, _mock_adk_streaming, sample_task_id, sample_image_generation_tool): + async def test_image_generation_tool(self, streaming_model, _streaming_context_vars, sample_task_id, sample_image_generation_tool): """Test ImageGenerationTool conversion""" streaming_model.client.responses.create = AsyncMock() @@ -616,7 +622,7 @@ async def test_image_generation_tool(self, streaming_model, _mock_adk_streaming, assert tools[0]['type'] == 'image_generation' @pytest.mark.asyncio - async def test_code_interpreter_tool(self, streaming_model, _mock_adk_streaming, sample_task_id, sample_code_interpreter_tool): + async def test_code_interpreter_tool(self, streaming_model, _streaming_context_vars, sample_task_id, sample_code_interpreter_tool): """Test CodeInterpreterTool conversion""" streaming_model.client.responses.create = AsyncMock() @@ -643,7 +649,7 @@ async def test_code_interpreter_tool(self, streaming_model, _mock_adk_streaming, assert tools[0]['type'] == 'code_interpreter' @pytest.mark.asyncio - async def test_local_shell_tool(self, streaming_model, _mock_adk_streaming, sample_task_id, sample_local_shell_tool): + async def test_local_shell_tool(self, streaming_model, _streaming_context_vars, sample_task_id, sample_local_shell_tool): """Test LocalShellTool conversion""" streaming_model.client.responses.create = AsyncMock() @@ -671,7 +677,7 @@ async def test_local_shell_tool(self, streaming_model, _mock_adk_streaming, samp # working_directory no longer in API - LocalShellTool uses executor internally @pytest.mark.asyncio - async def test_handoffs(self, streaming_model, _mock_adk_streaming, sample_task_id, sample_handoff): + async def test_handoffs(self, streaming_model, _streaming_context_vars, sample_task_id, sample_handoff): """Test Handoff conversion to function tools""" streaming_model.client.responses.create = AsyncMock() @@ -700,7 +706,7 @@ async def test_handoffs(self, streaming_model, _mock_adk_streaming, sample_task_ assert tools[0]['description'] == 'Transfer to support agent' @pytest.mark.asyncio - async def test_mixed_tools(self, streaming_model, _mock_adk_streaming, sample_task_id, + async def test_mixed_tools(self, streaming_model, _streaming_context_vars, sample_task_id, sample_function_tool, sample_web_search_tool, sample_handoff): """Test multiple tools together""" streaming_model.client.responses.create = AsyncMock() @@ -736,19 +742,24 @@ class TestStreamingModelBasics: """Test core streaming functionality""" @pytest.mark.asyncio - async def test_responses_api_streaming(self, streaming_model, mock_adk_streaming, sample_task_id): + async def test_responses_api_streaming(self, streaming_model, mock_adk_streaming, _streaming_context_vars, sample_task_id): """Test basic Responses API streaming flow""" streaming_model.client.responses.create = AsyncMock() - # Create a mock stream with text deltas + # Production uses ``isinstance(event, ...)`` against the OpenAI Responses + # event types to dispatch. ``spec=...`` makes isinstance pass without + # triggering pydantic validation on partially-constructed events. + item_added = MagicMock(spec=ResponseOutputItemAddedEvent) + item_added.item = MagicMock(type="message") + item_added.output_index = 0 + text_delta_1 = MagicMock(spec=ResponseTextDeltaEvent) + text_delta_1.delta = "Hello " + text_delta_2 = MagicMock(spec=ResponseTextDeltaEvent) + text_delta_2.delta = "world!" + completed = MagicMock(spec=ResponseCompletedEvent) + completed.response = MagicMock(output=[], usage=MagicMock()) mock_stream = AsyncMock() - events = [ - MagicMock(type="response.output_item.added", item=MagicMock(type="message")), - MagicMock(type="response.text.delta", delta="Hello "), - MagicMock(type="response.text.delta", delta="world!"), - MagicMock(type="response.completed", response=MagicMock(output=[])) - ] - mock_stream.__aiter__.return_value = iter(events) + mock_stream.__aiter__.return_value = iter([item_added, text_delta_1, text_delta_2, completed]) streaming_model.client.responses.create.return_value = mock_stream result = await streaming_model.get_response( @@ -773,17 +784,24 @@ async def test_responses_api_streaming(self, streaming_model, mock_adk_streaming assert isinstance(result, ModelResponse) @pytest.mark.asyncio - async def test_task_id_threading(self, streaming_model, mock_adk_streaming): - """Test that task_id is properly threaded through to streaming context""" + async def test_task_id_threading(self, streaming_model, mock_adk_streaming, _streaming_context_vars): + """Test that task_id from the streaming ContextVar is threaded through to + the streaming context. ``_streaming_context_vars`` yields the task_id that + was set on the ContextVar, which is what production reads (the kwarg + ``task_id=...`` to ``get_response`` is swallowed by ``**kwargs`` and ignored). + """ streaming_model.client.responses.create = AsyncMock() + item_added = MagicMock(spec=ResponseOutputItemAddedEvent) + item_added.item = MagicMock(type="message") + item_added.output_index = 0 + completed = MagicMock(spec=ResponseCompletedEvent) + completed.response = MagicMock(output=[], usage=MagicMock()) mock_stream = AsyncMock() - mock_stream.__aiter__.return_value = iter([ - MagicMock(type="response.completed", response=MagicMock(output=[])) - ]) + mock_stream.__aiter__.return_value = iter([item_added, completed]) streaming_model.client.responses.create.return_value = mock_stream - task_id = "test_task_12345" + expected_task_id = _streaming_context_vars await streaming_model.get_response( system_instructions="Test", @@ -793,27 +811,30 @@ async def test_task_id_threading(self, streaming_model, mock_adk_streaming): output_schema=None, handoffs=[], tracing=None, - task_id=task_id ) - # Verify task_id was passed to streaming context + # Verify the ContextVar's task_id was threaded through to the streaming context mock_adk_streaming.streaming_task_message_context.assert_called() call_args = mock_adk_streaming.streaming_task_message_context.call_args - assert call_args.kwargs['task_id'] == task_id + assert call_args.kwargs['task_id'] == expected_task_id @pytest.mark.asyncio - async def test_redis_context_creation(self, streaming_model, mock_adk_streaming, sample_task_id): + async def test_redis_context_creation(self, streaming_model, mock_adk_streaming, _streaming_context_vars, sample_task_id): """Test that Redis streaming contexts are created properly""" streaming_model.client.responses.create = AsyncMock() - # Mock stream with reasoning + # Production uses ``isinstance`` against OpenAI Responses event types; + # ``spec=...`` makes isinstance pass without triggering pydantic validation. + item_added = MagicMock(spec=ResponseOutputItemAddedEvent) + item_added.item = MagicMock(type="reasoning") + item_added.output_index = 0 + reasoning_delta = MagicMock(spec=ResponseReasoningSummaryTextDeltaEvent) + reasoning_delta.delta = "Thinking..." + reasoning_delta.summary_index = 0 + completed = MagicMock(spec=ResponseCompletedEvent) + completed.response = MagicMock(output=[], usage=MagicMock()) mock_stream = AsyncMock() - events = [ - MagicMock(type="response.output_item.added", item=MagicMock(type="reasoning")), - MagicMock(type="response.reasoning_summary_text.delta", delta="Thinking...", summary_index=0), - MagicMock(type="response.completed", response=MagicMock(output=[])) - ] - mock_stream.__aiter__.return_value = iter(events) + mock_stream.__aiter__.return_value = iter([item_added, reasoning_delta, completed]) streaming_model.client.responses.create.return_value = mock_stream await streaming_model.get_response( @@ -832,10 +853,16 @@ async def test_redis_context_creation(self, streaming_model, mock_adk_streaming, @pytest.mark.asyncio async def test_missing_task_id_error(self, streaming_model): - """Test that missing task_id raises appropriate error""" + """Test that missing streaming ContextVars raise an appropriate error. + + Production reads task_id, trace_id, and parent_span_id from ContextVars + populated by ContextInterceptor. Without ``_streaming_context_vars`` + requested, all three are at their defaults — empty strings — and + ``get_response`` raises before doing any work. + """ streaming_model.client.responses.create = AsyncMock() - with pytest.raises(ValueError, match="task_id is required"): + with pytest.raises(ValueError, match=r"task_id.*required"): await streaming_model.get_response( system_instructions="Test", input="Hello", @@ -844,5 +871,4 @@ async def test_missing_task_id_error(self, streaming_model): output_schema=None, handoffs=[], tracing=None, - task_id=None # Missing task_id ) \ No newline at end of file