From 6d991b6055c9a51d4f57ea6a0843924fb58ae9fc Mon Sep 17 00:00:00 2001 From: jhcipar Date: Tue, 17 Mar 2026 16:13:07 -0400 Subject: [PATCH 1/6] feat: poll request logs for qb workers when running --- src/runpod_flash/cli/utils/deployment.py | 29 ++- .../core/resources/request_logs.py | 189 ++++++++++++++++++ src/runpod_flash/core/resources/serverless.py | 44 +++- tests/unit/cli/utils/test_deployment.py | 95 +++++++++ tests/unit/resources/test_request_logs.py | 141 +++++++++++++ tests/unit/resources/test_serverless.py | 96 +++++++++ 6 files changed, 589 insertions(+), 5 deletions(-) create mode 100644 src/runpod_flash/core/resources/request_logs.py create mode 100644 tests/unit/resources/test_request_logs.py diff --git a/src/runpod_flash/cli/utils/deployment.py b/src/runpod_flash/cli/utils/deployment.py index 21314614..c2c2f8d3 100644 --- a/src/runpod_flash/cli/utils/deployment.py +++ b/src/runpod_flash/cli/utils/deployment.py @@ -15,6 +15,14 @@ log = logging.getLogger(__name__) +def _normalized_resource_attr(resource: Any, *names: str) -> str | None: + for name in names: + value = getattr(resource, name, None) + if isinstance(value, str) and value.strip(): + return value + return None + + async def upload_build(app_name: str, build_path: str | Path): app = await FlashApp.from_name(app_name) await app.upload_build(build_path) @@ -145,6 +153,14 @@ async def provision_resources_for_build( resources_endpoints[resource_name] = endpoint_url + endpoint_id = _normalized_resource_attr(deployed_resource, "endpoint_id", "id") + if endpoint_id: + manifest["resources"][resource_name]["endpoint_id"] = endpoint_id + + ai_key = _normalized_resource_attr(deployed_resource, "aiKey", "ai_key") + if ai_key: + manifest["resources"][resource_name]["aiKey"] = ai_key + # Track load balancer URL for prominent logging if manifest["resources"][resource_name].get("is_load_balanced"): lb_endpoint_url = endpoint_url @@ -277,6 +293,10 @@ async def reconcile_and_provision_resources( local_manifest["resources"][resource_name]["endpoint_id"] = ( state_config["endpoint_id"] ) + if "aiKey" in state_config: + local_manifest["resources"][resource_name]["aiKey"] = state_config[ + "aiKey" + ] if resource_name in state_manifest.get("resources_endpoints", {}): local_manifest.setdefault("resources_endpoints", {})[resource_name] = ( state_manifest["resources_endpoints"][resource_name] @@ -310,13 +330,18 @@ async def reconcile_and_provision_resources( deployed_resource = provisioning_results[i] # Extract endpoint info - endpoint_id = getattr(deployed_resource, "endpoint_id", None) - endpoint_url = getattr(deployed_resource, "endpoint_url", None) + endpoint_id = _normalized_resource_attr( + deployed_resource, "endpoint_id", "id" + ) + endpoint_url = _normalized_resource_attr(deployed_resource, "endpoint_url") + ai_key = _normalized_resource_attr(deployed_resource, "aiKey", "ai_key") if endpoint_id: local_manifest["resources"][resource_name]["endpoint_id"] = endpoint_id if endpoint_url: local_manifest["resources_endpoints"][resource_name] = endpoint_url + if ai_key: + local_manifest["resources"][resource_name]["aiKey"] = ai_key log.debug( f"{'Provisioned' if action_type == 'provision' else 'Updated'}: " diff --git a/src/runpod_flash/core/resources/request_logs.py b/src/runpod_flash/core/resources/request_logs.py new file mode 100644 index 00000000..8946bf48 --- /dev/null +++ b/src/runpod_flash/core/resources/request_logs.py @@ -0,0 +1,189 @@ +import logging +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +import dateutil +from typing import Any, List, Optional + +import httpx + +from runpod_flash.core.utils.http import get_authenticated_httpx_client + +log = logging.getLogger(__name__) + +API_BASE_URL = "https://api.runpod.ai" + + +def _format_log_timestamp(value: datetime) -> str: + return value.astimezone(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.000Z") + + +@dataclass +class QBRequestLogBatch: + lines: List[str] + matched_by_request_id: bool + worker_id: Optional[str] + +class QBRequestLogFetcher: + def __init__( + self, + timeout_seconds: float = 4.0, + max_lines: int = 25, + fallback_tail_lines: int = 10, + lookback_seconds: int = 5, + start_time: datetime = datetime.now(timezone.utc) + ): + self.timeout_seconds = timeout_seconds + self.max_lines = max_lines + self.fallback_tail_lines = fallback_tail_lines + self.lookback_seconds = lookback_seconds + self.start_time = start_time + self.fetched_until: datetime | None = None + self.seen = set() + + async def fetch_logs( + self, + endpoint_id: str, + endpoint_ai_key: str, + ): + if self.fetched_until: + self.start_time = self.fetched_until + fetch_until = datetime.now(timezone.utc) + logs_payload = await self._fetch_endpoint_logs(endpoint_id, endpoint_ai_key, fetch_until) + if not logs_payload: + return + + lines = self._extract_lines(logs_payload, fetch_until) + return QBRequestLogBatch( + lines = lines, + matched_by_request_id = False, + worker_id = None + ) + + async def _fetch_worker_id( + self, + endpoint_id: str, + request_id: str, + runpod_api_key: str, + ) -> Optional[str]: + url = f"{API_BASE_URL}/v2/{endpoint_id}/status/{request_id}" + + try: + async with get_authenticated_httpx_client( + timeout=self.timeout_seconds, + api_key_override=runpod_api_key, + ) as client: + response = await client.get(url) + response.raise_for_status() + payload = response.json() + except (httpx.HTTPError, ValueError) as exc: + log.debug("Failed to fetch worker for request %s: %s", request_id, exc) + return None + + worker_id = payload.get("workerId") + if not worker_id: + return None + return str(worker_id) + + async def _fetch_endpoint_logs( + self, + endpoint_id: str, + endpoint_ai_key: str, + end_utc: datetime, + start_utc: Optional[datetime] = None + ) -> Optional[dict[str, Any]]: + """ + fetch endpoint logs for a given time range, defaulting to the fetcher + configured start time + updates start utc when we successfully fetch logs + """ + url = f"{API_BASE_URL}/v2/{endpoint_id}/logs" + if not start_utc: + start_utc = self.start_time + + log.debug(f"fetching logs for time range: {start_utc} to {end_utc}") + params = { + "from": _format_log_timestamp(start_utc), + "to": _format_log_timestamp(end_utc), + "page": 0, + "pageSize": 200, + } + + try: + async with get_authenticated_httpx_client( + timeout=self.timeout_seconds, + api_key_override=endpoint_ai_key, + ) as client: + response = await client.get(url, params=params) + response.raise_for_status() + return response.json() + except httpx.HTTPStatusError as exc: + body_preview = "" + if exc.response is not None: + body_preview = (exc.response.text or "")[:500] + log.debug( + "Failed to fetch endpoint logs for %s: %s | response_body=%s", + endpoint_id, + exc, + body_preview, + ) + return None + except (httpx.HTTPError, ValueError) as exc: + log.debug("Failed to fetch endpoint logs for %s: %s", endpoint_id, exc) + return None + + def _extract_lines(self, payload: dict[str, Any], end_time: datetime) -> List[str]: + """ + extract lines from a response payload from sls endpoint response + deduplicates based on already seen lines + """ + records = payload.get("data") + if not isinstance(records, list): + return [] + + max_seen_dt = self.start_time + lines: List[str] = [] + + for record in records: + if isinstance(record, str): + stripped = record.strip() + if stripped and stripped not in self.seen: + self.seen.add(stripped) + lines.append(record) + continue + + if not isinstance(record, dict): + continue + + line = ( + record.get("message") + or record.get("log") + or record.get("text") + or record.get("raw") + ) + + dt = record.get("dt") + + if dt: + parsed = dateutil.parser.parse(dt) + if parsed.tzinfo is None: + parsed = parsed.replace(tzinfo=timezone.utc) # treat naive as UTC + else: + parsed = parsed.astimezone(timezone.utc) + max_seen_dt = max(parsed, max_seen_dt) + + if isinstance(line, str): + stripped = line.strip() + if stripped and stripped not in self.seen: + stripped = stripped.replace("\\n", "") + self.seen.add(stripped) + lines.append(stripped) + if lines: + # lines are returned in time descending order + lines.reverse() + if max_seen_dt > self.start_time: + self.fetched_until = max_seen_dt + else: + # not all logs have a timestamp, assume we should refetch + self.fetched_until = self.start_time + + return lines diff --git a/src/runpod_flash/core/resources/serverless.py b/src/runpod_flash/core/resources/serverless.py index 1f374646..c2a47b0c 100644 --- a/src/runpod_flash/core/resources/serverless.py +++ b/src/runpod_flash/core/resources/serverless.py @@ -24,6 +24,7 @@ from .cpu import CpuInstanceType from .gpu import GpuGroup, GpuType from .network_volume import NetworkVolume, DataCenter +from .request_logs import QBRequestLogFetcher from .template import KeyValuePair, PodTemplate from .resource_manager import ResourceManager @@ -205,6 +206,27 @@ def endpoint_url(self) -> str: base_url = self.endpoint.rp_client.endpoint_url_base return f"{base_url}/{self.id}" + async def _emit_endpoint_logs( + self, + fetcher: QBRequestLogFetcher, + ): + if self.type != ServerlessType.QB: + return + + if not self.id or not self.aiKey: + return + + batch = await fetcher.fetch_logs( + endpoint_id=self.id, + endpoint_ai_key=self.aiKey, + ) + if not batch: + return False + + if batch.lines: + for line in batch.lines: + print(f"worker log: {line}") + @field_serializer("scalerType") def serialize_scaler_type( self, value: Optional[ServerlessScalerType] @@ -930,8 +952,6 @@ async def run(self, payload: Dict[str, Any]) -> "JobOutput": job: Optional[Job] = None try: - # log.debug(f"[{self}] Payload: {payload}") - # Create a job using the endpoint log.info(f"{self} | API /run") job = await asyncio.to_thread(self.endpoint.run, request_input=payload) @@ -944,6 +964,7 @@ async def run(self, payload: Dict[str, Any]) -> "JobOutput": attempt = 0 job_status = Status.UNKNOWN last_status = job_status + fetcher = QBRequestLogFetcher() # Poll for job status while True: @@ -963,13 +984,30 @@ async def run(self, payload: Dict[str, Any]) -> "JobOutput": log.info(f"{log_subgroup} | Status: {job_status}") attempt = 0 + await self._emit_endpoint_logs( + fetcher=fetcher, + ) + last_status = job_status # Adjust polling pace appropriately - current_pace = get_backoff_delay(attempt) + current_pace = get_backoff_delay(attempt, max_seconds=5) if job_status in ("COMPLETED", "FAILED", "CANCELLED"): response = await asyncio.to_thread(job._fetch_job) + output = response.get("output") + if isinstance(output, dict): + stdout = output.get("stdout") + if isinstance(stdout, str): + kept = [] + for raw in stdout.splitlines(): + raw = raw.strip() + if not raw: + continue + if raw in fetcher.seen: + continue + kept.append(raw) + output["stdout"] = "\n".join(kept) return JobOutput(**response) except Exception as e: diff --git a/tests/unit/cli/utils/test_deployment.py b/tests/unit/cli/utils/test_deployment.py index 93e99c29..09b0acc3 100644 --- a/tests/unit/cli/utils/test_deployment.py +++ b/tests/unit/cli/utils/test_deployment.py @@ -483,3 +483,98 @@ async def test_deploy_succeeds_without_api_key_when_no_remote_calls(tmp_path): await reconcile_and_provision_resources( app, "build-123", "dev", local_manifest, show_progress=False ) + + +@pytest.mark.asyncio +async def test_provision_resources_persists_ai_key_to_manifest(mock_flash_app): + manifest = { + "resources": { + "cpu": {"resource_type": "ServerlessResource"}, + } + } + mock_flash_app.get_build_manifest.return_value = manifest + + deployed = MagicMock() + deployed.endpoint_url = "https://example.com/endpoint" + deployed.id = "endpoint-123" + deployed.aiKey = "ai-key-123" + + with ( + patch("runpod_flash.cli.utils.deployment.ResourceManager") as mock_manager_cls, + patch( + "runpod_flash.cli.utils.deployment.create_resource_from_manifest" + ) as mock_create_resource, + ): + mock_manager = MagicMock() + mock_manager.get_or_deploy_resource = AsyncMock(return_value=deployed) + mock_manager_cls.return_value = mock_manager + mock_create_resource.return_value = MagicMock() + + await provision_resources_for_build( + mock_flash_app, + "build-123", + "dev", + show_progress=False, + ) + + call_args = mock_flash_app.update_build_manifest.call_args + updated_manifest = call_args[0][1] + assert updated_manifest["resources"]["cpu"]["endpoint_id"] == "endpoint-123" + assert updated_manifest["resources"]["cpu"]["aiKey"] == "ai-key-123" + + +@pytest.mark.asyncio +async def test_reconciliation_copies_ai_key_from_state_manifest(tmp_path): + import json + + flash_dir = tmp_path / ".flash" + flash_dir.mkdir() + + local_manifest = { + "resources": { + "worker": { + "resource_type": "LiveServerless", + "config": "same", + "endpoint_id": "endpoint-123", + "aiKey": "ai-key-123", + }, + }, + "resources_endpoints": {}, + } + (flash_dir / "flash_manifest.json").write_text(json.dumps(local_manifest)) + + state_manifest = { + "resources": { + "worker": { + "resource_type": "LiveServerless", + "config": "same", + "endpoint_id": "endpoint-123", + "aiKey": "ai-key-123", + }, + }, + "resources_endpoints": { + "worker": "https://worker.api.runpod.ai", + }, + } + + app = AsyncMock() + app.get_build_manifest = AsyncMock(return_value=state_manifest) + app.update_build_manifest = AsyncMock() + + with ( + patch("pathlib.Path.cwd", return_value=tmp_path), + patch("runpod_flash.cli.utils.deployment.ResourceManager") as mock_manager_cls, + ): + mock_manager = MagicMock() + mock_manager.get_or_deploy_resource = AsyncMock() + mock_manager_cls.return_value = mock_manager + + await reconcile_and_provision_resources(app, "build-123", "dev", local_manifest) + + updated_manifest = app.update_build_manifest.call_args[0][1] + assert updated_manifest["resources"]["worker"]["endpoint_id"] == "endpoint-123" + assert updated_manifest["resources"]["worker"]["aiKey"] == "ai-key-123" + assert ( + updated_manifest["resources_endpoints"]["worker"] + == "https://worker.api.runpod.ai" + ) diff --git a/tests/unit/resources/test_request_logs.py b/tests/unit/resources/test_request_logs.py new file mode 100644 index 00000000..2cef0656 --- /dev/null +++ b/tests/unit/resources/test_request_logs.py @@ -0,0 +1,141 @@ +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +from runpod_flash.core.resources.request_logs import ( + QBRequestLogBatch, + QBRequestLogFetcher, +) + + +def _make_async_client(mock_client: MagicMock) -> MagicMock: + cm = MagicMock() + cm.__aenter__ = AsyncMock(return_value=mock_client) + cm.__aexit__ = AsyncMock(return_value=None) + return cm + + +@pytest.mark.asyncio +async def test_fetch_logs_returns_batch_with_chronological_unique_lines(): + fetcher = QBRequestLogFetcher( + max_lines=5, + start_time=datetime(2026, 1, 1, tzinfo=timezone.utc), + ) + + logs_response = MagicMock() + logs_response.raise_for_status = MagicMock() + logs_response.json.return_value = { + "data": [ + {"message": "line two", "dt": "2026-01-01T00:00:02Z"}, + {"message": "line one", "dt": "2026-01-01T00:00:01Z"}, + {"message": "line two", "dt": "2026-01-01T00:00:02Z"}, + ] + } + + mock_client = MagicMock() + mock_client.get = AsyncMock(return_value=logs_response) + + with patch( + "runpod_flash.core.resources.request_logs.get_authenticated_httpx_client", + return_value=_make_async_client(mock_client), + ): + batch = await fetcher.fetch_logs( + endpoint_id="endpoint-1", + endpoint_ai_key="ai-key", + ) + + assert isinstance(batch, QBRequestLogBatch) + assert batch is not None + assert batch.matched_by_request_id is False + assert batch.worker_id is None + assert batch.lines == ["line one", "line two"] + + +@pytest.mark.asyncio +async def test_fetch_logs_dedupes_seen_lines_across_calls(): + fetcher = QBRequestLogFetcher( + max_lines=5, + start_time=datetime(2026, 1, 1, tzinfo=timezone.utc), + ) + + first_logs_response = MagicMock() + first_logs_response.raise_for_status = MagicMock() + first_logs_response.json.return_value = { + "data": [ + {"message": "line two"}, + {"message": "line one"}, + ] + } + + second_logs_response = MagicMock() + second_logs_response.raise_for_status = MagicMock() + second_logs_response.json.return_value = { + "data": [ + {"message": "line three"}, + {"message": "line two"}, + ] + } + + mock_client = MagicMock() + mock_client.get = AsyncMock(side_effect=[first_logs_response, second_logs_response]) + + with patch( + "runpod_flash.core.resources.request_logs.get_authenticated_httpx_client", + return_value=_make_async_client(mock_client), + ): + first_batch = await fetcher.fetch_logs("endpoint-1", "ai-key") + second_batch = await fetcher.fetch_logs("endpoint-1", "ai-key") + + assert first_batch is not None + assert first_batch.lines == ["line one", "line two"] + assert second_batch is not None + assert second_batch.lines == ["line three"] + + +@pytest.mark.asyncio +async def test_fetch_logs_returns_none_on_http_error(): + fetcher = QBRequestLogFetcher() + + mock_client = MagicMock() + mock_client.get = AsyncMock(side_effect=httpx.ReadTimeout("timed out")) + + with patch( + "runpod_flash.core.resources.request_logs.get_authenticated_httpx_client", + return_value=_make_async_client(mock_client), + ): + batch = await fetcher.fetch_logs( + endpoint_id="endpoint-1", + endpoint_ai_key="ai-key", + ) + + assert batch is None + + +@pytest.mark.asyncio +async def test_endpoint_log_fetch_uses_v2_with_aikey_bearer_auth(): + fetcher = QBRequestLogFetcher() + + logs_response = MagicMock() + logs_response.raise_for_status = MagicMock() + logs_response.json.return_value = {"data": []} + + mock_client = MagicMock() + mock_client.get = AsyncMock(return_value=logs_response) + + with patch( + "runpod_flash.core.resources.request_logs.get_authenticated_httpx_client", + return_value=_make_async_client(mock_client), + ) as mock_client_factory: + await fetcher.fetch_logs( + endpoint_id="endpoint-1", + endpoint_ai_key="ai-key-123", + ) + + log_call = mock_client.get.await_args_list[0] + assert log_call.args[0] == "https://api.runpod.ai/v2/endpoint-1/logs" + assert "from" in log_call.kwargs["params"] + assert "to" in log_call.kwargs["params"] + assert "aikey" not in log_call.kwargs["params"] + assert mock_client_factory.call_args.kwargs["api_key_override"] == "ai-key-123" diff --git a/tests/unit/resources/test_serverless.py b/tests/unit/resources/test_serverless.py index 8acd74b8..f58a9ae6 100644 --- a/tests/unit/resources/test_serverless.py +++ b/tests/unit/resources/test_serverless.py @@ -11,6 +11,7 @@ ServerlessResource, ServerlessEndpoint, ServerlessScalerType, + ServerlessType, CudaVersion, JobOutput, WorkersHealth, @@ -22,6 +23,7 @@ from runpod_flash.core.resources.gpu import GpuGroup from runpod_flash.core.resources.cpu import CpuInstanceType from runpod_flash.core.resources.network_volume import NetworkVolume, DataCenter +from runpod_flash.core.resources.request_logs import QBRequestLogBatch from runpod_flash.core.resources.template import PodTemplate @@ -821,6 +823,100 @@ async def test_run_async_success(self): assert result.id == "job-123" assert result.status == "COMPLETED" + @pytest.mark.asyncio + async def test_run_async_fetches_endpoint_logs_while_polling(self): + """Test run async polls endpoint logs every cycle until completion.""" + serverless = ServerlessResource(name="test") + serverless.id = "endpoint-123" + serverless.type = ServerlessType.QB + serverless.aiKey = "ai-key-123" + + mock_job = MagicMock() + mock_job.job_id = "job-123" + mock_job.status.side_effect = [ + "IN_QUEUE", + "IN_PROGRESS", + "IN_PROGRESS", + "COMPLETED", + ] + mock_job._fetch_job.return_value = { + "id": "job-123", + "workerId": "worker-456", + "status": "COMPLETED", + "delayTime": 1000, + "executionTime": 2000, + "output": {"result": "success"}, + } + + mock_endpoint = MagicMock() + mock_endpoint.run.return_value = mock_job + + with patch.object( + type(serverless), + "endpoint", + new_callable=lambda: property(lambda self: mock_endpoint), + ): + with patch("asyncio.sleep"): + with patch.object( + ServerlessResource, + "_emit_endpoint_logs", + new=AsyncMock(), + ) as mock_emit_logs: + await serverless.run({"input": "test"}) + + assert mock_emit_logs.await_count == 4 + fetchers = [call.kwargs["fetcher"] for call in mock_emit_logs.await_args_list] + assert len({id(fetcher) for fetcher in fetchers}) == 1 + + @pytest.mark.asyncio + async def test_emit_endpoint_logs_prints_worker_lines(self): + """Endpoint log emission prints each worker log line.""" + serverless = ServerlessResource(name="test") + serverless.id = "endpoint-123" + serverless.type = ServerlessType.QB + serverless.aiKey = "endpoint-ai-key" + + mock_fetcher = MagicMock() + mock_fetcher.fetch_logs = AsyncMock( + return_value=QBRequestLogBatch( + worker_id=None, + lines=["line-a", "line-b"], + matched_by_request_id=False, + ) + ) + + with patch("builtins.print") as mock_print: + await serverless._emit_endpoint_logs(fetcher=mock_fetcher) + + mock_fetcher.fetch_logs.assert_awaited_once_with( + endpoint_id="endpoint-123", + endpoint_ai_key="endpoint-ai-key", + ) + mock_print.assert_any_call("worker log: line-a") + mock_print.assert_any_call("worker log: line-b") + + @pytest.mark.asyncio + async def test_emit_endpoint_logs_skips_when_missing_required_fields(self): + """Endpoint log fetch is skipped unless QB endpoint has id and aiKey.""" + serverless = ServerlessResource(name="test") + mock_fetcher = MagicMock() + mock_fetcher.fetch_logs = AsyncMock(return_value=None) + + serverless.type = ServerlessType.QB + serverless.id = None + serverless.aiKey = "endpoint-ai-key" + await serverless._emit_endpoint_logs(fetcher=mock_fetcher) + + serverless.id = "endpoint-123" + serverless.aiKey = None + await serverless._emit_endpoint_logs(fetcher=mock_fetcher) + + serverless.type = ServerlessType.LB + serverless.aiKey = "endpoint-ai-key" + await serverless._emit_endpoint_logs(fetcher=mock_fetcher) + + mock_fetcher.fetch_logs.assert_not_awaited() + @pytest.mark.asyncio async def test_run_async_failure_cancels_job(self): """Test run async cancels job on exception.""" From 94650a504c6ec4fc120984b8f15e5ebeb2582585 Mon Sep 17 00:00:00 2001 From: jhcipar Date: Tue, 17 Mar 2026 16:16:07 -0400 Subject: [PATCH 2/6] chore: formatting --- .../core/resources/request_logs.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/runpod_flash/core/resources/request_logs.py b/src/runpod_flash/core/resources/request_logs.py index 8946bf48..5e3f4c62 100644 --- a/src/runpod_flash/core/resources/request_logs.py +++ b/src/runpod_flash/core/resources/request_logs.py @@ -1,6 +1,6 @@ import logging from dataclasses import dataclass -from datetime import datetime, timedelta, timezone +from datetime import datetime, timezone import dateutil from typing import Any, List, Optional @@ -23,6 +23,7 @@ class QBRequestLogBatch: matched_by_request_id: bool worker_id: Optional[str] + class QBRequestLogFetcher: def __init__( self, @@ -30,7 +31,7 @@ def __init__( max_lines: int = 25, fallback_tail_lines: int = 10, lookback_seconds: int = 5, - start_time: datetime = datetime.now(timezone.utc) + start_time: datetime = datetime.now(timezone.utc), ): self.timeout_seconds = timeout_seconds self.max_lines = max_lines @@ -48,16 +49,16 @@ async def fetch_logs( if self.fetched_until: self.start_time = self.fetched_until fetch_until = datetime.now(timezone.utc) - logs_payload = await self._fetch_endpoint_logs(endpoint_id, endpoint_ai_key, fetch_until) + logs_payload = await self._fetch_endpoint_logs( + endpoint_id, endpoint_ai_key, fetch_until + ) if not logs_payload: return lines = self._extract_lines(logs_payload, fetch_until) return QBRequestLogBatch( - lines = lines, - matched_by_request_id = False, - worker_id = None - ) + lines=lines, matched_by_request_id=False, worker_id=None + ) async def _fetch_worker_id( self, @@ -89,7 +90,7 @@ async def _fetch_endpoint_logs( endpoint_id: str, endpoint_ai_key: str, end_utc: datetime, - start_utc: Optional[datetime] = None + start_utc: Optional[datetime] = None, ) -> Optional[dict[str, Any]]: """ fetch endpoint logs for a given time range, defaulting to the fetcher @@ -166,7 +167,7 @@ def _extract_lines(self, payload: dict[str, Any], end_time: datetime) -> List[st if dt: parsed = dateutil.parser.parse(dt) if parsed.tzinfo is None: - parsed = parsed.replace(tzinfo=timezone.utc) # treat naive as UTC + parsed = parsed.replace(tzinfo=timezone.utc) # treat naive as UTC else: parsed = parsed.astimezone(timezone.utc) max_seen_dt = max(parsed, max_seen_dt) From 6d930fff49ca9f77d289e88567a037f0493c5a6a Mon Sep 17 00:00:00 2001 From: jhcipar Date: Thu, 2 Apr 2026 14:55:23 -0400 Subject: [PATCH 3/6] chore: switch to pod level polling and endpoint status # Conflicts: # src/runpod_flash/core/resources/serverless.py --- .../core/resources/request_logs.py | 401 +++++++++++++----- src/runpod_flash/core/resources/serverless.py | 129 +++++- tests/unit/resources/test_request_logs.py | 240 ++++++++--- tests/unit/resources/test_serverless.py | 179 +++++++- 4 files changed, 762 insertions(+), 187 deletions(-) diff --git a/src/runpod_flash/core/resources/request_logs.py b/src/runpod_flash/core/resources/request_logs.py index 5e3f4c62..e2bc2d88 100644 --- a/src/runpod_flash/core/resources/request_logs.py +++ b/src/runpod_flash/core/resources/request_logs.py @@ -1,7 +1,9 @@ import logging +import os +import re from dataclasses import dataclass from datetime import datetime, timezone -import dateutil +from enum import Enum from typing import Any, List, Optional import httpx @@ -11,10 +13,30 @@ log = logging.getLogger(__name__) API_BASE_URL = "https://api.runpod.ai" +DEV_API_BASE_URL = "https://dev-api.runpod.ai" +HAPI_BASE_URL = "https://hapi.runpod.net" +DEV_HAPI_BASE_URL = "https://dev-hapi.runpod.net" +LOG_PREFIX_TIMESTAMP_RE = re.compile( + r"^(?P\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(?:\.\d+)?Z)" +) -def _format_log_timestamp(value: datetime) -> str: - return value.astimezone(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.000Z") +def _resolve_hapi_base_url() -> str: + runpod_env = os.getenv("RUNPOD_ENV", "").lower() + if runpod_env == "dev": + return DEV_HAPI_BASE_URL + + api_base = os.getenv("RUNPOD_API_BASE_URL", "") + if DEV_API_BASE_URL in api_base: + return DEV_HAPI_BASE_URL + + return HAPI_BASE_URL + + +class QBRequestLogPhase(str, Enum): + WAITING_FOR_WORKER = "WAITING_FOR_WORKER" + WAITING_FOR_WORKER_INITIALIZATION = "WAITING_FOR_WORKER_INITIALIZATION" + STREAMING = "STREAMING" @dataclass @@ -22,6 +44,7 @@ class QBRequestLogBatch: lines: List[str] matched_by_request_id: bool worker_id: Optional[str] + phase: QBRequestLogPhase class QBRequestLogFetcher: @@ -30,91 +53,230 @@ def __init__( timeout_seconds: float = 4.0, max_lines: int = 25, fallback_tail_lines: int = 10, - lookback_seconds: int = 5, - start_time: datetime = datetime.now(timezone.utc), + lookback_seconds: int = 20, + start_time: Optional[datetime] = None, ): self.timeout_seconds = timeout_seconds self.max_lines = max_lines self.fallback_tail_lines = fallback_tail_lines self.lookback_seconds = lookback_seconds - self.start_time = start_time - self.fetched_until: datetime | None = None + self.start_time = start_time or datetime.now(timezone.utc) self.seen = set() + self.worker_id: Optional[str] = None + self.has_streamed_logs = False + self.has_primed_worker_logs = False async def fetch_logs( self, endpoint_id: str, - endpoint_ai_key: str, + request_id: str, + status_api_key: str, + pod_logs_api_key: str, + status_api_key_fallback: Optional[str] = None, ): - if self.fetched_until: - self.start_time = self.fetched_until - fetch_until = datetime.now(timezone.utc) - logs_payload = await self._fetch_endpoint_logs( - endpoint_id, endpoint_ai_key, fetch_until + status_payload = await self._fetch_status_payload( + endpoint_id=endpoint_id, + request_id=request_id, + status_api_key=status_api_key, + status_api_key_fallback=status_api_key_fallback, + ) + assigned_worker_id = self._worker_id_from_status_payload(status_payload) + + metrics_payload = await self._fetch_metrics_payload( + endpoint_id=endpoint_id, + status_api_key=status_api_key, + status_api_key_fallback=status_api_key_fallback, + ) + running_worker_ids = self._running_worker_ids_from_metrics(metrics_payload) + initializing_workers = self._initializing_worker_count(metrics_payload) + + matched_by_request_id = False + if assigned_worker_id: + self._set_worker_id(assigned_worker_id) + matched_by_request_id = True + elif not self.worker_id and running_worker_ids: + self._set_worker_id(running_worker_ids[0]) + + if not self.worker_id: + phase = ( + QBRequestLogPhase.WAITING_FOR_WORKER_INITIALIZATION + if initializing_workers > 0 + else QBRequestLogPhase.WAITING_FOR_WORKER + ) + return QBRequestLogBatch( + lines=[], + matched_by_request_id=False, + worker_id=None, + phase=phase, + ) + + logs_payload = await self._fetch_pod_logs( + worker_id=self.worker_id, + runpod_api_key=pod_logs_api_key, ) if not logs_payload: - return + return QBRequestLogBatch( + lines=[], + matched_by_request_id=matched_by_request_id, + worker_id=self.worker_id, + phase=QBRequestLogPhase.WAITING_FOR_WORKER_INITIALIZATION, + ) + + if not self.has_primed_worker_logs: + lines = self._extract_initial_lines(logs_payload, request_id=request_id) + self.has_primed_worker_logs = True + if lines: + self.has_streamed_logs = True + return QBRequestLogBatch( + lines=lines[-self.max_lines :], + matched_by_request_id=matched_by_request_id, + worker_id=self.worker_id, + phase=( + QBRequestLogPhase.STREAMING + if self.has_streamed_logs + else QBRequestLogPhase.WAITING_FOR_WORKER_INITIALIZATION + ), + ) - lines = self._extract_lines(logs_payload, fetch_until) + lines = self._extract_lines(logs_payload) + if lines: + self.has_streamed_logs = True + phase = ( + QBRequestLogPhase.STREAMING + if self.has_streamed_logs + else QBRequestLogPhase.WAITING_FOR_WORKER_INITIALIZATION + ) return QBRequestLogBatch( - lines=lines, matched_by_request_id=False, worker_id=None + lines=lines[-self.max_lines :], + matched_by_request_id=matched_by_request_id, + worker_id=self.worker_id, + phase=phase, ) - async def _fetch_worker_id( + async def _fetch_status_payload( self, endpoint_id: str, request_id: str, - runpod_api_key: str, - ) -> Optional[str]: + status_api_key: str, + status_api_key_fallback: Optional[str], + ) -> Optional[dict[str, Any]]: url = f"{API_BASE_URL}/v2/{endpoint_id}/status/{request_id}" + auth_keys = self._auth_candidates(status_api_key, status_api_key_fallback) - try: - async with get_authenticated_httpx_client( - timeout=self.timeout_seconds, - api_key_override=runpod_api_key, - ) as client: - response = await client.get(url) - response.raise_for_status() - payload = response.json() - except (httpx.HTTPError, ValueError) as exc: - log.debug("Failed to fetch worker for request %s: %s", request_id, exc) - return None + for auth_key in auth_keys: + try: + async with get_authenticated_httpx_client( + timeout=self.timeout_seconds, + api_key_override=auth_key, + ) as client: + response = await client.get(url) + response.raise_for_status() + return response.json() + except httpx.HTTPStatusError as exc: + if exc.response is not None and exc.response.status_code == 401: + continue + log.debug( + "Failed to fetch worker for request %s: %s", + request_id, + exc, + ) + return None + except (httpx.HTTPError, ValueError) as exc: + log.debug("Failed to fetch worker for request %s: %s", request_id, exc) + return None + return None + + async def _fetch_metrics_payload( + self, + endpoint_id: str, + status_api_key: str, + status_api_key_fallback: Optional[str], + ) -> Optional[dict[str, Any]]: + url = f"{API_BASE_URL}/v1/{endpoint_id}/metrics" + auth_keys = self._auth_candidates(status_api_key, status_api_key_fallback) + + for auth_key in auth_keys: + try: + async with get_authenticated_httpx_client( + timeout=self.timeout_seconds, + api_key_override=auth_key, + ) as client: + response = await client.get(url) + response.raise_for_status() + return response.json() + except httpx.HTTPStatusError as exc: + if exc.response is not None and exc.response.status_code == 401: + continue + log.debug( + "Failed to fetch endpoint metrics for %s: %s", endpoint_id, exc + ) + return None + except (httpx.HTTPError, ValueError) as exc: + log.debug( + "Failed to fetch endpoint metrics for %s: %s", endpoint_id, exc + ) + return None + + return None + + @staticmethod + def _worker_id_from_status_payload( + payload: Optional[dict[str, Any]], + ) -> Optional[str]: + if not payload: + return None worker_id = payload.get("workerId") if not worker_id: return None return str(worker_id) - async def _fetch_endpoint_logs( + @staticmethod + def _running_worker_ids_from_metrics( + payload: Optional[dict[str, Any]], + ) -> List[str]: + if not payload: + return [] + ready_workers = payload.get("readyWorkers") + if not isinstance(ready_workers, list): + return [] + return [str(worker) for worker in ready_workers if worker] + + @staticmethod + def _initializing_worker_count(payload: Optional[dict[str, Any]]) -> int: + if not payload: + return 0 + workers = payload.get("workers") + if not isinstance(workers, dict): + return 0 + initializing = workers.get("initializing", 0) + if isinstance(initializing, int): + return initializing + return 0 + + @staticmethod + def _auth_candidates( + primary_key: str, + fallback_key: Optional[str], + ) -> List[str]: + keys = [primary_key] + if fallback_key and fallback_key != primary_key: + keys.append(fallback_key) + return keys + + async def _fetch_pod_logs( self, - endpoint_id: str, - endpoint_ai_key: str, - end_utc: datetime, - start_utc: Optional[datetime] = None, + worker_id: str, + runpod_api_key: str, ) -> Optional[dict[str, Any]]: - """ - fetch endpoint logs for a given time range, defaulting to the fetcher - configured start time - updates start utc when we successfully fetch logs - """ - url = f"{API_BASE_URL}/v2/{endpoint_id}/logs" - if not start_utc: - start_utc = self.start_time - - log.debug(f"fetching logs for time range: {start_utc} to {end_utc}") - params = { - "from": _format_log_timestamp(start_utc), - "to": _format_log_timestamp(end_utc), - "page": 0, - "pageSize": 200, - } + url = f"{_resolve_hapi_base_url()}/v1/pod/{worker_id}/logs" try: async with get_authenticated_httpx_client( timeout=self.timeout_seconds, - api_key_override=endpoint_ai_key, + api_key_override=runpod_api_key, ) as client: - response = await client.get(url, params=params) + response = await client.get(url) response.raise_for_status() return response.json() except httpx.HTTPStatusError as exc: @@ -122,69 +284,108 @@ async def _fetch_endpoint_logs( if exc.response is not None: body_preview = (exc.response.text or "")[:500] log.debug( - "Failed to fetch endpoint logs for %s: %s | response_body=%s", - endpoint_id, + "Failed to fetch pod logs for %s: %s | response_body=%s", + worker_id, exc, body_preview, ) return None except (httpx.HTTPError, ValueError) as exc: - log.debug("Failed to fetch endpoint logs for %s: %s", endpoint_id, exc) + log.debug("Failed to fetch pod logs for %s: %s", worker_id, exc) return None - def _extract_lines(self, payload: dict[str, Any], end_time: datetime) -> List[str]: - """ - extract lines from a response payload from sls endpoint response - deduplicates based on already seen lines - """ - records = payload.get("data") - if not isinstance(records, list): + def _extract_lines(self, payload: dict[str, Any]) -> List[str]: + records = self._collect_records(payload) + if not records: return [] - max_seen_dt = self.start_time lines: List[str] = [] for record in records: - if isinstance(record, str): - stripped = record.strip() - if stripped and stripped not in self.seen: - self.seen.add(stripped) - lines.append(record) + if not isinstance(record, str): continue - if not isinstance(record, dict): + stripped = record.strip().replace("\\n", "") + if not stripped or stripped in self.seen: continue + self.seen.add(stripped) + lines.append(stripped) - line = ( - record.get("message") - or record.get("log") - or record.get("text") - or record.get("raw") - ) + return lines - dt = record.get("dt") - - if dt: - parsed = dateutil.parser.parse(dt) - if parsed.tzinfo is None: - parsed = parsed.replace(tzinfo=timezone.utc) # treat naive as UTC - else: - parsed = parsed.astimezone(timezone.utc) - max_seen_dt = max(parsed, max_seen_dt) - - if isinstance(line, str): - stripped = line.strip() - if stripped and stripped not in self.seen: - stripped = stripped.replace("\\n", "") - self.seen.add(stripped) - lines.append(stripped) - if lines: - # lines are returned in time descending order - lines.reverse() - if max_seen_dt > self.start_time: - self.fetched_until = max_seen_dt - else: - # not all logs have a timestamp, assume we should refetch - self.fetched_until = self.start_time + def _extract_initial_lines( + self, payload: dict[str, Any], request_id: str + ) -> List[str]: + records = self._collect_records(payload) + if not records: + return [] + + cutoff = self.start_time.timestamp() - self.lookback_seconds + lines: List[str] = [] + saw_recent_window_line = False + + for record in records: + if not isinstance(record, str): + continue + + stripped = record.strip().replace("\\n", "") + if not stripped: + continue + + if stripped in self.seen: + continue + self.seen.add(stripped) + + timestamp = self._parse_prefix_timestamp(stripped) + if timestamp is not None and timestamp.timestamp() < cutoff: + continue + + if timestamp is not None: + saw_recent_window_line = True + lines.append(stripped) + continue + + if request_id and request_id in stripped: + lines.append(stripped) + continue + + if saw_recent_window_line: + lines.append(stripped) + continue return lines + + def _set_worker_id(self, worker_id: str) -> None: + if self.worker_id == worker_id: + return + self.worker_id = worker_id + self.seen = set() + self.has_streamed_logs = False + self.has_primed_worker_logs = False + + @staticmethod + def _collect_records(payload: dict[str, Any]) -> List[Any]: + container_records = payload.get("container") + system_records = payload.get("system") + + records: list[Any] = [] + if isinstance(system_records, list): + records.extend(system_records) + if isinstance(container_records, list): + records.extend(container_records) + + return records + + @staticmethod + def _parse_prefix_timestamp(line: str) -> Optional[datetime]: + match = LOG_PREFIX_TIMESTAMP_RE.match(line) + if not match: + return None + + timestamp_text = match.group("timestamp") + normalized = timestamp_text.replace("Z", "+00:00") + + try: + return datetime.fromisoformat(normalized) + except ValueError: + return None diff --git a/src/runpod_flash/core/resources/serverless.py b/src/runpod_flash/core/resources/serverless.py index 01f6a529..eb2327fe 100644 --- a/src/runpod_flash/core/resources/serverless.py +++ b/src/runpod_flash/core/resources/serverless.py @@ -2,6 +2,8 @@ import json import logging import os +import re +from datetime import datetime, timezone from enum import Enum from pathlib import Path from typing import Any, ClassVar, Dict, List, Optional, Set @@ -30,11 +32,11 @@ from .environment import EnvironmentVars from .cpu import CpuInstanceType from .gpu import GpuGroup, GpuType - -from .request_logs import QBRequestLogFetcher from .network_volume import NetworkVolume, DataCenter, CPU_DATACENTERS +from .request_logs import QBRequestLogFetcher, QBRequestLogPhase from .template import KeyValuePair, PodTemplate from .resource_manager import ResourceManager +from ..credentials import get_api_key # Prefix applied to endpoint names during live provisioning @@ -54,6 +56,27 @@ def get_env_vars() -> Dict[str, str]: log = logging.getLogger(__name__) +POD_LOG_PREFIX_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(?:\.\d+)?Z\s+") + + +def _is_prod_environment() -> bool: + env = os.getenv("RUNPOD_ENV") + if env: + return env.lower() == "prod" + api_base = os.getenv("RUNPOD_API_BASE_URL", "https://api.runpod.io") + return "api.runpod.io" in api_base or "api.runpod.ai" in api_base + + +def _normalize_stream_log_line(line: str) -> str: + normalized = line.strip() + if not normalized: + return "" + + if normalized.lower().startswith("worker log:"): + normalized = normalized.split(":", 1)[1].strip() + + normalized = POD_LOG_PREFIX_RE.sub("", normalized, count=1) + return normalized class ServerlessScalerType(Enum): @@ -222,24 +245,36 @@ def endpoint_url(self) -> str: async def _emit_endpoint_logs( self, fetcher: QBRequestLogFetcher, - ): + request_id: str, + ) -> Optional["QBRequestLogBatch"]: if self.type != ServerlessType.QB: - return + return None - if not self.id or not self.aiKey: - return + if not self.id: + return None + + pod_logs_api_key = get_api_key() + if not pod_logs_api_key: + return None + + status_api_key = self.aiKey or pod_logs_api_key batch = await fetcher.fetch_logs( endpoint_id=self.id, - endpoint_ai_key=self.aiKey, + request_id=request_id, + status_api_key=status_api_key, + pod_logs_api_key=pod_logs_api_key, + status_api_key_fallback=pod_logs_api_key, ) if not batch: - return False + return None if batch.lines: for line in batch.lines: print(f"worker log: {line}") + return batch + @field_serializer("scalerType") def serialize_scaler_type( self, value: Optional[ServerlessScalerType] @@ -1127,7 +1162,16 @@ async def run(self, payload: Dict[str, Any]) -> "JobOutput": attempt = 0 job_status = Status.UNKNOWN last_status = job_status - fetcher = QBRequestLogFetcher() + fetcher = QBRequestLogFetcher(start_time=datetime.now(timezone.utc)) + last_log_state: ( + tuple[ + QBRequestLogPhase, + bool, + Optional[str], + ] + | None + ) = None + assigned_streaming_announced_worker: Optional[str] = None # Poll for job status while True: @@ -1147,28 +1191,91 @@ async def run(self, payload: Dict[str, Any]) -> "JobOutput": log.info(f"{log_subgroup} | Status: {job_status}") attempt = 0 - await self._emit_endpoint_logs( + batch = await self._emit_endpoint_logs( fetcher=fetcher, + request_id=job.job_id, ) + if batch: + current_log_state = ( + batch.phase, + batch.matched_by_request_id, + batch.worker_id, + ) + state_changed = current_log_state != last_log_state + + if ( + batch.phase == QBRequestLogPhase.STREAMING + and batch.matched_by_request_id + and batch.worker_id + ): + if assigned_streaming_announced_worker != batch.worker_id: + log.info( + f"{log_subgroup} | Request assigned to worker {batch.worker_id}, streaming pod logs" + ) + assigned_streaming_announced_worker = batch.worker_id + elif state_changed: + if batch.phase == QBRequestLogPhase.WAITING_FOR_WORKER: + log.info( + f"{log_subgroup} | No workers available; check that your endpoint is properly configured and/or GPU availability for selected GPUs" + ) + elif ( + batch.phase + == QBRequestLogPhase.WAITING_FOR_WORKER_INITIALIZATION + ): + if batch.matched_by_request_id and batch.worker_id: + log.info( + f"{log_subgroup} | Request assigned to worker {batch.worker_id}, waiting for worker initialization/image pull logs" + ) + elif batch.worker_id: + log.info( + f"{log_subgroup} | Worker capacity detected, waiting for request assignment and worker initialization/image pull" + ) + else: + log.info( + f"{log_subgroup} | Waiting for worker initialization/image pull" + ) + elif batch.phase == QBRequestLogPhase.STREAMING: + log.info( + f"{log_subgroup} | Streaming endpoint startup logs while waiting for request assignment" + ) + + last_log_state = current_log_state + last_status = job_status # Adjust polling pace appropriately current_pace = get_backoff_delay(attempt, max_seconds=5) if job_status in ("COMPLETED", "FAILED", "CANCELLED"): + for _ in range(2): + await self._emit_endpoint_logs( + fetcher=fetcher, + request_id=job.job_id, + ) response = await asyncio.to_thread(job._fetch_job) output = response.get("output") if isinstance(output, dict): stdout = output.get("stdout") if isinstance(stdout, str): + seen_normalized = { + normalized + for line in fetcher.seen + if (normalized := _normalize_stream_log_line(line)) + } kept = [] for raw in stdout.splitlines(): raw = raw.strip() if not raw: continue - if raw in fetcher.seen: + + normalized_raw = _normalize_stream_log_line(raw) + if not normalized_raw: + continue + if normalized_raw in seen_normalized: continue + + seen_normalized.add(normalized_raw) kept.append(raw) output["stdout"] = "\n".join(kept) return JobOutput(**response) diff --git a/tests/unit/resources/test_request_logs.py b/tests/unit/resources/test_request_logs.py index 2cef0656..9619d3bf 100644 --- a/tests/unit/resources/test_request_logs.py +++ b/tests/unit/resources/test_request_logs.py @@ -5,8 +5,8 @@ import pytest from runpod_flash.core.resources.request_logs import ( - QBRequestLogBatch, QBRequestLogFetcher, + QBRequestLogPhase, ) @@ -18,24 +18,22 @@ def _make_async_client(mock_client: MagicMock) -> MagicMock: @pytest.mark.asyncio -async def test_fetch_logs_returns_batch_with_chronological_unique_lines(): - fetcher = QBRequestLogFetcher( - max_lines=5, - start_time=datetime(2026, 1, 1, tzinfo=timezone.utc), - ) - - logs_response = MagicMock() - logs_response.raise_for_status = MagicMock() - logs_response.json.return_value = { - "data": [ - {"message": "line two", "dt": "2026-01-01T00:00:02Z"}, - {"message": "line one", "dt": "2026-01-01T00:00:01Z"}, - {"message": "line two", "dt": "2026-01-01T00:00:02Z"}, - ] +async def test_waiting_for_workers_when_none_running_or_initializing(): + fetcher = QBRequestLogFetcher(start_time=datetime(2026, 1, 1, tzinfo=timezone.utc)) + + status_response = MagicMock() + status_response.raise_for_status = MagicMock() + status_response.json.return_value = {"status": "IN_QUEUE"} + + metrics_response = MagicMock() + metrics_response.raise_for_status = MagicMock() + metrics_response.json.return_value = { + "workers": {"initializing": 0}, + "readyWorkers": [], } mock_client = MagicMock() - mock_client.get = AsyncMock(return_value=logs_response) + mock_client.get = AsyncMock(side_effect=[status_response, metrics_response]) with patch( "runpod_flash.core.resources.request_logs.get_authenticated_httpx_client", @@ -43,99 +41,207 @@ async def test_fetch_logs_returns_batch_with_chronological_unique_lines(): ): batch = await fetcher.fetch_logs( endpoint_id="endpoint-1", - endpoint_ai_key="ai-key", + request_id="request-1", + status_api_key="status-key", + pod_logs_api_key="runpod-key", ) - assert isinstance(batch, QBRequestLogBatch) - assert batch is not None - assert batch.matched_by_request_id is False + assert batch.phase == QBRequestLogPhase.WAITING_FOR_WORKER assert batch.worker_id is None - assert batch.lines == ["line one", "line two"] + assert batch.lines == [] @pytest.mark.asyncio -async def test_fetch_logs_dedupes_seen_lines_across_calls(): - fetcher = QBRequestLogFetcher( - max_lines=5, - start_time=datetime(2026, 1, 1, tzinfo=timezone.utc), - ) +async def test_waiting_for_worker_initialization_when_workers_initializing(): + fetcher = QBRequestLogFetcher() - first_logs_response = MagicMock() - first_logs_response.raise_for_status = MagicMock() - first_logs_response.json.return_value = { - "data": [ - {"message": "line two"}, - {"message": "line one"}, - ] - } + status_response = MagicMock() + status_response.raise_for_status = MagicMock() + status_response.json.return_value = {"status": "IN_QUEUE"} - second_logs_response = MagicMock() - second_logs_response.raise_for_status = MagicMock() - second_logs_response.json.return_value = { - "data": [ - {"message": "line three"}, - {"message": "line two"}, - ] + metrics_response = MagicMock() + metrics_response.raise_for_status = MagicMock() + metrics_response.json.return_value = { + "workers": {"initializing": 1}, + "readyWorkers": [], } mock_client = MagicMock() - mock_client.get = AsyncMock(side_effect=[first_logs_response, second_logs_response]) + mock_client.get = AsyncMock(side_effect=[status_response, metrics_response]) with patch( "runpod_flash.core.resources.request_logs.get_authenticated_httpx_client", return_value=_make_async_client(mock_client), ): - first_batch = await fetcher.fetch_logs("endpoint-1", "ai-key") - second_batch = await fetcher.fetch_logs("endpoint-1", "ai-key") + batch = await fetcher.fetch_logs( + endpoint_id="endpoint-1", + request_id="request-1", + status_api_key="status-key", + pod_logs_api_key="runpod-key", + ) - assert first_batch is not None - assert first_batch.lines == ["line one", "line two"] - assert second_batch is not None - assert second_batch.lines == ["line three"] + assert batch.phase == QBRequestLogPhase.WAITING_FOR_WORKER_INITIALIZATION + assert batch.worker_id is None + assert batch.lines == [] @pytest.mark.asyncio -async def test_fetch_logs_returns_none_on_http_error(): - fetcher = QBRequestLogFetcher() +async def test_primes_existing_worker_logs_then_streams_new_lines(): + fetcher = QBRequestLogFetcher( + start_time=datetime(2026, 4, 2, 17, 14, 7, tzinfo=timezone.utc), + lookback_seconds=20, + ) + + status_1 = MagicMock() + status_1.raise_for_status = MagicMock() + status_1.json.return_value = {"status": "IN_QUEUE"} + + metrics_1 = MagicMock() + metrics_1.raise_for_status = MagicMock() + metrics_1.json.return_value = { + "workers": {"initializing": 0, "running": 1}, + "readyWorkers": ["worker-running-1"], + } + + old_logs = MagicMock() + old_logs.raise_for_status = MagicMock() + old_logs.json.return_value = { + "container": ["2026-04-02T17:14:05Z create container"], + "system": [ + "2026-04-02T16:38:18Z very old line", + '{"requestId": "request-1", "message": "Started.", "level": "INFO"}', + "ae1225 smoke: worker started", + ], + } + + status_2 = MagicMock() + status_2.raise_for_status = MagicMock() + status_2.json.return_value = {"status": "IN_QUEUE"} + + metrics_2 = MagicMock() + metrics_2.raise_for_status = MagicMock() + metrics_2.json.return_value = { + "workers": {"initializing": 0, "running": 1}, + "readyWorkers": ["worker-running-1"], + } + + new_logs = MagicMock() + new_logs.raise_for_status = MagicMock() + new_logs.json.return_value = { + "container": ["2026-04-02T17:14:08Z start container"], + "system": ["2026-04-02T17:14:05Z create container"], + } mock_client = MagicMock() - mock_client.get = AsyncMock(side_effect=httpx.ReadTimeout("timed out")) + mock_client.get = AsyncMock( + side_effect=[status_1, metrics_1, old_logs, status_2, metrics_2, new_logs] + ) with patch( "runpod_flash.core.resources.request_logs.get_authenticated_httpx_client", return_value=_make_async_client(mock_client), ): - batch = await fetcher.fetch_logs( + first_batch = await fetcher.fetch_logs( endpoint_id="endpoint-1", - endpoint_ai_key="ai-key", + request_id="request-1", + status_api_key="endpoint-ai-key", + pod_logs_api_key="runpod-key", ) + second_batch = await fetcher.fetch_logs( + endpoint_id="endpoint-1", + request_id="request-1", + status_api_key="endpoint-ai-key", + pod_logs_api_key="runpod-key", + ) + + assert first_batch.worker_id == "worker-running-1" + assert first_batch.phase == QBRequestLogPhase.STREAMING + assert first_batch.lines == [ + '{"requestId": "request-1", "message": "Started.", "level": "INFO"}', + "2026-04-02T17:14:05Z create container", + ] - assert batch is None + assert second_batch.worker_id == "worker-running-1" + assert second_batch.phase == QBRequestLogPhase.STREAMING + assert second_batch.lines == ["2026-04-02T17:14:08Z start container"] + assert second_batch.matched_by_request_id is False @pytest.mark.asyncio -async def test_endpoint_log_fetch_uses_v2_with_aikey_bearer_auth(): +async def test_status_uses_fallback_key_on_401(): fetcher = QBRequestLogFetcher() - logs_response = MagicMock() - logs_response.raise_for_status = MagicMock() - logs_response.json.return_value = {"data": []} + unauthorized = httpx.Response( + status_code=401, + request=httpx.Request( + "GET", "https://api.runpod.ai/v2/endpoint-1/status/request-1" + ), + ) + + status_response = MagicMock() + status_response.raise_for_status = MagicMock() + status_response.json.return_value = {"workerId": "worker-123"} + + metrics_response = MagicMock() + metrics_response.raise_for_status = MagicMock() + metrics_response.json.return_value = { + "workers": {"initializing": 0, "running": 1}, + "readyWorkers": ["worker-123"], + } + + pod_logs_response = MagicMock() + pod_logs_response.raise_for_status = MagicMock() + pod_logs_response.json.return_value = {"container": ["old"], "system": []} + + status_response_2 = MagicMock() + status_response_2.raise_for_status = MagicMock() + status_response_2.json.return_value = {"workerId": "worker-123"} + + metrics_response_2 = MagicMock() + metrics_response_2.raise_for_status = MagicMock() + metrics_response_2.json.return_value = { + "workers": {"initializing": 0, "running": 1}, + "readyWorkers": ["worker-123"], + } + + pod_logs_response_2 = MagicMock() + pod_logs_response_2.raise_for_status = MagicMock() + pod_logs_response_2.json.return_value = {"container": ["new"], "system": []} mock_client = MagicMock() - mock_client.get = AsyncMock(return_value=logs_response) + mock_client.get = AsyncMock( + side_effect=[ + httpx.HTTPStatusError( + "unauthorized", request=unauthorized.request, response=unauthorized + ), + status_response, + metrics_response, + pod_logs_response, + status_response_2, + metrics_response_2, + pod_logs_response_2, + ] + ) with patch( "runpod_flash.core.resources.request_logs.get_authenticated_httpx_client", return_value=_make_async_client(mock_client), - ) as mock_client_factory: + ): await fetcher.fetch_logs( endpoint_id="endpoint-1", - endpoint_ai_key="ai-key-123", + request_id="request-1", + status_api_key="endpoint-ai-key", + pod_logs_api_key="runpod-key", + status_api_key_fallback="runpod-key", + ) + second_batch = await fetcher.fetch_logs( + endpoint_id="endpoint-1", + request_id="request-1", + status_api_key="endpoint-ai-key", + pod_logs_api_key="runpod-key", + status_api_key_fallback="runpod-key", ) - log_call = mock_client.get.await_args_list[0] - assert log_call.args[0] == "https://api.runpod.ai/v2/endpoint-1/logs" - assert "from" in log_call.kwargs["params"] - assert "to" in log_call.kwargs["params"] - assert "aikey" not in log_call.kwargs["params"] - assert mock_client_factory.call_args.kwargs["api_key_override"] == "ai-key-123" + assert second_batch.worker_id == "worker-123" + assert second_batch.phase == QBRequestLogPhase.STREAMING + assert second_batch.lines == ["new"] diff --git a/tests/unit/resources/test_serverless.py b/tests/unit/resources/test_serverless.py index b1f86c9f..534a9bdf 100644 --- a/tests/unit/resources/test_serverless.py +++ b/tests/unit/resources/test_serverless.py @@ -23,7 +23,10 @@ from runpod_flash.core.resources.gpu import GpuGroup from runpod_flash.core.resources.cpu import CpuInstanceType from runpod_flash.core.resources.network_volume import NetworkVolume, DataCenter -from runpod_flash.core.resources.request_logs import QBRequestLogBatch +from runpod_flash.core.resources.request_logs import ( + QBRequestLogBatch, + QBRequestLogPhase, +) from runpod_flash.core.resources.template import PodTemplate @@ -1116,6 +1119,58 @@ async def test_run_async_success(self): assert result.id == "job-123" assert result.status == "COMPLETED" + @pytest.mark.asyncio + async def test_run_async_dedupes_stdout_against_streamed_pod_logs(self): + serverless = ServerlessResource(name="test") + serverless.id = "endpoint-123" + serverless.type = ServerlessType.QB + serverless.aiKey = "endpoint-ai-key" + + mock_job = MagicMock() + mock_job.job_id = "job-123" + mock_job.status.side_effect = ["IN_QUEUE", "COMPLETED"] + mock_job._fetch_job.return_value = { + "id": "job-123", + "workerId": "worker-456", + "status": "COMPLETED", + "delayTime": 1000, + "executionTime": 2000, + "output": { + "stdout": "2026-04-02T18:18:10.165152015Z 2026-04-02 18:18:10,164 | DEBUG | aiohttp_retry | client.py:110 | Attempt 1 out of 3\n" + "2026-04-02 18:18:10,164 | DEBUG | aiohttp_retry | client.py:110 | Attempt 1 out of 3\n" + "unique stdout line" + }, + } + + mock_endpoint = MagicMock() + mock_endpoint.run.return_value = mock_job + + async def fake_emit(*, fetcher, request_id): + fetcher.seen.add( + "2026-04-02T18:18:10.165152015Z 2026-04-02 18:18:10,164 | DEBUG | aiohttp_retry | client.py:110 | Attempt 1 out of 3" + ) + return None + + with patch.object( + type(serverless), + "endpoint", + new_callable=lambda: property(lambda self: mock_endpoint), + ): + with patch("asyncio.sleep"): + with patch.object( + ServerlessResource, + "_emit_endpoint_logs", + new=AsyncMock(side_effect=fake_emit), + ): + with patch( + "runpod_flash.core.resources.serverless.get_api_key", + return_value="runpod-key-123", + ): + result = await serverless.run({"input": "test"}) + + assert isinstance(result, JobOutput) + assert result.output["stdout"] == "unique stdout line" + @pytest.mark.asyncio async def test_run_async_fetches_endpoint_logs_while_polling(self): """Test run async polls endpoint logs every cycle until completion.""" @@ -1157,9 +1212,81 @@ async def test_run_async_fetches_endpoint_logs_while_polling(self): ) as mock_emit_logs: await serverless.run({"input": "test"}) - assert mock_emit_logs.await_count == 4 + assert mock_emit_logs.await_count == 6 fetchers = [call.kwargs["fetcher"] for call in mock_emit_logs.await_args_list] assert len({id(fetcher) for fetcher in fetchers}) == 1 + request_ids = [ + call.kwargs["request_id"] for call in mock_emit_logs.await_args_list + ] + assert all(request_id == "job-123" for request_id in request_ids) + + @pytest.mark.asyncio + async def test_run_async_announces_assigned_worker_streaming_once(self): + serverless = ServerlessResource(name="test") + serverless.id = "endpoint-123" + serverless.type = ServerlessType.QB + serverless.aiKey = "ai-key-123" + + mock_job = MagicMock() + mock_job.job_id = "job-123" + mock_job.status.side_effect = [ + "IN_PROGRESS", + "IN_PROGRESS", + "IN_PROGRESS", + "COMPLETED", + ] + mock_job._fetch_job.return_value = { + "id": "job-123", + "workerId": "worker-456", + "status": "COMPLETED", + "delayTime": 1000, + "executionTime": 2000, + "output": {"result": "success"}, + } + + mock_endpoint = MagicMock() + mock_endpoint.run.return_value = mock_job + + assigned_batch = QBRequestLogBatch( + worker_id="worker-456", + lines=[], + matched_by_request_id=True, + phase=QBRequestLogPhase.STREAMING, + ) + + with patch.object( + type(serverless), + "endpoint", + new_callable=lambda: property(lambda self: mock_endpoint), + ): + with patch("asyncio.sleep"): + with patch.object( + ServerlessResource, + "_emit_endpoint_logs", + new=AsyncMock( + side_effect=[ + assigned_batch, + assigned_batch, + assigned_batch, + assigned_batch, + assigned_batch, + assigned_batch, + ] + ), + ): + with patch( + "runpod_flash.core.resources.serverless.log.info" + ) as mock_log_info: + await serverless.run({"input": "test"}) + + assigned_messages = [ + str(call.args[0]) + for call in mock_log_info.call_args_list + if call.args + and "Request assigned to worker worker-456, streaming pod logs" + in str(call.args[0]) + ] + assert len(assigned_messages) == 1 @pytest.mark.asyncio async def test_emit_endpoint_logs_prints_worker_lines(self): @@ -1175,22 +1302,35 @@ async def test_emit_endpoint_logs_prints_worker_lines(self): worker_id=None, lines=["line-a", "line-b"], matched_by_request_id=False, + phase=QBRequestLogPhase.STREAMING, ) ) - with patch("builtins.print") as mock_print: - await serverless._emit_endpoint_logs(fetcher=mock_fetcher) + with patch( + "runpod_flash.core.resources.serverless.get_api_key", + return_value="runpod-key-123", + ): + with patch("builtins.print") as mock_print: + batch = await serverless._emit_endpoint_logs( + fetcher=mock_fetcher, + request_id="job-123", + ) mock_fetcher.fetch_logs.assert_awaited_once_with( endpoint_id="endpoint-123", - endpoint_ai_key="endpoint-ai-key", + request_id="job-123", + status_api_key="endpoint-ai-key", + pod_logs_api_key="runpod-key-123", + status_api_key_fallback="runpod-key-123", ) + assert batch is not None + assert batch.phase == QBRequestLogPhase.STREAMING mock_print.assert_any_call("worker log: line-a") mock_print.assert_any_call("worker log: line-b") @pytest.mark.asyncio async def test_emit_endpoint_logs_skips_when_missing_required_fields(self): - """Endpoint log fetch is skipped unless QB endpoint has id and aiKey.""" + """Endpoint log fetch is skipped unless QB endpoint has id and API key.""" serverless = ServerlessResource(name="test") mock_fetcher = MagicMock() mock_fetcher.fetch_logs = AsyncMock(return_value=None) @@ -1198,15 +1338,36 @@ async def test_emit_endpoint_logs_skips_when_missing_required_fields(self): serverless.type = ServerlessType.QB serverless.id = None serverless.aiKey = "endpoint-ai-key" - await serverless._emit_endpoint_logs(fetcher=mock_fetcher) + with patch( + "runpod_flash.core.resources.serverless.get_api_key", + return_value="runpod-key-123", + ): + await serverless._emit_endpoint_logs( + fetcher=mock_fetcher, + request_id="job-123", + ) serverless.id = "endpoint-123" serverless.aiKey = None - await serverless._emit_endpoint_logs(fetcher=mock_fetcher) + with patch( + "runpod_flash.core.resources.serverless.get_api_key", + return_value=None, + ): + await serverless._emit_endpoint_logs( + fetcher=mock_fetcher, + request_id="job-123", + ) serverless.type = ServerlessType.LB serverless.aiKey = "endpoint-ai-key" - await serverless._emit_endpoint_logs(fetcher=mock_fetcher) + with patch( + "runpod_flash.core.resources.serverless.get_api_key", + return_value="runpod-key-123", + ): + await serverless._emit_endpoint_logs( + fetcher=mock_fetcher, + request_id="job-123", + ) mock_fetcher.fetch_logs.assert_not_awaited() From 18e4bc0e4dcdf218919bccb146241ee6fb74e117 Mon Sep 17 00:00:00 2001 From: jhcipar Date: Tue, 7 Apr 2026 14:23:13 -0400 Subject: [PATCH 4/6] chore: more descriptive endpoint log polling --- .../core/resources/request_logs.py | 55 +++- src/runpod_flash/core/resources/serverless.py | 49 ++- .../worker_availability_diagnostic.py | 280 ++++++++++++++++++ tests/unit/resources/test_serverless.py | 177 +++++++++++ .../test_worker_availability_diagnostic.py | 132 +++++++++ 5 files changed, 685 insertions(+), 8 deletions(-) create mode 100644 src/runpod_flash/core/resources/worker_availability_diagnostic.py create mode 100644 tests/unit/resources/test_worker_availability_diagnostic.py diff --git a/src/runpod_flash/core/resources/request_logs.py b/src/runpod_flash/core/resources/request_logs.py index e2bc2d88..5a827c87 100644 --- a/src/runpod_flash/core/resources/request_logs.py +++ b/src/runpod_flash/core/resources/request_logs.py @@ -2,6 +2,7 @@ import os import re from dataclasses import dataclass +from dataclasses import field from datetime import datetime, timezone from enum import Enum from typing import Any, List, Optional @@ -45,6 +46,8 @@ class QBRequestLogBatch: matched_by_request_id: bool worker_id: Optional[str] phase: QBRequestLogPhase + worker_metrics: dict[str, int] = field(default_factory=dict) + ready_worker_ids: List[str] = field(default_factory=list) class QBRequestLogFetcher: @@ -89,6 +92,8 @@ async def fetch_logs( ) running_worker_ids = self._running_worker_ids_from_metrics(metrics_payload) initializing_workers = self._initializing_worker_count(metrics_payload) + worker_metrics = self._worker_metrics_snapshot(metrics_payload) + ready_worker_ids = self._ready_worker_ids_from_metrics(metrics_payload) matched_by_request_id = False if assigned_worker_id: @@ -108,6 +113,8 @@ async def fetch_logs( matched_by_request_id=False, worker_id=None, phase=phase, + worker_metrics=worker_metrics, + ready_worker_ids=ready_worker_ids, ) logs_payload = await self._fetch_pod_logs( @@ -120,6 +127,8 @@ async def fetch_logs( matched_by_request_id=matched_by_request_id, worker_id=self.worker_id, phase=QBRequestLogPhase.WAITING_FOR_WORKER_INITIALIZATION, + worker_metrics=worker_metrics, + ready_worker_ids=ready_worker_ids, ) if not self.has_primed_worker_logs: @@ -136,6 +145,8 @@ async def fetch_logs( if self.has_streamed_logs else QBRequestLogPhase.WAITING_FOR_WORKER_INITIALIZATION ), + worker_metrics=worker_metrics, + ready_worker_ids=ready_worker_ids, ) lines = self._extract_lines(logs_payload) @@ -151,6 +162,8 @@ async def fetch_logs( matched_by_request_id=matched_by_request_id, worker_id=self.worker_id, phase=phase, + worker_metrics=worker_metrics, + ready_worker_ids=ready_worker_ids, ) async def _fetch_status_payload( @@ -193,8 +206,8 @@ async def _fetch_metrics_payload( status_api_key: str, status_api_key_fallback: Optional[str], ) -> Optional[dict[str, Any]]: - url = f"{API_BASE_URL}/v1/{endpoint_id}/metrics" auth_keys = self._auth_candidates(status_api_key, status_api_key_fallback) + url = f"{API_BASE_URL}/v2/{endpoint_id}/metrics" for auth_key in auth_keys: try: @@ -209,12 +222,18 @@ async def _fetch_metrics_payload( if exc.response is not None and exc.response.status_code == 401: continue log.debug( - "Failed to fetch endpoint metrics for %s: %s", endpoint_id, exc + "Failed to fetch endpoint metrics for %s via %s: %s", + endpoint_id, + url, + exc, ) return None except (httpx.HTTPError, ValueError) as exc: log.debug( - "Failed to fetch endpoint metrics for %s: %s", endpoint_id, exc + "Failed to fetch endpoint metrics for %s via %s: %s", + endpoint_id, + url, + exc, ) return None @@ -242,6 +261,36 @@ def _running_worker_ids_from_metrics( return [] return [str(worker) for worker in ready_workers if worker] + @staticmethod + def _ready_worker_ids_from_metrics(payload: Optional[dict[str, Any]]) -> List[str]: + if not payload: + return [] + ready_workers = payload.get("readyWorkers") + if not isinstance(ready_workers, list): + return [] + return [str(worker) for worker in ready_workers if worker] + + @staticmethod + def _worker_metrics_snapshot(payload: Optional[dict[str, Any]]) -> dict[str, int]: + base = { + "ready": 0, + "running": 0, + "idle": 0, + "initializing": 0, + "throttled": 0, + "unhealthy": 0, + } + if not payload: + return base + workers = payload.get("workers") + if not isinstance(workers, dict): + return base + for key in base: + value = workers.get(key) + if isinstance(value, int): + base[key] = value + return base + @staticmethod def _initializing_worker_count(payload: Optional[dict[str, Any]]) -> int: if not payload: diff --git a/src/runpod_flash/core/resources/serverless.py b/src/runpod_flash/core/resources/serverless.py index eb2327fe..f2e68a0a 100644 --- a/src/runpod_flash/core/resources/serverless.py +++ b/src/runpod_flash/core/resources/serverless.py @@ -34,6 +34,7 @@ from .gpu import GpuGroup, GpuType from .network_volume import NetworkVolume, DataCenter, CPU_DATACENTERS from .request_logs import QBRequestLogFetcher, QBRequestLogPhase +from .worker_availability_diagnostic import WorkerAvailabilityDiagnostic from .template import KeyValuePair, PodTemplate from .resource_manager import ResourceManager from ..credentials import get_api_key @@ -1172,10 +1173,14 @@ async def run(self, payload: Dict[str, Any]) -> "JobOutput": | None ) = None assigned_streaming_announced_worker: Optional[str] = None + worker_availability_diagnostic = WorkerAvailabilityDiagnostic() + repeated_no_worker_message: Optional[str] = None + waiting_update_count = 0 # Poll for job status while True: await asyncio.sleep(current_pace) + emit_regular_update = False # Check job status job_status = await asyncio.to_thread(job.status) @@ -1183,9 +1188,9 @@ async def run(self, payload: Dict[str, Any]) -> "JobOutput": if last_status == job_status: # nothing changed, increase the gap attempt += 1 - indicator = "." * (attempt // 2) if attempt % 2 == 0 else "" - if indicator: - log.info(f"{log_subgroup} | {indicator}") + if job_status != "IN_PROGRESS" and attempt % 2 == 0: + emit_regular_update = True + log.info(f"{log_subgroup} | {'.' * (attempt // 2)}") else: # status changed, reset the gap log.info(f"{log_subgroup} | Status: {job_status}") @@ -1209,6 +1214,8 @@ async def run(self, payload: Dict[str, Any]) -> "JobOutput": and batch.matched_by_request_id and batch.worker_id ): + repeated_no_worker_message = None + waiting_update_count = 0 if assigned_streaming_announced_worker != batch.worker_id: log.info( f"{log_subgroup} | Request assigned to worker {batch.worker_id}, streaming pod logs" @@ -1216,13 +1223,25 @@ async def run(self, payload: Dict[str, Any]) -> "JobOutput": assigned_streaming_announced_worker = batch.worker_id elif state_changed: if batch.phase == QBRequestLogPhase.WAITING_FOR_WORKER: - log.info( - f"{log_subgroup} | No workers available; check that your endpoint is properly configured and/or GPU availability for selected GPUs" + diagnostic = await worker_availability_diagnostic.diagnose( + self, + worker_metrics=batch.worker_metrics, ) + log.info(f"{log_subgroup} | {diagnostic.message}") + if diagnostic.reason in ( + "no_gpu_availability", + "workers_throttled", + ): + repeated_no_worker_message = diagnostic.message + else: + repeated_no_worker_message = None + waiting_update_count = 0 elif ( batch.phase == QBRequestLogPhase.WAITING_FOR_WORKER_INITIALIZATION ): + repeated_no_worker_message = None + waiting_update_count = 0 if batch.matched_by_request_id and batch.worker_id: log.info( f"{log_subgroup} | Request assigned to worker {batch.worker_id}, waiting for worker initialization/image pull logs" @@ -1236,12 +1255,32 @@ async def run(self, payload: Dict[str, Any]) -> "JobOutput": f"{log_subgroup} | Waiting for worker initialization/image pull" ) elif batch.phase == QBRequestLogPhase.STREAMING: + repeated_no_worker_message = None + waiting_update_count = 0 log.info( f"{log_subgroup} | Streaming endpoint startup logs while waiting for request assignment" ) last_log_state = current_log_state + if emit_regular_update: + waiting_update_count += 1 + if waiting_update_count % 5 == 0: + worker_state = ( + batch.worker_id if batch and batch.worker_id else "None" + ) + worker_metrics = (batch.worker_metrics if batch else {}) or {} + assignment_state = ( + "assigned" + if batch and batch.matched_by_request_id + else "unassigned" + ) + log.info( + f"{log_subgroup} | Waiting for request: endpoint metrics: worker={worker_state}, assignment={assignment_state}, status={job_status}, workers={{ready:{worker_metrics.get('ready', 0)}, running:{worker_metrics.get('running', 0)}, idle:{worker_metrics.get('idle', 0)}, initializing:{worker_metrics.get('initializing', 0)}, throttled:{worker_metrics.get('throttled', 0)}, unhealthy:{worker_metrics.get('unhealthy', 0)}}}, readyWorkers={batch.ready_worker_ids if batch else []}" + ) + if repeated_no_worker_message: + log.info(f"{log_subgroup} | {repeated_no_worker_message}") + last_status = job_status # Adjust polling pace appropriately diff --git a/src/runpod_flash/core/resources/worker_availability_diagnostic.py b/src/runpod_flash/core/resources/worker_availability_diagnostic.py new file mode 100644 index 00000000..a4ea2bf7 --- /dev/null +++ b/src/runpod_flash/core/resources/worker_availability_diagnostic.py @@ -0,0 +1,280 @@ +import logging +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING + +from ..api.runpod import RunpodGraphQLClient + +if TYPE_CHECKING: + from .serverless import ServerlessResource + + +log = logging.getLogger(__name__) + + +GPU_STOCK_QUERY = """ +query ServerlessGpuTypes($lowestPriceInput: GpuLowestPriceInput, $gpuTypesInput: GpuTypeFilter) { + gpuTypes(input: $gpuTypesInput) { + id + displayName + lowestPrice(input: $lowestPriceInput) { + stockStatus + __typename + } + __typename + } +} +""" + + +CPU_STOCK_QUERY = """ +query SecureCpuTypes($cpuFlavorInput: CpuFlavorInput, $specificsInput: SpecificsInput) { + cpuFlavors(input: $cpuFlavorInput) { + id + specifics(input: $specificsInput) { + stockStatus + __typename + } + __typename + } +} +""" + + +@dataclass +class WorkerAvailabilityResult: + message: str + has_availability: Optional[bool] + reason: str + + +class WorkerAvailabilityDiagnostic: + async def diagnose( + self, + resource: "ServerlessResource", + worker_metrics: Optional[Dict[str, int]] = None, + ) -> WorkerAvailabilityResult: + if (resource.workersMax or 0) == 0: + return WorkerAvailabilityResult( + message="No compute available for your chosen configuration: your max workers are currently set to 0.", + has_availability=False, + reason="workers_max_zero", + ) + + compute_kind, compute_choice = self._selected_compute(resource) + if not compute_choice: + return WorkerAvailabilityResult( + message="No compute available for your chosen configuration.", + has_availability=None, + reason="no_compute_selected", + ) + + throttled_workers = (worker_metrics or {}).get("throttled", 0) + if throttled_workers > 0: + return WorkerAvailabilityResult( + message=( + f"Workers are currently throttled on endpoint for selected {compute_kind} {compute_choice}. " + "Consider raising max workers or changing gpu type." + ), + has_availability=True, + reason="workers_throttled", + ) + + locations = self._selected_locations(resource) + + if compute_kind == "gpu": + availability_by_location = await self._gpu_availability( + gpu_id=compute_choice, + gpu_count=resource.gpuCount or 1, + locations=locations, + ) + return self._build_message( + compute_kind="gpu", + compute_choice=compute_choice, + locations=locations, + availability_by_location=availability_by_location, + include_available_signal=True, + ) + + if compute_kind == "cpu": + availability_by_location = await self._cpu_availability( + instance_id=compute_choice, + locations=locations, + ) + return self._build_message( + compute_kind="cpu", + compute_choice=compute_choice, + locations=locations, + availability_by_location=availability_by_location, + include_available_signal=False, + ) + + return WorkerAvailabilityResult( + message="No compute available for your chosen configuration.", + has_availability=None, + reason="unknown", + ) + + def _build_message( + self, + compute_kind: str, + compute_choice: str, + locations: List[str], + availability_by_location: Dict[str, Optional[str]], + include_available_signal: bool, + ) -> WorkerAvailabilityResult: + has_availability = any(status for status in availability_by_location.values()) + + if not has_availability: + selected_locations = ", ".join(locations) if locations else "all locations" + return WorkerAvailabilityResult( + message=( + f"No workers available on endpoint: no {compute_kind} availability for {compute_kind} type {compute_choice} " + f"in selected locations ({selected_locations})." + ), + has_availability=False, + reason=f"no_{compute_kind}_availability", + ) + + if include_available_signal: + signal = self._summarize_stock_signal(availability_by_location) + return WorkerAvailabilityResult( + message=( + f"No workers available right now. Current availability signal " + f"for selected gpu {compute_choice}: {signal}." + ), + has_availability=True, + reason="gpu_has_availability", + ) + + return WorkerAvailabilityResult( + message=( + f"No workers available right now for selected {compute_kind} " + f"{compute_choice}." + ), + has_availability=True, + reason=f"{compute_kind}_has_availability", + ) + + async def _gpu_availability( + self, + gpu_id: str, + gpu_count: int, + locations: List[str], + ) -> Dict[str, Optional[str]]: + location_inputs = locations or [None] + availability_by_location: Dict[str, Optional[str]] = {} + + async with RunpodGraphQLClient() as client: + for location in location_inputs: + variables = { + "gpuTypesInput": {"ids": [gpu_id]}, + "lowestPriceInput": { + "dataCenterId": location, + "gpuCount": gpu_count, + "secureCloud": True, + "includeAiApi": True, + "allowedCudaVersions": [], + "compliance": [], + }, + } + key = location or "global" + try: + result = await client._execute_graphql(GPU_STOCK_QUERY, variables) + gpu_types = result.get("gpuTypes") or [] + first = gpu_types[0] if gpu_types else {} + lowest = first.get("lowestPrice") if isinstance(first, dict) else {} + status = ( + lowest.get("stockStatus") if isinstance(lowest, dict) else None + ) + availability_by_location[key] = status + except Exception as exc: + log.debug("GPU availability query failed for %s: %s", key, exc) + availability_by_location[key] = None + + return availability_by_location + + async def _cpu_availability( + self, + instance_id: str, + locations: List[str], + ) -> Dict[str, Optional[str]]: + flavor_id = self._cpu_flavor_id(instance_id) + if not flavor_id: + return {loc: None for loc in (locations or ["global"])} + + location_inputs = locations or [""] + availability_by_location: Dict[str, Optional[str]] = {} + + async with RunpodGraphQLClient() as client: + for location in location_inputs: + variables = { + "cpuFlavorInput": {"id": flavor_id}, + "specificsInput": { + "dataCenterId": location, + "instanceId": instance_id, + }, + } + key = location or "global" + try: + result = await client._execute_graphql(CPU_STOCK_QUERY, variables) + cpu_flavors = result.get("cpuFlavors") or [] + first = cpu_flavors[0] if cpu_flavors else {} + specifics = ( + first.get("specifics") if isinstance(first, dict) else {} + ) + status = ( + specifics.get("stockStatus") + if isinstance(specifics, dict) + else None + ) + availability_by_location[key] = status + except Exception as exc: + log.debug("CPU availability query failed for %s: %s", key, exc) + availability_by_location[key] = None + + return availability_by_location + + @staticmethod + def _selected_compute(resource: "ServerlessResource") -> Tuple[str, Optional[str]]: + if resource.instanceIds: + first_instance = resource.instanceIds[0] + choice = ( + first_instance.value + if hasattr(first_instance, "value") + else str(first_instance) + ) + return "cpu", choice + + gpu_ids = [ + part.strip() for part in (resource.gpuIds or "").split(",") if part.strip() + ] + if gpu_ids: + return "gpu", gpu_ids[0] + + return "unknown", None + + @staticmethod + def _selected_locations(resource: "ServerlessResource") -> List[str]: + return [ + part.strip() + for part in (resource.locations or "").split(",") + if part.strip() + ] + + @staticmethod + def _cpu_flavor_id(instance_id: str) -> Optional[str]: + if "-" not in instance_id: + return None + return instance_id.split("-", 1)[0] + + @staticmethod + def _summarize_stock_signal( + availability_by_location: Dict[str, Optional[str]], + ) -> str: + non_empty = [status for status in availability_by_location.values() if status] + if not non_empty: + return "unknown" + + priority = {"High": 3, "Medium": 2, "Low": 1} + best = max(non_empty, key=lambda status: priority.get(status, 0)) + return best diff --git a/tests/unit/resources/test_serverless.py b/tests/unit/resources/test_serverless.py index 534a9bdf..b97b1d1d 100644 --- a/tests/unit/resources/test_serverless.py +++ b/tests/unit/resources/test_serverless.py @@ -27,6 +27,9 @@ QBRequestLogBatch, QBRequestLogPhase, ) +from runpod_flash.core.resources.worker_availability_diagnostic import ( + WorkerAvailabilityResult, +) from runpod_flash.core.resources.template import PodTemplate @@ -1288,6 +1291,180 @@ async def test_run_async_announces_assigned_worker_streaming_once(self): ] assert len(assigned_messages) == 1 + @pytest.mark.asyncio + async def test_run_async_repeats_no_gpu_availability_message_every_five_updates( + self, + ): + serverless = ServerlessResource(name="test") + serverless.id = "endpoint-123" + serverless.type = ServerlessType.QB + serverless.aiKey = "ai-key-123" + + mock_job = MagicMock() + mock_job.job_id = "job-123" + mock_job.status.side_effect = [ + "IN_QUEUE", + "IN_QUEUE", + "IN_QUEUE", + "IN_QUEUE", + "IN_QUEUE", + "IN_QUEUE", + "IN_QUEUE", + "IN_QUEUE", + "IN_QUEUE", + "IN_QUEUE", + "IN_QUEUE", + "IN_QUEUE", + "COMPLETED", + ] + mock_job._fetch_job.return_value = { + "id": "job-123", + "workerId": "worker-456", + "status": "COMPLETED", + "delayTime": 1000, + "executionTime": 2000, + "output": {"result": "success"}, + } + + waiting_batch = QBRequestLogBatch( + worker_id=None, + lines=[], + matched_by_request_id=False, + phase=QBRequestLogPhase.WAITING_FOR_WORKER, + ) + + async def emit_waiting_batch(*, fetcher, request_id): + return waiting_batch + + mock_endpoint = MagicMock() + mock_endpoint.run.return_value = mock_job + + with patch.object( + type(serverless), + "endpoint", + new_callable=lambda: property(lambda self: mock_endpoint), + ): + with patch("asyncio.sleep"): + with patch.object( + ServerlessResource, + "_emit_endpoint_logs", + new=AsyncMock(side_effect=emit_waiting_batch), + ): + with patch( + "runpod_flash.core.resources.serverless.WorkerAvailabilityDiagnostic.diagnose", + new=AsyncMock( + return_value=WorkerAvailabilityResult( + message=( + "No workers available on endpoint: no gpu availability for gpu type NVIDIA GeForce RTX 4090" + ), + has_availability=False, + reason="no_gpu_availability", + ) + ), + ): + with patch( + "runpod_flash.core.resources.serverless.log.info" + ) as mock_log_info: + await serverless.run({"input": "test"}) + + no_worker_messages = [ + str(call.args[0]) + for call in mock_log_info.call_args_list + if call.args + and "No workers available on endpoint: no gpu availability for gpu type" + in str(call.args[0]) + ] + assert len(no_worker_messages) == 2 + + @pytest.mark.asyncio + async def test_run_async_stops_waiting_metrics_logs_after_in_progress(self): + serverless = ServerlessResource(name="test") + serverless.id = "endpoint-123" + serverless.type = ServerlessType.QB + serverless.aiKey = "ai-key-123" + + mock_job = MagicMock() + mock_job.job_id = "job-123" + mock_job.status.side_effect = [ + "IN_QUEUE", + "IN_QUEUE", + "IN_QUEUE", + "IN_QUEUE", + "IN_QUEUE", + "IN_QUEUE", + "IN_QUEUE", + "IN_QUEUE", + "IN_QUEUE", + "IN_QUEUE", + "IN_QUEUE", + "IN_QUEUE", + "IN_PROGRESS", + "IN_PROGRESS", + "IN_PROGRESS", + "IN_PROGRESS", + "COMPLETED", + ] + mock_job._fetch_job.return_value = { + "id": "job-123", + "workerId": "worker-456", + "status": "COMPLETED", + "delayTime": 1000, + "executionTime": 2000, + "output": {"result": "success"}, + } + + waiting_batch = QBRequestLogBatch( + worker_id=None, + lines=[], + matched_by_request_id=False, + phase=QBRequestLogPhase.WAITING_FOR_WORKER, + worker_metrics={ + "ready": 0, + "running": 0, + "idle": 0, + "initializing": 0, + "throttled": 2, + "unhealthy": 0, + }, + ) + mock_endpoint = MagicMock() + mock_endpoint.run.return_value = mock_job + + with patch.object( + type(serverless), + "endpoint", + new_callable=lambda: property(lambda self: mock_endpoint), + ): + with patch("asyncio.sleep"): + with patch.object( + ServerlessResource, + "_emit_endpoint_logs", + new=AsyncMock(return_value=waiting_batch), + ): + with patch( + "runpod_flash.core.resources.serverless.WorkerAvailabilityDiagnostic.diagnose", + new=AsyncMock( + return_value=WorkerAvailabilityResult( + message="Workers are currently throttled on endpoint for selected gpu NVIDIA GeForce RTX 4090. Consider raising max workers or changing gpu type.", + has_availability=True, + reason="workers_throttled", + ) + ), + ): + with patch( + "runpod_flash.core.resources.serverless.log.info" + ) as mock_log_info: + await serverless.run({"input": "test"}) + + metrics_logs = [ + str(call.args[0]) + for call in mock_log_info.call_args_list + if call.args + and "Waiting for request: endpoint metrics:" in str(call.args[0]) + ] + assert metrics_logs + assert not any("status=IN_PROGRESS" in line for line in metrics_logs) + @pytest.mark.asyncio async def test_emit_endpoint_logs_prints_worker_lines(self): """Endpoint log emission prints each worker log line.""" diff --git a/tests/unit/resources/test_worker_availability_diagnostic.py b/tests/unit/resources/test_worker_availability_diagnostic.py new file mode 100644 index 00000000..bc363d48 --- /dev/null +++ b/tests/unit/resources/test_worker_availability_diagnostic.py @@ -0,0 +1,132 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from runpod_flash.core.resources.cpu import CpuInstanceType +from runpod_flash.core.resources.serverless import ServerlessResource +from runpod_flash.core.resources.worker_availability_diagnostic import ( + WorkerAvailabilityDiagnostic, +) + + +def _make_client_context(mock_client: MagicMock) -> MagicMock: + cm = MagicMock() + cm.__aenter__ = AsyncMock(return_value=mock_client) + cm.__aexit__ = AsyncMock(return_value=None) + return cm + + +@pytest.mark.asyncio +async def test_diagnose_returns_workers_max_zero_message(): + resource = ServerlessResource(name="test", workersMax=0) + + diagnostic = WorkerAvailabilityDiagnostic() + result = await diagnostic.diagnose(resource) + + assert result.has_availability is False + assert "max workers are currently set to 0" in result.message + assert result.reason == "workers_max_zero" + + +@pytest.mark.asyncio +async def test_diagnose_gpu_no_availability_includes_selected_locations(): + resource = ServerlessResource(name="test") + resource.gpuIds = "NVIDIA GeForce RTX 4090" + resource.locations = "EU-RO-1,US-GA-2" + + mock_client = MagicMock() + mock_client._execute_graphql = AsyncMock( + side_effect=[ + {"gpuTypes": [{"lowestPrice": {"stockStatus": None}}]}, + {"gpuTypes": [{"lowestPrice": {"stockStatus": None}}]}, + ] + ) + + with patch( + "runpod_flash.core.resources.worker_availability_diagnostic.RunpodGraphQLClient", + return_value=_make_client_context(mock_client), + ): + result = await WorkerAvailabilityDiagnostic().diagnose(resource) + + assert result.has_availability is False + assert ( + "No workers available on endpoint: no gpu availability for gpu type NVIDIA GeForce RTX 4090" + in result.message + ) + assert "EU-RO-1, US-GA-2" in result.message + assert result.reason == "no_gpu_availability" + + +@pytest.mark.asyncio +async def test_diagnose_gpu_availability_shows_signal_without_locations(): + resource = ServerlessResource(name="test") + resource.gpuIds = "NVIDIA GeForce RTX 4090" + resource.locations = "EU-RO-1,US-GA-2" + + mock_client = MagicMock() + mock_client._execute_graphql = AsyncMock( + side_effect=[ + {"gpuTypes": [{"lowestPrice": {"stockStatus": None}}]}, + {"gpuTypes": [{"lowestPrice": {"stockStatus": "Low"}}]}, + ] + ) + + with patch( + "runpod_flash.core.resources.worker_availability_diagnostic.RunpodGraphQLClient", + return_value=_make_client_context(mock_client), + ): + result = await WorkerAvailabilityDiagnostic().diagnose(resource) + + assert result.has_availability is True + assert ( + "Current availability signal for selected gpu NVIDIA GeForce RTX 4090: Low" + in result.message + ) + assert "EU-RO-1" not in result.message + assert "US-GA-2" not in result.message + assert result.reason == "gpu_has_availability" + + +@pytest.mark.asyncio +async def test_diagnose_cpu_no_availability_message(): + resource = ServerlessResource(name="test") + resource.instanceIds = [CpuInstanceType.CPU3G_2_8] + resource.locations = "EU-RO-1,US-GA-2" + + mock_client = MagicMock() + mock_client._execute_graphql = AsyncMock( + side_effect=[ + {"cpuFlavors": [{"specifics": {"stockStatus": None}}]}, + {"cpuFlavors": [{"specifics": {"stockStatus": None}}]}, + ] + ) + + with patch( + "runpod_flash.core.resources.worker_availability_diagnostic.RunpodGraphQLClient", + return_value=_make_client_context(mock_client), + ): + result = await WorkerAvailabilityDiagnostic().diagnose(resource) + + assert result.has_availability is False + assert ( + "No workers available on endpoint: no cpu availability for cpu type cpu3g-2-8" + in result.message + ) + assert "EU-RO-1, US-GA-2" in result.message + assert result.reason == "no_cpu_availability" + + +@pytest.mark.asyncio +async def test_diagnose_prefers_throttled_reason_over_no_availability(): + resource = ServerlessResource(name="test") + resource.gpuIds = "NVIDIA GeForce RTX 4090" + + result = await WorkerAvailabilityDiagnostic().diagnose( + resource, + worker_metrics={"throttled": 3}, + ) + + assert result.has_availability is True + assert result.reason == "workers_throttled" + assert "Workers are currently throttled on endpoint" in result.message + assert "Consider raising max workers or changing gpu type" in result.message From 9a66a5dfc4a6733ad3fdd016e1325f4482f9e287 Mon Sep 17 00:00:00 2001 From: jhcipar Date: Tue, 7 Apr 2026 17:20:31 -0400 Subject: [PATCH 5/6] chore: doc updates and aikey manifest persistence --- docs/Deployment_Architecture.md | 6 ++ src/runpod_flash/cli/docs/flash-deploy.md | 14 ++- src/runpod_flash/cli/docs/flash-logging.md | 6 ++ src/runpod_flash/cli/utils/deployment.py | 25 ++++- src/runpod_flash/core/api/runpod.py | 74 ++++++++++++++ .../core/resources/request_logs.py | 17 +--- src/runpod_flash/core/resources/serverless.py | 30 ++++-- .../worker_availability_diagnostic.py | 98 ++++++------------- tests/unit/cli/utils/test_deployment.py | 4 + .../test_worker_availability_diagnostic.py | 55 +++++++---- 10 files changed, 213 insertions(+), 116 deletions(-) diff --git a/docs/Deployment_Architecture.md b/docs/Deployment_Architecture.md index d0f64afc..d6988fdb 100644 --- a/docs/Deployment_Architecture.md +++ b/docs/Deployment_Architecture.md @@ -191,6 +191,12 @@ When `flash deploy` provisions endpoints: 3. The State Manager stores `{environment_id, resource_name} -> endpoint_id` 4. At runtime, the `ServiceRegistry` uses the manifest + State Manager to route calls +### Manifest credential handling + +- Runtime endpoint metadata (including API-returned `aiKey`) may be stored in the State Manager manifest for deployment reconciliation. +- Local `.flash/flash_manifest.json` is sanitized before it is written to disk and does not include `aiKey`. +- `RUNPOD_API_KEY` is sourced from environment/credential storage and injected into endpoint env when needed; it is not persisted in the local manifest. + See [Cross-Endpoint Routing](Cross_Endpoint_Routing.md) for the full runtime flow. ## Related Documentation diff --git a/src/runpod_flash/cli/docs/flash-deploy.md b/src/runpod_flash/cli/docs/flash-deploy.md index 0ee5380f..6a92a30a 100644 --- a/src/runpod_flash/cli/docs/flash-deploy.md +++ b/src/runpod_flash/cli/docs/flash-deploy.md @@ -138,9 +138,17 @@ The deploy command combines building and deploying your Flash application in a s - Registers endpoints in environment tracking 4. **Post-Deployment**: - - Displays deployment URLs and available routes - - Shows authentication and testing guidance - - Cleans up temporary build directory + - Displays deployment URLs and available routes + - Shows authentication and testing guidance + - Cleans up temporary build directory + +## Manifest and Credential Handling + +During deploy, Flash updates manifest metadata with runtime endpoint details (for example `endpoint_id`, endpoint URLs, and `aiKey` when returned by the API). + +- The manifest stored in State Manager keeps runtime metadata used for reconciliation. +- The local `.flash/flash_manifest.json` is sanitized before writing to disk and does not persist `aiKey`. +- `RUNPOD_API_KEY` continues to be resolved from credentials/env at runtime and is not stored in the local manifest. ## Build Options diff --git a/src/runpod_flash/cli/docs/flash-logging.md b/src/runpod_flash/cli/docs/flash-logging.md index 417abcf6..a5565225 100644 --- a/src/runpod_flash/cli/docs/flash-logging.md +++ b/src/runpod_flash/cli/docs/flash-logging.md @@ -28,6 +28,12 @@ Logs are written in the same format as console output, so you can grep through t - **Graceful degradation**: Continues with stdout-only if file logging fails - **Zero configuration**: Works out of the box with sensible defaults +### QB request log polling during `Endpoint.run(...)` + +- For queue-based (QB) endpoints, Flash polls endpoint status/metrics while waiting and streams worker log lines to stdout when available. +- Polling is used for async `run(...)` flows (not `runsync(...)`), and is skipped for non-QB endpoint types. +- If endpoint `aiKey` is unavailable, Flash falls back to your configured `RUNPOD_API_KEY`; without a key, log streaming is skipped. + ## Log Location By default, logs are written to: diff --git a/src/runpod_flash/cli/utils/deployment.py b/src/runpod_flash/cli/utils/deployment.py index 64850747..609efdf2 100644 --- a/src/runpod_flash/cli/utils/deployment.py +++ b/src/runpod_flash/cli/utils/deployment.py @@ -1,6 +1,7 @@ """Deployment environment management utilities.""" import asyncio +import copy import json import logging from typing import Dict, Any @@ -23,6 +24,19 @@ def _normalized_resource_attr(resource: Any, *names: str) -> str | None: return None +def _manifest_without_ai_keys(manifest: Dict[str, Any]) -> Dict[str, Any]: + sanitized_manifest = copy.deepcopy(manifest) + resources = sanitized_manifest.get("resources") + if not isinstance(resources, dict): + return sanitized_manifest + + for config in resources.values(): + if isinstance(config, dict): + config.pop("aiKey", None) + + return sanitized_manifest + + async def upload_build(app_name: str, build_path: str | Path): app = await FlashApp.from_name(app_name) await app.upload_build(build_path) @@ -338,9 +352,12 @@ async def reconcile_and_provision_resources( endpoint_id = _normalized_resource_attr( deployed_resource, "endpoint_id", "id" ) - endpoint_url = _normalized_resource_attr(deployed_resource, "endpoint_url") + endpoint_url = getattr(deployed_resource, "endpoint_url", None) + if isinstance(endpoint_url, str): + endpoint_url = endpoint_url.strip() or None + else: + endpoint_url = None ai_key = _normalized_resource_attr(deployed_resource, "aiKey", "ai_key") - if endpoint_id: local_manifest["resources"][resource_name]["endpoint_id"] = endpoint_id if endpoint_url: @@ -373,9 +390,11 @@ async def reconcile_and_provision_resources( f"Successfully provisioned: {provisioned}" ) + local_manifest_for_disk = _manifest_without_ai_keys(local_manifest) + # Write updated manifest back to local file manifest_path = Path.cwd() / ".flash" / "flash_manifest.json" - manifest_path.write_text(json.dumps(local_manifest, indent=2)) + manifest_path.write_text(json.dumps(local_manifest_for_disk, indent=2)) log.debug(f"Local manifest updated at {manifest_path.relative_to(Path.cwd())}") diff --git a/src/runpod_flash/core/api/runpod.py b/src/runpod_flash/core/api/runpod.py index 7ab92a45..fdca0a9d 100644 --- a/src/runpod_flash/core/api/runpod.py +++ b/src/runpod_flash/core/api/runpod.py @@ -364,6 +364,80 @@ async def get_gpu_types( result = await self._execute_graphql(query, variables) return result.get("gpuTypes", []) + async def get_gpu_lowest_price_stock_status( + self, + gpu_id: str, + gpu_count: int, + data_center_id: Optional[str] = None, + ) -> Optional[str]: + query = """ + query ServerlessGpuTypes($lowestPriceInput: GpuLowestPriceInput, $gpuTypesInput: GpuTypeFilter) { + gpuTypes(input: $gpuTypesInput) { + lowestPrice(input: $lowestPriceInput) { + stockStatus + } + } + } + """ + + variables = { + "gpuTypesInput": {"ids": [gpu_id]}, + "lowestPriceInput": { + "dataCenterId": data_center_id, + "gpuCount": gpu_count, + "secureCloud": True, + "includeAiApi": True, + "allowedCudaVersions": [], + "compliance": [], + }, + } + + result = await self._execute_graphql(query, variables) + gpu_types = result.get("gpuTypes") or [] + first = gpu_types[0] if gpu_types else {} + lowest = first.get("lowestPrice") if isinstance(first, dict) else {} + if not isinstance(lowest, dict): + return None + status = lowest.get("stockStatus") + if isinstance(status, str) and status.strip(): + return status.strip() + return None + + async def get_cpu_specific_stock_status( + self, + cpu_flavor_id: str, + instance_id: str, + data_center_id: str, + ) -> Optional[str]: + query = """ + query SecureCpuTypes($cpuFlavorInput: CpuFlavorInput, $specificsInput: SpecificsInput) { + cpuFlavors(input: $cpuFlavorInput) { + specifics(input: $specificsInput) { + stockStatus + } + } + } + """ + + variables = { + "cpuFlavorInput": {"id": cpu_flavor_id}, + "specificsInput": { + "dataCenterId": data_center_id, + "instanceId": instance_id, + }, + } + + result = await self._execute_graphql(query, variables) + cpu_flavors = result.get("cpuFlavors") or [] + first = cpu_flavors[0] if cpu_flavors else {} + specifics = first.get("specifics") if isinstance(first, dict) else {} + if not isinstance(specifics, dict): + return None + status = specifics.get("stockStatus") + if isinstance(status, str) and status.strip(): + return status.strip() + return None + async def get_endpoint(self, endpoint_id: str) -> Dict[str, Any]: """Get endpoint details.""" # Note: The schema doesn't show a specific endpoint query diff --git a/src/runpod_flash/core/resources/request_logs.py b/src/runpod_flash/core/resources/request_logs.py index 5a827c87..df7ce2ed 100644 --- a/src/runpod_flash/core/resources/request_logs.py +++ b/src/runpod_flash/core/resources/request_logs.py @@ -13,7 +13,7 @@ log = logging.getLogger(__name__) -API_BASE_URL = "https://api.runpod.ai" +API_BASE_URL = os.getenv("RUNPOD_API_BASE_URL", "https://api.runpod.ai").rstrip("/") DEV_API_BASE_URL = "https://dev-api.runpod.ai" HAPI_BASE_URL = "https://hapi.runpod.net" DEV_HAPI_BASE_URL = "https://dev-hapi.runpod.net" @@ -55,13 +55,11 @@ def __init__( self, timeout_seconds: float = 4.0, max_lines: int = 25, - fallback_tail_lines: int = 10, lookback_seconds: int = 20, start_time: Optional[datetime] = None, ): self.timeout_seconds = timeout_seconds self.max_lines = max_lines - self.fallback_tail_lines = fallback_tail_lines self.lookback_seconds = lookback_seconds self.start_time = start_time or datetime.now(timezone.utc) self.seen = set() @@ -90,7 +88,7 @@ async def fetch_logs( status_api_key=status_api_key, status_api_key_fallback=status_api_key_fallback, ) - running_worker_ids = self._running_worker_ids_from_metrics(metrics_payload) + running_worker_ids = self._ready_worker_ids_from_metrics(metrics_payload) initializing_workers = self._initializing_worker_count(metrics_payload) worker_metrics = self._worker_metrics_snapshot(metrics_payload) ready_worker_ids = self._ready_worker_ids_from_metrics(metrics_payload) @@ -250,17 +248,6 @@ def _worker_id_from_status_payload( return None return str(worker_id) - @staticmethod - def _running_worker_ids_from_metrics( - payload: Optional[dict[str, Any]], - ) -> List[str]: - if not payload: - return [] - ready_workers = payload.get("readyWorkers") - if not isinstance(ready_workers, list): - return [] - return [str(worker) for worker in ready_workers if worker] - @staticmethod def _ready_worker_ids_from_metrics(payload: Optional[dict[str, Any]]) -> List[str]: if not payload: diff --git a/src/runpod_flash/core/resources/serverless.py b/src/runpod_flash/core/resources/serverless.py index f2e68a0a..4d9358a2 100644 --- a/src/runpod_flash/core/resources/serverless.py +++ b/src/runpod_flash/core/resources/serverless.py @@ -33,7 +33,7 @@ from .cpu import CpuInstanceType from .gpu import GpuGroup, GpuType from .network_volume import NetworkVolume, DataCenter, CPU_DATACENTERS -from .request_logs import QBRequestLogFetcher, QBRequestLogPhase +from .request_logs import QBRequestLogBatch, QBRequestLogFetcher, QBRequestLogPhase from .worker_availability_diagnostic import WorkerAvailabilityDiagnostic from .template import KeyValuePair, PodTemplate from .resource_manager import ResourceManager @@ -60,14 +60,6 @@ def get_env_vars() -> Dict[str, str]: POD_LOG_PREFIX_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(?:\.\d+)?Z\s+") -def _is_prod_environment() -> bool: - env = os.getenv("RUNPOD_ENV") - if env: - return env.lower() == "prod" - api_base = os.getenv("RUNPOD_API_BASE_URL", "https://api.runpod.io") - return "api.runpod.io" in api_base or "api.runpod.ai" in api_base - - def _normalize_stream_log_line(line: str) -> str: normalized = line.strip() if not normalized: @@ -1176,6 +1168,7 @@ async def run(self, payload: Dict[str, Any]) -> "JobOutput": worker_availability_diagnostic = WorkerAvailabilityDiagnostic() repeated_no_worker_message: Optional[str] = None waiting_update_count = 0 + emitted_initial_wait_metrics = False # Poll for job status while True: @@ -1236,12 +1229,30 @@ async def run(self, payload: Dict[str, Any]) -> "JobOutput": else: repeated_no_worker_message = None waiting_update_count = 0 + if ( + job_status != "IN_PROGRESS" + and not emitted_initial_wait_metrics + ): + worker_state = ( + batch.worker_id if batch.worker_id else "None" + ) + worker_metrics = batch.worker_metrics or {} + assignment_state = ( + "assigned" + if batch.matched_by_request_id + else "unassigned" + ) + log.info( + f"{log_subgroup} | Waiting for request: endpoint metrics: worker={worker_state}, assignment={assignment_state}, status={job_status}, workers={{ready:{worker_metrics.get('ready', 0)}, running:{worker_metrics.get('running', 0)}, idle:{worker_metrics.get('idle', 0)}, initializing:{worker_metrics.get('initializing', 0)}, throttled:{worker_metrics.get('throttled', 0)}, unhealthy:{worker_metrics.get('unhealthy', 0)}}}, readyWorkers={batch.ready_worker_ids}" + ) + emitted_initial_wait_metrics = True elif ( batch.phase == QBRequestLogPhase.WAITING_FOR_WORKER_INITIALIZATION ): repeated_no_worker_message = None waiting_update_count = 0 + emitted_initial_wait_metrics = False if batch.matched_by_request_id and batch.worker_id: log.info( f"{log_subgroup} | Request assigned to worker {batch.worker_id}, waiting for worker initialization/image pull logs" @@ -1257,6 +1268,7 @@ async def run(self, payload: Dict[str, Any]) -> "JobOutput": elif batch.phase == QBRequestLogPhase.STREAMING: repeated_no_worker_message = None waiting_update_count = 0 + emitted_initial_wait_metrics = False log.info( f"{log_subgroup} | Streaming endpoint startup logs while waiting for request assignment" ) diff --git a/src/runpod_flash/core/resources/worker_availability_diagnostic.py b/src/runpod_flash/core/resources/worker_availability_diagnostic.py index a4ea2bf7..825eab1a 100644 --- a/src/runpod_flash/core/resources/worker_availability_diagnostic.py +++ b/src/runpod_flash/core/resources/worker_availability_diagnostic.py @@ -1,6 +1,6 @@ import logging from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING +from typing import Dict, List, Optional, Tuple, TYPE_CHECKING from ..api.runpod import RunpodGraphQLClient @@ -9,35 +9,7 @@ log = logging.getLogger(__name__) - - -GPU_STOCK_QUERY = """ -query ServerlessGpuTypes($lowestPriceInput: GpuLowestPriceInput, $gpuTypesInput: GpuTypeFilter) { - gpuTypes(input: $gpuTypesInput) { - id - displayName - lowestPrice(input: $lowestPriceInput) { - stockStatus - __typename - } - __typename - } -} -""" - - -CPU_STOCK_QUERY = """ -query SecureCpuTypes($cpuFlavorInput: CpuFlavorInput, $specificsInput: SpecificsInput) { - cpuFlavors(input: $cpuFlavorInput) { - id - specifics(input: $specificsInput) { - stockStatus - __typename - } - __typename - } -} -""" +AVAILABLE_STOCK_STATUSES = {"LOW", "MEDIUM", "HIGH"} @dataclass @@ -70,10 +42,11 @@ async def diagnose( throttled_workers = (worker_metrics or {}).get("throttled", 0) if throttled_workers > 0: + compute_label = "gpu type" if compute_kind == "gpu" else "cpu type" return WorkerAvailabilityResult( message=( f"Workers are currently throttled on endpoint for selected {compute_kind} {compute_choice}. " - "Consider raising max workers or changing gpu type." + f"Consider raising max workers or changing {compute_label}." ), has_availability=True, reason="workers_throttled", @@ -122,7 +95,10 @@ def _build_message( availability_by_location: Dict[str, Optional[str]], include_available_signal: bool, ) -> WorkerAvailabilityResult: - has_availability = any(status for status in availability_by_location.values()) + has_availability = any( + self._is_available_stock_status(status) + for status in availability_by_location.values() + ) if not has_availability: selected_locations = ", ".join(locations) if locations else "all locations" @@ -166,25 +142,12 @@ async def _gpu_availability( async with RunpodGraphQLClient() as client: for location in location_inputs: - variables = { - "gpuTypesInput": {"ids": [gpu_id]}, - "lowestPriceInput": { - "dataCenterId": location, - "gpuCount": gpu_count, - "secureCloud": True, - "includeAiApi": True, - "allowedCudaVersions": [], - "compliance": [], - }, - } key = location or "global" try: - result = await client._execute_graphql(GPU_STOCK_QUERY, variables) - gpu_types = result.get("gpuTypes") or [] - first = gpu_types[0] if gpu_types else {} - lowest = first.get("lowestPrice") if isinstance(first, dict) else {} - status = ( - lowest.get("stockStatus") if isinstance(lowest, dict) else None + status = await client.get_gpu_lowest_price_stock_status( + gpu_id=gpu_id, + gpu_count=gpu_count, + data_center_id=location, ) availability_by_location[key] = status except Exception as exc: @@ -207,25 +170,12 @@ async def _cpu_availability( async with RunpodGraphQLClient() as client: for location in location_inputs: - variables = { - "cpuFlavorInput": {"id": flavor_id}, - "specificsInput": { - "dataCenterId": location, - "instanceId": instance_id, - }, - } key = location or "global" try: - result = await client._execute_graphql(CPU_STOCK_QUERY, variables) - cpu_flavors = result.get("cpuFlavors") or [] - first = cpu_flavors[0] if cpu_flavors else {} - specifics = ( - first.get("specifics") if isinstance(first, dict) else {} - ) - status = ( - specifics.get("stockStatus") - if isinstance(specifics, dict) - else None + status = await client.get_cpu_specific_stock_status( + cpu_flavor_id=flavor_id, + instance_id=instance_id, + data_center_id=location, ) availability_by_location[key] = status except Exception as exc: @@ -275,6 +225,18 @@ def _summarize_stock_signal( if not non_empty: return "unknown" - priority = {"High": 3, "Medium": 2, "Low": 1} - best = max(non_empty, key=lambda status: priority.get(status, 0)) + priority = {"HIGH": 3, "MEDIUM": 2, "LOW": 1} + + def score(value: str) -> int: + normalized = value.strip().upper().replace("-", "_").replace(" ", "_") + return priority.get(normalized, 0) + + best = max(non_empty, key=score) return best + + @staticmethod + def _is_available_stock_status(status: Optional[str]) -> bool: + if not isinstance(status, str): + return False + normalized = status.strip().upper().replace("-", "_").replace(" ", "_") + return normalized in AVAILABLE_STOCK_STATUSES diff --git a/tests/unit/cli/utils/test_deployment.py b/tests/unit/cli/utils/test_deployment.py index 09b0acc3..16bf3dc2 100644 --- a/tests/unit/cli/utils/test_deployment.py +++ b/tests/unit/cli/utils/test_deployment.py @@ -578,3 +578,7 @@ async def test_reconciliation_copies_ai_key_from_state_manifest(tmp_path): updated_manifest["resources_endpoints"]["worker"] == "https://worker.api.runpod.ai" ) + + with open(flash_dir / "flash_manifest.json") as f: + persisted_manifest = json.load(f) + assert "aiKey" not in persisted_manifest["resources"]["worker"] diff --git a/tests/unit/resources/test_worker_availability_diagnostic.py b/tests/unit/resources/test_worker_availability_diagnostic.py index bc363d48..0807fe04 100644 --- a/tests/unit/resources/test_worker_availability_diagnostic.py +++ b/tests/unit/resources/test_worker_availability_diagnostic.py @@ -35,12 +35,7 @@ async def test_diagnose_gpu_no_availability_includes_selected_locations(): resource.locations = "EU-RO-1,US-GA-2" mock_client = MagicMock() - mock_client._execute_graphql = AsyncMock( - side_effect=[ - {"gpuTypes": [{"lowestPrice": {"stockStatus": None}}]}, - {"gpuTypes": [{"lowestPrice": {"stockStatus": None}}]}, - ] - ) + mock_client.get_gpu_lowest_price_stock_status = AsyncMock(side_effect=[None, None]) with patch( "runpod_flash.core.resources.worker_availability_diagnostic.RunpodGraphQLClient", @@ -64,12 +59,7 @@ async def test_diagnose_gpu_availability_shows_signal_without_locations(): resource.locations = "EU-RO-1,US-GA-2" mock_client = MagicMock() - mock_client._execute_graphql = AsyncMock( - side_effect=[ - {"gpuTypes": [{"lowestPrice": {"stockStatus": None}}]}, - {"gpuTypes": [{"lowestPrice": {"stockStatus": "Low"}}]}, - ] - ) + mock_client.get_gpu_lowest_price_stock_status = AsyncMock(side_effect=[None, "Low"]) with patch( "runpod_flash.core.resources.worker_availability_diagnostic.RunpodGraphQLClient", @@ -94,12 +84,7 @@ async def test_diagnose_cpu_no_availability_message(): resource.locations = "EU-RO-1,US-GA-2" mock_client = MagicMock() - mock_client._execute_graphql = AsyncMock( - side_effect=[ - {"cpuFlavors": [{"specifics": {"stockStatus": None}}]}, - {"cpuFlavors": [{"specifics": {"stockStatus": None}}]}, - ] - ) + mock_client.get_cpu_specific_stock_status = AsyncMock(side_effect=[None, None]) with patch( "runpod_flash.core.resources.worker_availability_diagnostic.RunpodGraphQLClient", @@ -130,3 +115,37 @@ async def test_diagnose_prefers_throttled_reason_over_no_availability(): assert result.reason == "workers_throttled" assert "Workers are currently throttled on endpoint" in result.message assert "Consider raising max workers or changing gpu type" in result.message + + +@pytest.mark.asyncio +async def test_diagnose_cpu_throttled_message_references_cpu_type(): + resource = ServerlessResource(name="test") + resource.instanceIds = [CpuInstanceType.CPU3G_2_8] + + result = await WorkerAvailabilityDiagnostic().diagnose( + resource, + worker_metrics={"throttled": 2}, + ) + + assert result.reason == "workers_throttled" + assert "changing cpu type" in result.message + + +@pytest.mark.asyncio +async def test_diagnose_treats_out_of_stock_as_unavailable(): + resource = ServerlessResource(name="test") + resource.gpuIds = "NVIDIA GeForce RTX 4090" + + mock_client = MagicMock() + mock_client.get_gpu_lowest_price_stock_status = AsyncMock( + side_effect=["OUT_OF_STOCK"] + ) + + with patch( + "runpod_flash.core.resources.worker_availability_diagnostic.RunpodGraphQLClient", + return_value=_make_client_context(mock_client), + ): + result = await WorkerAvailabilityDiagnostic().diagnose(resource) + + assert result.has_availability is False + assert result.reason == "no_gpu_availability" From 9a636d3577c58b39db5dee0eed6bc11ff361bc78 Mon Sep 17 00:00:00 2001 From: jhcipar Date: Wed, 8 Apr 2026 16:15:20 -0400 Subject: [PATCH 6/6] chore: PR feedback --- src/runpod_flash/cli/utils/deployment.py | 25 +++++++- src/runpod_flash/core/resources/serverless.py | 37 ++++++------ tests/unit/cli/utils/test_deployment.py | 50 ++++++++++++++++ tests/unit/resources/test_serverless.py | 59 +++++++++++++++++-- 4 files changed, 145 insertions(+), 26 deletions(-) diff --git a/src/runpod_flash/cli/utils/deployment.py b/src/runpod_flash/cli/utils/deployment.py index 609efdf2..62776b34 100644 --- a/src/runpod_flash/cli/utils/deployment.py +++ b/src/runpod_flash/cli/utils/deployment.py @@ -9,12 +9,18 @@ from pathlib import Path from runpod_flash.config import get_paths +from runpod_flash.core.resources.serverless import ServerlessResource from runpod_flash.core.resources.app import FlashApp from runpod_flash.core.resources.resource_manager import ResourceManager from runpod_flash.runtime.resource_provisioner import create_resource_from_manifest log = logging.getLogger(__name__) +RUNTIME_RESOURCE_FIELDS = set(ServerlessResource.RUNTIME_FIELDS) | { + "id", + "endpoint_id", +} + def _normalized_resource_attr(resource: Any, *names: str) -> str | None: for name in names: @@ -37,6 +43,13 @@ def _manifest_without_ai_keys(manifest: Dict[str, Any]) -> Dict[str, Any]: return sanitized_manifest +def _resource_config_for_compare(config: Dict[str, Any]) -> Dict[str, Any]: + compare_config = copy.deepcopy(config) + for field in RUNTIME_RESOURCE_FIELDS: + compare_config.pop(field, None) + return compare_config + + async def upload_build(app_name: str, build_path: str | Path): app = await FlashApp.from_name(app_name) await app.upload_build(build_path) @@ -288,9 +301,15 @@ async def reconcile_and_provision_resources( local_config = local_manifest["resources"][resource_name] state_config = state_manifest.get("resources", {}).get(resource_name, {}) - # Simple hash comparison for config changes - local_json = json.dumps(local_config, sort_keys=True) - state_json = json.dumps(state_config, sort_keys=True) + # Compare only user-managed config fields (exclude runtime metadata) + local_json = json.dumps( + _resource_config_for_compare(local_config), + sort_keys=True, + ) + state_json = json.dumps( + _resource_config_for_compare(state_config), + sort_keys=True, + ) # Check if endpoint exists in state manifest has_endpoint = resource_name in state_manifest.get("resources_endpoints", {}) diff --git a/src/runpod_flash/core/resources/serverless.py b/src/runpod_flash/core/resources/serverless.py index 4d9358a2..85c757af 100644 --- a/src/runpod_flash/core/resources/serverless.py +++ b/src/runpod_flash/core/resources/serverless.py @@ -3,6 +3,7 @@ import logging import os import re +from collections import Counter from datetime import datetime, timezone from enum import Enum from pathlib import Path @@ -264,7 +265,7 @@ async def _emit_endpoint_logs( if batch.lines: for line in batch.lines: - print(f"worker log: {line}") + log.info("worker log: %s", line) return batch @@ -1308,27 +1309,29 @@ async def run(self, payload: Dict[str, Any]) -> "JobOutput": output = response.get("output") if isinstance(output, dict): stdout = output.get("stdout") - if isinstance(stdout, str): - seen_normalized = { + should_dedupe_stdout = ( + self.type == ServerlessType.QB + and fetcher.has_streamed_logs + and bool(fetcher.seen) + ) + if should_dedupe_stdout and isinstance(stdout, str): + seen_normalized_counts = Counter( normalized for line in fetcher.seen if (normalized := _normalize_stream_log_line(line)) - } + ) kept = [] - for raw in stdout.splitlines(): - raw = raw.strip() - if not raw: - continue - - normalized_raw = _normalize_stream_log_line(raw) - if not normalized_raw: - continue - if normalized_raw in seen_normalized: + for raw_line in stdout.splitlines(keepends=True): + normalized_raw = _normalize_stream_log_line(raw_line) + if ( + normalized_raw + and seen_normalized_counts.get(normalized_raw, 0) + > 0 + ): + seen_normalized_counts[normalized_raw] -= 1 continue - - seen_normalized.add(normalized_raw) - kept.append(raw) - output["stdout"] = "\n".join(kept) + kept.append(raw_line) + output["stdout"] = "".join(kept) return JobOutput(**response) except Exception as e: diff --git a/tests/unit/cli/utils/test_deployment.py b/tests/unit/cli/utils/test_deployment.py index 16bf3dc2..7e20656a 100644 --- a/tests/unit/cli/utils/test_deployment.py +++ b/tests/unit/cli/utils/test_deployment.py @@ -582,3 +582,53 @@ async def test_reconciliation_copies_ai_key_from_state_manifest(tmp_path): with open(flash_dir / "flash_manifest.json") as f: persisted_manifest = json.load(f) assert "aiKey" not in persisted_manifest["resources"]["worker"] + + +@pytest.mark.asyncio +async def test_reconciliation_ignores_runtime_fields_in_config_comparison(tmp_path): + import json + + flash_dir = tmp_path / ".flash" + flash_dir.mkdir() + + local_manifest = { + "resources": { + "worker": { + "resource_type": "LiveServerless", + "config": "same", + }, + }, + "resources_endpoints": {}, + } + (flash_dir / "flash_manifest.json").write_text(json.dumps(local_manifest)) + + state_manifest = { + "resources": { + "worker": { + "resource_type": "LiveServerless", + "config": "same", + "aiKey": "ai-key-123", + "endpoint_id": "endpoint-123", + "templateId": "template-123", + }, + }, + "resources_endpoints": { + "worker": "https://worker.api.runpod.ai", + }, + } + + app = AsyncMock() + app.get_build_manifest = AsyncMock(return_value=state_manifest) + app.update_build_manifest = AsyncMock() + + with ( + patch("pathlib.Path.cwd", return_value=tmp_path), + patch("runpod_flash.cli.utils.deployment.ResourceManager") as mock_manager_cls, + ): + mock_manager = MagicMock() + mock_manager.get_or_deploy_resource = AsyncMock() + mock_manager_cls.return_value = mock_manager + + await reconcile_and_provision_resources(app, "build-123", "dev", local_manifest) + + mock_manager.get_or_deploy_resource.assert_not_called() diff --git a/tests/unit/resources/test_serverless.py b/tests/unit/resources/test_serverless.py index b97b1d1d..694f7d76 100644 --- a/tests/unit/resources/test_serverless.py +++ b/tests/unit/resources/test_serverless.py @@ -1149,6 +1149,7 @@ async def test_run_async_dedupes_stdout_against_streamed_pod_logs(self): mock_endpoint.run.return_value = mock_job async def fake_emit(*, fetcher, request_id): + fetcher.has_streamed_logs = True fetcher.seen.add( "2026-04-02T18:18:10.165152015Z 2026-04-02 18:18:10,164 | DEBUG | aiohttp_retry | client.py:110 | Attempt 1 out of 3" ) @@ -1172,7 +1173,53 @@ async def fake_emit(*, fetcher, request_id): result = await serverless.run({"input": "test"}) assert isinstance(result, JobOutput) - assert result.output["stdout"] == "unique stdout line" + assert result.output["stdout"] == ( + "2026-04-02 18:18:10,164 | DEBUG | aiohttp_retry | client.py:110 | Attempt 1 out of 3\n" + "unique stdout line" + ) + + @pytest.mark.asyncio + async def test_run_async_keeps_stdout_unchanged_when_no_streamed_logs(self): + serverless = ServerlessResource(name="test") + serverless.id = "endpoint-123" + serverless.type = ServerlessType.QB + serverless.aiKey = "endpoint-ai-key" + + original_stdout = "dup line\ndup line\n\n spaced line" + mock_job = MagicMock() + mock_job.job_id = "job-123" + mock_job.status.side_effect = ["IN_QUEUE", "COMPLETED"] + mock_job._fetch_job.return_value = { + "id": "job-123", + "workerId": "worker-456", + "status": "COMPLETED", + "delayTime": 1000, + "executionTime": 2000, + "output": {"stdout": original_stdout}, + } + + mock_endpoint = MagicMock() + mock_endpoint.run.return_value = mock_job + + async def fake_emit(*, fetcher, request_id): + fetcher.seen.add("dup line") + return None + + with patch.object( + type(serverless), + "endpoint", + new_callable=lambda: property(lambda self: mock_endpoint), + ): + with patch("asyncio.sleep"): + with patch.object( + ServerlessResource, + "_emit_endpoint_logs", + new=AsyncMock(side_effect=fake_emit), + ): + result = await serverless.run({"input": "test"}) + + assert isinstance(result, JobOutput) + assert result.output["stdout"] == original_stdout @pytest.mark.asyncio async def test_run_async_fetches_endpoint_logs_while_polling(self): @@ -1466,8 +1513,8 @@ async def test_run_async_stops_waiting_metrics_logs_after_in_progress(self): assert not any("status=IN_PROGRESS" in line for line in metrics_logs) @pytest.mark.asyncio - async def test_emit_endpoint_logs_prints_worker_lines(self): - """Endpoint log emission prints each worker log line.""" + async def test_emit_endpoint_logs_uses_logger_for_worker_lines(self): + """Endpoint log emission logs each worker line through logger.""" serverless = ServerlessResource(name="test") serverless.id = "endpoint-123" serverless.type = ServerlessType.QB @@ -1487,7 +1534,7 @@ async def test_emit_endpoint_logs_prints_worker_lines(self): "runpod_flash.core.resources.serverless.get_api_key", return_value="runpod-key-123", ): - with patch("builtins.print") as mock_print: + with patch("runpod_flash.core.resources.serverless.log.info") as mock_info: batch = await serverless._emit_endpoint_logs( fetcher=mock_fetcher, request_id="job-123", @@ -1502,8 +1549,8 @@ async def test_emit_endpoint_logs_prints_worker_lines(self): ) assert batch is not None assert batch.phase == QBRequestLogPhase.STREAMING - mock_print.assert_any_call("worker log: line-a") - mock_print.assert_any_call("worker log: line-b") + mock_info.assert_any_call("worker log: %s", "line-a") + mock_info.assert_any_call("worker log: %s", "line-b") @pytest.mark.asyncio async def test_emit_endpoint_logs_skips_when_missing_required_fields(self):