diff --git a/.release-please-manifest.json b/.release-please-manifest.json index 30b6d45ad..582db2fa7 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "0.10.1" + ".": "0.10.2" } diff --git a/.stats.yml b/.stats.yml index 7b6d3dc17..89e80cc66 100644 --- a/.stats.yml +++ b/.stats.yml @@ -1,4 +1,4 @@ configured_endpoints: 45 -openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/sgp%2Fagentex-sdk-636aa63c588134e6f47fc45212049593d91f810a9c7bd8d7a57810cf1b5ffc92.yml -openapi_spec_hash: c76a42d4510aeafd896bc0d596f17af4 +openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/sgp%2Fagentex-sdk-eeb5bf63b18d948611eec48d0225e9bba63b170f64eeeb35d91825724b7cf6c3.yml +openapi_spec_hash: 5bbd18a405a11e8497d38a5a88b98018 config_hash: fb079ef7936611b032568661b8165f19 diff --git a/CHANGELOG.md b/CHANGELOG.md index c0b5df1c1..536753e33 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,23 @@ # Changelog +## 0.10.2 (2026-04-21) + +Full Changelog: [v0.10.1...v0.10.2](https://github.com/scaleapi/scale-agentex-python/compare/v0.10.1...v0.10.2) + +### Features + +* **api:** api update ([d5b9945](https://github.com/scaleapi/scale-agentex-python/commit/d5b99455c248a629bb2c56a2b5daf192d9f70db8)) + + +### Bug Fixes + +* **adk:** fix to queue drain ([#327](https://github.com/scaleapi/scale-agentex-python/issues/327)) ([b59d6d8](https://github.com/scaleapi/scale-agentex-python/commit/b59d6d8b59cec9548ec468cae3827d785c9f86f7)) + + +### Performance Improvements + +* **client:** optimize file structure copying in multipart requests ([87fe899](https://github.com/scaleapi/scale-agentex-python/commit/87fe899713a2ec88f1c32b347a7d5c78124aaf56)) + ## 0.10.1 (2026-04-17) Full Changelog: [v0.10.0...v0.10.1](https://github.com/scaleapi/scale-agentex-python/compare/v0.10.0...v0.10.1) diff --git a/pyproject.toml b/pyproject.toml index 90eb34012..a1851c1ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "agentex-sdk" -version = "0.10.1" +version = "0.10.2" description = "The official Python library for the agentex API" dynamic = ["readme"] license = "Apache-2.0" diff --git a/src/agentex/_files.py b/src/agentex/_files.py index cc14c14f6..0fdce17bf 100644 --- a/src/agentex/_files.py +++ b/src/agentex/_files.py @@ -3,8 +3,8 @@ import io import os import pathlib -from typing import overload -from typing_extensions import TypeGuard +from typing import Sequence, cast, overload +from typing_extensions import TypeVar, TypeGuard import anyio @@ -17,7 +17,9 @@ HttpxFileContent, HttpxRequestFiles, ) -from ._utils import is_tuple_t, is_mapping_t, is_sequence_t +from ._utils import is_list, is_mapping, is_tuple_t, is_mapping_t, is_sequence_t + +_T = TypeVar("_T") def is_base64_file_input(obj: object) -> TypeGuard[Base64FileInput]: @@ -121,3 +123,51 @@ async def async_read_file_content(file: FileContent) -> HttpxFileContent: return await anyio.Path(file).read_bytes() return file + + +def deepcopy_with_paths(item: _T, paths: Sequence[Sequence[str]]) -> _T: + """Copy only the containers along the given paths. + + Used to guard against mutation by extract_files without copying the entire structure. + Only dicts and lists that lie on a path are copied; everything else + is returned by reference. + + For example, given paths=[["foo", "files", "file"]] and the structure: + { + "foo": { + "bar": {"baz": {}}, + "files": {"file": } + } + } + The root dict, "foo", and "files" are copied (they lie on the path). + "bar" and "baz" are returned by reference (off the path). + """ + return _deepcopy_with_paths(item, paths, 0) + + +def _deepcopy_with_paths(item: _T, paths: Sequence[Sequence[str]], index: int) -> _T: + if not paths: + return item + if is_mapping(item): + key_to_paths: dict[str, list[Sequence[str]]] = {} + for path in paths: + if index < len(path): + key_to_paths.setdefault(path[index], []).append(path) + + # if no path continues through this mapping, it won't be mutated and copying it is redundant + if not key_to_paths: + return item + + result = dict(item) + for key, subpaths in key_to_paths.items(): + if key in result: + result[key] = _deepcopy_with_paths(result[key], subpaths, index + 1) + return cast(_T, result) + if is_list(item): + array_paths = [path for path in paths if index < len(path) and path[index] == ""] + + # if no path expects a list here, nothing will be mutated inside it - return by reference + if not array_paths: + return cast(_T, item) + return cast(_T, [_deepcopy_with_paths(entry, array_paths, index + 1) for entry in item]) + return item diff --git a/src/agentex/_utils/__init__.py b/src/agentex/_utils/__init__.py index 10cb66d2d..1c090e51f 100644 --- a/src/agentex/_utils/__init__.py +++ b/src/agentex/_utils/__init__.py @@ -24,7 +24,6 @@ coerce_integer as coerce_integer, file_from_path as file_from_path, strip_not_given as strip_not_given, - deepcopy_minimal as deepcopy_minimal, get_async_library as get_async_library, maybe_coerce_float as maybe_coerce_float, get_required_header as get_required_header, diff --git a/src/agentex/_utils/_utils.py b/src/agentex/_utils/_utils.py index 63b8cd602..771859f5e 100644 --- a/src/agentex/_utils/_utils.py +++ b/src/agentex/_utils/_utils.py @@ -177,21 +177,6 @@ def is_iterable(obj: object) -> TypeGuard[Iterable[object]]: return isinstance(obj, Iterable) -def deepcopy_minimal(item: _T) -> _T: - """Minimal reimplementation of copy.deepcopy() that will only copy certain object types: - - - mappings, e.g. `dict` - - list - - This is done for performance reasons. - """ - if is_mapping(item): - return cast(_T, {k: deepcopy_minimal(v) for k, v in item.items()}) - if is_list(item): - return cast(_T, [deepcopy_minimal(entry) for entry in item]) - return item - - # copied from https://github.com/Rapptz/RoboDanny def human_join(seq: Sequence[str], *, delim: str = ", ", final: str = "or") -> str: size = len(seq) diff --git a/src/agentex/_version.py b/src/agentex/_version.py index e5a43716a..5d881e84d 100644 --- a/src/agentex/_version.py +++ b/src/agentex/_version.py @@ -1,4 +1,4 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. __title__ = "agentex" -__version__ = "0.10.1" # x-release-please-version +__version__ = "0.10.2" # x-release-please-version diff --git a/src/agentex/lib/adk/_modules/tracing.py b/src/agentex/lib/adk/_modules/tracing.py index cb3b4b22b..67150f01d 100644 --- a/src/agentex/lib/adk/_modules/tracing.py +++ b/src/agentex/lib/adk/_modules/tracing.py @@ -64,9 +64,7 @@ def _tracing_service(self) -> TracingService: # Re-create the underlying httpx client when the event loop changes # (e.g. between HTTP requests in a sync ASGI server) to avoid # "Event loop is closed" / "bound to a different event loop" errors. - if self._tracing_service_lazy is None or ( - loop_id is not None and loop_id != self._bound_loop_id - ): + if self._tracing_service_lazy is None or (loop_id is not None and loop_id != self._bound_loop_id): import httpx # Disable keepalive so each span HTTP call gets a fresh TCP @@ -93,6 +91,7 @@ async def span( input: list[Any] | dict[str, Any] | BaseModel | None = None, data: list[Any] | dict[str, Any] | BaseModel | None = None, parent_id: str | None = None, + task_id: str | None = None, start_to_close_timeout: timedelta = timedelta(seconds=5), heartbeat_timeout: timedelta = timedelta(seconds=5), retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY, @@ -109,6 +108,7 @@ async def span( input (Union[List, Dict, BaseModel]): The input for the span. parent_id (Optional[str]): The parent span ID for the span. data (Optional[Union[List, Dict, BaseModel]]): The data for the span. + task_id (Optional[str]): The task ID this span belongs to. start_to_close_timeout (timedelta): The start to close timeout for the span. heartbeat_timeout (timedelta): The heartbeat timeout for the span. retry_policy (RetryPolicy): The retry policy for the span. @@ -126,6 +126,7 @@ async def span( input=input, parent_id=parent_id, data=data, + task_id=task_id, start_to_close_timeout=start_to_close_timeout, heartbeat_timeout=heartbeat_timeout, retry_policy=retry_policy, @@ -149,6 +150,7 @@ async def start_span( input: list[Any] | dict[str, Any] | BaseModel | None = None, parent_id: str | None = None, data: list[Any] | dict[str, Any] | BaseModel | None = None, + task_id: str | None = None, start_to_close_timeout: timedelta = timedelta(seconds=5), heartbeat_timeout: timedelta = timedelta(seconds=1), retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY, @@ -162,6 +164,7 @@ async def start_span( input (Union[List, Dict, BaseModel]): The input for the span. parent_id (Optional[str]): The parent span ID for the span. data (Optional[Union[List, Dict, BaseModel]]): The data for the span. + task_id (Optional[str]): The task ID this span belongs to. start_to_close_timeout (timedelta): The start to close timeout for the span. heartbeat_timeout (timedelta): The heartbeat timeout for the span. retry_policy (RetryPolicy): The retry policy for the span. @@ -175,6 +178,7 @@ async def start_span( name=name, input=input, data=data, + task_id=task_id, ) if in_temporal_workflow(): return await ActivityHelpers.execute_activity( @@ -192,6 +196,7 @@ async def start_span( input=input, parent_id=parent_id, data=data, + task_id=task_id, ) async def end_span( diff --git a/src/agentex/lib/core/services/adk/tracing.py b/src/agentex/lib/core/services/adk/tracing.py index 210d2f625..77efffd9e 100644 --- a/src/agentex/lib/core/services/adk/tracing.py +++ b/src/agentex/lib/core/services/adk/tracing.py @@ -22,6 +22,7 @@ async def start_span( parent_id: str | None = None, input: list[Any] | dict[str, Any] | BaseModel | None = None, data: list[Any] | dict[str, Any] | BaseModel | None = None, + task_id: str | None = None, ) -> Span | None: trace = self._tracer.trace(trace_id) span = await trace.start_span( @@ -29,6 +30,7 @@ async def start_span( parent_id=parent_id, input=input or {}, data=data, + task_id=task_id, ) heartbeat_if_in_workflow("start span") return span diff --git a/src/agentex/lib/core/temporal/activities/adk/tracing_activities.py b/src/agentex/lib/core/temporal/activities/adk/tracing_activities.py index 65afcded0..aec541afe 100644 --- a/src/agentex/lib/core/temporal/activities/adk/tracing_activities.py +++ b/src/agentex/lib/core/temporal/activities/adk/tracing_activities.py @@ -24,6 +24,7 @@ class StartSpanParams(BaseModel): name: str input: list[Any] | dict[str, Any] | BaseModel | None = None data: list[Any] | dict[str, Any] | BaseModel | None = None + task_id: str | None = None class EndSpanParams(BaseModel): @@ -47,6 +48,7 @@ async def start_span(self, params: StartSpanParams) -> Span | None: name=params.name, input=params.input, data=params.data, + task_id=params.task_id, ) @activity.defn(name=TracingActivityName.END_SPAN) diff --git a/src/agentex/lib/core/tracing/processors/agentex_tracing_processor.py b/src/agentex/lib/core/tracing/processors/agentex_tracing_processor.py index 54e0d1187..e00f23dac 100644 --- a/src/agentex/lib/core/tracing/processors/agentex_tracing_processor.py +++ b/src/agentex/lib/core/tracing/processors/agentex_tracing_processor.py @@ -79,6 +79,9 @@ def __init__(self, config: AgentexTracingProcessorConfig): # noqa: ARG002 ), ) + # TODO(AGX1-199): Add batch create/update endpoints to Agentex API and use + # them here instead of one HTTP call per span. + # https://linear.app/scale-epd/issue/AGX1-199/add-agentex-batch-endpoint-for-traces @override async def on_span_start(self, span: Span) -> None: await self.client.spans.create( diff --git a/src/agentex/lib/core/tracing/processors/sgp_tracing_processor.py b/src/agentex/lib/core/tracing/processors/sgp_tracing_processor.py index 2f94e7f87..1376df06c 100644 --- a/src/agentex/lib/core/tracing/processors/sgp_tracing_processor.py +++ b/src/agentex/lib/core/tracing/processors/sgp_tracing_processor.py @@ -141,6 +141,9 @@ async def on_span_start(self, span: Span) -> None: if self.disabled: logger.warning("SGP is disabled, skipping span upsert") return + # TODO(AGX1-198): Batch multiple spans into a single upsert_batch call + # instead of one span per HTTP request. + # https://linear.app/scale-epd/issue/AGX1-198/actually-use-sgp-batching-for-spans await self.sgp_async_client.spans.upsert_batch( # type: ignore[union-attr] items=[sgp_span.to_request_params()] ) @@ -155,6 +158,7 @@ async def on_span_end(self, span: Span) -> None: return self._add_source_to_span(span) + sgp_span.input = span.input # type: ignore[assignment] sgp_span.output = span.output # type: ignore[assignment] sgp_span.metadata = span.data # type: ignore[assignment] sgp_span.end_time = span.end_time.isoformat() # type: ignore[union-attr] diff --git a/src/agentex/lib/core/tracing/span_queue.py b/src/agentex/lib/core/tracing/span_queue.py index e881cc1da..d5d09dd0f 100644 --- a/src/agentex/lib/core/tracing/span_queue.py +++ b/src/agentex/lib/core/tracing/span_queue.py @@ -12,6 +12,8 @@ logger = make_logger(__name__) +_DEFAULT_BATCH_SIZE = 50 + class SpanEventType(str, Enum): START = "start" @@ -28,15 +30,18 @@ class _SpanQueueItem: class AsyncSpanQueue: """Background FIFO queue for async span processing. - Span events are enqueued synchronously (non-blocking) and processed - sequentially by a background drain task. This keeps tracing HTTP calls - off the critical request path while preserving start-before-end ordering. + Span events are enqueued synchronously (non-blocking) and drained by a + background task. Items are processed in batches: all START events in a + batch are flushed concurrently, then all END events, so that per-span + start-before-end ordering is preserved while HTTP calls for independent + spans execute in parallel. """ - def __init__(self) -> None: + def __init__(self, batch_size: int = _DEFAULT_BATCH_SIZE) -> None: self._queue: asyncio.Queue[_SpanQueueItem] = asyncio.Queue() self._drain_task: asyncio.Task[None] | None = None self._stopping = False + self._batch_size = batch_size def enqueue( self, @@ -54,9 +59,45 @@ def _ensure_drain_running(self) -> None: if self._drain_task is None or self._drain_task.done(): self._drain_task = asyncio.create_task(self._drain_loop()) + # ------------------------------------------------------------------ + # Drain loop + # ------------------------------------------------------------------ + async def _drain_loop(self) -> None: while True: - item = await self._queue.get() + # Block until at least one item is available. + first = await self._queue.get() + batch: list[_SpanQueueItem] = [first] + + # Opportunistically grab more ready items (non-blocking). + while len(batch) < self._batch_size: + try: + batch.append(self._queue.get_nowait()) + except asyncio.QueueEmpty: + break + + try: + # Separate START and END events. Processing all STARTs before + # ENDs ensures that on_span_start completes before on_span_end + # for any span whose both events land in the same batch. + starts = [i for i in batch if i.event_type == SpanEventType.START] + ends = [i for i in batch if i.event_type == SpanEventType.END] + + if starts: + await self._process_items(starts) + if ends: + await self._process_items(ends) + finally: + for _ in batch: + self._queue.task_done() + # Release span data for GC. + batch.clear() + + @staticmethod + async def _process_items(items: list[_SpanQueueItem]) -> None: + """Process a list of span events concurrently.""" + + async def _handle(item: _SpanQueueItem) -> None: try: if item.event_type == SpanEventType.START: coros = [p.on_span_start(item.span) for p in item.processors] @@ -72,9 +113,15 @@ async def _drain_loop(self) -> None: exc_info=result, ) except Exception: - logger.exception("Unexpected error in span queue drain loop for span %s", item.span.id) - finally: - self._queue.task_done() + logger.exception( + "Unexpected error in span queue for span %s", item.span.id + ) + + await asyncio.gather(*[_handle(item) for item in items]) + + # ------------------------------------------------------------------ + # Shutdown + # ------------------------------------------------------------------ async def shutdown(self, timeout: float = 30.0) -> None: self._stopping = True diff --git a/src/agentex/lib/core/tracing/trace.py b/src/agentex/lib/core/tracing/trace.py index 7925df7fc..70b268b18 100644 --- a/src/agentex/lib/core/tracing/trace.py +++ b/src/agentex/lib/core/tracing/trace.py @@ -54,6 +54,7 @@ def start_span( parent_id: str | None = None, input: dict[str, Any] | list[dict[str, Any]] | BaseModel | None = None, data: dict[str, Any] | list[dict[str, Any]] | BaseModel | None = None, + task_id: str | None = None, ) -> Span: """ Start a new span and register it with the API. @@ -63,6 +64,7 @@ def start_span( parent_id: Optional parent span ID. input: Optional input data for the span. data: Optional additional data for the span. + task_id: Optional ID of the task this span belongs to. Returns: The newly created span. @@ -86,6 +88,7 @@ def start_span( start_time=start_time, input=serialized_input, data=serialized_data, + task_id=task_id, ) for processor in self.processors: @@ -150,6 +153,7 @@ def span( parent_id: str | None = None, input: dict[str, Any] | list[dict[str, Any]] | BaseModel | None = None, data: dict[str, Any] | list[dict[str, Any]] | BaseModel | None = None, + task_id: str | None = None, ): """ Context manager for spans. @@ -158,7 +162,7 @@ def span( if not self.trace_id: yield None return - span = self.start_span(name, parent_id, input, data) + span = self.start_span(name, parent_id, input, data, task_id=task_id) try: yield span finally: @@ -198,6 +202,7 @@ async def start_span( parent_id: str | None = None, input: dict[str, Any] | list[dict[str, Any]] | BaseModel | None = None, data: dict[str, Any] | list[dict[str, Any]] | BaseModel | None = None, + task_id: str | None = None, ) -> Span: """ Start a new span and register it with the API. @@ -207,6 +212,7 @@ async def start_span( parent_id: Optional parent span ID. input: Optional input data for the span. data: Optional additional data for the span. + task_id: Optional ID of the task this span belongs to. Returns: The newly created span. @@ -229,6 +235,7 @@ async def start_span( start_time=start_time, input=serialized_input, data=serialized_data, + task_id=task_id, ) if self.processors: @@ -293,6 +300,7 @@ async def span( parent_id: str | None = None, input: dict[str, Any] | list[dict[str, Any]] | BaseModel | None = None, data: dict[str, Any] | list[dict[str, Any]] | BaseModel | None = None, + task_id: str | None = None, ) -> AsyncGenerator[Span | None, None]: """ Context manager for spans. @@ -302,6 +310,7 @@ async def span( parent_id: Optional parent span ID. input: Optional input data for the span. data: Optional additional data for the span. + task_id: Optional ID of the task this span belongs to. Yields: The span object. @@ -309,7 +318,7 @@ async def span( if not self.trace_id: yield None return - span = await self.start_span(name, parent_id, input, data) + span = await self.start_span(name, parent_id, input, data, task_id=task_id) try: yield span finally: diff --git a/src/agentex/resources/spans.py b/src/agentex/resources/spans.py index f8d97e0de..ecd692644 100644 --- a/src/agentex/resources/spans.py +++ b/src/agentex/resources/spans.py @@ -57,6 +57,7 @@ def create( input: Union[Dict[str, object], Iterable[Dict[str, object]], None] | Omit = omit, output: Union[Dict[str, object], Iterable[Dict[str, object]], None] | Omit = omit, parent_id: Optional[str] | Omit = omit, + task_id: Optional[str] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -86,6 +87,8 @@ def create( parent_id: ID of the parent span if this is a child span in a trace + task_id: ID of the task this span belongs to + extra_headers: Send extra headers extra_query: Add additional query parameters to the request @@ -107,6 +110,7 @@ def create( "input": input, "output": output, "parent_id": parent_id, + "task_id": task_id, }, span_create_params.SpanCreateParams, ), @@ -160,6 +164,7 @@ def update( output: Union[Dict[str, object], Iterable[Dict[str, object]], None] | Omit = omit, parent_id: Optional[str] | Omit = omit, start_time: Union[str, datetime, None] | Omit = omit, + task_id: Optional[str] | Omit = omit, trace_id: Optional[str] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -186,6 +191,8 @@ def update( start_time: The time the span started + task_id: ID of the task this span belongs to + trace_id: Unique identifier for the trace this span belongs to extra_headers: Send extra headers @@ -209,6 +216,7 @@ def update( "output": output, "parent_id": parent_id, "start_time": start_time, + "task_id": task_id, "trace_id": trace_id, }, span_update_params.SpanUpdateParams, @@ -226,6 +234,7 @@ def list( order_by: Optional[str] | Omit = omit, order_direction: str | Omit = omit, page_number: int | Omit = omit, + task_id: Optional[str] | Omit = omit, trace_id: Optional[str] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -235,7 +244,7 @@ def list( timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> SpanListResponse: """ - List all spans for a given trace ID + List spans, optionally filtered by trace_id and/or task_id Args: extra_headers: Send extra headers @@ -259,6 +268,7 @@ def list( "order_by": order_by, "order_direction": order_direction, "page_number": page_number, + "task_id": task_id, "trace_id": trace_id, }, span_list_params.SpanListParams, @@ -300,6 +310,7 @@ async def create( input: Union[Dict[str, object], Iterable[Dict[str, object]], None] | Omit = omit, output: Union[Dict[str, object], Iterable[Dict[str, object]], None] | Omit = omit, parent_id: Optional[str] | Omit = omit, + task_id: Optional[str] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -329,6 +340,8 @@ async def create( parent_id: ID of the parent span if this is a child span in a trace + task_id: ID of the task this span belongs to + extra_headers: Send extra headers extra_query: Add additional query parameters to the request @@ -350,6 +363,7 @@ async def create( "input": input, "output": output, "parent_id": parent_id, + "task_id": task_id, }, span_create_params.SpanCreateParams, ), @@ -403,6 +417,7 @@ async def update( output: Union[Dict[str, object], Iterable[Dict[str, object]], None] | Omit = omit, parent_id: Optional[str] | Omit = omit, start_time: Union[str, datetime, None] | Omit = omit, + task_id: Optional[str] | Omit = omit, trace_id: Optional[str] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -429,6 +444,8 @@ async def update( start_time: The time the span started + task_id: ID of the task this span belongs to + trace_id: Unique identifier for the trace this span belongs to extra_headers: Send extra headers @@ -452,6 +469,7 @@ async def update( "output": output, "parent_id": parent_id, "start_time": start_time, + "task_id": task_id, "trace_id": trace_id, }, span_update_params.SpanUpdateParams, @@ -469,6 +487,7 @@ async def list( order_by: Optional[str] | Omit = omit, order_direction: str | Omit = omit, page_number: int | Omit = omit, + task_id: Optional[str] | Omit = omit, trace_id: Optional[str] | Omit = omit, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -478,7 +497,7 @@ async def list( timeout: float | httpx.Timeout | None | NotGiven = not_given, ) -> SpanListResponse: """ - List all spans for a given trace ID + List spans, optionally filtered by trace_id and/or task_id Args: extra_headers: Send extra headers @@ -502,6 +521,7 @@ async def list( "order_by": order_by, "order_direction": order_direction, "page_number": page_number, + "task_id": task_id, "trace_id": trace_id, }, span_list_params.SpanListParams, diff --git a/src/agentex/types/span.py b/src/agentex/types/span.py index e5719bbe6..98793c03b 100644 --- a/src/agentex/types/span.py +++ b/src/agentex/types/span.py @@ -34,3 +34,6 @@ class Span(BaseModel): parent_id: Optional[str] = None """ID of the parent span if this is a child span in a trace""" + + task_id: Optional[str] = None + """ID of the task this span belongs to""" diff --git a/src/agentex/types/span_create_params.py b/src/agentex/types/span_create_params.py index c7111d6fa..7debfc8d4 100644 --- a/src/agentex/types/span_create_params.py +++ b/src/agentex/types/span_create_params.py @@ -38,3 +38,6 @@ class SpanCreateParams(TypedDict, total=False): parent_id: Optional[str] """ID of the parent span if this is a child span in a trace""" + + task_id: Optional[str] + """ID of the task this span belongs to""" diff --git a/src/agentex/types/span_list_params.py b/src/agentex/types/span_list_params.py index 40d4d651b..286c3d2bf 100644 --- a/src/agentex/types/span_list_params.py +++ b/src/agentex/types/span_list_params.py @@ -17,4 +17,6 @@ class SpanListParams(TypedDict, total=False): page_number: int + task_id: Optional[str] + trace_id: Optional[str] diff --git a/src/agentex/types/span_update_params.py b/src/agentex/types/span_update_params.py index 3ebb54654..fda32dbad 100644 --- a/src/agentex/types/span_update_params.py +++ b/src/agentex/types/span_update_params.py @@ -33,5 +33,8 @@ class SpanUpdateParams(TypedDict, total=False): start_time: Annotated[Union[str, datetime, None], PropertyInfo(format="iso8601")] """The time the span started""" + task_id: Optional[str] + """ID of the task this span belongs to""" + trace_id: Optional[str] """Unique identifier for the trace this span belongs to""" diff --git a/tests/api_resources/test_spans.py b/tests/api_resources/test_spans.py index 948760a67..7cccec2ad 100644 --- a/tests/api_resources/test_spans.py +++ b/tests/api_resources/test_spans.py @@ -42,6 +42,7 @@ def test_method_create_with_all_params(self, client: Agentex) -> None: input={"foo": "bar"}, output={"foo": "bar"}, parent_id="parent_id", + task_id="task_id", ) assert_matches_type(Span, span, path=["response"]) @@ -137,6 +138,7 @@ def test_method_update_with_all_params(self, client: Agentex) -> None: output={"foo": "bar"}, parent_id="parent_id", start_time=parse_datetime("2019-12-27T18:11:19.117Z"), + task_id="task_id", trace_id="trace_id", ) assert_matches_type(Span, span, path=["response"]) @@ -189,6 +191,7 @@ def test_method_list_with_all_params(self, client: Agentex) -> None: order_by="order_by", order_direction="order_direction", page_number=1, + task_id="task_id", trace_id="trace_id", ) assert_matches_type(SpanListResponse, span, path=["response"]) @@ -244,6 +247,7 @@ async def test_method_create_with_all_params(self, async_client: AsyncAgentex) - input={"foo": "bar"}, output={"foo": "bar"}, parent_id="parent_id", + task_id="task_id", ) assert_matches_type(Span, span, path=["response"]) @@ -339,6 +343,7 @@ async def test_method_update_with_all_params(self, async_client: AsyncAgentex) - output={"foo": "bar"}, parent_id="parent_id", start_time=parse_datetime("2019-12-27T18:11:19.117Z"), + task_id="task_id", trace_id="trace_id", ) assert_matches_type(Span, span, path=["response"]) @@ -391,6 +396,7 @@ async def test_method_list_with_all_params(self, async_client: AsyncAgentex) -> order_by="order_by", order_direction="order_direction", page_number=1, + task_id="task_id", trace_id="trace_id", ) assert_matches_type(SpanListResponse, span, path=["response"]) diff --git a/tests/lib/adk/test_tracing_activities.py b/tests/lib/adk/test_tracing_activities.py new file mode 100644 index 000000000..248ba94a7 --- /dev/null +++ b/tests/lib/adk/test_tracing_activities.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from unittest.mock import AsyncMock + +from temporalio.testing import ActivityEnvironment + +from agentex.types.span import Span + + +def _make_span(**overrides) -> Span: + defaults = { + "id": "span-123", + "name": "test-span", + "start_time": datetime(2026, 1, 1, tzinfo=timezone.utc), + "trace_id": "trace-123", + } + defaults.update(overrides) + return Span(**defaults) + + +def _make_tracing_activities(): + from agentex.lib.core.services.adk.tracing import TracingService + from agentex.lib.core.temporal.activities.adk.tracing_activities import TracingActivities + + mock_service = AsyncMock(spec=TracingService) + activities = TracingActivities(tracing_service=mock_service) + env = ActivityEnvironment() + return mock_service, activities, env + + +class TestStartSpanActivity: + async def test_start_span_with_task_id(self): + from agentex.lib.core.temporal.activities.adk.tracing_activities import StartSpanParams + + mock_service, activities, env = _make_tracing_activities() + expected = _make_span(task_id="task-abc") + mock_service.start_span.return_value = expected + + params = StartSpanParams( + trace_id="trace-123", + name="test-span", + task_id="task-abc", + ) + result = await env.run(activities.start_span, params) + + assert result == expected + assert result.task_id == "task-abc" + mock_service.start_span.assert_called_once_with( + trace_id="trace-123", + parent_id=None, + name="test-span", + input=None, + data=None, + task_id="task-abc", + ) + + async def test_start_span_without_task_id(self): + from agentex.lib.core.temporal.activities.adk.tracing_activities import StartSpanParams + + mock_service, activities, env = _make_tracing_activities() + expected = _make_span() + mock_service.start_span.return_value = expected + + params = StartSpanParams(trace_id="trace-123", name="test-span") + result = await env.run(activities.start_span, params) + + assert result == expected + mock_service.start_span.assert_called_once_with( + trace_id="trace-123", + parent_id=None, + name="test-span", + input=None, + data=None, + task_id=None, + ) + + +class TestEndSpanActivity: + async def test_end_span_preserves_task_id(self): + from agentex.lib.core.temporal.activities.adk.tracing_activities import EndSpanParams + + mock_service, activities, env = _make_tracing_activities() + span = _make_span(task_id="task-abc") + expected = _make_span( + task_id="task-abc", + end_time=datetime(2026, 1, 1, tzinfo=timezone.utc), + ) + mock_service.end_span.return_value = expected + + params = EndSpanParams(trace_id="trace-123", span=span) + result = await env.run(activities.end_span, params) + + assert result == expected + assert result.task_id == "task-abc" + mock_service.end_span.assert_called_once_with(trace_id="trace-123", span=span) diff --git a/tests/lib/adk/test_tracing_module.py b/tests/lib/adk/test_tracing_module.py new file mode 100644 index 000000000..52d5d3f82 --- /dev/null +++ b/tests/lib/adk/test_tracing_module.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from unittest.mock import AsyncMock, patch + +import agentex.lib.adk._modules.tracing as _tracing_mod +from agentex.types.span import Span +from agentex.lib.adk._modules.tracing import TracingModule +from agentex.lib.core.services.adk.tracing import TracingService + + +def _make_span(**overrides) -> Span: + defaults = { + "id": "span-123", + "name": "test-span", + "start_time": datetime(2026, 1, 1, tzinfo=timezone.utc), + "trace_id": "trace-123", + } + defaults.update(overrides) + return Span(**defaults) + + +def _make_module() -> tuple[AsyncMock, TracingModule]: + mock_service = AsyncMock(spec=TracingService) + module = TracingModule(tracing_service=mock_service) + return mock_service, module + + +class TestStartSpan: + async def test_start_span_with_task_id(self): + mock_service, module = _make_module() + expected = _make_span(task_id="task-abc") + mock_service.start_span.return_value = expected + + with patch.object(_tracing_mod, "in_temporal_workflow", return_value=False): + result = await module.start_span( + trace_id="trace-123", + name="test-span", + task_id="task-abc", + ) + + assert result == expected + assert result.task_id == "task-abc" + mock_service.start_span.assert_called_once_with( + trace_id="trace-123", + name="test-span", + input=None, + parent_id=None, + data=None, + task_id="task-abc", + ) + + async def test_start_span_without_task_id(self): + mock_service, module = _make_module() + expected = _make_span() + mock_service.start_span.return_value = expected + + with patch.object(_tracing_mod, "in_temporal_workflow", return_value=False): + result = await module.start_span(trace_id="trace-123", name="test-span") + + assert result == expected + mock_service.start_span.assert_called_once_with( + trace_id="trace-123", + name="test-span", + input=None, + parent_id=None, + data=None, + task_id=None, + ) + + +class TestEndSpan: + async def test_end_span_preserves_task_id(self): + mock_service, module = _make_module() + span = _make_span(task_id="task-abc") + expected = _make_span( + task_id="task-abc", + end_time=datetime(2026, 1, 1, tzinfo=timezone.utc), + ) + mock_service.end_span.return_value = expected + + with patch.object(_tracing_mod, "in_temporal_workflow", return_value=False): + result = await module.end_span(trace_id="trace-123", span=span) + + assert result == expected + assert result.task_id == "task-abc" + mock_service.end_span.assert_called_once_with(trace_id="trace-123", span=span) + + +class TestSpanContextManager: + async def test_span_context_manager_forwards_task_id(self): + mock_service, module = _make_module() + started = _make_span(task_id="task-abc") + mock_service.start_span.return_value = started + mock_service.end_span.return_value = started + + with patch.object(_tracing_mod, "in_temporal_workflow", return_value=False): + async with module.span( + trace_id="trace-123", + name="test-span", + task_id="task-abc", + ) as span: + assert span is not None + assert span.task_id == "task-abc" + + assert mock_service.start_span.call_args.kwargs["task_id"] == "task-abc" + mock_service.end_span.assert_called_once() + + async def test_span_context_manager_noop_when_no_trace_id(self): + mock_service, module = _make_module() + + with patch.object(_tracing_mod, "in_temporal_workflow", return_value=False): + async with module.span(trace_id="", name="test-span") as span: + assert span is None + + mock_service.start_span.assert_not_called() + mock_service.end_span.assert_not_called() diff --git a/tests/lib/adk/test_tracing_service.py b/tests/lib/adk/test_tracing_service.py new file mode 100644 index 000000000..dceb000f5 --- /dev/null +++ b/tests/lib/adk/test_tracing_service.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock + +from agentex.types.span import Span +from agentex.lib.core.services.adk.tracing import TracingService + + +def _make_span(**overrides) -> Span: + defaults = { + "id": "span-123", + "name": "test-span", + "start_time": datetime(2026, 1, 1, tzinfo=timezone.utc), + "trace_id": "trace-123", + } + defaults.update(overrides) + return Span(**defaults) + + +def _make_service() -> tuple[MagicMock, MagicMock, TracingService]: + """Build a TracingService backed by an AsyncTracer whose + trace.start_span / trace.end_span are mocked.""" + mock_trace = MagicMock() + mock_trace.start_span = AsyncMock() + mock_trace.end_span = AsyncMock() + + mock_tracer = MagicMock() + mock_tracer.trace.return_value = mock_trace + + service = TracingService(tracer=mock_tracer) + return mock_tracer, mock_trace, service + + +class TestStartSpanService: + async def test_start_span_passes_task_id(self): + mock_tracer, mock_trace, service = _make_service() + expected = _make_span(task_id="task-abc") + mock_trace.start_span.return_value = expected + + result = await service.start_span( + trace_id="trace-123", + name="test-span", + task_id="task-abc", + ) + + assert result == expected + mock_tracer.trace.assert_called_once_with("trace-123") + mock_trace.start_span.assert_awaited_once_with( + name="test-span", + parent_id=None, + input={}, + data=None, + task_id="task-abc", + ) + + async def test_start_span_without_task_id(self): + _mock_tracer, mock_trace, service = _make_service() + expected = _make_span() + mock_trace.start_span.return_value = expected + + result = await service.start_span(trace_id="trace-123", name="test-span") + + assert result == expected + mock_trace.start_span.assert_awaited_once_with( + name="test-span", + parent_id=None, + input={}, + data=None, + task_id=None, + ) + + +class TestEndSpanService: + async def test_end_span_forwards_span(self): + mock_tracer, mock_trace, service = _make_service() + span = _make_span(task_id="task-abc") + mock_trace.end_span.return_value = span + + result = await service.end_span(trace_id="trace-123", span=span) + + assert result is span + mock_tracer.trace.assert_called_once_with("trace-123") + mock_trace.end_span.assert_awaited_once_with(span) diff --git a/tests/lib/core/tracing/processors/test_sgp_tracing_processor.py b/tests/lib/core/tracing/processors/test_sgp_tracing_processor.py index 1acafa527..818fed375 100644 --- a/tests/lib/core/tracing/processors/test_sgp_tracing_processor.py +++ b/tests/lib/core/tracing/processors/test_sgp_tracing_processor.py @@ -162,3 +162,29 @@ async def test_span_end_for_unknown_span_is_noop(self): await processor.on_span_end(span) assert len(processor._spans) == 0 + + async def test_sgp_span_input_updated_on_end(self): + """on_span_end should update sgp_span.input from the incoming span.""" + processor, _ = self._make_processor() + + with patch(f"{MODULE}.create_span", side_effect=lambda **kw: _make_mock_sgp_span()): + span = _make_span() + span.input = {"messages": [{"role": "user", "content": "hello"}]} + await processor.on_span_start(span) + + assert len(processor._spans) == 1 + + # Simulate modified input at end time + updated_input: dict[str, object] = {"messages": [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "hi"}, + ]} + span.input = updated_input + span.output = {"response": "hi"} + span.end_time = datetime.now(UTC) + await processor.on_span_end(span) + + # Span should be removed after end + assert len(processor._spans) == 0 + # The end upsert should have been called + assert processor.sgp_async_client.spans.upsert_batch.call_count == 2 # start + end diff --git a/tests/lib/core/tracing/test_span_queue.py b/tests/lib/core/tracing/test_span_queue.py index 1f39fb25d..4524ba187 100644 --- a/tests/lib/core/tracing/test_span_queue.py +++ b/tests/lib/core/tracing/test_span_queue.py @@ -3,6 +3,7 @@ import time import uuid import asyncio +from typing import cast from datetime import UTC, datetime from unittest.mock import AsyncMock, MagicMock, patch @@ -52,7 +53,8 @@ async def slow_start(span: Span) -> None: class TestAsyncSpanQueueOrdering: - async def test_fifo_ordering_preserved(self): + async def test_per_span_start_before_end(self): + """START always completes before END for the same span, even with batching.""" call_log: list[tuple[str, str]] = [] async def record_start(span: Span) -> None: @@ -77,12 +79,19 @@ async def record_end(span: Span) -> None: await queue.shutdown() - assert call_log == [ - ("start", "span-a"), - ("end", "span-a"), - ("start", "span-b"), - ("end", "span-b"), - ] + # All 4 events should fire + assert len(call_log) == 4 + + # Per-span invariant: START before END + for span_id in ("span-a", "span-b"): + start_idx = next(i for i, (ev, sid) in enumerate(call_log) if ev == "start" and sid == span_id) + end_idx = next(i for i, (ev, sid) in enumerate(call_log) if ev == "end" and sid == span_id) + assert start_idx < end_idx, f"START should come before END for {span_id}" + + # All STARTs before all ENDs within a batch + start_indices = [i for i, (ev, _) in enumerate(call_log) if ev == "start"] + end_indices = [i for i, (ev, _) in enumerate(call_log) if ev == "end"] + assert max(start_indices) < min(end_indices), "All STARTs should complete before any END" class TestAsyncSpanQueueErrorHandling: @@ -154,6 +163,61 @@ async def test_enqueue_after_shutdown_is_dropped(self): proc.on_span_start.assert_not_called() +class TestAsyncSpanQueueBatchConcurrency: + async def test_batch_processes_multiple_items_concurrently(self): + """Items in the same batch should run concurrently, not serially.""" + concurrency = 0 + max_concurrency = 0 + lock = asyncio.Lock() + + async def slow_start(span: Span) -> None: + nonlocal concurrency, max_concurrency + async with lock: + concurrency += 1 + max_concurrency = max(max_concurrency, concurrency) + await asyncio.sleep(0.05) + async with lock: + concurrency -= 1 + + proc = _make_processor(on_span_start=AsyncMock(side_effect=slow_start)) + queue = AsyncSpanQueue() + + # Enqueue 10 START events before the drain loop runs — they should + # all land in the same batch and be processed concurrently. + for i in range(10): + queue.enqueue(SpanEventType.START, _make_span(f"span-{i}"), [proc]) + + await queue.shutdown() + + assert max_concurrency > 1, ( + f"Expected concurrent processing, but max concurrency was {max_concurrency}" + ) + + async def test_batch_faster_than_serial(self): + """Batched drain should be significantly faster than serial for slow processors.""" + n_items = 10 + per_item_delay = 0.05 # 50ms per processor call + + async def slow_start(span: Span) -> None: + await asyncio.sleep(per_item_delay) + + proc = _make_processor(on_span_start=AsyncMock(side_effect=slow_start)) + queue = AsyncSpanQueue() + + for i in range(n_items): + queue.enqueue(SpanEventType.START, _make_span(f"span-{i}"), [proc]) + + start = time.monotonic() + await queue.shutdown() + elapsed = time.monotonic() - start + + serial_time = n_items * per_item_delay + assert elapsed < serial_time * 0.5, ( + f"Batch drain took {elapsed:.3f}s — serial would be {serial_time:.3f}s. " + f"Expected at least 2x speedup from concurrency." + ) + + class TestAsyncSpanQueueIntegration: async def test_integration_with_async_trace(self): call_log: list[tuple[str, str]] = [] @@ -196,3 +260,55 @@ async def record_end(span: Span) -> None: assert call_log[1][0] == "end" # Same span ID for both events assert call_log[0][1] == call_log[1][1] + + async def test_end_event_preserves_modified_input(self): + """END event should carry span.input so modifications after start are preserved.""" + start_spans: list[Span] = [] + end_spans: list[Span] = [] + + async def capture_start(span: Span) -> None: + start_spans.append(span) + + async def capture_end(span: Span) -> None: + end_spans.append(span) + + proc = _make_processor( + on_span_start=AsyncMock(side_effect=capture_start), + on_span_end=AsyncMock(side_effect=capture_end), + ) + queue = AsyncSpanQueue() + + from agentex.lib.core.tracing.trace import AsyncTrace + + mock_client = MagicMock() + trace = AsyncTrace( + processors=[proc], + client=mock_client, + trace_id="test-trace", + span_queue=queue, + ) + + initial_input: dict[str, object] = {"messages": [{"role": "user", "content": "hello"}]} + async with trace.span("llm-call", input=initial_input) as span: + # Simulate modifying input after start (e.g. chatbot appending messages) + messages = cast(list[dict[str, str]], cast(dict[str, object], span.input)["messages"]) + messages.append({"role": "assistant", "content": "hi there"}) + messages.append({"role": "user", "content": "how are you?"}) + span.output = cast(dict[str, object], {"response": "I'm good!"}) + + await queue.shutdown() + + assert len(start_spans) == 1 + assert len(end_spans) == 1 + + # START should carry the original input (serialized at start time) + assert start_spans[0].input is not None + assert len(cast(dict[str, list[object]], start_spans[0].input)["messages"]) == 1 # only the original message + + # END should carry the modified input (re-serialized at end time) + assert end_spans[0].input is not None + assert len(cast(dict[str, list[object]], end_spans[0].input)["messages"]) == 3 # all three messages + + # END should still carry output and end_time + assert end_spans[0].output is not None + assert end_spans[0].end_time is not None diff --git a/tests/lib/core/tracing/test_span_queue_load.py b/tests/lib/core/tracing/test_span_queue_load.py new file mode 100644 index 000000000..652589881 --- /dev/null +++ b/tests/lib/core/tracing/test_span_queue_load.py @@ -0,0 +1,306 @@ +""" +Manual load test for the tracing pipeline. + +Measures peak queue depth, drain time, and memory under sustained load with +large system prompts — the scenario that causes OOM in K8s. + +SKIPPED by default. Run explicitly with: + + RUN_LOAD_TESTS=1 PYTHONPATH=src python -m pytest \ + tests/lib/core/tracing/test_span_queue_load.py \ + -v -o "addopts=--tb=short" -s + +To compare before/after the fix: + + # 1) Baseline (before fix) — checkout the parent commit: + git stash # if you have uncommitted changes + git checkout ced40bb + RUN_LOAD_TESTS=1 PYTHONPATH=src python -m pytest \ + tests/lib/core/tracing/test_span_queue_load.py \ + -v -o "addopts=--tb=short" -s + + # 2) After fix — return to your branch: + git checkout - + git stash pop # if you stashed + RUN_LOAD_TESTS=1 PYTHONPATH=src python -m pytest \ + tests/lib/core/tracing/test_span_queue_load.py \ + -v -o "addopts=--tb=short" -s +""" + +from __future__ import annotations + +import gc +import os +import sys +import time +import uuid +import asyncio +import resource +from datetime import UTC, datetime +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from agentex.types.span import Span +from agentex.lib.core.tracing.trace import AsyncTrace +from agentex.lib.core.tracing.span_queue import AsyncSpanQueue + +# --------------------------------------------------------------------------- +# Configuration — tune to match production load profile +# --------------------------------------------------------------------------- +N_SPANS = 10_000 +PROMPT_SIZE = 50_000 # 50 KB system prompt per span +PROCESSOR_DELAY_S = 0.005 # 5 ms per processor call (simulates API latency) +REQUEST_INTERVAL_S = 0.0002 # 0.2 ms between requests (~5000 req/s burst) +SAMPLE_INTERVAL = 200 # sample queue depth every N spans + + +def _make_span(span_id: str | None = None) -> Span: + return Span( + id=span_id or str(uuid.uuid4()), + name="test-span", + start_time=datetime.now(UTC), + trace_id="trace-1", + ) + + +@pytest.mark.skipif( + not os.environ.get("RUN_LOAD_TESTS"), + reason="Load test — run with RUN_LOAD_TESTS=1", +) +class TestSpanQueueLoad: + async def test_sustained_load(self): + """ + Push 10,000 spans with 50KB system prompts through the tracing pipeline + at a steady rate while the drain loop runs concurrently. + + Prints a full report with peak queue depth, timing, and memory. + Compare the output between old code (ced40bb) and the fix branch. + """ + peak_queue_size = 0 + queue_samples: list[tuple[int, int]] = [] + + async def slow_start(span: Span) -> None: + await asyncio.sleep(PROCESSOR_DELAY_S) + + async def slow_end(span: Span) -> None: + await asyncio.sleep(PROCESSOR_DELAY_S) + + proc = AsyncMock() + proc.on_span_start = AsyncMock(side_effect=slow_start) + proc.on_span_end = AsyncMock(side_effect=slow_end) + + queue = AsyncSpanQueue() + trace = AsyncTrace( + processors=[proc], + client=MagicMock(), + trace_id="load-test", + span_queue=queue, + ) + + gc.collect() + + if sys.platform == "darwin": + rss_to_mb = 1 / 1024 / 1024 # bytes + else: + rss_to_mb = 1 / 1024 # KB + + rss_before = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss * rss_to_mb + t_start = time.monotonic() + + # ---- Enqueue phase (steady stream) ---- + for i in range(N_SPANS): + input_data = { + "system_prompt": f"You are agent #{i}. " + "x" * PROMPT_SIZE, + "messages": [{"role": "user", "content": f"Request {i}"}], + } + span = await trace.start_span(f"llm-call-{i}", input=input_data) + span.output = { + "response": f"Reply {i}", + "usage": {"prompt_tokens": 500, "completion_tokens": 100}, + } + await trace.end_span(span) + + # Yield to event loop so the drain task can run between requests. + await asyncio.sleep(REQUEST_INTERVAL_S) + + qs = queue._queue.qsize() + if qs > peak_queue_size: + peak_queue_size = qs + if i % SAMPLE_INTERVAL == 0: + queue_samples.append((i, qs)) + + t_enqueue = time.monotonic() + + # ---- Drain phase (flush remaining) ---- + await queue.shutdown(timeout=300) + t_end = time.monotonic() + + rss_after = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss * rss_to_mb + + enqueue_s = t_enqueue - t_start + drain_s = t_end - t_enqueue + total_s = t_end - t_start + + # ---- Report ---- + print() + print(f"{'=' * 60}") + print(f" Load Test: {N_SPANS:,} spans x {PROMPT_SIZE // 1000}KB prompt") + print(f" Processor delay: {PROCESSOR_DELAY_S * 1000:.0f}ms" + f" | Request interval: {REQUEST_INTERVAL_S * 1000:.1f}ms") + print(f"{'=' * 60}") + print(f" Peak queue depth: {peak_queue_size:>10,} items") + print(f" Enqueue time: {enqueue_s:>10.2f} s") + print(f" Drain time: {drain_s:>10.2f} s") + print(f" Total time: {total_s:>10.2f} s") + print(f" RSS before: {rss_before:>10.1f} MB") + print(f" RSS after: {rss_after:>10.1f} MB") + print(f" RSS delta: {rss_after - rss_before:>10.1f} MB") + print(f"{'=' * 60}") + print() + print(" Queue depth over time:") + for idx, depth in queue_samples: + bar = "#" * (depth // 200) if depth > 0 else "." + print(f" span {idx:>6,}: {depth:>6,} items {bar}") + print() + + # Soft assertion — the test is informational, but flag extreme backup + assert peak_queue_size < N_SPANS * 2, ( + f"Queue never drained during load — peak was {peak_queue_size} " + f"(total items enqueued: {N_SPANS * 2})" + ) + + async def test_growing_context_chatbot(self): + """ + Simulate concurrent chatbot conversations where each turn adds to the + message history. Each LLM call span carries the FULL conversation + (system prompt + all prior messages), so input size grows linearly + per turn and total memory is O(N^2) across turns. + + This is the worst-case scenario for queue memory: later turns produce + spans with much larger inputs than early turns. + + Config below: 50 concurrent conversations × 40 turns each = 2,000 + total spans. By turn 40, each span carries ~50KB system prompt + + ~80KB of message history. + """ + N_CONVERSATIONS = 50 + TURNS_PER_CONV = 40 + SYS_PROMPT_SIZE = 50_000 # 50 KB system prompt + MSG_SIZE = 2_000 # 2 KB per user/assistant message + DELAY = 0.005 # 5 ms processor latency + INTERVAL = 0.0002 # 0.2 ms between turns + + peak_queue_size = 0 + total_spans = N_CONVERSATIONS * TURNS_PER_CONV + queue_samples: list[tuple[int, int, int]] = [] # (span_idx, queue_depth, input_kb) + span_count = 0 + + async def slow_start(span: Span) -> None: + await asyncio.sleep(DELAY) + + async def slow_end(span: Span) -> None: + await asyncio.sleep(DELAY) + + proc = AsyncMock() + proc.on_span_start = AsyncMock(side_effect=slow_start) + proc.on_span_end = AsyncMock(side_effect=slow_end) + + queue = AsyncSpanQueue() + trace = AsyncTrace( + processors=[proc], + client=MagicMock(), + trace_id="chatbot-load", + span_queue=queue, + ) + + gc.collect() + + if sys.platform == "darwin": + rss_to_mb = 1 / 1024 / 1024 + else: + rss_to_mb = 1 / 1024 + + rss_before = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss * rss_to_mb + t_start = time.monotonic() + + # Build N_CONVERSATIONS, each accumulating message history + conversations: list[list[dict]] = [[] for _ in range(N_CONVERSATIONS)] + system_prompt = "You are a helpful assistant. " + "x" * SYS_PROMPT_SIZE + + for turn in range(TURNS_PER_CONV): + for conv_id in range(N_CONVERSATIONS): + # User sends a message + conversations[conv_id].append({ + "role": "user", + "content": f"[conv={conv_id} turn={turn}] " + "u" * MSG_SIZE, + }) + + # LLM call span — carries full conversation history + input_data = { + "system_prompt": system_prompt, + "messages": list(conversations[conv_id]), # copy of full history + } + input_kb = len(str(input_data)) // 1024 + + span = await trace.start_span( + f"llm-conv{conv_id}-turn{turn}", + input=input_data, + ) + assistant_reply = f"[reply conv={conv_id} turn={turn}] " + "a" * MSG_SIZE + span.output = {"response": assistant_reply} + await trace.end_span(span) + + # Assistant reply added to history + conversations[conv_id].append({ + "role": "assistant", + "content": assistant_reply, + }) + + span_count += 1 + await asyncio.sleep(INTERVAL) + + qs = queue._queue.qsize() + if qs > peak_queue_size: + peak_queue_size = qs + if span_count % 100 == 0: + queue_samples.append((span_count, qs, input_kb)) + + t_enqueue = time.monotonic() + await queue.shutdown(timeout=300) + t_end = time.monotonic() + + rss_after = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss * rss_to_mb + enqueue_s = t_enqueue - t_start + drain_s = t_end - t_enqueue + total_s = t_end - t_start + + # ---- Report ---- + print() + print(f"{'=' * 60}") + print(f" Chatbot Load Test: {N_CONVERSATIONS} convos x" + f" {TURNS_PER_CONV} turns = {total_spans:,} spans") + print(f" System prompt: {SYS_PROMPT_SIZE // 1000}KB" + f" | Message size: {MSG_SIZE // 1000}KB" + f" | Processor delay: {DELAY * 1000:.0f}ms") + print(f"{'=' * 60}") + print(f" Peak queue depth: {peak_queue_size:>10,} items") + print(f" Enqueue time: {enqueue_s:>10.2f} s") + print(f" Drain time: {drain_s:>10.2f} s") + print(f" Total time: {total_s:>10.2f} s") + print(f" RSS before: {rss_before:>10.1f} MB") + print(f" RSS after: {rss_after:>10.1f} MB") + print(f" RSS delta: {rss_after - rss_before:>10.1f} MB") + print(f"{'=' * 60}") + print() + print(" Queue depth & per-span input size over time:") + print(f" {'span':>8} {'queue':>8} {'input':>8}") + for idx, depth, ikb in queue_samples: + q_bar = "#" * (depth // 100) if depth > 0 else "." + print(f" {idx:>7,} {depth:>7,} {ikb:>6}KB {q_bar}") + print() + + assert peak_queue_size < total_spans * 2, ( + f"Queue never drained — peak was {peak_queue_size} " + f"(total items enqueued: {total_spans * 2})" + ) diff --git a/tests/lib/core/tracing/test_trace_task_id.py b/tests/lib/core/tracing/test_trace_task_id.py new file mode 100644 index 000000000..1a616cc94 --- /dev/null +++ b/tests/lib/core/tracing/test_trace_task_id.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from unittest.mock import MagicMock + +from agentex.lib.core.tracing.trace import Trace, AsyncTrace + + +def _make_sync_trace(trace_id: str = "trace-123") -> tuple[MagicMock, Trace]: + client = MagicMock() + trace = Trace(processors=[], client=client, trace_id=trace_id) + return client, trace + + +def _make_async_trace(trace_id: str = "trace-123") -> tuple[MagicMock, AsyncTrace]: + client = MagicMock() + trace = AsyncTrace(processors=[], client=client, trace_id=trace_id) + return client, trace + + +class TestSyncTraceTaskId: + def test_start_span_sets_task_id_on_span(self): + _client, trace = _make_sync_trace() + span = trace.start_span(name="foo", task_id="task-abc") + assert span.task_id == "task-abc" + assert span.trace_id == "trace-123" + + def test_start_span_defaults_task_id_to_none(self): + _client, trace = _make_sync_trace() + span = trace.start_span(name="foo") + assert span.task_id is None + + def test_end_span_preserves_task_id_from_span(self): + _client, trace = _make_sync_trace() + span = trace.start_span(name="foo", task_id="task-abc") + trace.end_span(span) + assert span.task_id == "task-abc" + + +class TestAsyncTraceTaskId: + async def test_start_span_sets_task_id_on_span(self): + _client, trace = _make_async_trace() + span = await trace.start_span(name="foo", task_id="task-abc") + assert span.task_id == "task-abc" + assert span.trace_id == "trace-123" + + async def test_start_span_defaults_task_id_to_none(self): + _client, trace = _make_async_trace() + span = await trace.start_span(name="foo") + assert span.task_id is None + + async def test_end_span_preserves_task_id_from_span(self): + _client, trace = _make_async_trace() + span = await trace.start_span(name="foo", task_id="task-abc") + await trace.end_span(span) + assert span.task_id == "task-abc" diff --git a/tests/test_deepcopy.py b/tests/test_deepcopy.py deleted file mode 100644 index 3fc74ccbc..000000000 --- a/tests/test_deepcopy.py +++ /dev/null @@ -1,58 +0,0 @@ -from agentex._utils import deepcopy_minimal - - -def assert_different_identities(obj1: object, obj2: object) -> None: - assert obj1 == obj2 - assert id(obj1) != id(obj2) - - -def test_simple_dict() -> None: - obj1 = {"foo": "bar"} - obj2 = deepcopy_minimal(obj1) - assert_different_identities(obj1, obj2) - - -def test_nested_dict() -> None: - obj1 = {"foo": {"bar": True}} - obj2 = deepcopy_minimal(obj1) - assert_different_identities(obj1, obj2) - assert_different_identities(obj1["foo"], obj2["foo"]) - - -def test_complex_nested_dict() -> None: - obj1 = {"foo": {"bar": [{"hello": "world"}]}} - obj2 = deepcopy_minimal(obj1) - assert_different_identities(obj1, obj2) - assert_different_identities(obj1["foo"], obj2["foo"]) - assert_different_identities(obj1["foo"]["bar"], obj2["foo"]["bar"]) - assert_different_identities(obj1["foo"]["bar"][0], obj2["foo"]["bar"][0]) - - -def test_simple_list() -> None: - obj1 = ["a", "b", "c"] - obj2 = deepcopy_minimal(obj1) - assert_different_identities(obj1, obj2) - - -def test_nested_list() -> None: - obj1 = ["a", [1, 2, 3]] - obj2 = deepcopy_minimal(obj1) - assert_different_identities(obj1, obj2) - assert_different_identities(obj1[1], obj2[1]) - - -class MyObject: ... - - -def test_ignores_other_types() -> None: - # custom classes - my_obj = MyObject() - obj1 = {"foo": my_obj} - obj2 = deepcopy_minimal(obj1) - assert_different_identities(obj1, obj2) - assert obj1["foo"] is my_obj - - # tuples - obj3 = ("a", "b") - obj4 = deepcopy_minimal(obj3) - assert obj3 is obj4 diff --git a/tests/test_files.py b/tests/test_files.py index e88f71f41..fe5c22590 100644 --- a/tests/test_files.py +++ b/tests/test_files.py @@ -4,7 +4,8 @@ import pytest from dirty_equals import IsDict, IsList, IsBytes, IsTuple -from agentex._files import to_httpx_files, async_to_httpx_files +from agentex._files import to_httpx_files, deepcopy_with_paths, async_to_httpx_files +from agentex._utils import extract_files readme_path = Path(__file__).parent.parent.joinpath("README.md") @@ -49,3 +50,99 @@ def test_string_not_allowed() -> None: "file": "foo", # type: ignore } ) + + +def assert_different_identities(obj1: object, obj2: object) -> None: + assert obj1 == obj2 + assert obj1 is not obj2 + + +class TestDeepcopyWithPaths: + def test_copies_top_level_dict(self) -> None: + original = {"file": b"data", "other": "value"} + result = deepcopy_with_paths(original, [["file"]]) + assert_different_identities(result, original) + + def test_file_value_is_same_reference(self) -> None: + file_bytes = b"contents" + original = {"file": file_bytes} + result = deepcopy_with_paths(original, [["file"]]) + assert_different_identities(result, original) + assert result["file"] is file_bytes + + def test_list_popped_wholesale(self) -> None: + files = [b"f1", b"f2"] + original = {"files": files, "title": "t"} + result = deepcopy_with_paths(original, [["files", ""]]) + assert_different_identities(result, original) + result_files = result["files"] + assert isinstance(result_files, list) + assert_different_identities(result_files, files) + + def test_nested_array_path_copies_list_and_elements(self) -> None: + elem1 = {"file": b"f1", "extra": 1} + elem2 = {"file": b"f2", "extra": 2} + original = {"items": [elem1, elem2]} + result = deepcopy_with_paths(original, [["items", "", "file"]]) + assert_different_identities(result, original) + result_items = result["items"] + assert isinstance(result_items, list) + assert_different_identities(result_items, original["items"]) + assert_different_identities(result_items[0], elem1) + assert_different_identities(result_items[1], elem2) + + def test_empty_paths_returns_same_object(self) -> None: + original = {"foo": "bar"} + result = deepcopy_with_paths(original, []) + assert result is original + + def test_multiple_paths(self) -> None: + f1 = b"file1" + f2 = b"file2" + original = {"a": f1, "b": f2, "c": "unchanged"} + result = deepcopy_with_paths(original, [["a"], ["b"]]) + assert_different_identities(result, original) + assert result["a"] is f1 + assert result["b"] is f2 + assert result["c"] is original["c"] + + def test_extract_files_does_not_mutate_original_top_level(self) -> None: + file_bytes = b"contents" + original = {"file": file_bytes, "other": "value"} + + copied = deepcopy_with_paths(original, [["file"]]) + extracted = extract_files(copied, paths=[["file"]]) + + assert extracted == [("file", file_bytes)] + assert original == {"file": file_bytes, "other": "value"} + assert copied == {"other": "value"} + + def test_extract_files_does_not_mutate_original_nested_array_path(self) -> None: + file1 = b"f1" + file2 = b"f2" + original = { + "items": [ + {"file": file1, "extra": 1}, + {"file": file2, "extra": 2}, + ], + "title": "example", + } + + copied = deepcopy_with_paths(original, [["items", "", "file"]]) + extracted = extract_files(copied, paths=[["items", "", "file"]]) + + assert extracted == [("items[][file]", file1), ("items[][file]", file2)] + assert original == { + "items": [ + {"file": file1, "extra": 1}, + {"file": file2, "extra": 2}, + ], + "title": "example", + } + assert copied == { + "items": [ + {"extra": 1}, + {"extra": 2}, + ], + "title": "example", + }