Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ def __init__(self, config: AgentexTracingProcessorConfig): # noqa: ARG002
),
)

# TODO(AGX1-199): Add batch create/update endpoints to Agentex API and use
# them here instead of one HTTP call per span.
# https://linear.app/scale-epd/issue/AGX1-199/add-agentex-batch-endpoint-for-traces
@override
async def on_span_start(self, span: Span) -> None:
await self.client.spans.create(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ async def on_span_start(self, span: Span) -> None:
if self.disabled:
logger.warning("SGP is disabled, skipping span upsert")
return
# TODO(AGX1-198): Batch multiple spans into a single upsert_batch call
# instead of one span per HTTP request.
# https://linear.app/scale-epd/issue/AGX1-198/actually-use-sgp-batching-for-spans
await self.sgp_async_client.spans.upsert_batch( # type: ignore[union-attr]
items=[sgp_span.to_request_params()]
)
Expand All @@ -155,6 +158,7 @@ async def on_span_end(self, span: Span) -> None:
return

self._add_source_to_span(span)
sgp_span.input = span.input # type: ignore[assignment]
sgp_span.output = span.output # type: ignore[assignment]
sgp_span.metadata = span.data # type: ignore[assignment]
sgp_span.end_time = span.end_time.isoformat() # type: ignore[union-attr]
Expand Down
63 changes: 55 additions & 8 deletions src/agentex/lib/core/tracing/span_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

logger = make_logger(__name__)

_DEFAULT_BATCH_SIZE = 50


class SpanEventType(str, Enum):
START = "start"
Expand All @@ -28,15 +30,18 @@ class _SpanQueueItem:
class AsyncSpanQueue:
"""Background FIFO queue for async span processing.

Span events are enqueued synchronously (non-blocking) and processed
sequentially by a background drain task. This keeps tracing HTTP calls
off the critical request path while preserving start-before-end ordering.
Span events are enqueued synchronously (non-blocking) and drained by a
background task. Items are processed in batches: all START events in a
batch are flushed concurrently, then all END events, so that per-span
start-before-end ordering is preserved while HTTP calls for independent
spans execute in parallel.
"""

def __init__(self) -> None:
def __init__(self, batch_size: int = _DEFAULT_BATCH_SIZE) -> None:
self._queue: asyncio.Queue[_SpanQueueItem] = asyncio.Queue()
self._drain_task: asyncio.Task[None] | None = None
self._stopping = False
self._batch_size = batch_size

def enqueue(
self,
Expand All @@ -54,9 +59,45 @@ def _ensure_drain_running(self) -> None:
if self._drain_task is None or self._drain_task.done():
self._drain_task = asyncio.create_task(self._drain_loop())

# ------------------------------------------------------------------
# Drain loop
# ------------------------------------------------------------------

async def _drain_loop(self) -> None:
while True:
item = await self._queue.get()
# Block until at least one item is available.
first = await self._queue.get()
batch: list[_SpanQueueItem] = [first]

# Opportunistically grab more ready items (non-blocking).
while len(batch) < self._batch_size:
try:
batch.append(self._queue.get_nowait())
except asyncio.QueueEmpty:
break

try:
# Separate START and END events. Processing all STARTs before
# ENDs ensures that on_span_start completes before on_span_end
# for any span whose both events land in the same batch.
starts = [i for i in batch if i.event_type == SpanEventType.START]
ends = [i for i in batch if i.event_type == SpanEventType.END]

if starts:
await self._process_items(starts)
if ends:
await self._process_items(ends)
finally:
for _ in batch:
self._queue.task_done()
# Release span data for GC.
batch.clear()

@staticmethod
async def _process_items(items: list[_SpanQueueItem]) -> None:
"""Process a list of span events concurrently."""

async def _handle(item: _SpanQueueItem) -> None:
try:
if item.event_type == SpanEventType.START:
coros = [p.on_span_start(item.span) for p in item.processors]
Expand All @@ -72,9 +113,15 @@ async def _drain_loop(self) -> None:
exc_info=result,
)
except Exception:
logger.exception("Unexpected error in span queue drain loop for span %s", item.span.id)
finally:
self._queue.task_done()
logger.exception(
"Unexpected error in span queue for span %s", item.span.id
)

await asyncio.gather(*[_handle(item) for item in items])

# ------------------------------------------------------------------
# Shutdown
# ------------------------------------------------------------------

async def shutdown(self, timeout: float = 30.0) -> None:
self._stopping = True
Expand Down
26 changes: 26 additions & 0 deletions tests/lib/core/tracing/processors/test_sgp_tracing_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,29 @@ async def test_span_end_for_unknown_span_is_noop(self):
await processor.on_span_end(span)

assert len(processor._spans) == 0

async def test_sgp_span_input_updated_on_end(self):
"""on_span_end should update sgp_span.input from the incoming span."""
processor, _ = self._make_processor()

with patch(f"{MODULE}.create_span", side_effect=lambda **kw: _make_mock_sgp_span()):
span = _make_span()
span.input = {"messages": [{"role": "user", "content": "hello"}]}
await processor.on_span_start(span)

assert len(processor._spans) == 1

# Simulate modified input at end time
updated_input: dict[str, object] = {"messages": [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "hi"},
]}
span.input = updated_input
span.output = {"response": "hi"}
span.end_time = datetime.now(UTC)
await processor.on_span_end(span)

# Span should be removed after end
assert len(processor._spans) == 0
# The end upsert should have been called
assert processor.sgp_async_client.spans.upsert_batch.call_count == 2 # start + end
130 changes: 123 additions & 7 deletions tests/lib/core/tracing/test_span_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time
import uuid
import asyncio
from typing import cast
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock, patch

Expand Down Expand Up @@ -52,7 +53,8 @@ async def slow_start(span: Span) -> None:


class TestAsyncSpanQueueOrdering:
async def test_fifo_ordering_preserved(self):
async def test_per_span_start_before_end(self):
"""START always completes before END for the same span, even with batching."""
call_log: list[tuple[str, str]] = []

async def record_start(span: Span) -> None:
Expand All @@ -77,12 +79,19 @@ async def record_end(span: Span) -> None:

await queue.shutdown()

assert call_log == [
("start", "span-a"),
("end", "span-a"),
("start", "span-b"),
("end", "span-b"),
]
# All 4 events should fire
assert len(call_log) == 4

# Per-span invariant: START before END
for span_id in ("span-a", "span-b"):
start_idx = next(i for i, (ev, sid) in enumerate(call_log) if ev == "start" and sid == span_id)
end_idx = next(i for i, (ev, sid) in enumerate(call_log) if ev == "end" and sid == span_id)
assert start_idx < end_idx, f"START should come before END for {span_id}"

# All STARTs before all ENDs within a batch
start_indices = [i for i, (ev, _) in enumerate(call_log) if ev == "start"]
end_indices = [i for i, (ev, _) in enumerate(call_log) if ev == "end"]
assert max(start_indices) < min(end_indices), "All STARTs should complete before any END"


class TestAsyncSpanQueueErrorHandling:
Expand Down Expand Up @@ -154,6 +163,61 @@ async def test_enqueue_after_shutdown_is_dropped(self):
proc.on_span_start.assert_not_called()


class TestAsyncSpanQueueBatchConcurrency:
async def test_batch_processes_multiple_items_concurrently(self):
"""Items in the same batch should run concurrently, not serially."""
concurrency = 0
max_concurrency = 0
lock = asyncio.Lock()

async def slow_start(span: Span) -> None:
nonlocal concurrency, max_concurrency
async with lock:
concurrency += 1
max_concurrency = max(max_concurrency, concurrency)
await asyncio.sleep(0.05)
async with lock:
concurrency -= 1

proc = _make_processor(on_span_start=AsyncMock(side_effect=slow_start))
queue = AsyncSpanQueue()

# Enqueue 10 START events before the drain loop runs — they should
# all land in the same batch and be processed concurrently.
for i in range(10):
queue.enqueue(SpanEventType.START, _make_span(f"span-{i}"), [proc])

await queue.shutdown()

assert max_concurrency > 1, (
f"Expected concurrent processing, but max concurrency was {max_concurrency}"
)

async def test_batch_faster_than_serial(self):
"""Batched drain should be significantly faster than serial for slow processors."""
n_items = 10
per_item_delay = 0.05 # 50ms per processor call

async def slow_start(span: Span) -> None:
await asyncio.sleep(per_item_delay)

proc = _make_processor(on_span_start=AsyncMock(side_effect=slow_start))
queue = AsyncSpanQueue()

for i in range(n_items):
queue.enqueue(SpanEventType.START, _make_span(f"span-{i}"), [proc])

start = time.monotonic()
await queue.shutdown()
elapsed = time.monotonic() - start

serial_time = n_items * per_item_delay
assert elapsed < serial_time * 0.5, (
f"Batch drain took {elapsed:.3f}s — serial would be {serial_time:.3f}s. "
f"Expected at least 2x speedup from concurrency."
)


class TestAsyncSpanQueueIntegration:
async def test_integration_with_async_trace(self):
call_log: list[tuple[str, str]] = []
Expand Down Expand Up @@ -196,3 +260,55 @@ async def record_end(span: Span) -> None:
assert call_log[1][0] == "end"
# Same span ID for both events
assert call_log[0][1] == call_log[1][1]

async def test_end_event_preserves_modified_input(self):
"""END event should carry span.input so modifications after start are preserved."""
start_spans: list[Span] = []
end_spans: list[Span] = []

async def capture_start(span: Span) -> None:
start_spans.append(span)

async def capture_end(span: Span) -> None:
end_spans.append(span)

proc = _make_processor(
on_span_start=AsyncMock(side_effect=capture_start),
on_span_end=AsyncMock(side_effect=capture_end),
)
queue = AsyncSpanQueue()

from agentex.lib.core.tracing.trace import AsyncTrace

mock_client = MagicMock()
trace = AsyncTrace(
processors=[proc],
client=mock_client,
trace_id="test-trace",
span_queue=queue,
)

initial_input: dict[str, object] = {"messages": [{"role": "user", "content": "hello"}]}
async with trace.span("llm-call", input=initial_input) as span:
# Simulate modifying input after start (e.g. chatbot appending messages)
messages = cast(list[dict[str, str]], cast(dict[str, object], span.input)["messages"])
messages.append({"role": "assistant", "content": "hi there"})
messages.append({"role": "user", "content": "how are you?"})
span.output = cast(dict[str, object], {"response": "I'm good!"})

await queue.shutdown()

assert len(start_spans) == 1
assert len(end_spans) == 1

# START should carry the original input (serialized at start time)
assert start_spans[0].input is not None
assert len(cast(dict[str, list[object]], start_spans[0].input)["messages"]) == 1 # only the original message

# END should carry the modified input (re-serialized at end time)
assert end_spans[0].input is not None
assert len(cast(dict[str, list[object]], end_spans[0].input)["messages"]) == 3 # all three messages

# END should still carry output and end_time
assert end_spans[0].output is not None
assert end_spans[0].end_time is not None
Loading