diff --git a/src/runpod_flash/endpoint.py b/src/runpod_flash/endpoint.py index d47cc7d3..30db9484 100644 --- a/src/runpod_flash/endpoint.py +++ b/src/runpod_flash/endpoint.py @@ -27,6 +27,10 @@ _POLL_MAX_INTERVAL = 5.0 _POLL_BACKOFF_FACTOR = 1.5 +# max consecutive transient httpx errors tolerated during wait() polling +# before re-raising. resets on any successful poll. +_POLL_MAX_CONSECUTIVE_ERRORS = 5 + class _ClientCoroutine: """wraps a coroutine from a client-mode HTTP call. @@ -137,8 +141,11 @@ async def wait(self, timeout: Optional[float] = None) -> "EndpointJob": import asyncio import time + import httpx + deadline = (time.monotonic() + timeout) if timeout is not None else None interval = _POLL_INITIAL_INTERVAL + consecutive_errors = 0 while not self.done: if deadline is not None and time.monotonic() >= deadline: @@ -152,7 +159,27 @@ async def wait(self, timeout: Optional[float] = None) -> "EndpointJob": f"job {self.id} did not complete within {timeout}s " f"(last status: {self._data.get('status', 'UNKNOWN')})" ) - await self.status() + try: + await self.status() + except (httpx.TransportError, httpx.TimeoutException) as e: + # transient network / protocol / timeout error from the + # runpod api. the underlying job is still healthy, so back + # off and retry rather than aborting wait(). + # HTTPStatusError (4xx/5xx from raise_for_status) is NOT + # caught here: 4xx auth/config bugs must fail loud. + consecutive_errors += 1 + log.debug( + "transient httpx error polling job %s (%d/%d): %s", + self.id, + consecutive_errors, + _POLL_MAX_CONSECUTIVE_ERRORS, + e, + ) + if consecutive_errors >= _POLL_MAX_CONSECUTIVE_ERRORS: + raise + interval = min(interval * _POLL_BACKOFF_FACTOR, _POLL_MAX_INTERVAL) + continue + consecutive_errors = 0 interval = min(interval * _POLL_BACKOFF_FACTOR, _POLL_MAX_INTERVAL) return self diff --git a/tests/unit/test_endpoint_client.py b/tests/unit/test_endpoint_client.py index 236e1570..6b85a2e5 100644 --- a/tests/unit/test_endpoint_client.py +++ b/tests/unit/test_endpoint_client.py @@ -199,6 +199,132 @@ async def test_wait_timeout_raises(self): await job.wait(timeout=0.3) +@pytest.fixture +def fast_poll(monkeypatch): + """shrink the poll intervals so retry tests don't sit on real sleeps.""" + monkeypatch.setattr("runpod_flash.endpoint._POLL_INITIAL_INTERVAL", 0.001) + monkeypatch.setattr("runpod_flash.endpoint._POLL_MAX_INTERVAL", 0.005) + + +class TestEndpointJobWaitTransientErrors: + """retry behavior for transient httpx errors during wait() polling (AE-3154).""" + + @staticmethod + def _make_job(): + ep = Endpoint(id="ep-1") + ep._endpoint_url = "https://api.runpod.ai/v2/ep-1" + job = EndpointJob({"id": "j-1", "status": "IN_QUEUE"}, ep) + return ep, job + + @pytest.mark.asyncio + async def test_transient_error_then_success(self, fast_poll): + """one RemoteProtocolError then COMPLETED — wait() returns normally.""" + import httpx + + ep, job = self._make_job() + + side_effects = [ + httpx.RemoteProtocolError("server disconnected"), + {"id": "j-1", "status": "COMPLETED", "output": {"r": 1}}, + ] + ep._api_get = AsyncMock(side_effect=side_effects) + + result = await job.wait() + + assert result is job + assert job._data["status"] == "COMPLETED" + assert job.output == {"r": 1} + assert ep._api_get.call_count == 2 + + @pytest.mark.asyncio + async def test_repeated_transient_errors_exceed_threshold(self, fast_poll): + """5 consecutive RemoteProtocolErrors — wait() re-raises the httpx error.""" + import httpx + + from runpod_flash.endpoint import _POLL_MAX_CONSECUTIVE_ERRORS + + ep, job = self._make_job() + ep._api_get = AsyncMock( + side_effect=httpx.RemoteProtocolError("server disconnected") + ) + + with pytest.raises(httpx.RemoteProtocolError): + await job.wait() + + assert ep._api_get.call_count == _POLL_MAX_CONSECUTIVE_ERRORS + + @pytest.mark.asyncio + async def test_counter_resets_on_successful_poll(self, fast_poll): + """error bursts under the threshold separated by successes do not abort.""" + import httpx + + ep, job = self._make_job() + + side_effects = [ + httpx.RemoteProtocolError("drop 1"), + httpx.RemoteProtocolError("drop 2"), + {"id": "j-1", "status": "IN_PROGRESS"}, + httpx.RemoteProtocolError("drop 3"), + httpx.RemoteProtocolError("drop 4"), + httpx.RemoteProtocolError("drop 5"), + httpx.RemoteProtocolError("drop 6"), + {"id": "j-1", "status": "IN_PROGRESS"}, + {"id": "j-1", "status": "COMPLETED", "output": {"r": 1}}, + ] + ep._api_get = AsyncMock(side_effect=side_effects) + + result = await job.wait() + + assert result is job + assert job._data["status"] == "COMPLETED" + assert ep._api_get.call_count == len(side_effects) + + @pytest.mark.asyncio + async def test_http_status_error_not_swallowed(self, fast_poll): + """4xx HTTPStatusError must propagate immediately (auth/config bugs).""" + import httpx + + ep, job = self._make_job() + + request = httpx.Request("GET", "https://api.runpod.ai/v2/ep-1/status/j-1") + response = httpx.Response(401, request=request) + ep._api_get = AsyncMock( + side_effect=httpx.HTTPStatusError( + "401 unauthorized", request=request, response=response + ) + ) + + with pytest.raises(httpx.HTTPStatusError): + await job.wait() + + # exactly one call: not retried + assert ep._api_get.call_count == 1 + + @pytest.mark.asyncio + async def test_timeout_still_authoritative(self, fast_poll, monkeypatch): + """when deadline is hit before threshold, raise TimeoutError not httpx error. + + Raises the threshold above the number of retries the deadline allows, so + the test actually exercises the retry path (multiple suppressed httpx + errors) before the deadline trips -- not just the pre-sleep guard. + """ + import httpx + + monkeypatch.setattr("runpod_flash.endpoint._POLL_MAX_CONSECUTIVE_ERRORS", 1000) + + ep, job = self._make_job() + ep._api_get = AsyncMock( + side_effect=httpx.RemoteProtocolError("server disconnected") + ) + + with pytest.raises(TimeoutError, match="did not complete within"): + await job.wait(timeout=0.05) + + # proves the retry path was exercised: status() was called and the + # httpx error was suppressed at least once before the deadline tripped. + assert ep._api_get.call_count >= 2 + + # -- Endpoint.run / runsync / cancel --