Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions src/agentex/lib/adk/_modules/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,7 @@ def _tracing_service(self) -> TracingService:
# Re-create the underlying httpx client when the event loop changes
# (e.g. between HTTP requests in a sync ASGI server) to avoid
# "Event loop is closed" / "bound to a different event loop" errors.
if self._tracing_service_lazy is None or (
loop_id is not None and loop_id != self._bound_loop_id
):
if self._tracing_service_lazy is None or (loop_id is not None and loop_id != self._bound_loop_id):
import httpx

# Disable keepalive so each span HTTP call gets a fresh TCP
Expand All @@ -93,6 +91,7 @@ async def span(
input: list[Any] | dict[str, Any] | BaseModel | None = None,
data: list[Any] | dict[str, Any] | BaseModel | None = None,
parent_id: str | None = None,
task_id: str | None = None,
start_to_close_timeout: timedelta = timedelta(seconds=5),
heartbeat_timeout: timedelta = timedelta(seconds=5),
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
Expand All @@ -109,6 +108,7 @@ async def span(
input (Union[List, Dict, BaseModel]): The input for the span.
parent_id (Optional[str]): The parent span ID for the span.
data (Optional[Union[List, Dict, BaseModel]]): The data for the span.
task_id (Optional[str]): The task ID this span belongs to.
start_to_close_timeout (timedelta): The start to close timeout for the span.
heartbeat_timeout (timedelta): The heartbeat timeout for the span.
retry_policy (RetryPolicy): The retry policy for the span.
Expand All @@ -126,6 +126,7 @@ async def span(
input=input,
parent_id=parent_id,
data=data,
task_id=task_id,
start_to_close_timeout=start_to_close_timeout,
heartbeat_timeout=heartbeat_timeout,
retry_policy=retry_policy,
Expand All @@ -149,6 +150,7 @@ async def start_span(
input: list[Any] | dict[str, Any] | BaseModel | None = None,
parent_id: str | None = None,
data: list[Any] | dict[str, Any] | BaseModel | None = None,
task_id: str | None = None,
start_to_close_timeout: timedelta = timedelta(seconds=5),
heartbeat_timeout: timedelta = timedelta(seconds=1),
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
Expand All @@ -162,6 +164,7 @@ async def start_span(
input (Union[List, Dict, BaseModel]): The input for the span.
parent_id (Optional[str]): The parent span ID for the span.
data (Optional[Union[List, Dict, BaseModel]]): The data for the span.
task_id (Optional[str]): The task ID this span belongs to.
start_to_close_timeout (timedelta): The start to close timeout for the span.
heartbeat_timeout (timedelta): The heartbeat timeout for the span.
retry_policy (RetryPolicy): The retry policy for the span.
Expand All @@ -175,6 +178,7 @@ async def start_span(
name=name,
input=input,
data=data,
task_id=task_id,
)
if in_temporal_workflow():
return await ActivityHelpers.execute_activity(
Expand All @@ -192,6 +196,7 @@ async def start_span(
input=input,
parent_id=parent_id,
data=data,
task_id=task_id,
)

async def end_span(
Expand Down
2 changes: 2 additions & 0 deletions src/agentex/lib/core/services/adk/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@ async def start_span(
parent_id: str | None = None,
input: list[Any] | dict[str, Any] | BaseModel | None = None,
data: list[Any] | dict[str, Any] | BaseModel | None = None,
task_id: str | None = None,
) -> Span | None:
trace = self._tracer.trace(trace_id)
span = await trace.start_span(
name=name,
parent_id=parent_id,
input=input or {},
data=data,
task_id=task_id,
)
heartbeat_if_in_workflow("start span")
return span
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class StartSpanParams(BaseModel):
name: str
input: list[Any] | dict[str, Any] | BaseModel | None = None
data: list[Any] | dict[str, Any] | BaseModel | None = None
task_id: str | None = None


class EndSpanParams(BaseModel):
Expand All @@ -47,6 +48,7 @@ async def start_span(self, params: StartSpanParams) -> Span | None:
name=params.name,
input=params.input,
data=params.data,
task_id=params.task_id,
)

@activity.defn(name=TracingActivityName.END_SPAN)
Expand Down
13 changes: 11 additions & 2 deletions src/agentex/lib/core/tracing/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def start_span(
parent_id: str | None = None,
input: dict[str, Any] | list[dict[str, Any]] | BaseModel | None = None,
data: dict[str, Any] | list[dict[str, Any]] | BaseModel | None = None,
task_id: str | None = None,
) -> Span:
"""
Start a new span and register it with the API.
Expand All @@ -63,6 +64,7 @@ def start_span(
parent_id: Optional parent span ID.
input: Optional input data for the span.
data: Optional additional data for the span.
task_id: Optional ID of the task this span belongs to.

Returns:
The newly created span.
Expand All @@ -86,6 +88,7 @@ def start_span(
start_time=start_time,
input=serialized_input,
data=serialized_data,
task_id=task_id,
)

for processor in self.processors:
Expand Down Expand Up @@ -150,6 +153,7 @@ def span(
parent_id: str | None = None,
input: dict[str, Any] | list[dict[str, Any]] | BaseModel | None = None,
data: dict[str, Any] | list[dict[str, Any]] | BaseModel | None = None,
task_id: str | None = None,
):
"""
Context manager for spans.
Expand All @@ -158,7 +162,7 @@ def span(
if not self.trace_id:
yield None
return
span = self.start_span(name, parent_id, input, data)
span = self.start_span(name, parent_id, input, data, task_id=task_id)
try:
yield span
finally:
Expand Down Expand Up @@ -198,6 +202,7 @@ async def start_span(
parent_id: str | None = None,
input: dict[str, Any] | list[dict[str, Any]] | BaseModel | None = None,
data: dict[str, Any] | list[dict[str, Any]] | BaseModel | None = None,
task_id: str | None = None,
) -> Span:
"""
Start a new span and register it with the API.
Expand All @@ -207,6 +212,7 @@ async def start_span(
parent_id: Optional parent span ID.
input: Optional input data for the span.
data: Optional additional data for the span.
task_id: Optional ID of the task this span belongs to.

Returns:
The newly created span.
Expand All @@ -229,6 +235,7 @@ async def start_span(
start_time=start_time,
input=serialized_input,
data=serialized_data,
task_id=task_id,
)

if self.processors:
Expand Down Expand Up @@ -293,6 +300,7 @@ async def span(
parent_id: str | None = None,
input: dict[str, Any] | list[dict[str, Any]] | BaseModel | None = None,
data: dict[str, Any] | list[dict[str, Any]] | BaseModel | None = None,
task_id: str | None = None,
) -> AsyncGenerator[Span | None, None]:
"""
Context manager for spans.
Expand All @@ -302,14 +310,15 @@ async def span(
parent_id: Optional parent span ID.
input: Optional input data for the span.
data: Optional additional data for the span.
task_id: Optional ID of the task this span belongs to.

Yields:
The span object.
"""
if not self.trace_id:
yield None
return
span = await self.start_span(name, parent_id, input, data)
span = await self.start_span(name, parent_id, input, data, task_id=task_id)
try:
yield span
finally:
Expand Down
96 changes: 96 additions & 0 deletions tests/lib/adk/test_tracing_activities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from __future__ import annotations

from datetime import datetime, timezone
from unittest.mock import AsyncMock

from temporalio.testing import ActivityEnvironment

from agentex.types.span import Span


def _make_span(**overrides) -> Span:
defaults = {
"id": "span-123",
"name": "test-span",
"start_time": datetime(2026, 1, 1, tzinfo=timezone.utc),
"trace_id": "trace-123",
}
defaults.update(overrides)
return Span(**defaults)


def _make_tracing_activities():
from agentex.lib.core.services.adk.tracing import TracingService
from agentex.lib.core.temporal.activities.adk.tracing_activities import TracingActivities

mock_service = AsyncMock(spec=TracingService)
activities = TracingActivities(tracing_service=mock_service)
env = ActivityEnvironment()
return mock_service, activities, env


class TestStartSpanActivity:
async def test_start_span_with_task_id(self):
from agentex.lib.core.temporal.activities.adk.tracing_activities import StartSpanParams

mock_service, activities, env = _make_tracing_activities()
expected = _make_span(task_id="task-abc")
mock_service.start_span.return_value = expected

params = StartSpanParams(
trace_id="trace-123",
name="test-span",
task_id="task-abc",
)
result = await env.run(activities.start_span, params)

assert result == expected
assert result.task_id == "task-abc"
mock_service.start_span.assert_called_once_with(
trace_id="trace-123",
parent_id=None,
name="test-span",
input=None,
data=None,
task_id="task-abc",
)

async def test_start_span_without_task_id(self):
from agentex.lib.core.temporal.activities.adk.tracing_activities import StartSpanParams

mock_service, activities, env = _make_tracing_activities()
expected = _make_span()
mock_service.start_span.return_value = expected

params = StartSpanParams(trace_id="trace-123", name="test-span")
result = await env.run(activities.start_span, params)

assert result == expected
mock_service.start_span.assert_called_once_with(
trace_id="trace-123",
parent_id=None,
name="test-span",
input=None,
data=None,
task_id=None,
)


class TestEndSpanActivity:
async def test_end_span_preserves_task_id(self):
from agentex.lib.core.temporal.activities.adk.tracing_activities import EndSpanParams

mock_service, activities, env = _make_tracing_activities()
span = _make_span(task_id="task-abc")
expected = _make_span(
task_id="task-abc",
end_time=datetime(2026, 1, 1, tzinfo=timezone.utc),
)
mock_service.end_span.return_value = expected

params = EndSpanParams(trace_id="trace-123", span=span)
result = await env.run(activities.end_span, params)

assert result == expected
assert result.task_id == "task-abc"
mock_service.end_span.assert_called_once_with(trace_id="trace-123", span=span)
Loading