Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/runpod_flash/core/api/runpod.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,14 @@ def __init__(self, api_key: Optional[str] = None):
async def _get_session(self) -> aiohttp.ClientSession:
"""Get or create an aiohttp session."""
if self.session is None or self.session.closed:
from runpod_flash.core.utils.user_agent import get_user_agent

timeout = aiohttp.ClientTimeout(total=300) # 5 minute timeout
connector = aiohttp.TCPConnector(resolver=ThreadedResolver())
self.session = aiohttp.ClientSession(
timeout=timeout,
headers={
"User-Agent": get_user_agent(),
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
Expand Down Expand Up @@ -812,10 +815,13 @@ def __init__(self, api_key: Optional[str] = None):
async def _get_session(self) -> aiohttp.ClientSession:
"""Get or create an aiohttp session."""
if self.session is None or self.session.closed:
from runpod_flash.core.utils.user_agent import get_user_agent

timeout = aiohttp.ClientTimeout(total=300) # 5 minute timeout
self.session = aiohttp.ClientSession(
timeout=timeout,
headers={
"User-Agent": get_user_agent(),
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
Expand Down
14 changes: 12 additions & 2 deletions src/runpod_flash/core/resources/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,11 +337,16 @@ async def download_tarball(self, environment_id: str, dest_file: str) -> None:
ValueError: If environment has no active artifact
requests.HTTPError: If download fails
"""
from runpod_flash.core.utils.user_agent import get_user_agent

await self._hydrate()
result = await self._get_active_artifact(environment_id)
url = result["downloadUrl"]

headers = {"User-Agent": get_user_agent()}

with open(dest_file, "wb") as stream:
with requests.get(url, stream=True) as resp:
with requests.get(url, stream=True, headers=headers) as resp:
resp.raise_for_status()
for chunk in resp.iter_content():
if chunk:
Expand Down Expand Up @@ -462,14 +467,19 @@ async def upload_build(self, tar_path: Union[str, Path]) -> Dict[str, Any]:
except json.JSONDecodeError as e:
raise ValueError(f"Invalid manifest JSON at {manifest_path}: {e}") from e

from runpod_flash.core.utils.user_agent import get_user_agent

await self._hydrate()
tarball_size = tar_path.stat().st_size

result = await self._get_tarball_upload_url(tarball_size)
url = result["uploadUrl"]
object_key = result["objectKey"]

headers = {"Content-Type": TARBALL_CONTENT_TYPE}
headers = {
"User-Agent": get_user_agent(),
"Content-Type": TARBALL_CONTENT_TYPE,
}

with tar_path.open("rb") as fh:
resp = requests.put(url, data=fh, headers=headers)
Expand Down
28 changes: 21 additions & 7 deletions src/runpod_flash/core/utils/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@ def get_authenticated_httpx_client(
timeout: Optional[float] = None,
api_key_override: Optional[str] = None,
) -> httpx.AsyncClient:
"""Create httpx AsyncClient with RunPod authentication.
"""Create httpx AsyncClient with RunPod authentication and User-Agent.

Automatically includes:
- User-Agent header identifying flash client and version
- Authorization header if RUNPOD_API_KEY is set

Automatically includes Authorization header if RUNPOD_API_KEY is set.
This provides a centralized place to manage authentication headers for
all RunPod HTTP requests, avoiding repetitive manual header addition.

Expand All @@ -23,7 +26,7 @@ def get_authenticated_httpx_client(
Used for propagating API keys from mothership to worker endpoints.

Returns:
Configured httpx.AsyncClient with Authorization header
Configured httpx.AsyncClient with User-Agent and Authorization headers

Example:
async with get_authenticated_httpx_client() as client:
Expand All @@ -37,7 +40,11 @@ def get_authenticated_httpx_client(
async with get_authenticated_httpx_client(api_key_override=context_key) as client:
response = await client.post(url, json=data)
"""
headers = {}
from .user_agent import get_user_agent

headers = {
"User-Agent": get_user_agent(),
}
api_key = api_key_override or os.environ.get("RUNPOD_API_KEY")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
Expand All @@ -49,9 +56,12 @@ def get_authenticated_httpx_client(
def get_authenticated_requests_session(
api_key_override: Optional[str] = None,
) -> requests.Session:
"""Create requests Session with RunPod authentication.
"""Create requests Session with RunPod authentication and User-Agent.

Automatically includes:
- User-Agent header identifying flash client and version
- Authorization header if RUNPOD_API_KEY is set

Automatically includes Authorization header if RUNPOD_API_KEY is set.
Provides a centralized place to manage authentication headers for
synchronous RunPod HTTP requests.

Expand All @@ -60,7 +70,7 @@ def get_authenticated_requests_session(
Used for propagating API keys from mothership to worker endpoints.

Returns:
Configured requests.Session with Authorization header
Configured requests.Session with User-Agent and Authorization headers

Example:
session = get_authenticated_requests_session()
Expand All @@ -76,7 +86,11 @@ def get_authenticated_requests_session(
with contextlib.closing(get_authenticated_requests_session(api_key_override=context_key)) as session:
response = session.post(url, json=data)
"""
from .user_agent import get_user_agent

session = requests.Session()
session.headers["User-Agent"] = get_user_agent()

api_key = api_key_override or os.environ.get("RUNPOD_API_KEY")
if api_key:
session.headers["Authorization"] = f"Bearer {api_key}"
Expand Down
27 changes: 27 additions & 0 deletions src/runpod_flash/core/utils/user_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""User-Agent header generation for HTTP requests."""

import platform
from importlib.metadata import version


def get_user_agent() -> str:
"""Get the User-Agent string for flash HTTP requests.

Returns:
User-Agent string in format: Runpod Flash/<version> (Python <python_version>; <OS> <OS_version>; <arch>)

Example:
>>> get_user_agent()
'Runpod Flash/1.1.1 (Python 3.11.12; Darwin 25.2.0; arm64)'
"""
try:
pkg_version = version("runpod-flash")
except Exception:
pkg_version = "unknown"

python_version = platform.python_version()
os_name = platform.system()
os_version = platform.release()
arch = platform.machine()

return f"Runpod Flash/{pkg_version} (Python {python_version}; {os_name} {os_version}; {arch})"
46 changes: 46 additions & 0 deletions tests/unit/core/utils/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,28 @@ def test_get_authenticated_httpx_client_zero_timeout(self, monkeypatch):
assert client is not None
assert client.timeout.read == 0.0

def test_get_authenticated_httpx_client_includes_user_agent(self, monkeypatch):
"""Test client includes User-Agent header."""
monkeypatch.delenv("RUNPOD_API_KEY", raising=False)

client = get_authenticated_httpx_client()

assert client is not None
assert "User-Agent" in client.headers
assert client.headers["User-Agent"].startswith("Runpod Flash/")

def test_get_authenticated_httpx_client_user_agent_with_auth(self, monkeypatch):
"""Test client includes both User-Agent and Authorization headers."""
monkeypatch.setenv("RUNPOD_API_KEY", "test-key")

client = get_authenticated_httpx_client()

assert client is not None
assert "User-Agent" in client.headers
assert "Authorization" in client.headers
assert client.headers["User-Agent"].startswith("Runpod Flash/")
assert client.headers["Authorization"] == "Bearer test-key"


class TestGetAuthenticatedRequestsSession:
"""Test the get_authenticated_requests_session utility function."""
Expand Down Expand Up @@ -123,3 +145,27 @@ def test_get_authenticated_requests_session_is_valid_session(self, monkeypatch):

assert isinstance(session, requests.Session)
session.close()

def test_get_authenticated_requests_session_includes_user_agent(self, monkeypatch):
"""Test session includes User-Agent header."""
monkeypatch.delenv("RUNPOD_API_KEY", raising=False)

session = get_authenticated_requests_session()

assert session is not None
assert "User-Agent" in session.headers
assert session.headers["User-Agent"].startswith("Runpod Flash/")
session.close()

def test_get_authenticated_requests_session_user_agent_with_auth(self, monkeypatch):
"""Test session includes both User-Agent and Authorization headers."""
monkeypatch.setenv("RUNPOD_API_KEY", "test-key")

session = get_authenticated_requests_session()

assert session is not None
assert "User-Agent" in session.headers
assert "Authorization" in session.headers
assert session.headers["User-Agent"].startswith("Runpod Flash/")
assert session.headers["Authorization"] == "Bearer test-key"
session.close()
99 changes: 99 additions & 0 deletions tests/unit/core/utils/test_user_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""Tests for user_agent module."""

import platform
import re


def test_get_user_agent_format():
"""Test User-Agent string format matches expected pattern."""
from runpod_flash.core.utils.user_agent import get_user_agent

ua = get_user_agent()

# Should match: "Runpod Flash/<version> (Python <py_version>; <OS> <OS_version>; <arch>)"
pattern = r"^Runpod Flash/[\w\.]+ \(Python [\d\.]+; \w+ [\w\.\-]+; [\w\d_]+\)$"
assert re.match(pattern, ua), f"User-Agent '{ua}' doesn't match expected format"


def test_get_user_agent_contains_version():
"""Test User-Agent includes version information."""
from runpod_flash.core.utils.user_agent import get_user_agent

ua = get_user_agent()

# Should start with "Runpod Flash/"
assert ua.startswith("Runpod Flash/"), (
f"User-Agent should start with 'Runpod Flash/', got: {ua}"
)

# Should contain version (either real version or 'unknown')
version_part = ua.split(" ")[2] # "Runpod Flash/<version> (Python ..."
assert version_part.startswith("(Python"), (
"User-Agent should contain Python version"
)


def test_get_user_agent_contains_python_version():
"""Test User-Agent includes Python version."""
from runpod_flash.core.utils.user_agent import get_user_agent

ua = get_user_agent()
python_version = platform.python_version()

assert f"Python {python_version}" in ua, (
f"User-Agent should contain Python {python_version}"
)


def test_get_user_agent_contains_os():
"""Test User-Agent includes OS name."""
from runpod_flash.core.utils.user_agent import get_user_agent

ua = get_user_agent()
os_name = platform.system()

assert os_name in ua, f"User-Agent should contain OS name {os_name}"


def test_get_user_agent_contains_os_version():
"""Test User-Agent includes OS version."""
from runpod_flash.core.utils.user_agent import get_user_agent

ua = get_user_agent()
os_version = platform.release()

assert os_version in ua, f"User-Agent should contain OS version {os_version}"


def test_get_user_agent_contains_architecture():
"""Test User-Agent includes CPU architecture."""
from runpod_flash.core.utils.user_agent import get_user_agent

ua = get_user_agent()
arch = platform.machine()

assert arch in ua, f"User-Agent should contain architecture {arch}"


def test_get_user_agent_structure():
"""Test User-Agent has correct structure."""
from runpod_flash.core.utils.user_agent import get_user_agent

ua = get_user_agent()

# Should have exactly one opening and closing parenthesis
assert ua.count("(") == 1, "User-Agent should have exactly one opening parenthesis"
assert ua.count(")") == 1, "User-Agent should have exactly one closing parenthesis"

# Should have exactly two semicolons (Python/OS separator, OS/arch separator)
assert ua.count(";") == 2, "User-Agent should have exactly two semicolons"


def test_get_user_agent_consistency():
"""Test User-Agent is consistent across multiple calls."""
from runpod_flash.core.utils.user_agent import get_user_agent

ua1 = get_user_agent()
ua2 = get_user_agent()

assert ua1 == ua2, "User-Agent should be consistent across calls"
Loading