diff --git a/src/runloop_api_client/lib/polling.py b/src/runloop_api_client/lib/polling.py index 899d2a9bf..1fd4e2a3a 100644 --- a/src/runloop_api_client/lib/polling.py +++ b/src/runloop_api_client/lib/polling.py @@ -1,5 +1,5 @@ import time -from typing import Any, TypeVar, Callable, Optional +from typing import Any, Union, TypeVar, Callable, Optional from dataclasses import dataclass T = TypeVar("T") @@ -73,3 +73,46 @@ def poll_until( raise PollingTimeout(f"Exceeded timeout of {config.timeout_seconds} seconds", last_result) time.sleep(config.interval_seconds) + + +def retry_server_poll_until( + retriever: Callable[[float], T], + is_terminal: Callable[[T], bool], + timeout_seconds: float = 30.0, + on_error: Optional[Callable[[Exception], T]] = None, +) -> T: + """ + Retry a server-side long-poll until a condition is met or max timeout is reached. + + Args: + retriever: Callable that takes the remaining timeout (seconds) and + returns the object to check. + is_terminal: Callable that returns True when polling should stop + timeout_seconds: Total time to wait. Must be > 0 + on_error: Optional error handler that can return a value to continue polling + or re-raise the exception to stop polling + + Returns: + The final state of the polled object + + Raises: + PollingTimeout: When max attempts or timeout is reached + """ + last_result: Union[T, None] = None + start_time = time.time() + + while True: + remaining_time = timeout_seconds - (time.time() - start_time) + if remaining_time <= 0: + raise PollingTimeout(f"Exceeded timeout of {timeout_seconds} seconds", last_result) + + try: + last_result = retriever(remaining_time) + except Exception as e: + if on_error is not None: + last_result = on_error(e) + else: + raise + + if is_terminal(last_result): + return last_result diff --git a/src/runloop_api_client/lib/polling_async.py b/src/runloop_api_client/lib/polling_async.py index 9bc1bb752..f6789d3a9 100644 --- a/src/runloop_api_client/lib/polling_async.py +++ b/src/runloop_api_client/lib/polling_async.py @@ -60,7 +60,7 @@ async def async_poll_until( await asyncio.sleep(config.interval_seconds) -async def retry_server_poll_until( +async def async_retry_server_poll_until( retriever: Callable[[float], Awaitable[T]], is_terminal: Callable[[T], bool], timeout_seconds: float = 30.0, diff --git a/src/runloop_api_client/resources/devboxes/devboxes.py b/src/runloop_api_client/resources/devboxes/devboxes.py index 96628e2bb..83459959b 100644 --- a/src/runloop_api_client/resources/devboxes/devboxes.py +++ b/src/runloop_api_client/resources/devboxes/devboxes.py @@ -72,7 +72,7 @@ AsyncDiskSnapshotsCursorIDPage, ) from ..._exceptions import RunloopError, APIStatusError, APITimeoutError -from ...lib.polling import PollingConfig, poll_until +from ...lib.polling import PollingConfig, poll_until, retry_server_poll_until as sync_retry_server_poll_until from ..._base_client import AsyncPaginator, make_request_options from .disk_snapshots import ( DiskSnapshotsResource, @@ -82,7 +82,7 @@ DiskSnapshotsResourceWithStreamingResponse, AsyncDiskSnapshotsResourceWithStreamingResponse, ) -from ...lib.polling_async import async_poll_until, retry_server_poll_until +from ...lib.polling_async import async_poll_until, async_retry_server_poll_until from ...types.devbox_view import DevboxView from ...types.tunnel_view import TunnelView from ...types.shared_params.mount import Mount @@ -397,30 +397,31 @@ def await_running( RunloopError: If devbox enters a non-running terminal state """ - def wait_for_devbox_status() -> DevboxView: - # This wait_for_status endpoint polls the devbox status for 10 seconds until it reaches either running or failure. - # If it's neither, it will throw an error. - return self._post( - f"/v1/devboxes/{id}/wait_for_status", - body={"statuses": ["running", "failure", "shutdown"]}, - cast_to=DevboxView, - ) - - def handle_timeout_error(error: Exception) -> DevboxView: - # Handle timeout errors by returning current devbox state to continue polling - if isinstance(error, APITimeoutError) or ( - isinstance(error, APIStatusError) and error.response.status_code == 408 - ): - # Return a placeholder result to continue polling - return placeholder_devbox_view(id) - - # Re-raise other errors to stop polling - raise error + def wait_for_devbox_status(remaining_timeout_seconds: float) -> DevboxView: + try: + return self._post( + f"/v1/devboxes/{id}/wait_for_status", + body={"statuses": ["running", "failure", "shutdown"], "timeout_seconds": remaining_timeout_seconds}, + cast_to=DevboxView, + options={"max_retries": 0}, + ) + except (APITimeoutError, APIStatusError) as error: + if isinstance(error, APITimeoutError) or error.response.status_code == 408: + return placeholder_devbox_view(id) + raise def is_done_booting(devbox: DevboxView) -> bool: return devbox.status not in DEVBOX_BOOTING_STATES - devbox = poll_until(wait_for_devbox_status, is_done_booting, polling_config, handle_timeout_error) + config = polling_config + if not config: + config = PollingConfig() + + timeout = config.interval_seconds * config.max_attempts + if config.timeout_seconds is not None and config.timeout_seconds > 0: + timeout = min(config.timeout_seconds, timeout) + + devbox = sync_retry_server_poll_until(wait_for_devbox_status, is_done_booting, timeout) if devbox.status != "running": raise RunloopError(f"Devbox entered non-running terminal state: {devbox.status}") @@ -452,6 +453,7 @@ def wait_for_devbox_status() -> DevboxView: f"/v1/devboxes/{id}/wait_for_status", body={"statuses": list(DEVBOX_TERMINAL_STATES)}, cast_to=DevboxView, + options={"max_retries": 0}, ) def handle_timeout_error(error: Exception) -> DevboxView: @@ -2063,6 +2065,7 @@ async def wait_for_devbox_status(remaining_timeout_seconds: float) -> DevboxView f"/v1/devboxes/{id}/wait_for_status", body={"statuses": ["running", "failure", "shutdown"], "timeout_seconds": remaining_timeout_seconds}, cast_to=DevboxView, + options={"max_retries": 0}, ) except (APITimeoutError, APIStatusError) as error: # Handle timeout errors by returning current devbox state to continue polling @@ -2088,7 +2091,7 @@ def is_done_booting(devbox: DevboxView) -> bool: if config.timeout_seconds is not None and config.timeout_seconds > 0: timeout = min(config.timeout_seconds, timeout) - devbox = await retry_server_poll_until(wait_for_devbox_status, is_done_booting, timeout) + devbox = await async_retry_server_poll_until(wait_for_devbox_status, is_done_booting, timeout) if devbox.status != "running": raise RunloopError(f"Devbox entered non-running terminal state: {devbox.status}") @@ -2121,6 +2124,7 @@ async def wait_for_devbox_status() -> DevboxView: f"/v1/devboxes/{id}/wait_for_status", body={"statuses": list(DEVBOX_TERMINAL_STATES)}, cast_to=DevboxView, + options={"max_retries": 0}, ) except (APITimeoutError, APIStatusError) as error: if isinstance(error, APITimeoutError) or error.response.status_code == 408: