Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 24 additions & 20 deletions src/runpod_flash/core/resources/serverless.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,33 +982,37 @@ def _inject_template_env(self, key: str, value: str) -> None:
def _inject_runtime_template_vars(self) -> None:
"""Inject runtime env vars into template.env without mutating self.env.

For QB endpoints making remote calls: injects RUNPOD_API_KEY.
For LB endpoints: injects FLASH_MODULE_PATH.
For any endpoint making remote calls (QB or LB): injects RUNPOD_API_KEY.
For LB endpoints: also injects FLASH_MODULE_PATH.

Called by both _do_deploy (initial) and update (env changes) so
runtime vars survive template updates.
"""
env_dict = self.env or {}

if self.type == ServerlessType.QB:
if self._check_makes_remote_calls():
if "RUNPOD_API_KEY" not in env_dict:
from runpod_flash.core.credentials import get_api_key

api_key = get_api_key()
if api_key:
self._inject_template_env("RUNPOD_API_KEY", api_key)
log.debug(
f"{self.name}: Injected RUNPOD_API_KEY for remote calls "
f"(makes_remote_calls=True)"
)
else:
log.warning(
f"{self.name}: makes_remote_calls=True but RUNPOD_API_KEY not set. "
f"Remote calls to other endpoints will fail."
)
# Inject RUNPOD_API_KEY for any endpoint that makes remote calls,
# regardless of type. flash deploy gets the token via
# create_resource_from_manifest, but flash run provisions LB endpoints
# directly through lb_execute, which bypasses that path. Injecting here
# keeps deploy and run symmetric so cross-endpoint calls carry the
# token (SLS-336).
if self._check_makes_remote_calls() and "RUNPOD_API_KEY" not in env_dict:
from runpod_flash.core.credentials import get_api_key

api_key = get_api_key()
if api_key:
self._inject_template_env("RUNPOD_API_KEY", api_key)
log.debug(
f"{self.name}: Injected RUNPOD_API_KEY for remote calls "
f"(makes_remote_calls=True)"
)
else:
log.warning(
f"{self.name}: makes_remote_calls=True but RUNPOD_API_KEY not set. "
f"Remote calls to other endpoints will fail."
)

elif self.type == ServerlessType.LB:
if self.type == ServerlessType.LB:
module_path = self._get_module_path()
if module_path and "FLASH_MODULE_PATH" not in env_dict:
self._inject_template_env("FLASH_MODULE_PATH", module_path)
Expand Down
61 changes: 61 additions & 0 deletions tests/unit/resources/test_serverless.py
Original file line number Diff line number Diff line change
Expand Up @@ -2658,6 +2658,67 @@ async def test_do_deploy_lb_injects_module_path_into_template_env(self):
type_entries = [e for e in template_env if e["key"] == "FLASH_ENDPOINT_TYPE"]
assert len(type_entries) == 0

@pytest.mark.asyncio
async def test_do_deploy_lb_injects_api_key_when_makes_remote_calls(self):
"""LB endpoints that make remote calls must get RUNPOD_API_KEY injected.

Regression test for SLS-336: under ``flash run`` an LB route is
provisioned directly via ``lb_execute`` (bypassing the manifest path
that injects the token in ``flash deploy``). Without injection here the
LB worker has no token and cross-endpoint calls fail with HTTP 401.
"""
from runpod_flash.core.resources.load_balancer_sls_resource import (
LoadBalancerSlsResource,
)

resource = LoadBalancerSlsResource(
name="lb-remote-caller",
imageName="test:latest",
env={"LOG_LEVEL": "INFO"},
flashboot=False,
)

mock_client = AsyncMock()
mock_client.save_endpoint = AsyncMock(
return_value={
"id": "endpoint-lb-remote",
"name": "lb-remote-caller",
"templateId": "tpl-lb-remote",
"gpuIds": "AMPERE_48",
"allowedCudaVersions": "",
}
)

with patch(
"runpod_flash.core.resources.serverless.RunpodGraphQLClient"
) as mock_client_class:
mock_client_class.return_value.__aenter__.return_value = mock_client
mock_client_class.return_value.__aexit__.return_value = None

with patch.object(
ServerlessResource,
"_ensure_network_volume_deployed",
new=AsyncMock(),
):
with patch.object(
LoadBalancerSlsResource, "is_deployed", return_value=False
):
with patch.object(
ServerlessResource,
"_check_makes_remote_calls",
return_value=True,
):
with patch.dict(
os.environ, {"RUNPOD_API_KEY": "test-lb-key-789"}
):
await resource._do_deploy()

payload = mock_client.save_endpoint.call_args.args[0]
template_env = payload.get("template", {}).get("env", [])
api_key_entries = [e for e in template_env if e["key"] == "RUNPOD_API_KEY"]
assert len(api_key_entries) == 1
assert api_key_entries[0]["value"] == "test-lb-key-789"


class TestBuildTemplateUpdatePayload:
"""Test _build_template_update_payload always includes env."""
Expand Down
Loading