diff --git a/src/runloop_api_client/_base_client.py b/src/runloop_api_client/_base_client.py index 410e78aab..88e0bbb3b 100644 --- a/src/runloop_api_client/_base_client.py +++ b/src/runloop_api_client/_base_client.py @@ -8,8 +8,10 @@ import asyncio import inspect import logging +import weakref import platform import warnings +import threading import email.utils from types import TracebackType from random import random @@ -90,6 +92,88 @@ log: logging.Logger = logging.getLogger(__name__) +# Shared HTTP transport state. We share transports (connection pools) rather +# than full httpx clients so each SDK instance keeps its own cookie jar and +# mutable client state. Refcounted wrappers close the real transport only +# when the last user releases it. +# The async transport is keyed by event loop because connections bind to the +# loop that created them and cannot be reused across asyncio.run() calls. +_pool_lock = threading.Lock() + + +class _SharedTransport(httpx.BaseTransport): + """Refcounted wrapper: delegates to a real transport, closes it when refcount hits 0.""" + + def __init__(self, transport: httpx.BaseTransport) -> None: + self._transport = transport + self._refcount = 1 + self._lock = threading.Lock() + + @property + def refcount(self) -> int: + return self._refcount + + def acquire(self) -> bool: + with self._lock: + if self._refcount <= 0: + return False + self._refcount += 1 + return True + + @override + def handle_request(self, request: httpx.Request) -> httpx.Response: + return self._transport.handle_request(request) + + @override + def close(self) -> None: + should_close = False + with self._lock: + self._refcount -= 1 + if self._refcount <= 0: + should_close = True + if should_close: + self._transport.close() + + +class _SharedAsyncTransport(httpx.AsyncBaseTransport): + """Async refcounted wrapper: delegates to a real async transport.""" + + def __init__(self, transport: httpx.AsyncBaseTransport) -> None: + self._transport = transport + self._refcount = 1 + self._lock = threading.Lock() + + @property + def refcount(self) -> int: + return self._refcount + + def acquire(self) -> bool: + with self._lock: + if self._refcount <= 0: + return False + self._refcount += 1 + return True + + @override + async def handle_async_request(self, request: httpx.Request) -> httpx.Response: + return await self._transport.handle_async_request(request) + + @override + async def aclose(self) -> None: + should_close = False + with self._lock: + self._refcount -= 1 + if self._refcount <= 0: + should_close = True + if should_close: + await self._transport.aclose() + + +_shared_sync_transport: _SharedTransport | None = None +_shared_async_transports: weakref.WeakKeyDictionary[asyncio.AbstractEventLoop, _SharedAsyncTransport] = ( + weakref.WeakKeyDictionary() +) + # TODO: make base page type vars covariant SyncPageT = TypeVar("SyncPageT", bound="BaseSyncPage[Any]") AsyncPageT = TypeVar("AsyncPageT", bound="BaseAsyncPage[Any]") @@ -816,6 +900,7 @@ def __init__(self, **kwargs: Any) -> None: kwargs.setdefault("timeout", DEFAULT_TIMEOUT) kwargs.setdefault("limits", DEFAULT_CONNECTION_LIMITS) kwargs.setdefault("follow_redirects", True) + kwargs.setdefault("http2", True) super().__init__(**kwargs) @@ -845,6 +930,8 @@ def __del__(self) -> None: class SyncAPIClient(BaseClient[httpx.Client, Stream[Any]]): _client: httpx.Client _default_stream_cls: type[Stream[Any]] | None = None + _uses_shared_pool: bool + _closed: bool def __init__( self, @@ -857,6 +944,7 @@ def __init__( custom_headers: Mapping[str, str] | None = None, custom_query: Mapping[str, object] | None = None, _strict_response_validation: bool, + shared_http_pool: bool = True, ) -> None: if not is_given(timeout): # if the user passed in a custom http client with a non-default @@ -886,24 +974,46 @@ def __init__( custom_headers=custom_headers, _strict_response_validation=_strict_response_validation, ) - self._client = http_client or SyncHttpxClientWrapper( - base_url=base_url, - # cast to a valid type because mypy doesn't understand our type narrowing - timeout=cast(Timeout, timeout), - ) + + self._closed = False + + if http_client is not None: + self._client = http_client + self._uses_shared_pool = False + elif shared_http_pool: + global _shared_sync_transport + with _pool_lock: + if _shared_sync_transport is None or not _shared_sync_transport.acquire(): + _shared_sync_transport = _SharedTransport( + httpx.HTTPTransport(limits=DEFAULT_CONNECTION_LIMITS, http2=True), + ) + self._client = SyncHttpxClientWrapper( + base_url=base_url, + timeout=cast(Timeout, timeout), + transport=_shared_sync_transport, + ) + self._uses_shared_pool = True + else: + self._client = SyncHttpxClientWrapper( + base_url=base_url, + timeout=cast(Timeout, timeout), + ) + self._uses_shared_pool = False def is_closed(self) -> bool: - return self._client.is_closed + return self._closed or self._client.is_closed def close(self) -> None: """Close the underlying HTTPX client. The client will *not* be usable after this. """ - # If an error is thrown while constructing a client, self._client - # may not be present - if hasattr(self, "_client"): - self._client.close() + if not hasattr(self, "_client"): + return + if self._closed: + return + self._closed = True + self._client.close() def __enter__(self: _T) -> _T: return self @@ -1018,6 +1128,7 @@ def request( max_retries=max_retries, options=input_options, response=None, + error=err, ) continue @@ -1032,6 +1143,7 @@ def request( max_retries=max_retries, options=input_options, response=None, + error=err, ) continue @@ -1083,7 +1195,13 @@ def request( ) def _sleep_for_retry( - self, *, retries_taken: int, max_retries: int, options: FinalRequestOptions, response: httpx.Response | None + self, + *, + retries_taken: int, + max_retries: int, + options: FinalRequestOptions, + response: httpx.Response | None, + error: BaseException | None = None, ) -> None: remaining_retries = max_retries - retries_taken if remaining_retries == 1: @@ -1092,7 +1210,23 @@ def _sleep_for_retry( log.debug("%i retries left", remaining_retries) timeout = self._calculate_retry_timeout(remaining_retries, options, response.headers if response else None) - log.info("Retrying request to %s in %f seconds", options.url, timeout) + if response is not None: + log.info( + "Retrying request to %s in %f seconds (status %d)", + options.url, + timeout, + response.status_code, + ) + elif error is not None: + log.info( + "Retrying request to %s in %f seconds (%s: %s)", + options.url, + timeout, + type(error).__name__, + error, + ) + else: + log.info("Retrying request to %s in %f seconds", options.url, timeout) time.sleep(timeout) @@ -1428,6 +1562,8 @@ def __del__(self) -> None: class AsyncAPIClient(BaseClient[httpx.AsyncClient, AsyncStream[Any]]): _client: httpx.AsyncClient _default_stream_cls: type[AsyncStream[Any]] | None = None + _uses_shared_pool: bool + _closed: bool def __init__( self, @@ -1440,6 +1576,7 @@ def __init__( http_client: httpx.AsyncClient | None = None, custom_headers: Mapping[str, str] | None = None, custom_query: Mapping[str, object] | None = None, + shared_http_pool: bool = True, ) -> None: if not is_given(timeout): # if the user passed in a custom http client with a non-default @@ -1469,20 +1606,59 @@ def __init__( custom_headers=custom_headers, _strict_response_validation=_strict_response_validation, ) - self._client = http_client or AsyncHttpxClientWrapper( - base_url=base_url, - # cast to a valid type because mypy doesn't understand our type narrowing - timeout=cast(Timeout, timeout), - ) + + self._closed = False + + if http_client is not None: + self._client = http_client + self._uses_shared_pool = False + elif shared_http_pool: + try: + loop: asyncio.AbstractEventLoop | None = asyncio.get_running_loop() + except RuntimeError: + loop = None + if loop is not None: + with _pool_lock: + existing = _shared_async_transports.get(loop) + if existing is not None and existing.acquire(): + transport: _SharedAsyncTransport = existing + else: + transport = _SharedAsyncTransport( + httpx.AsyncHTTPTransport(limits=DEFAULT_CONNECTION_LIMITS, http2=True), + ) + _shared_async_transports[loop] = transport + self._client = AsyncHttpxClientWrapper( + base_url=base_url, + timeout=cast(Timeout, timeout), + transport=transport, + ) + self._uses_shared_pool = True + else: + self._client = AsyncHttpxClientWrapper( + base_url=base_url, + timeout=cast(Timeout, timeout), + ) + self._uses_shared_pool = False + else: + self._client = AsyncHttpxClientWrapper( + base_url=base_url, + timeout=cast(Timeout, timeout), + ) + self._uses_shared_pool = False def is_closed(self) -> bool: - return self._client.is_closed + return self._closed or self._client.is_closed async def close(self) -> None: """Close the underlying HTTPX client. The client will *not* be usable after this. """ + if not hasattr(self, "_client"): + return + if self._closed: + return + self._closed = True await self._client.aclose() async def __aenter__(self: _T) -> _T: @@ -1603,6 +1779,7 @@ async def request( max_retries=max_retries, options=input_options, response=None, + error=err, ) continue @@ -1617,6 +1794,7 @@ async def request( max_retries=max_retries, options=input_options, response=None, + error=err, ) continue @@ -1668,7 +1846,13 @@ async def request( ) async def _sleep_for_retry( - self, *, retries_taken: int, max_retries: int, options: FinalRequestOptions, response: httpx.Response | None + self, + *, + retries_taken: int, + max_retries: int, + options: FinalRequestOptions, + response: httpx.Response | None, + error: BaseException | None = None, ) -> None: remaining_retries = max_retries - retries_taken if remaining_retries == 1: @@ -1677,7 +1861,23 @@ async def _sleep_for_retry( log.debug("%i retries left", remaining_retries) timeout = self._calculate_retry_timeout(remaining_retries, options, response.headers if response else None) - log.info("Retrying request to %s in %f seconds", options.url, timeout) + if response is not None: + log.info( + "Retrying request to %s in %f seconds (status %d)", + options.url, + timeout, + response.status_code, + ) + elif error is not None: + log.info( + "Retrying request to %s in %f seconds (%s: %s)", + options.url, + timeout, + type(error).__name__, + error, + ) + else: + log.info("Retrying request to %s in %f seconds", options.url, timeout) await anyio.sleep(timeout) diff --git a/src/runloop_api_client/_client.py b/src/runloop_api_client/_client.py index 61db3a474..1867f9997 100644 --- a/src/runloop_api_client/_client.py +++ b/src/runloop_api_client/_client.py @@ -84,6 +84,10 @@ def __init__( # We provide a `DefaultHttpxClient` class that you can pass to retain the default values we use for `limits`, `timeout` & `follow_redirects`. # See the [httpx documentation](https://www.python-httpx.org/api/#client) for more details. http_client: httpx.Client | None = None, + # Share a single httpx connection pool across all Runloop client instances. + # Enables HTTP/2 multiplexing and avoids ConnectTimeout storms under high concurrency. + # Set to False to create a private connection pool (old behavior). + shared_http_pool: bool = True, # Enable or disable schema validation for data returned by the API. # When enabled an error APIResponseValidationError is raised # if the API responds with invalid data for the expected schema. @@ -120,6 +124,7 @@ def __init__( custom_headers=default_headers, custom_query=default_query, _strict_response_validation=_strict_response_validation, + shared_http_pool=shared_http_pool, ) self._idempotency_header = "x-request-id" @@ -249,6 +254,7 @@ def copy( base_url: str | httpx.URL | None = None, timeout: float | Timeout | None | NotGiven = not_given, http_client: httpx.Client | None = None, + shared_http_pool: bool | None = None, max_retries: int | NotGiven = not_given, default_headers: Mapping[str, str] | None = None, set_default_headers: Mapping[str, str] | None = None, @@ -277,12 +283,19 @@ def copy( elif set_default_query is not None: params = set_default_query - http_client = http_client or self._client + if http_client is not None: + resolved_shared = False + elif shared_http_pool is not None: + resolved_shared = shared_http_pool + else: + resolved_shared = self._uses_shared_pool + return self.__class__( bearer_token=bearer_token or self.bearer_token, base_url=base_url or self.base_url, timeout=self.timeout if isinstance(timeout, NotGiven) else timeout, http_client=http_client, + shared_http_pool=resolved_shared, max_retries=max_retries if is_given(max_retries) else self.max_retries, default_headers=headers, default_query=params, @@ -344,6 +357,10 @@ def __init__( # We provide a `DefaultAsyncHttpxClient` class that you can pass to retain the default values we use for `limits`, `timeout` & `follow_redirects`. # See the [httpx documentation](https://www.python-httpx.org/api/#asyncclient) for more details. http_client: httpx.AsyncClient | None = None, + # Share a single httpx connection pool across all AsyncRunloop client instances. + # Enables HTTP/2 multiplexing and avoids ConnectTimeout storms under high concurrency. + # Set to False to create a private connection pool (old behavior). + shared_http_pool: bool = True, # Enable or disable schema validation for data returned by the API. # When enabled an error APIResponseValidationError is raised # if the API responds with invalid data for the expected schema. @@ -380,6 +397,7 @@ def __init__( custom_headers=default_headers, custom_query=default_query, _strict_response_validation=_strict_response_validation, + shared_http_pool=shared_http_pool, ) self._idempotency_header = "x-request-id" @@ -509,6 +527,7 @@ def copy( base_url: str | httpx.URL | None = None, timeout: float | Timeout | None | NotGiven = not_given, http_client: httpx.AsyncClient | None = None, + shared_http_pool: bool | None = None, max_retries: int | NotGiven = not_given, default_headers: Mapping[str, str] | None = None, set_default_headers: Mapping[str, str] | None = None, @@ -537,12 +556,19 @@ def copy( elif set_default_query is not None: params = set_default_query - http_client = http_client or self._client + if http_client is not None: + resolved_shared = False + elif shared_http_pool is not None: + resolved_shared = shared_http_pool + else: + resolved_shared = self._uses_shared_pool + return self.__class__( bearer_token=bearer_token or self.bearer_token, base_url=base_url or self.base_url, timeout=self.timeout if isinstance(timeout, NotGiven) else timeout, http_client=http_client, + shared_http_pool=resolved_shared, max_retries=max_retries if is_given(max_retries) else self.max_retries, default_headers=headers, default_query=params, diff --git a/src/runloop_api_client/_constants.py b/src/runloop_api_client/_constants.py index d6361c8ad..88f944ce2 100644 --- a/src/runloop_api_client/_constants.py +++ b/src/runloop_api_client/_constants.py @@ -8,7 +8,7 @@ # default timeout is 30 seconds DEFAULT_TIMEOUT = httpx.Timeout(timeout=30, connect=5.0) DEFAULT_MAX_RETRIES = 5 -DEFAULT_CONNECTION_LIMITS = httpx.Limits(max_connections=100, max_keepalive_connections=20) +DEFAULT_CONNECTION_LIMITS = httpx.Limits(max_connections=20, max_keepalive_connections=10) INITIAL_RETRY_DELAY = 1.0 MAX_RETRY_DELAY = 60.0 diff --git a/src/runloop_api_client/lib/wait_for_status.py b/src/runloop_api_client/lib/wait_for_status.py new file mode 100644 index 000000000..73df2bf95 --- /dev/null +++ b/src/runloop_api_client/lib/wait_for_status.py @@ -0,0 +1,99 @@ +"""Helpers for polling wait_for_status long-poll endpoints. + +Each function wraps a server-side long-poll POST with a client-side retry +loop. On each iteration the remaining timeout is forwarded to the server +so the server can long-poll for up to that duration. 408 responses and +client-side timeouts are converted to a caller-supplied placeholder so the +loop can continue. No client-side sleep between iterations — the +server-side long-poll *is* the wait. +""" + +from __future__ import annotations + +import time +from typing import List, Type, TypeVar, Callable, Optional, Awaitable + +from .polling import PollingConfig, PollingTimeout +from .._exceptions import APIStatusError, APITimeoutError + +T = TypeVar("T") + + +def wait_for_status( + post_fn: Callable[..., T], + path: str, + statuses: List[str], + cast_to: Type[T], + placeholder: Callable[[], T], + is_terminal: Callable[[T], bool], + polling_config: Optional[PollingConfig] = None, +) -> T: + """Sync long-poll for a status change, retrying until *is_terminal* or timeout.""" + config = polling_config or PollingConfig() + timeout = config.interval_seconds * config.max_attempts + if config.timeout_seconds is not None and config.timeout_seconds > 0: + timeout = min(config.timeout_seconds, timeout) + + start_time = time.time() + last_result: T | None = None + + while True: + remaining = timeout - (time.time() - start_time) + if remaining <= 0: + raise PollingTimeout(f"Exceeded timeout of {timeout} seconds", last_result) + + try: + last_result = post_fn( + path, + body={"statuses": statuses, "timeout_seconds": remaining}, + cast_to=cast_to, + options={"max_retries": 0}, + ) + except (APITimeoutError, APIStatusError) as error: + if isinstance(error, APITimeoutError) or error.response.status_code == 408: + last_result = placeholder() + else: + raise + + if is_terminal(last_result): + return last_result + + +async def async_wait_for_status( + post_fn: Callable[..., Awaitable[T]], + path: str, + statuses: List[str], + cast_to: Type[T], + placeholder: Callable[[], T], + is_terminal: Callable[[T], bool], + polling_config: Optional[PollingConfig] = None, +) -> T: + """Async long-poll for a status change, retrying until *is_terminal* or timeout.""" + config = polling_config or PollingConfig() + timeout = config.interval_seconds * config.max_attempts + if config.timeout_seconds is not None and config.timeout_seconds > 0: + timeout = min(config.timeout_seconds, timeout) + + start_time = time.time() + last_result: T | None = None + + while True: + remaining = timeout - (time.time() - start_time) + if remaining <= 0: + raise PollingTimeout(f"Exceeded timeout of {timeout} seconds", last_result) + + try: + last_result = await post_fn( + path, + body={"statuses": statuses, "timeout_seconds": remaining}, + cast_to=cast_to, + options={"max_retries": 0}, + ) + except (APITimeoutError, APIStatusError) as error: + if isinstance(error, APITimeoutError) or error.response.status_code == 408: + last_result = placeholder() + else: + raise + + if is_terminal(last_result): + return last_result diff --git a/src/runloop_api_client/resources/devboxes/devboxes.py b/src/runloop_api_client/resources/devboxes/devboxes.py index 83459959b..888369e98 100644 --- a/src/runloop_api_client/resources/devboxes/devboxes.py +++ b/src/runloop_api_client/resources/devboxes/devboxes.py @@ -72,7 +72,7 @@ AsyncDiskSnapshotsCursorIDPage, ) from ..._exceptions import RunloopError, APIStatusError, APITimeoutError -from ...lib.polling import PollingConfig, poll_until, retry_server_poll_until as sync_retry_server_poll_until +from ...lib.polling import PollingConfig, poll_until from ..._base_client import AsyncPaginator, make_request_options from .disk_snapshots import ( DiskSnapshotsResource, @@ -82,9 +82,10 @@ DiskSnapshotsResourceWithStreamingResponse, AsyncDiskSnapshotsResourceWithStreamingResponse, ) -from ...lib.polling_async import async_poll_until, async_retry_server_poll_until +from ...lib.polling_async import async_poll_until from ...types.devbox_view import DevboxView from ...types.tunnel_view import TunnelView +from ...lib.wait_for_status import wait_for_status, async_wait_for_status from ...types.shared_params.mount import Mount from ...types.devbox_snapshot_view import DevboxSnapshotView from ...types.shared.launch_parameters import LaunchParameters as SharedLaunchParameters @@ -383,11 +384,7 @@ def await_running( Args: id: The ID of the devbox to wait for - config: Optional polling configuration - extra_headers: Send extra headers - extra_query: Add additional query parameters to the request - extra_body: Add additional JSON properties to the request - timeout: Override the client-level default timeout for this request, in seconds + polling_config: Optional polling configuration Returns: The devbox in running state @@ -397,31 +394,18 @@ def await_running( RunloopError: If devbox enters a non-running terminal state """ - def wait_for_devbox_status(remaining_timeout_seconds: float) -> DevboxView: - try: - return self._post( - f"/v1/devboxes/{id}/wait_for_status", - body={"statuses": ["running", "failure", "shutdown"], "timeout_seconds": remaining_timeout_seconds}, - cast_to=DevboxView, - options={"max_retries": 0}, - ) - except (APITimeoutError, APIStatusError) as error: - if isinstance(error, APITimeoutError) or error.response.status_code == 408: - return placeholder_devbox_view(id) - raise - def is_done_booting(devbox: DevboxView) -> bool: return devbox.status not in DEVBOX_BOOTING_STATES - config = polling_config - if not config: - config = PollingConfig() - - timeout = config.interval_seconds * config.max_attempts - if config.timeout_seconds is not None and config.timeout_seconds > 0: - timeout = min(config.timeout_seconds, timeout) - - devbox = sync_retry_server_poll_until(wait_for_devbox_status, is_done_booting, timeout) + devbox = wait_for_status( + self._post, + f"/v1/devboxes/{id}/wait_for_status", + ["running", "failure", "shutdown"], + DevboxView, + lambda: placeholder_devbox_view(id), + is_done_booting, + polling_config, + ) if devbox.status != "running": raise RunloopError(f"Devbox entered non-running terminal state: {devbox.status}") @@ -448,25 +432,18 @@ def await_suspended( RunloopError: If the devbox enters a non-suspended terminal state. """ - def wait_for_devbox_status() -> DevboxView: - return self._post( - f"/v1/devboxes/{id}/wait_for_status", - body={"statuses": list(DEVBOX_TERMINAL_STATES)}, - cast_to=DevboxView, - options={"max_retries": 0}, - ) - - def handle_timeout_error(error: Exception) -> DevboxView: - if isinstance(error, APITimeoutError) or ( - isinstance(error, APIStatusError) and error.response.status_code == 408 - ): - return placeholder_devbox_view(id) - raise error - def is_terminal_state(devbox: DevboxView) -> bool: return devbox.status in DEVBOX_TERMINAL_STATES - devbox = poll_until(wait_for_devbox_status, is_terminal_state, polling_config, handle_timeout_error) + devbox = wait_for_status( + self._post, + f"/v1/devboxes/{id}/wait_for_status", + list(DEVBOX_TERMINAL_STATES), + DevboxView, + lambda: placeholder_devbox_view(id), + is_terminal_state, + polling_config, + ) if devbox.status != "suspended": raise RunloopError(f"Devbox entered non-suspended terminal state: {devbox.status}") @@ -2045,9 +2022,6 @@ async def await_running( Args: id: The ID of the devbox to wait for polling_config: Optional polling configuration - extra_headers: Send extra headers - extra_query: Add additional query parameters to the request - extra_body: Add additional JSON properties to the request Returns: The devbox in running state @@ -2057,41 +2031,18 @@ async def await_running( RunloopError: If devbox enters a non-running terminal state """ - async def wait_for_devbox_status(remaining_timeout_seconds: float) -> DevboxView: - # This wait_for_status endpoint polls the devbox status for 10 seconds until it reaches either running or failure. - # If it's neither, it will throw an error. - try: - return await self._post( - f"/v1/devboxes/{id}/wait_for_status", - body={"statuses": ["running", "failure", "shutdown"], "timeout_seconds": remaining_timeout_seconds}, - cast_to=DevboxView, - options={"max_retries": 0}, - ) - except (APITimeoutError, APIStatusError) as error: - # Handle timeout errors by returning current devbox state to continue polling - if isinstance(error, APITimeoutError) or error.response.status_code == 408: - # Return a placeholder result to continue polling - return placeholder_devbox_view(id) - - # Re-raise other errors to stop polling - raise - def is_done_booting(devbox: DevboxView) -> bool: return devbox.status not in DEVBOX_BOOTING_STATES - # calculate the timeout to use. The PollingConfig doesn't - # match the semantics for server-side polling well, so we - # instead convert interval*attempts to a total time, and take - # the minimum total. - config = polling_config - if not config: - config = PollingConfig() # use defaults - - timeout = config.interval_seconds * config.max_attempts - if config.timeout_seconds is not None and config.timeout_seconds > 0: - timeout = min(config.timeout_seconds, timeout) - - devbox = await async_retry_server_poll_until(wait_for_devbox_status, is_done_booting, timeout) + devbox = await async_wait_for_status( + self._post, + f"/v1/devboxes/{id}/wait_for_status", + ["running", "failure", "shutdown"], + DevboxView, + lambda: placeholder_devbox_view(id), + is_done_booting, + polling_config, + ) if devbox.status != "running": raise RunloopError(f"Devbox entered non-running terminal state: {devbox.status}") @@ -2118,23 +2069,18 @@ async def await_suspended( RunloopError: If the devbox enters a non-suspended terminal state. """ - async def wait_for_devbox_status() -> DevboxView: - try: - return await self._post( - f"/v1/devboxes/{id}/wait_for_status", - body={"statuses": list(DEVBOX_TERMINAL_STATES)}, - cast_to=DevboxView, - options={"max_retries": 0}, - ) - except (APITimeoutError, APIStatusError) as error: - if isinstance(error, APITimeoutError) or error.response.status_code == 408: - return placeholder_devbox_view(id) - raise - def is_terminal_state(devbox: DevboxView) -> bool: return devbox.status in DEVBOX_TERMINAL_STATES - devbox = await async_poll_until(wait_for_devbox_status, is_terminal_state, polling_config) + devbox = await async_wait_for_status( + self._post, + f"/v1/devboxes/{id}/wait_for_status", + list(DEVBOX_TERMINAL_STATES), + DevboxView, + lambda: placeholder_devbox_view(id), + is_terminal_state, + polling_config, + ) if devbox.status != "suspended": raise RunloopError(f"Devbox entered non-suspended terminal state: {devbox.status}") diff --git a/src/runloop_api_client/resources/devboxes/executions.py b/src/runloop_api_client/resources/devboxes/executions.py index ff7638798..e5bfd7ed8 100755 --- a/src/runloop_api_client/resources/devboxes/executions.py +++ b/src/runloop_api_client/resources/devboxes/executions.py @@ -20,8 +20,7 @@ ) from ..._constants import DEFAULT_TIMEOUT, RAW_RESPONSE_HEADER from ..._streaming import Stream, AsyncStream, ReconnectingStream, AsyncReconnectingStream -from ..._exceptions import APIStatusError, APITimeoutError -from ...lib.polling import PollingConfig, poll_until +from ...lib.polling import PollingConfig from ..._base_client import make_request_options from ...types.devboxes import ( execution_kill_params, @@ -32,7 +31,7 @@ execution_stream_stderr_updates_params, execution_stream_stdout_updates_params, ) -from ...lib.polling_async import async_poll_until +from ...lib.wait_for_status import wait_for_status, async_wait_for_status from ...types.devbox_send_std_in_result import DevboxSendStdInResult from ...types.devbox_execution_detail_view import DevboxExecutionDetailView from ...types.devboxes.execution_update_chunk import ExecutionUpdateChunk @@ -129,12 +128,8 @@ def await_completed( Args: execution_id: The ID of the execution to wait for - id: The ID of the devbox - config: Optional polling configuration - extra_headers: Send extra headers - extra_query: Add additional query parameters to the request - extra_body: Add additional JSON properties to the request - timeout: Override the client-level default timeout for this request, in seconds + devbox_id: The ID of the devbox + polling_config: Optional polling configuration Returns: The completed execution @@ -143,29 +138,18 @@ def await_completed( PollingTimeout: If polling times out before execution completes """ - def wait_for_execution_status() -> DevboxAsyncExecutionDetailView: - # This wait_for_status endpoint polls the execution status for 60 seconds until it reaches either completed. - return self._post( - f"/v1/devboxes/{devbox_id}/executions/{execution_id}/wait_for_status", - body={"statuses": ["completed"]}, - cast_to=DevboxAsyncExecutionDetailView, - ) - - def handle_timeout_error(error: Exception) -> DevboxAsyncExecutionDetailView: - # Handle timeout errors by returning current execution state to continue polling - if isinstance(error, APITimeoutError) or ( - isinstance(error, APIStatusError) and error.response.status_code == 408 - ): - # Return a placeholder result to continue polling - return placeholder_execution_detail_view(devbox_id, execution_id) - else: - # Re-raise other errors to stop polling - raise error - def is_done(execution: DevboxAsyncExecutionDetailView) -> bool: return execution.status == "completed" - return poll_until(wait_for_execution_status, is_done, polling_config, handle_timeout_error) + return wait_for_status( + self._post, + f"/v1/devboxes/{devbox_id}/executions/{execution_id}/wait_for_status", + ["completed"], + DevboxAsyncExecutionDetailView, + lambda: placeholder_execution_detail_view(devbox_id, execution_id), + is_done, + polling_config, + ) def execute_async( self, @@ -675,12 +659,8 @@ async def await_completed( Args: execution_id: The ID of the execution to wait for - id: The ID of the devbox + devbox_id: The ID of the devbox polling_config: Optional polling configuration - extra_headers: Send extra headers - extra_query: Add additional query parameters to the request - extra_body: Add additional JSON properties to the request - timeout: Override the client-level default timeout for this request, in seconds Returns: The completed execution @@ -689,25 +669,18 @@ async def await_completed( PollingTimeout: If polling times out before execution completes """ - async def wait_for_execution_status() -> DevboxAsyncExecutionDetailView: - try: - return await self._post( - f"/v1/devboxes/{devbox_id}/executions/{execution_id}/wait_for_status", - body={"statuses": ["completed"]}, - cast_to=DevboxAsyncExecutionDetailView, - ) - except (APITimeoutError, APIStatusError) as error: - # Handle timeout errors by returning placeholder to continue polling - if isinstance(error, APITimeoutError) or error.response.status_code == 408: - return placeholder_execution_detail_view(devbox_id, execution_id) - - # Re-raise other errors to stop polling - raise - def is_done(execution: DevboxAsyncExecutionDetailView) -> bool: return execution.status == "completed" - return await async_poll_until(wait_for_execution_status, is_done, polling_config) + return await async_wait_for_status( + self._post, + f"/v1/devboxes/{devbox_id}/executions/{execution_id}/wait_for_status", + ["completed"], + DevboxAsyncExecutionDetailView, + lambda: placeholder_execution_detail_view(devbox_id, execution_id), + is_done, + polling_config, + ) async def execute_async( self, diff --git a/tests/test_client.py b/tests/test_client.py index 408c7cedd..7728bf5bb 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -32,7 +32,9 @@ DefaultHttpxClient, DefaultAsyncHttpxClient, get_platform, + _SharedTransport, make_request_options, + _SharedAsyncTransport, ) from .utils import update_env @@ -105,7 +107,9 @@ async def _make_async_iterator(iterable: Iterable[T], counter: Optional[Counter] def _get_open_connections(client: Runloop | AsyncRunloop) -> int: transport = client._client._transport - assert isinstance(transport, httpx.HTTPTransport) or isinstance(transport, httpx.AsyncHTTPTransport) + if isinstance(transport, (_SharedTransport, _SharedAsyncTransport)): + transport = transport._transport + assert isinstance(transport, (httpx.HTTPTransport, httpx.AsyncHTTPTransport)) pool = transport._pool return len(pool._requests) diff --git a/tests/test_shared_pool.py b/tests/test_shared_pool.py new file mode 100644 index 000000000..4220f8ba9 --- /dev/null +++ b/tests/test_shared_pool.py @@ -0,0 +1,319 @@ +"""Tests for shared HTTP transport pool behavior. + +Verifies that SDK clients share (or don't share) the underlying httpx +transport, and that refcounting correctly manages the transport lifecycle. +""" + +from __future__ import annotations + +import os +import asyncio +from typing import Any, Iterator + +import httpx +import pytest + +import runloop_api_client._base_client as _base_mod +from runloop_api_client import Runloop, AsyncRunloop + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") +bearer_token = "My Bearer Token" + + +@pytest.fixture(autouse=True) +def _reset_shared_pool() -> Iterator[None]: # pyright: ignore[reportUnusedFunction] + _clear_pool_state() + yield + _clear_pool_state() + + +def _clear_pool_state() -> None: + with _base_mod._pool_lock: + old_sync = _base_mod._shared_sync_transport + _base_mod._shared_sync_transport = None + _base_mod._shared_async_transports.clear() + if old_sync is not None: + try: + old_sync._transport.close() + except Exception: + pass + + +def _make_client(**kwargs: Any) -> Runloop: + kwargs.setdefault("base_url", base_url) + kwargs.setdefault("bearer_token", bearer_token) + return Runloop(**kwargs) + + +def _make_async_client(**kwargs: Any) -> AsyncRunloop: + kwargs.setdefault("base_url", base_url) + kwargs.setdefault("bearer_token", bearer_token) + return AsyncRunloop(**kwargs) + + +def _get_transport(client: Runloop | AsyncRunloop) -> Any: + return client._client._transport # type: ignore[union-attr] + + +# --------------------------------------------------------------------------- +# Sync: sharing behavior +# --------------------------------------------------------------------------- + + +class TestSyncSharedPool: + def test_shared_pool_uses_same_transport(self): + c1 = _make_client(shared_http_pool=True) + c2 = _make_client(shared_http_pool=True) + + assert _get_transport(c1) is _get_transport(c2) + assert c1._client is not c2._client + assert c1._uses_shared_pool is True + assert c2._uses_shared_pool is True + + c1.close() + c2.close() + + def test_private_pool_uses_different_transports(self): + c1 = _make_client(shared_http_pool=False) + c2 = _make_client(shared_http_pool=False) + + assert _get_transport(c1) is not _get_transport(c2) + assert c1._uses_shared_pool is False + assert c2._uses_shared_pool is False + + c1.close() + c2.close() + + def test_custom_http_client_bypasses_sharing(self): + custom = httpx.Client() + c1 = _make_client(http_client=custom, shared_http_pool=True) + + assert c1._client is custom + assert c1._uses_shared_pool is False + + c1.close() + custom.close() + + def test_default_is_shared(self): + c1 = _make_client() + assert c1._uses_shared_pool is True + c1.close() + + def test_cookie_isolation(self): + c1 = _make_client(shared_http_pool=True) + c2 = _make_client(shared_http_pool=True) + + c1._client.cookies.set("session", "secret-123") + assert "session" not in c2._client.cookies + + c1.close() + c2.close() + + +class TestSyncRefcounting: + def test_close_one_keeps_transport_alive(self): + c1 = _make_client(shared_http_pool=True) + c2 = _make_client(shared_http_pool=True) + transport = _get_transport(c1) + + assert transport.refcount == 2 + + c1.close() + assert transport.refcount == 1 + assert not c2.is_closed() + + c2.close() + assert transport.refcount == 0 + + def test_double_close_is_safe(self): + c1 = _make_client(shared_http_pool=True) + transport = _get_transport(c1) + + c1.close() + c1.close() # should not raise or double-decrement + assert transport.refcount == 0 + + def test_three_clients_refcount(self): + c1 = _make_client(shared_http_pool=True) + c2 = _make_client(shared_http_pool=True) + c3 = _make_client(shared_http_pool=True) + transport = _get_transport(c1) + + assert transport.refcount == 3 + + c1.close() + assert transport.refcount == 2 + + c2.close() + assert transport.refcount == 1 + + c3.close() + assert transport.refcount == 0 + + def test_transport_recreated_after_full_release(self): + c1 = _make_client(shared_http_pool=True) + t1 = _get_transport(c1) + c1.close() + + c2 = _make_client(shared_http_pool=True) + t2 = _get_transport(c2) + assert t2 is not t1 + assert t2.refcount == 1 + + c2.close() + + +class TestSyncCopy: + def test_copy_inherits_shared_pool(self): + c1 = _make_client(shared_http_pool=True) + c2 = c1.copy() + transport = _get_transport(c1) + + assert c2._uses_shared_pool is True + assert _get_transport(c2) is transport + assert transport.refcount == 2 + + c1.close() + c2.close() + + def test_copy_with_custom_client_disables_sharing(self): + c1 = _make_client(shared_http_pool=True) + custom = httpx.Client() + c2 = c1.copy(http_client=custom) + + assert c2._uses_shared_pool is False + assert c2._client is custom + + c1.close() + c2.close() + custom.close() + + def test_copy_of_non_shared_stays_non_shared(self): + c1 = _make_client(shared_http_pool=False) + c2 = c1.copy() + + assert c2._uses_shared_pool is False + assert _get_transport(c2) is not _get_transport(c1) + + c1.close() + c2.close() + + +# --------------------------------------------------------------------------- +# Async: sharing behavior +# --------------------------------------------------------------------------- + + +class TestAsyncSharedPool: + async def test_shared_pool_uses_same_transport(self): + c1 = _make_async_client(shared_http_pool=True) + c2 = _make_async_client(shared_http_pool=True) + + assert _get_transport(c1) is _get_transport(c2) + assert c1._client is not c2._client + assert c1._uses_shared_pool is True + assert c2._uses_shared_pool is True + + def test_private_pool_uses_different_transports(self): + c1 = _make_async_client(shared_http_pool=False) + c2 = _make_async_client(shared_http_pool=False) + + assert _get_transport(c1) is not _get_transport(c2) + assert c1._uses_shared_pool is False + + def test_custom_http_client_bypasses_sharing(self): + custom = httpx.AsyncClient() + c1 = _make_async_client(http_client=custom, shared_http_pool=True) + + assert c1._client is custom + assert c1._uses_shared_pool is False + + async def test_default_is_shared(self): + c1 = _make_async_client() + assert c1._uses_shared_pool is True + + def test_no_loop_creates_private_client(self): + c1 = _make_async_client(shared_http_pool=True) + assert c1._uses_shared_pool is False + + +class TestAsyncRefcounting: + async def test_close_one_keeps_transport_alive(self): + c1 = _make_async_client(shared_http_pool=True) + c2 = _make_async_client(shared_http_pool=True) + transport = _get_transport(c1) + + assert transport.refcount == 2 + + await c1.close() + assert transport.refcount == 1 + assert not c2.is_closed() + + await c2.close() + assert transport.refcount == 0 + + async def test_double_close_is_safe(self): + c1 = _make_async_client(shared_http_pool=True) + transport = _get_transport(c1) + + await c1.close() + await c1.close() # should not raise or double-decrement + assert transport.refcount == 0 + + def test_no_loop_client_closes_properly(self): + """Client created without a running loop should close without leaking.""" + c1 = _make_async_client(shared_http_pool=True) + assert c1._uses_shared_pool is False + + asyncio.run(c1.close()) + assert c1.is_closed() + + +class TestAsyncCopy: + async def test_copy_inherits_shared_pool(self): + c1 = _make_async_client(shared_http_pool=True) + c2 = c1.copy() + transport = _get_transport(c1) + + assert c2._uses_shared_pool is True + assert _get_transport(c2) is transport + assert transport.refcount == 2 + + async def test_copy_with_custom_client_disables_sharing(self): + c1 = _make_async_client(shared_http_pool=True) + custom = httpx.AsyncClient() + c2 = c1.copy(http_client=custom) + + assert c2._uses_shared_pool is False + assert c2._client is custom + + +class TestAsyncCrossLoop: + def test_separate_loops_get_separate_transports(self): + """Clients created in different asyncio.run() calls must not share a transport.""" + + async def create_client() -> Any: + c = _make_async_client(shared_http_pool=True) + transport = _get_transport(c) + await c.close() + return transport + + t1 = asyncio.run(create_client()) + t2 = asyncio.run(create_client()) + + assert t1 is not t2, "each loop should get its own transport" + + def test_same_loop_shares_transport(self): + """Clients created in the same asyncio.run() must share a transport.""" + + async def create_two() -> tuple[int, int]: + c1 = _make_async_client(shared_http_pool=True) + c2 = _make_async_client(shared_http_pool=True) + id1 = id(_get_transport(c1)) + id2 = id(_get_transport(c2)) + await c1.close() + await c2.close() + return id1, id2 + + id1, id2 = asyncio.run(create_two()) + assert id1 == id2 diff --git a/uv.lock b/uv.lock index 88dc754a1..a35165b2c 100644 --- a/uv.lock +++ b/uv.lock @@ -2422,7 +2422,7 @@ wheels = [ [[package]] name = "runloop-api-client" -version = "1.20.0" +version = "1.20.2" source = { editable = "." } dependencies = [ { name = "anyio" },