From 5a1698deef8375d1f4a24f541d7e6d929ecb5862 Mon Sep 17 00:00:00 2001 From: Declan Brady Date: Tue, 21 Apr 2026 10:35:04 -0400 Subject: [PATCH] Add task_id to span creation --- src/agentex/lib/adk/_modules/tracing.py | 11 +- src/agentex/lib/core/services/adk/tracing.py | 2 + .../activities/adk/tracing_activities.py | 2 + src/agentex/lib/core/tracing/trace.py | 13 +- tests/lib/adk/test_tracing_activities.py | 96 ++++++++++++++ tests/lib/adk/test_tracing_module.py | 117 ++++++++++++++++++ tests/lib/adk/test_tracing_service.py | 84 +++++++++++++ tests/lib/core/tracing/test_trace_task_id.py | 55 ++++++++ 8 files changed, 375 insertions(+), 5 deletions(-) create mode 100644 tests/lib/adk/test_tracing_activities.py create mode 100644 tests/lib/adk/test_tracing_module.py create mode 100644 tests/lib/adk/test_tracing_service.py create mode 100644 tests/lib/core/tracing/test_trace_task_id.py diff --git a/src/agentex/lib/adk/_modules/tracing.py b/src/agentex/lib/adk/_modules/tracing.py index cb3b4b22b..67150f01d 100644 --- a/src/agentex/lib/adk/_modules/tracing.py +++ b/src/agentex/lib/adk/_modules/tracing.py @@ -64,9 +64,7 @@ def _tracing_service(self) -> TracingService: # Re-create the underlying httpx client when the event loop changes # (e.g. between HTTP requests in a sync ASGI server) to avoid # "Event loop is closed" / "bound to a different event loop" errors. - if self._tracing_service_lazy is None or ( - loop_id is not None and loop_id != self._bound_loop_id - ): + if self._tracing_service_lazy is None or (loop_id is not None and loop_id != self._bound_loop_id): import httpx # Disable keepalive so each span HTTP call gets a fresh TCP @@ -93,6 +91,7 @@ async def span( input: list[Any] | dict[str, Any] | BaseModel | None = None, data: list[Any] | dict[str, Any] | BaseModel | None = None, parent_id: str | None = None, + task_id: str | None = None, start_to_close_timeout: timedelta = timedelta(seconds=5), heartbeat_timeout: timedelta = timedelta(seconds=5), retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY, @@ -109,6 +108,7 @@ async def span( input (Union[List, Dict, BaseModel]): The input for the span. parent_id (Optional[str]): The parent span ID for the span. data (Optional[Union[List, Dict, BaseModel]]): The data for the span. + task_id (Optional[str]): The task ID this span belongs to. start_to_close_timeout (timedelta): The start to close timeout for the span. heartbeat_timeout (timedelta): The heartbeat timeout for the span. retry_policy (RetryPolicy): The retry policy for the span. @@ -126,6 +126,7 @@ async def span( input=input, parent_id=parent_id, data=data, + task_id=task_id, start_to_close_timeout=start_to_close_timeout, heartbeat_timeout=heartbeat_timeout, retry_policy=retry_policy, @@ -149,6 +150,7 @@ async def start_span( input: list[Any] | dict[str, Any] | BaseModel | None = None, parent_id: str | None = None, data: list[Any] | dict[str, Any] | BaseModel | None = None, + task_id: str | None = None, start_to_close_timeout: timedelta = timedelta(seconds=5), heartbeat_timeout: timedelta = timedelta(seconds=1), retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY, @@ -162,6 +164,7 @@ async def start_span( input (Union[List, Dict, BaseModel]): The input for the span. parent_id (Optional[str]): The parent span ID for the span. data (Optional[Union[List, Dict, BaseModel]]): The data for the span. + task_id (Optional[str]): The task ID this span belongs to. start_to_close_timeout (timedelta): The start to close timeout for the span. heartbeat_timeout (timedelta): The heartbeat timeout for the span. retry_policy (RetryPolicy): The retry policy for the span. @@ -175,6 +178,7 @@ async def start_span( name=name, input=input, data=data, + task_id=task_id, ) if in_temporal_workflow(): return await ActivityHelpers.execute_activity( @@ -192,6 +196,7 @@ async def start_span( input=input, parent_id=parent_id, data=data, + task_id=task_id, ) async def end_span( diff --git a/src/agentex/lib/core/services/adk/tracing.py b/src/agentex/lib/core/services/adk/tracing.py index 210d2f625..77efffd9e 100644 --- a/src/agentex/lib/core/services/adk/tracing.py +++ b/src/agentex/lib/core/services/adk/tracing.py @@ -22,6 +22,7 @@ async def start_span( parent_id: str | None = None, input: list[Any] | dict[str, Any] | BaseModel | None = None, data: list[Any] | dict[str, Any] | BaseModel | None = None, + task_id: str | None = None, ) -> Span | None: trace = self._tracer.trace(trace_id) span = await trace.start_span( @@ -29,6 +30,7 @@ async def start_span( parent_id=parent_id, input=input or {}, data=data, + task_id=task_id, ) heartbeat_if_in_workflow("start span") return span diff --git a/src/agentex/lib/core/temporal/activities/adk/tracing_activities.py b/src/agentex/lib/core/temporal/activities/adk/tracing_activities.py index 65afcded0..aec541afe 100644 --- a/src/agentex/lib/core/temporal/activities/adk/tracing_activities.py +++ b/src/agentex/lib/core/temporal/activities/adk/tracing_activities.py @@ -24,6 +24,7 @@ class StartSpanParams(BaseModel): name: str input: list[Any] | dict[str, Any] | BaseModel | None = None data: list[Any] | dict[str, Any] | BaseModel | None = None + task_id: str | None = None class EndSpanParams(BaseModel): @@ -47,6 +48,7 @@ async def start_span(self, params: StartSpanParams) -> Span | None: name=params.name, input=params.input, data=params.data, + task_id=params.task_id, ) @activity.defn(name=TracingActivityName.END_SPAN) diff --git a/src/agentex/lib/core/tracing/trace.py b/src/agentex/lib/core/tracing/trace.py index 7925df7fc..70b268b18 100644 --- a/src/agentex/lib/core/tracing/trace.py +++ b/src/agentex/lib/core/tracing/trace.py @@ -54,6 +54,7 @@ def start_span( parent_id: str | None = None, input: dict[str, Any] | list[dict[str, Any]] | BaseModel | None = None, data: dict[str, Any] | list[dict[str, Any]] | BaseModel | None = None, + task_id: str | None = None, ) -> Span: """ Start a new span and register it with the API. @@ -63,6 +64,7 @@ def start_span( parent_id: Optional parent span ID. input: Optional input data for the span. data: Optional additional data for the span. + task_id: Optional ID of the task this span belongs to. Returns: The newly created span. @@ -86,6 +88,7 @@ def start_span( start_time=start_time, input=serialized_input, data=serialized_data, + task_id=task_id, ) for processor in self.processors: @@ -150,6 +153,7 @@ def span( parent_id: str | None = None, input: dict[str, Any] | list[dict[str, Any]] | BaseModel | None = None, data: dict[str, Any] | list[dict[str, Any]] | BaseModel | None = None, + task_id: str | None = None, ): """ Context manager for spans. @@ -158,7 +162,7 @@ def span( if not self.trace_id: yield None return - span = self.start_span(name, parent_id, input, data) + span = self.start_span(name, parent_id, input, data, task_id=task_id) try: yield span finally: @@ -198,6 +202,7 @@ async def start_span( parent_id: str | None = None, input: dict[str, Any] | list[dict[str, Any]] | BaseModel | None = None, data: dict[str, Any] | list[dict[str, Any]] | BaseModel | None = None, + task_id: str | None = None, ) -> Span: """ Start a new span and register it with the API. @@ -207,6 +212,7 @@ async def start_span( parent_id: Optional parent span ID. input: Optional input data for the span. data: Optional additional data for the span. + task_id: Optional ID of the task this span belongs to. Returns: The newly created span. @@ -229,6 +235,7 @@ async def start_span( start_time=start_time, input=serialized_input, data=serialized_data, + task_id=task_id, ) if self.processors: @@ -293,6 +300,7 @@ async def span( parent_id: str | None = None, input: dict[str, Any] | list[dict[str, Any]] | BaseModel | None = None, data: dict[str, Any] | list[dict[str, Any]] | BaseModel | None = None, + task_id: str | None = None, ) -> AsyncGenerator[Span | None, None]: """ Context manager for spans. @@ -302,6 +310,7 @@ async def span( parent_id: Optional parent span ID. input: Optional input data for the span. data: Optional additional data for the span. + task_id: Optional ID of the task this span belongs to. Yields: The span object. @@ -309,7 +318,7 @@ async def span( if not self.trace_id: yield None return - span = await self.start_span(name, parent_id, input, data) + span = await self.start_span(name, parent_id, input, data, task_id=task_id) try: yield span finally: diff --git a/tests/lib/adk/test_tracing_activities.py b/tests/lib/adk/test_tracing_activities.py new file mode 100644 index 000000000..248ba94a7 --- /dev/null +++ b/tests/lib/adk/test_tracing_activities.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from unittest.mock import AsyncMock + +from temporalio.testing import ActivityEnvironment + +from agentex.types.span import Span + + +def _make_span(**overrides) -> Span: + defaults = { + "id": "span-123", + "name": "test-span", + "start_time": datetime(2026, 1, 1, tzinfo=timezone.utc), + "trace_id": "trace-123", + } + defaults.update(overrides) + return Span(**defaults) + + +def _make_tracing_activities(): + from agentex.lib.core.services.adk.tracing import TracingService + from agentex.lib.core.temporal.activities.adk.tracing_activities import TracingActivities + + mock_service = AsyncMock(spec=TracingService) + activities = TracingActivities(tracing_service=mock_service) + env = ActivityEnvironment() + return mock_service, activities, env + + +class TestStartSpanActivity: + async def test_start_span_with_task_id(self): + from agentex.lib.core.temporal.activities.adk.tracing_activities import StartSpanParams + + mock_service, activities, env = _make_tracing_activities() + expected = _make_span(task_id="task-abc") + mock_service.start_span.return_value = expected + + params = StartSpanParams( + trace_id="trace-123", + name="test-span", + task_id="task-abc", + ) + result = await env.run(activities.start_span, params) + + assert result == expected + assert result.task_id == "task-abc" + mock_service.start_span.assert_called_once_with( + trace_id="trace-123", + parent_id=None, + name="test-span", + input=None, + data=None, + task_id="task-abc", + ) + + async def test_start_span_without_task_id(self): + from agentex.lib.core.temporal.activities.adk.tracing_activities import StartSpanParams + + mock_service, activities, env = _make_tracing_activities() + expected = _make_span() + mock_service.start_span.return_value = expected + + params = StartSpanParams(trace_id="trace-123", name="test-span") + result = await env.run(activities.start_span, params) + + assert result == expected + mock_service.start_span.assert_called_once_with( + trace_id="trace-123", + parent_id=None, + name="test-span", + input=None, + data=None, + task_id=None, + ) + + +class TestEndSpanActivity: + async def test_end_span_preserves_task_id(self): + from agentex.lib.core.temporal.activities.adk.tracing_activities import EndSpanParams + + mock_service, activities, env = _make_tracing_activities() + span = _make_span(task_id="task-abc") + expected = _make_span( + task_id="task-abc", + end_time=datetime(2026, 1, 1, tzinfo=timezone.utc), + ) + mock_service.end_span.return_value = expected + + params = EndSpanParams(trace_id="trace-123", span=span) + result = await env.run(activities.end_span, params) + + assert result == expected + assert result.task_id == "task-abc" + mock_service.end_span.assert_called_once_with(trace_id="trace-123", span=span) diff --git a/tests/lib/adk/test_tracing_module.py b/tests/lib/adk/test_tracing_module.py new file mode 100644 index 000000000..52d5d3f82 --- /dev/null +++ b/tests/lib/adk/test_tracing_module.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from unittest.mock import AsyncMock, patch + +import agentex.lib.adk._modules.tracing as _tracing_mod +from agentex.types.span import Span +from agentex.lib.adk._modules.tracing import TracingModule +from agentex.lib.core.services.adk.tracing import TracingService + + +def _make_span(**overrides) -> Span: + defaults = { + "id": "span-123", + "name": "test-span", + "start_time": datetime(2026, 1, 1, tzinfo=timezone.utc), + "trace_id": "trace-123", + } + defaults.update(overrides) + return Span(**defaults) + + +def _make_module() -> tuple[AsyncMock, TracingModule]: + mock_service = AsyncMock(spec=TracingService) + module = TracingModule(tracing_service=mock_service) + return mock_service, module + + +class TestStartSpan: + async def test_start_span_with_task_id(self): + mock_service, module = _make_module() + expected = _make_span(task_id="task-abc") + mock_service.start_span.return_value = expected + + with patch.object(_tracing_mod, "in_temporal_workflow", return_value=False): + result = await module.start_span( + trace_id="trace-123", + name="test-span", + task_id="task-abc", + ) + + assert result == expected + assert result.task_id == "task-abc" + mock_service.start_span.assert_called_once_with( + trace_id="trace-123", + name="test-span", + input=None, + parent_id=None, + data=None, + task_id="task-abc", + ) + + async def test_start_span_without_task_id(self): + mock_service, module = _make_module() + expected = _make_span() + mock_service.start_span.return_value = expected + + with patch.object(_tracing_mod, "in_temporal_workflow", return_value=False): + result = await module.start_span(trace_id="trace-123", name="test-span") + + assert result == expected + mock_service.start_span.assert_called_once_with( + trace_id="trace-123", + name="test-span", + input=None, + parent_id=None, + data=None, + task_id=None, + ) + + +class TestEndSpan: + async def test_end_span_preserves_task_id(self): + mock_service, module = _make_module() + span = _make_span(task_id="task-abc") + expected = _make_span( + task_id="task-abc", + end_time=datetime(2026, 1, 1, tzinfo=timezone.utc), + ) + mock_service.end_span.return_value = expected + + with patch.object(_tracing_mod, "in_temporal_workflow", return_value=False): + result = await module.end_span(trace_id="trace-123", span=span) + + assert result == expected + assert result.task_id == "task-abc" + mock_service.end_span.assert_called_once_with(trace_id="trace-123", span=span) + + +class TestSpanContextManager: + async def test_span_context_manager_forwards_task_id(self): + mock_service, module = _make_module() + started = _make_span(task_id="task-abc") + mock_service.start_span.return_value = started + mock_service.end_span.return_value = started + + with patch.object(_tracing_mod, "in_temporal_workflow", return_value=False): + async with module.span( + trace_id="trace-123", + name="test-span", + task_id="task-abc", + ) as span: + assert span is not None + assert span.task_id == "task-abc" + + assert mock_service.start_span.call_args.kwargs["task_id"] == "task-abc" + mock_service.end_span.assert_called_once() + + async def test_span_context_manager_noop_when_no_trace_id(self): + mock_service, module = _make_module() + + with patch.object(_tracing_mod, "in_temporal_workflow", return_value=False): + async with module.span(trace_id="", name="test-span") as span: + assert span is None + + mock_service.start_span.assert_not_called() + mock_service.end_span.assert_not_called() diff --git a/tests/lib/adk/test_tracing_service.py b/tests/lib/adk/test_tracing_service.py new file mode 100644 index 000000000..dceb000f5 --- /dev/null +++ b/tests/lib/adk/test_tracing_service.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock + +from agentex.types.span import Span +from agentex.lib.core.services.adk.tracing import TracingService + + +def _make_span(**overrides) -> Span: + defaults = { + "id": "span-123", + "name": "test-span", + "start_time": datetime(2026, 1, 1, tzinfo=timezone.utc), + "trace_id": "trace-123", + } + defaults.update(overrides) + return Span(**defaults) + + +def _make_service() -> tuple[MagicMock, MagicMock, TracingService]: + """Build a TracingService backed by an AsyncTracer whose + trace.start_span / trace.end_span are mocked.""" + mock_trace = MagicMock() + mock_trace.start_span = AsyncMock() + mock_trace.end_span = AsyncMock() + + mock_tracer = MagicMock() + mock_tracer.trace.return_value = mock_trace + + service = TracingService(tracer=mock_tracer) + return mock_tracer, mock_trace, service + + +class TestStartSpanService: + async def test_start_span_passes_task_id(self): + mock_tracer, mock_trace, service = _make_service() + expected = _make_span(task_id="task-abc") + mock_trace.start_span.return_value = expected + + result = await service.start_span( + trace_id="trace-123", + name="test-span", + task_id="task-abc", + ) + + assert result == expected + mock_tracer.trace.assert_called_once_with("trace-123") + mock_trace.start_span.assert_awaited_once_with( + name="test-span", + parent_id=None, + input={}, + data=None, + task_id="task-abc", + ) + + async def test_start_span_without_task_id(self): + _mock_tracer, mock_trace, service = _make_service() + expected = _make_span() + mock_trace.start_span.return_value = expected + + result = await service.start_span(trace_id="trace-123", name="test-span") + + assert result == expected + mock_trace.start_span.assert_awaited_once_with( + name="test-span", + parent_id=None, + input={}, + data=None, + task_id=None, + ) + + +class TestEndSpanService: + async def test_end_span_forwards_span(self): + mock_tracer, mock_trace, service = _make_service() + span = _make_span(task_id="task-abc") + mock_trace.end_span.return_value = span + + result = await service.end_span(trace_id="trace-123", span=span) + + assert result is span + mock_tracer.trace.assert_called_once_with("trace-123") + mock_trace.end_span.assert_awaited_once_with(span) diff --git a/tests/lib/core/tracing/test_trace_task_id.py b/tests/lib/core/tracing/test_trace_task_id.py new file mode 100644 index 000000000..1a616cc94 --- /dev/null +++ b/tests/lib/core/tracing/test_trace_task_id.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from unittest.mock import MagicMock + +from agentex.lib.core.tracing.trace import Trace, AsyncTrace + + +def _make_sync_trace(trace_id: str = "trace-123") -> tuple[MagicMock, Trace]: + client = MagicMock() + trace = Trace(processors=[], client=client, trace_id=trace_id) + return client, trace + + +def _make_async_trace(trace_id: str = "trace-123") -> tuple[MagicMock, AsyncTrace]: + client = MagicMock() + trace = AsyncTrace(processors=[], client=client, trace_id=trace_id) + return client, trace + + +class TestSyncTraceTaskId: + def test_start_span_sets_task_id_on_span(self): + _client, trace = _make_sync_trace() + span = trace.start_span(name="foo", task_id="task-abc") + assert span.task_id == "task-abc" + assert span.trace_id == "trace-123" + + def test_start_span_defaults_task_id_to_none(self): + _client, trace = _make_sync_trace() + span = trace.start_span(name="foo") + assert span.task_id is None + + def test_end_span_preserves_task_id_from_span(self): + _client, trace = _make_sync_trace() + span = trace.start_span(name="foo", task_id="task-abc") + trace.end_span(span) + assert span.task_id == "task-abc" + + +class TestAsyncTraceTaskId: + async def test_start_span_sets_task_id_on_span(self): + _client, trace = _make_async_trace() + span = await trace.start_span(name="foo", task_id="task-abc") + assert span.task_id == "task-abc" + assert span.trace_id == "trace-123" + + async def test_start_span_defaults_task_id_to_none(self): + _client, trace = _make_async_trace() + span = await trace.start_span(name="foo") + assert span.task_id is None + + async def test_end_span_preserves_task_id_from_span(self): + _client, trace = _make_async_trace() + span = await trace.start_span(name="foo", task_id="task-abc") + await trace.end_span(span) + assert span.task_id == "task-abc"