diff --git a/src/agentex/lib/core/tracing/processors/agentex_tracing_processor.py b/src/agentex/lib/core/tracing/processors/agentex_tracing_processor.py index 54e0d1187..e00f23dac 100644 --- a/src/agentex/lib/core/tracing/processors/agentex_tracing_processor.py +++ b/src/agentex/lib/core/tracing/processors/agentex_tracing_processor.py @@ -79,6 +79,9 @@ def __init__(self, config: AgentexTracingProcessorConfig): # noqa: ARG002 ), ) + # TODO(AGX1-199): Add batch create/update endpoints to Agentex API and use + # them here instead of one HTTP call per span. + # https://linear.app/scale-epd/issue/AGX1-199/add-agentex-batch-endpoint-for-traces @override async def on_span_start(self, span: Span) -> None: await self.client.spans.create( diff --git a/src/agentex/lib/core/tracing/processors/sgp_tracing_processor.py b/src/agentex/lib/core/tracing/processors/sgp_tracing_processor.py index 2f94e7f87..1376df06c 100644 --- a/src/agentex/lib/core/tracing/processors/sgp_tracing_processor.py +++ b/src/agentex/lib/core/tracing/processors/sgp_tracing_processor.py @@ -141,6 +141,9 @@ async def on_span_start(self, span: Span) -> None: if self.disabled: logger.warning("SGP is disabled, skipping span upsert") return + # TODO(AGX1-198): Batch multiple spans into a single upsert_batch call + # instead of one span per HTTP request. + # https://linear.app/scale-epd/issue/AGX1-198/actually-use-sgp-batching-for-spans await self.sgp_async_client.spans.upsert_batch( # type: ignore[union-attr] items=[sgp_span.to_request_params()] ) @@ -155,6 +158,7 @@ async def on_span_end(self, span: Span) -> None: return self._add_source_to_span(span) + sgp_span.input = span.input # type: ignore[assignment] sgp_span.output = span.output # type: ignore[assignment] sgp_span.metadata = span.data # type: ignore[assignment] sgp_span.end_time = span.end_time.isoformat() # type: ignore[union-attr] diff --git a/src/agentex/lib/core/tracing/span_queue.py b/src/agentex/lib/core/tracing/span_queue.py index e881cc1da..d5d09dd0f 100644 --- a/src/agentex/lib/core/tracing/span_queue.py +++ b/src/agentex/lib/core/tracing/span_queue.py @@ -12,6 +12,8 @@ logger = make_logger(__name__) +_DEFAULT_BATCH_SIZE = 50 + class SpanEventType(str, Enum): START = "start" @@ -28,15 +30,18 @@ class _SpanQueueItem: class AsyncSpanQueue: """Background FIFO queue for async span processing. - Span events are enqueued synchronously (non-blocking) and processed - sequentially by a background drain task. This keeps tracing HTTP calls - off the critical request path while preserving start-before-end ordering. + Span events are enqueued synchronously (non-blocking) and drained by a + background task. Items are processed in batches: all START events in a + batch are flushed concurrently, then all END events, so that per-span + start-before-end ordering is preserved while HTTP calls for independent + spans execute in parallel. """ - def __init__(self) -> None: + def __init__(self, batch_size: int = _DEFAULT_BATCH_SIZE) -> None: self._queue: asyncio.Queue[_SpanQueueItem] = asyncio.Queue() self._drain_task: asyncio.Task[None] | None = None self._stopping = False + self._batch_size = batch_size def enqueue( self, @@ -54,9 +59,45 @@ def _ensure_drain_running(self) -> None: if self._drain_task is None or self._drain_task.done(): self._drain_task = asyncio.create_task(self._drain_loop()) + # ------------------------------------------------------------------ + # Drain loop + # ------------------------------------------------------------------ + async def _drain_loop(self) -> None: while True: - item = await self._queue.get() + # Block until at least one item is available. + first = await self._queue.get() + batch: list[_SpanQueueItem] = [first] + + # Opportunistically grab more ready items (non-blocking). + while len(batch) < self._batch_size: + try: + batch.append(self._queue.get_nowait()) + except asyncio.QueueEmpty: + break + + try: + # Separate START and END events. Processing all STARTs before + # ENDs ensures that on_span_start completes before on_span_end + # for any span whose both events land in the same batch. + starts = [i for i in batch if i.event_type == SpanEventType.START] + ends = [i for i in batch if i.event_type == SpanEventType.END] + + if starts: + await self._process_items(starts) + if ends: + await self._process_items(ends) + finally: + for _ in batch: + self._queue.task_done() + # Release span data for GC. + batch.clear() + + @staticmethod + async def _process_items(items: list[_SpanQueueItem]) -> None: + """Process a list of span events concurrently.""" + + async def _handle(item: _SpanQueueItem) -> None: try: if item.event_type == SpanEventType.START: coros = [p.on_span_start(item.span) for p in item.processors] @@ -72,9 +113,15 @@ async def _drain_loop(self) -> None: exc_info=result, ) except Exception: - logger.exception("Unexpected error in span queue drain loop for span %s", item.span.id) - finally: - self._queue.task_done() + logger.exception( + "Unexpected error in span queue for span %s", item.span.id + ) + + await asyncio.gather(*[_handle(item) for item in items]) + + # ------------------------------------------------------------------ + # Shutdown + # ------------------------------------------------------------------ async def shutdown(self, timeout: float = 30.0) -> None: self._stopping = True diff --git a/tests/lib/core/tracing/processors/test_sgp_tracing_processor.py b/tests/lib/core/tracing/processors/test_sgp_tracing_processor.py index 1acafa527..818fed375 100644 --- a/tests/lib/core/tracing/processors/test_sgp_tracing_processor.py +++ b/tests/lib/core/tracing/processors/test_sgp_tracing_processor.py @@ -162,3 +162,29 @@ async def test_span_end_for_unknown_span_is_noop(self): await processor.on_span_end(span) assert len(processor._spans) == 0 + + async def test_sgp_span_input_updated_on_end(self): + """on_span_end should update sgp_span.input from the incoming span.""" + processor, _ = self._make_processor() + + with patch(f"{MODULE}.create_span", side_effect=lambda **kw: _make_mock_sgp_span()): + span = _make_span() + span.input = {"messages": [{"role": "user", "content": "hello"}]} + await processor.on_span_start(span) + + assert len(processor._spans) == 1 + + # Simulate modified input at end time + updated_input: dict[str, object] = {"messages": [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "hi"}, + ]} + span.input = updated_input + span.output = {"response": "hi"} + span.end_time = datetime.now(UTC) + await processor.on_span_end(span) + + # Span should be removed after end + assert len(processor._spans) == 0 + # The end upsert should have been called + assert processor.sgp_async_client.spans.upsert_batch.call_count == 2 # start + end diff --git a/tests/lib/core/tracing/test_span_queue.py b/tests/lib/core/tracing/test_span_queue.py index 1f39fb25d..4524ba187 100644 --- a/tests/lib/core/tracing/test_span_queue.py +++ b/tests/lib/core/tracing/test_span_queue.py @@ -3,6 +3,7 @@ import time import uuid import asyncio +from typing import cast from datetime import UTC, datetime from unittest.mock import AsyncMock, MagicMock, patch @@ -52,7 +53,8 @@ async def slow_start(span: Span) -> None: class TestAsyncSpanQueueOrdering: - async def test_fifo_ordering_preserved(self): + async def test_per_span_start_before_end(self): + """START always completes before END for the same span, even with batching.""" call_log: list[tuple[str, str]] = [] async def record_start(span: Span) -> None: @@ -77,12 +79,19 @@ async def record_end(span: Span) -> None: await queue.shutdown() - assert call_log == [ - ("start", "span-a"), - ("end", "span-a"), - ("start", "span-b"), - ("end", "span-b"), - ] + # All 4 events should fire + assert len(call_log) == 4 + + # Per-span invariant: START before END + for span_id in ("span-a", "span-b"): + start_idx = next(i for i, (ev, sid) in enumerate(call_log) if ev == "start" and sid == span_id) + end_idx = next(i for i, (ev, sid) in enumerate(call_log) if ev == "end" and sid == span_id) + assert start_idx < end_idx, f"START should come before END for {span_id}" + + # All STARTs before all ENDs within a batch + start_indices = [i for i, (ev, _) in enumerate(call_log) if ev == "start"] + end_indices = [i for i, (ev, _) in enumerate(call_log) if ev == "end"] + assert max(start_indices) < min(end_indices), "All STARTs should complete before any END" class TestAsyncSpanQueueErrorHandling: @@ -154,6 +163,61 @@ async def test_enqueue_after_shutdown_is_dropped(self): proc.on_span_start.assert_not_called() +class TestAsyncSpanQueueBatchConcurrency: + async def test_batch_processes_multiple_items_concurrently(self): + """Items in the same batch should run concurrently, not serially.""" + concurrency = 0 + max_concurrency = 0 + lock = asyncio.Lock() + + async def slow_start(span: Span) -> None: + nonlocal concurrency, max_concurrency + async with lock: + concurrency += 1 + max_concurrency = max(max_concurrency, concurrency) + await asyncio.sleep(0.05) + async with lock: + concurrency -= 1 + + proc = _make_processor(on_span_start=AsyncMock(side_effect=slow_start)) + queue = AsyncSpanQueue() + + # Enqueue 10 START events before the drain loop runs — they should + # all land in the same batch and be processed concurrently. + for i in range(10): + queue.enqueue(SpanEventType.START, _make_span(f"span-{i}"), [proc]) + + await queue.shutdown() + + assert max_concurrency > 1, ( + f"Expected concurrent processing, but max concurrency was {max_concurrency}" + ) + + async def test_batch_faster_than_serial(self): + """Batched drain should be significantly faster than serial for slow processors.""" + n_items = 10 + per_item_delay = 0.05 # 50ms per processor call + + async def slow_start(span: Span) -> None: + await asyncio.sleep(per_item_delay) + + proc = _make_processor(on_span_start=AsyncMock(side_effect=slow_start)) + queue = AsyncSpanQueue() + + for i in range(n_items): + queue.enqueue(SpanEventType.START, _make_span(f"span-{i}"), [proc]) + + start = time.monotonic() + await queue.shutdown() + elapsed = time.monotonic() - start + + serial_time = n_items * per_item_delay + assert elapsed < serial_time * 0.5, ( + f"Batch drain took {elapsed:.3f}s — serial would be {serial_time:.3f}s. " + f"Expected at least 2x speedup from concurrency." + ) + + class TestAsyncSpanQueueIntegration: async def test_integration_with_async_trace(self): call_log: list[tuple[str, str]] = [] @@ -196,3 +260,55 @@ async def record_end(span: Span) -> None: assert call_log[1][0] == "end" # Same span ID for both events assert call_log[0][1] == call_log[1][1] + + async def test_end_event_preserves_modified_input(self): + """END event should carry span.input so modifications after start are preserved.""" + start_spans: list[Span] = [] + end_spans: list[Span] = [] + + async def capture_start(span: Span) -> None: + start_spans.append(span) + + async def capture_end(span: Span) -> None: + end_spans.append(span) + + proc = _make_processor( + on_span_start=AsyncMock(side_effect=capture_start), + on_span_end=AsyncMock(side_effect=capture_end), + ) + queue = AsyncSpanQueue() + + from agentex.lib.core.tracing.trace import AsyncTrace + + mock_client = MagicMock() + trace = AsyncTrace( + processors=[proc], + client=mock_client, + trace_id="test-trace", + span_queue=queue, + ) + + initial_input: dict[str, object] = {"messages": [{"role": "user", "content": "hello"}]} + async with trace.span("llm-call", input=initial_input) as span: + # Simulate modifying input after start (e.g. chatbot appending messages) + messages = cast(list[dict[str, str]], cast(dict[str, object], span.input)["messages"]) + messages.append({"role": "assistant", "content": "hi there"}) + messages.append({"role": "user", "content": "how are you?"}) + span.output = cast(dict[str, object], {"response": "I'm good!"}) + + await queue.shutdown() + + assert len(start_spans) == 1 + assert len(end_spans) == 1 + + # START should carry the original input (serialized at start time) + assert start_spans[0].input is not None + assert len(cast(dict[str, list[object]], start_spans[0].input)["messages"]) == 1 # only the original message + + # END should carry the modified input (re-serialized at end time) + assert end_spans[0].input is not None + assert len(cast(dict[str, list[object]], end_spans[0].input)["messages"]) == 3 # all three messages + + # END should still carry output and end_time + assert end_spans[0].output is not None + assert end_spans[0].end_time is not None diff --git a/tests/lib/core/tracing/test_span_queue_load.py b/tests/lib/core/tracing/test_span_queue_load.py new file mode 100644 index 000000000..652589881 --- /dev/null +++ b/tests/lib/core/tracing/test_span_queue_load.py @@ -0,0 +1,306 @@ +""" +Manual load test for the tracing pipeline. + +Measures peak queue depth, drain time, and memory under sustained load with +large system prompts — the scenario that causes OOM in K8s. + +SKIPPED by default. Run explicitly with: + + RUN_LOAD_TESTS=1 PYTHONPATH=src python -m pytest \ + tests/lib/core/tracing/test_span_queue_load.py \ + -v -o "addopts=--tb=short" -s + +To compare before/after the fix: + + # 1) Baseline (before fix) — checkout the parent commit: + git stash # if you have uncommitted changes + git checkout ced40bb + RUN_LOAD_TESTS=1 PYTHONPATH=src python -m pytest \ + tests/lib/core/tracing/test_span_queue_load.py \ + -v -o "addopts=--tb=short" -s + + # 2) After fix — return to your branch: + git checkout - + git stash pop # if you stashed + RUN_LOAD_TESTS=1 PYTHONPATH=src python -m pytest \ + tests/lib/core/tracing/test_span_queue_load.py \ + -v -o "addopts=--tb=short" -s +""" + +from __future__ import annotations + +import gc +import os +import sys +import time +import uuid +import asyncio +import resource +from datetime import UTC, datetime +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from agentex.types.span import Span +from agentex.lib.core.tracing.trace import AsyncTrace +from agentex.lib.core.tracing.span_queue import AsyncSpanQueue + +# --------------------------------------------------------------------------- +# Configuration — tune to match production load profile +# --------------------------------------------------------------------------- +N_SPANS = 10_000 +PROMPT_SIZE = 50_000 # 50 KB system prompt per span +PROCESSOR_DELAY_S = 0.005 # 5 ms per processor call (simulates API latency) +REQUEST_INTERVAL_S = 0.0002 # 0.2 ms between requests (~5000 req/s burst) +SAMPLE_INTERVAL = 200 # sample queue depth every N spans + + +def _make_span(span_id: str | None = None) -> Span: + return Span( + id=span_id or str(uuid.uuid4()), + name="test-span", + start_time=datetime.now(UTC), + trace_id="trace-1", + ) + + +@pytest.mark.skipif( + not os.environ.get("RUN_LOAD_TESTS"), + reason="Load test — run with RUN_LOAD_TESTS=1", +) +class TestSpanQueueLoad: + async def test_sustained_load(self): + """ + Push 10,000 spans with 50KB system prompts through the tracing pipeline + at a steady rate while the drain loop runs concurrently. + + Prints a full report with peak queue depth, timing, and memory. + Compare the output between old code (ced40bb) and the fix branch. + """ + peak_queue_size = 0 + queue_samples: list[tuple[int, int]] = [] + + async def slow_start(span: Span) -> None: + await asyncio.sleep(PROCESSOR_DELAY_S) + + async def slow_end(span: Span) -> None: + await asyncio.sleep(PROCESSOR_DELAY_S) + + proc = AsyncMock() + proc.on_span_start = AsyncMock(side_effect=slow_start) + proc.on_span_end = AsyncMock(side_effect=slow_end) + + queue = AsyncSpanQueue() + trace = AsyncTrace( + processors=[proc], + client=MagicMock(), + trace_id="load-test", + span_queue=queue, + ) + + gc.collect() + + if sys.platform == "darwin": + rss_to_mb = 1 / 1024 / 1024 # bytes + else: + rss_to_mb = 1 / 1024 # KB + + rss_before = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss * rss_to_mb + t_start = time.monotonic() + + # ---- Enqueue phase (steady stream) ---- + for i in range(N_SPANS): + input_data = { + "system_prompt": f"You are agent #{i}. " + "x" * PROMPT_SIZE, + "messages": [{"role": "user", "content": f"Request {i}"}], + } + span = await trace.start_span(f"llm-call-{i}", input=input_data) + span.output = { + "response": f"Reply {i}", + "usage": {"prompt_tokens": 500, "completion_tokens": 100}, + } + await trace.end_span(span) + + # Yield to event loop so the drain task can run between requests. + await asyncio.sleep(REQUEST_INTERVAL_S) + + qs = queue._queue.qsize() + if qs > peak_queue_size: + peak_queue_size = qs + if i % SAMPLE_INTERVAL == 0: + queue_samples.append((i, qs)) + + t_enqueue = time.monotonic() + + # ---- Drain phase (flush remaining) ---- + await queue.shutdown(timeout=300) + t_end = time.monotonic() + + rss_after = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss * rss_to_mb + + enqueue_s = t_enqueue - t_start + drain_s = t_end - t_enqueue + total_s = t_end - t_start + + # ---- Report ---- + print() + print(f"{'=' * 60}") + print(f" Load Test: {N_SPANS:,} spans x {PROMPT_SIZE // 1000}KB prompt") + print(f" Processor delay: {PROCESSOR_DELAY_S * 1000:.0f}ms" + f" | Request interval: {REQUEST_INTERVAL_S * 1000:.1f}ms") + print(f"{'=' * 60}") + print(f" Peak queue depth: {peak_queue_size:>10,} items") + print(f" Enqueue time: {enqueue_s:>10.2f} s") + print(f" Drain time: {drain_s:>10.2f} s") + print(f" Total time: {total_s:>10.2f} s") + print(f" RSS before: {rss_before:>10.1f} MB") + print(f" RSS after: {rss_after:>10.1f} MB") + print(f" RSS delta: {rss_after - rss_before:>10.1f} MB") + print(f"{'=' * 60}") + print() + print(" Queue depth over time:") + for idx, depth in queue_samples: + bar = "#" * (depth // 200) if depth > 0 else "." + print(f" span {idx:>6,}: {depth:>6,} items {bar}") + print() + + # Soft assertion — the test is informational, but flag extreme backup + assert peak_queue_size < N_SPANS * 2, ( + f"Queue never drained during load — peak was {peak_queue_size} " + f"(total items enqueued: {N_SPANS * 2})" + ) + + async def test_growing_context_chatbot(self): + """ + Simulate concurrent chatbot conversations where each turn adds to the + message history. Each LLM call span carries the FULL conversation + (system prompt + all prior messages), so input size grows linearly + per turn and total memory is O(N^2) across turns. + + This is the worst-case scenario for queue memory: later turns produce + spans with much larger inputs than early turns. + + Config below: 50 concurrent conversations × 40 turns each = 2,000 + total spans. By turn 40, each span carries ~50KB system prompt + + ~80KB of message history. + """ + N_CONVERSATIONS = 50 + TURNS_PER_CONV = 40 + SYS_PROMPT_SIZE = 50_000 # 50 KB system prompt + MSG_SIZE = 2_000 # 2 KB per user/assistant message + DELAY = 0.005 # 5 ms processor latency + INTERVAL = 0.0002 # 0.2 ms between turns + + peak_queue_size = 0 + total_spans = N_CONVERSATIONS * TURNS_PER_CONV + queue_samples: list[tuple[int, int, int]] = [] # (span_idx, queue_depth, input_kb) + span_count = 0 + + async def slow_start(span: Span) -> None: + await asyncio.sleep(DELAY) + + async def slow_end(span: Span) -> None: + await asyncio.sleep(DELAY) + + proc = AsyncMock() + proc.on_span_start = AsyncMock(side_effect=slow_start) + proc.on_span_end = AsyncMock(side_effect=slow_end) + + queue = AsyncSpanQueue() + trace = AsyncTrace( + processors=[proc], + client=MagicMock(), + trace_id="chatbot-load", + span_queue=queue, + ) + + gc.collect() + + if sys.platform == "darwin": + rss_to_mb = 1 / 1024 / 1024 + else: + rss_to_mb = 1 / 1024 + + rss_before = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss * rss_to_mb + t_start = time.monotonic() + + # Build N_CONVERSATIONS, each accumulating message history + conversations: list[list[dict]] = [[] for _ in range(N_CONVERSATIONS)] + system_prompt = "You are a helpful assistant. " + "x" * SYS_PROMPT_SIZE + + for turn in range(TURNS_PER_CONV): + for conv_id in range(N_CONVERSATIONS): + # User sends a message + conversations[conv_id].append({ + "role": "user", + "content": f"[conv={conv_id} turn={turn}] " + "u" * MSG_SIZE, + }) + + # LLM call span — carries full conversation history + input_data = { + "system_prompt": system_prompt, + "messages": list(conversations[conv_id]), # copy of full history + } + input_kb = len(str(input_data)) // 1024 + + span = await trace.start_span( + f"llm-conv{conv_id}-turn{turn}", + input=input_data, + ) + assistant_reply = f"[reply conv={conv_id} turn={turn}] " + "a" * MSG_SIZE + span.output = {"response": assistant_reply} + await trace.end_span(span) + + # Assistant reply added to history + conversations[conv_id].append({ + "role": "assistant", + "content": assistant_reply, + }) + + span_count += 1 + await asyncio.sleep(INTERVAL) + + qs = queue._queue.qsize() + if qs > peak_queue_size: + peak_queue_size = qs + if span_count % 100 == 0: + queue_samples.append((span_count, qs, input_kb)) + + t_enqueue = time.monotonic() + await queue.shutdown(timeout=300) + t_end = time.monotonic() + + rss_after = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss * rss_to_mb + enqueue_s = t_enqueue - t_start + drain_s = t_end - t_enqueue + total_s = t_end - t_start + + # ---- Report ---- + print() + print(f"{'=' * 60}") + print(f" Chatbot Load Test: {N_CONVERSATIONS} convos x" + f" {TURNS_PER_CONV} turns = {total_spans:,} spans") + print(f" System prompt: {SYS_PROMPT_SIZE // 1000}KB" + f" | Message size: {MSG_SIZE // 1000}KB" + f" | Processor delay: {DELAY * 1000:.0f}ms") + print(f"{'=' * 60}") + print(f" Peak queue depth: {peak_queue_size:>10,} items") + print(f" Enqueue time: {enqueue_s:>10.2f} s") + print(f" Drain time: {drain_s:>10.2f} s") + print(f" Total time: {total_s:>10.2f} s") + print(f" RSS before: {rss_before:>10.1f} MB") + print(f" RSS after: {rss_after:>10.1f} MB") + print(f" RSS delta: {rss_after - rss_before:>10.1f} MB") + print(f"{'=' * 60}") + print() + print(" Queue depth & per-span input size over time:") + print(f" {'span':>8} {'queue':>8} {'input':>8}") + for idx, depth, ikb in queue_samples: + q_bar = "#" * (depth // 100) if depth > 0 else "." + print(f" {idx:>7,} {depth:>7,} {ikb:>6}KB {q_bar}") + print() + + assert peak_queue_size < total_spans * 2, ( + f"Queue never drained — peak was {peak_queue_size} " + f"(total items enqueued: {total_spans * 2})" + )