Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 220 additions & 20 deletions src/runloop_api_client/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
import asyncio
import inspect
import logging
import weakref
import platform
import warnings
import threading
import email.utils
from types import TracebackType
from random import random
Expand Down Expand Up @@ -90,6 +92,88 @@

log: logging.Logger = logging.getLogger(__name__)

# Shared HTTP transport state. We share transports (connection pools) rather
# than full httpx clients so each SDK instance keeps its own cookie jar and
# mutable client state. Refcounted wrappers close the real transport only
# when the last user releases it.
# The async transport is keyed by event loop because connections bind to the
# loop that created them and cannot be reused across asyncio.run() calls.
_pool_lock = threading.Lock()


class _SharedTransport(httpx.BaseTransport):
"""Refcounted wrapper: delegates to a real transport, closes it when refcount hits 0."""

def __init__(self, transport: httpx.BaseTransport) -> None:
self._transport = transport
self._refcount = 1
self._lock = threading.Lock()

@property
def refcount(self) -> int:
return self._refcount

def acquire(self) -> bool:
with self._lock:
if self._refcount <= 0:
return False
self._refcount += 1
return True

@override
def handle_request(self, request: httpx.Request) -> httpx.Response:
return self._transport.handle_request(request)

@override
def close(self) -> None:
should_close = False
with self._lock:
self._refcount -= 1
if self._refcount <= 0:
should_close = True
if should_close:
self._transport.close()


class _SharedAsyncTransport(httpx.AsyncBaseTransport):
"""Async refcounted wrapper: delegates to a real async transport."""

def __init__(self, transport: httpx.AsyncBaseTransport) -> None:
self._transport = transport
self._refcount = 1
self._lock = threading.Lock()

@property
def refcount(self) -> int:
return self._refcount

def acquire(self) -> bool:
with self._lock:
if self._refcount <= 0:
return False
self._refcount += 1
return True

@override
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
return await self._transport.handle_async_request(request)

@override
async def aclose(self) -> None:
should_close = False
with self._lock:
self._refcount -= 1
if self._refcount <= 0:
should_close = True
if should_close:
await self._transport.aclose()


_shared_sync_transport: _SharedTransport | None = None
_shared_async_transports: weakref.WeakKeyDictionary[asyncio.AbstractEventLoop, _SharedAsyncTransport] = (
weakref.WeakKeyDictionary()
)

# TODO: make base page type vars covariant
SyncPageT = TypeVar("SyncPageT", bound="BaseSyncPage[Any]")
AsyncPageT = TypeVar("AsyncPageT", bound="BaseAsyncPage[Any]")
Expand Down Expand Up @@ -816,6 +900,7 @@ def __init__(self, **kwargs: Any) -> None:
kwargs.setdefault("timeout", DEFAULT_TIMEOUT)
kwargs.setdefault("limits", DEFAULT_CONNECTION_LIMITS)
kwargs.setdefault("follow_redirects", True)
kwargs.setdefault("http2", True)
super().__init__(**kwargs)


Expand Down Expand Up @@ -845,6 +930,8 @@ def __del__(self) -> None:
class SyncAPIClient(BaseClient[httpx.Client, Stream[Any]]):
_client: httpx.Client
_default_stream_cls: type[Stream[Any]] | None = None
_uses_shared_pool: bool
_closed: bool

def __init__(
self,
Expand All @@ -857,6 +944,7 @@ def __init__(
custom_headers: Mapping[str, str] | None = None,
custom_query: Mapping[str, object] | None = None,
_strict_response_validation: bool,
shared_http_pool: bool = True,
) -> None:
if not is_given(timeout):
# if the user passed in a custom http client with a non-default
Expand Down Expand Up @@ -886,24 +974,46 @@ def __init__(
custom_headers=custom_headers,
_strict_response_validation=_strict_response_validation,
)
self._client = http_client or SyncHttpxClientWrapper(
base_url=base_url,
# cast to a valid type because mypy doesn't understand our type narrowing
timeout=cast(Timeout, timeout),
)

self._closed = False

if http_client is not None:
self._client = http_client
self._uses_shared_pool = False
elif shared_http_pool:
global _shared_sync_transport
with _pool_lock:
if _shared_sync_transport is None or not _shared_sync_transport.acquire():
_shared_sync_transport = _SharedTransport(
httpx.HTTPTransport(limits=DEFAULT_CONNECTION_LIMITS, http2=True),
)
self._client = SyncHttpxClientWrapper(
base_url=base_url,
timeout=cast(Timeout, timeout),
transport=_shared_sync_transport,
)
self._uses_shared_pool = True
else:
self._client = SyncHttpxClientWrapper(
base_url=base_url,
timeout=cast(Timeout, timeout),
)
self._uses_shared_pool = False

def is_closed(self) -> bool:
return self._client.is_closed
return self._closed or self._client.is_closed

def close(self) -> None:
"""Close the underlying HTTPX client.

The client will *not* be usable after this.
"""
# If an error is thrown while constructing a client, self._client
# may not be present
if hasattr(self, "_client"):
self._client.close()
if not hasattr(self, "_client"):
return
if self._closed:
return
self._closed = True
self._client.close()

def __enter__(self: _T) -> _T:
return self
Expand Down Expand Up @@ -1018,6 +1128,7 @@ def request(
max_retries=max_retries,
options=input_options,
response=None,
error=err,
)
continue

Expand All @@ -1032,6 +1143,7 @@ def request(
max_retries=max_retries,
options=input_options,
response=None,
error=err,
)
continue

Expand Down Expand Up @@ -1083,7 +1195,13 @@ def request(
)

def _sleep_for_retry(
self, *, retries_taken: int, max_retries: int, options: FinalRequestOptions, response: httpx.Response | None
self,
*,
retries_taken: int,
max_retries: int,
options: FinalRequestOptions,
response: httpx.Response | None,
error: BaseException | None = None,
) -> None:
remaining_retries = max_retries - retries_taken
if remaining_retries == 1:
Expand All @@ -1092,7 +1210,23 @@ def _sleep_for_retry(
log.debug("%i retries left", remaining_retries)

timeout = self._calculate_retry_timeout(remaining_retries, options, response.headers if response else None)
log.info("Retrying request to %s in %f seconds", options.url, timeout)
if response is not None:
log.info(
"Retrying request to %s in %f seconds (status %d)",
options.url,
timeout,
response.status_code,
)
elif error is not None:
log.info(
"Retrying request to %s in %f seconds (%s: %s)",
options.url,
timeout,
type(error).__name__,
error,
)
else:
log.info("Retrying request to %s in %f seconds", options.url, timeout)

time.sleep(timeout)

Expand Down Expand Up @@ -1428,6 +1562,8 @@ def __del__(self) -> None:
class AsyncAPIClient(BaseClient[httpx.AsyncClient, AsyncStream[Any]]):
_client: httpx.AsyncClient
_default_stream_cls: type[AsyncStream[Any]] | None = None
_uses_shared_pool: bool
_closed: bool

def __init__(
self,
Expand All @@ -1440,6 +1576,7 @@ def __init__(
http_client: httpx.AsyncClient | None = None,
custom_headers: Mapping[str, str] | None = None,
custom_query: Mapping[str, object] | None = None,
shared_http_pool: bool = True,
) -> None:
if not is_given(timeout):
# if the user passed in a custom http client with a non-default
Expand Down Expand Up @@ -1469,20 +1606,59 @@ def __init__(
custom_headers=custom_headers,
_strict_response_validation=_strict_response_validation,
)
self._client = http_client or AsyncHttpxClientWrapper(
base_url=base_url,
# cast to a valid type because mypy doesn't understand our type narrowing
timeout=cast(Timeout, timeout),
)

self._closed = False

if http_client is not None:
self._client = http_client
self._uses_shared_pool = False
elif shared_http_pool:
try:
loop: asyncio.AbstractEventLoop | None = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop is not None:
with _pool_lock:
existing = _shared_async_transports.get(loop)
if existing is not None and existing.acquire():
transport: _SharedAsyncTransport = existing
else:
transport = _SharedAsyncTransport(
httpx.AsyncHTTPTransport(limits=DEFAULT_CONNECTION_LIMITS, http2=True),
)
_shared_async_transports[loop] = transport
self._client = AsyncHttpxClientWrapper(
base_url=base_url,
timeout=cast(Timeout, timeout),
transport=transport,
)
self._uses_shared_pool = True
else:
self._client = AsyncHttpxClientWrapper(
base_url=base_url,
timeout=cast(Timeout, timeout),
)
self._uses_shared_pool = False
else:
self._client = AsyncHttpxClientWrapper(
base_url=base_url,
timeout=cast(Timeout, timeout),
)
self._uses_shared_pool = False

def is_closed(self) -> bool:
return self._client.is_closed
return self._closed or self._client.is_closed

async def close(self) -> None:
"""Close the underlying HTTPX client.

The client will *not* be usable after this.
"""
if not hasattr(self, "_client"):
return
if self._closed:
return
self._closed = True
await self._client.aclose()

async def __aenter__(self: _T) -> _T:
Expand Down Expand Up @@ -1603,6 +1779,7 @@ async def request(
max_retries=max_retries,
options=input_options,
response=None,
error=err,
)
continue

Expand All @@ -1617,6 +1794,7 @@ async def request(
max_retries=max_retries,
options=input_options,
response=None,
error=err,
)
continue

Expand Down Expand Up @@ -1668,7 +1846,13 @@ async def request(
)

async def _sleep_for_retry(
self, *, retries_taken: int, max_retries: int, options: FinalRequestOptions, response: httpx.Response | None
self,
*,
retries_taken: int,
max_retries: int,
options: FinalRequestOptions,
response: httpx.Response | None,
error: BaseException | None = None,
) -> None:
remaining_retries = max_retries - retries_taken
if remaining_retries == 1:
Expand All @@ -1677,7 +1861,23 @@ async def _sleep_for_retry(
log.debug("%i retries left", remaining_retries)

timeout = self._calculate_retry_timeout(remaining_retries, options, response.headers if response else None)
log.info("Retrying request to %s in %f seconds", options.url, timeout)
if response is not None:
log.info(
"Retrying request to %s in %f seconds (status %d)",
options.url,
timeout,
response.status_code,
)
elif error is not None:
log.info(
"Retrying request to %s in %f seconds (%s: %s)",
options.url,
timeout,
type(error).__name__,
error,
)
else:
log.info("Retrying request to %s in %f seconds", options.url, timeout)

await anyio.sleep(timeout)

Expand Down
Loading
Loading