Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace requests with httpx #147

Merged
merged 9 commits into from Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitattributes
@@ -0,0 +1 @@
tests/cassettes/** binary
6 changes: 4 additions & 2 deletions pyproject.toml
Expand Up @@ -10,12 +10,14 @@ readme = "README.md"
license = { file = "LICENSE" }
authors = [{ name = "Replicate, Inc." }]
requires-python = ">=3.8"
dependencies = ["packaging", "pydantic>1", "requests>2"]
dependencies = ["packaging", "pydantic>1", "httpx>=0.21.0,<1"]
optional-dependencies = { dev = [
"black",
"mypy",
"pytest",
"responses",
"pytest-asyncio",
"pytest-recording",
"respx",
"ruff",
] }

Expand Down
279 changes: 180 additions & 99 deletions replicate/client.py
@@ -1,106 +1,77 @@
import os
import random
import re
from json import JSONDecodeError
from typing import Any, Dict, Iterator, Optional, Union

import requests
from requests.adapters import HTTPAdapter, Retry
from requests.cookies import RequestsCookieJar

from replicate.__about__ import __version__
from replicate.deployment import DeploymentCollection
from replicate.exceptions import ModelError, ReplicateError
from replicate.model import ModelCollection
from replicate.prediction import PredictionCollection
from replicate.training import TrainingCollection
import time
from datetime import datetime
from typing import (
Any,
Iterable,
Iterator,
Mapping,
Optional,
Union,
)

import httpx

from .__about__ import __version__
from .deployment import DeploymentCollection
from .exceptions import ModelError, ReplicateError
from .model import ModelCollection
from .prediction import PredictionCollection
from .training import TrainingCollection


class Client:
def __init__(self, api_token: Optional[str] = None) -> None:
"""A Replicate API client library"""

def __init__(
self,
api_token: Optional[str] = None,
*,
base_url: Optional[str] = None,
timeout: Optional[httpx.Timeout] = None,
**kwargs,
) -> None:
super().__init__()
# Client is instantiated at import time, so do as little as possible.
# This includes resolving environment variables -- they might be set programmatically.
self.api_token = api_token
self.base_url = os.environ.get(

api_token = api_token or os.environ.get("REPLICATE_API_TOKEN")

base_url = base_url or os.environ.get(
"REPLICATE_API_BASE_URL", "https://api.replicate.com"
)
self.poll_interval = float(os.environ.get("REPLICATE_POLL_INTERVAL", "0.5"))

# TODO: make thread safe
self.read_session = _create_session()
read_retries = Retry(
total=5,
backoff_factor=2,
# Only retry 500s on GET so we don't unintionally mutute data
allowed_methods=["GET"],
# https://support.cloudflare.com/hc/en-us/articles/115003011431-Troubleshooting-Cloudflare-5XX-errors
status_forcelist=[
429,
500,
502,
503,
504,
520,
521,
522,
523,
524,
526,
527,
],
)
self.read_session.mount("http://", HTTPAdapter(max_retries=read_retries))
self.read_session.mount("https://", HTTPAdapter(max_retries=read_retries))

self.write_session = _create_session()
write_retries = Retry(
total=5,
backoff_factor=2,
allowed_methods=["POST", "PUT"],
# Only retry POST/PUT requests on rate limits, so we don't unintionally mutute data
status_forcelist=[429],
timeout = timeout or httpx.Timeout(
5.0, read=30.0, write=30.0, connect=5.0, pool=10.0
)
self.write_session.mount("http://", HTTPAdapter(max_retries=write_retries))
self.write_session.mount("https://", HTTPAdapter(max_retries=write_retries))

def _request(self, method: str, path: str, **kwargs) -> requests.Response:
# from requests.Session
if method in ["GET", "OPTIONS"]:
kwargs.setdefault("allow_redirects", True)
if method in ["HEAD"]:
kwargs.setdefault("allow_redirects", False)
kwargs.setdefault("headers", {})
kwargs["headers"].update(self._headers())
session = self.read_session
if method in ["POST", "PUT", "DELETE", "PATCH"]:
session = self.write_session
resp = session.request(method, self.base_url + path, **kwargs)
if 400 <= resp.status_code < 600:
try:
raise ReplicateError(resp.json()["detail"])
except (JSONDecodeError, KeyError):
pass
raise ReplicateError(f"HTTP error: {resp.status_code, resp.reason}")
return resp

def _headers(self) -> Dict[str, str]:
return {
"Authorization": f"Token {self._api_token()}",
self.poll_interval = float(os.environ.get("REPLICATE_POLL_INTERVAL", "0.5"))

headers = {
"Authorization": f"Token {api_token}",
"User-Agent": f"replicate-python/{__version__}",
}

def _api_token(self) -> str:
token = self.api_token
# Evaluate lazily in case environment variable is set with dotenv, or something
if token is None:
token = os.environ.get("REPLICATE_API_TOKEN")
if not token:
raise ReplicateError(
"""No API token provided. You need to set the REPLICATE_API_TOKEN environment variable or create a client with `replicate.Client(api_token=...)`.
transport = kwargs.pop("transport", httpx.HTTPTransport())

You can find your API key on https://replicate.com"""
)
return token
self._client = self._build_client(
**kwargs,
base_url=base_url,
headers=headers,
timeout=timeout,
transport=RetryTransport(wrapped_transport=transport),
)

def _build_client(self, **kwargs) -> httpx.Client:
return httpx.Client(**kwargs)

def _request(self, method: str, path: str, **kwargs) -> httpx.Response:
resp = self._client.request(method, path, **kwargs)

if 400 <= resp.status_code < 600:
raise ReplicateError(resp.json()["detail"])

return resp

@property
def models(self) -> ModelCollection:
Expand Down Expand Up @@ -152,19 +123,129 @@ def run(self, model_version: str, **kwargs) -> Union[Any, Iterator[Any]]:
return prediction.output


class _NonpersistentCookieJar(RequestsCookieJar):
"""
A cookie jar that doesn't persist cookies between requests.
# Adapted from https://github.com/encode/httpx/issues/108#issuecomment-1132753155
class RetryTransport(httpx.AsyncBaseTransport, httpx.BaseTransport):
"""A custom HTTP transport that automatically retries requests using an exponential backoff strategy
for specific HTTP status codes and request methods.
"""

def set(self, name, value, **kwargs) -> None:
return
RETRYABLE_METHODS = frozenset(["HEAD", "GET", "PUT", "DELETE", "OPTIONS", "TRACE"])
RETRYABLE_STATUS_CODES = frozenset(
[
429, # Too Many Requests
503, # Service Unavailable
504, # Gateway Timeout
]
)
MAX_BACKOFF_WAIT = 60

def __init__(
self,
wrapped_transport: Union[httpx.BaseTransport, httpx.AsyncBaseTransport],
max_attempts: int = 10,
max_backoff_wait: float = MAX_BACKOFF_WAIT,
backoff_factor: float = 0.1,
jitter_ratio: float = 0.1,
retryable_methods: Optional[Iterable[str]] = None,
retry_status_codes: Optional[Iterable[int]] = None,
) -> None:
self._wrapped_transport = wrapped_transport

if jitter_ratio < 0 or jitter_ratio > 0.5:
raise ValueError(
f"jitter ratio should be between 0 and 0.5, actual {jitter_ratio}"
)

self.max_attempts = max_attempts
self.backoff_factor = backoff_factor
self.retryable_methods = (
frozenset(retryable_methods)
if retryable_methods
else self.RETRYABLE_METHODS
)
self.retry_status_codes = (
frozenset(retry_status_codes)
if retry_status_codes
else self.RETRYABLE_STATUS_CODES
)
self.jitter_ratio = jitter_ratio
self.max_backoff_wait = max_backoff_wait

def _calculate_sleep(
self, attempts_made: int, headers: Union[httpx.Headers, Mapping[str, str]]
) -> float:
retry_after_header = (headers.get("Retry-After") or "").strip()
if retry_after_header:
if retry_after_header.isdigit():
return float(retry_after_header)

try:
parsed_date = datetime.fromisoformat(retry_after_header).astimezone()
diff = (parsed_date - datetime.now().astimezone()).total_seconds()
if diff > 0:
return min(diff, self.max_backoff_wait)
except ValueError:
pass

backoff = self.backoff_factor * (2 ** (attempts_made - 1))
jitter = (backoff * self.jitter_ratio) * random.choice([1, -1]) # noqa: S311
total_backoff = backoff + jitter
return min(total_backoff, self.max_backoff_wait)

def handle_request(self, request: httpx.Request) -> httpx.Response:
response = self._wrapped_transport.handle_request(request) # type: ignore

if request.method not in self.retryable_methods:
return response

remaining_attempts = self.max_attempts - 1
attempts_made = 1

while True:
if (
remaining_attempts < 1
or response.status_code not in self.retry_status_codes
):
return response

response.close()

sleep_for = self._calculate_sleep(attempts_made, response.headers)
time.sleep(sleep_for)

response = self._wrapped_transport.handle_request(request) # type: ignore

attempts_made += 1
remaining_attempts -= 1

async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
response = await self._wrapped_transport.handle_async_request(request) # type: ignore

if request.method not in self.retryable_methods:
return response

remaining_attempts = self.max_attempts - 1
attempts_made = 1

while True:
if (
remaining_attempts < 1
or response.status_code not in self.retry_status_codes
):
return response

response.close()

sleep_for = self._calculate_sleep(attempts_made, response.headers)
time.sleep(sleep_for)

response = await self._wrapped_transport.handle_async_request(request) # type: ignore

def set_cookie(self, cookie, *args, **kwargs) -> None:
return
attempts_made += 1
remaining_attempts -= 1

async def aclose(self) -> None:
await self._wrapped_transport.aclose() # type: ignore

def _create_session() -> requests.Session:
s = requests.Session()
s.cookies = _NonpersistentCookieJar()
return s
def close(self) -> None:
self._wrapped_transport.close() # type: ignore
4 changes: 2 additions & 2 deletions replicate/files.py
Expand Up @@ -4,7 +4,7 @@
import os
from typing import Optional

import requests
import httpx


def upload_file(fh: io.IOBase, output_file_prefix: Optional[str] = None) -> str:
Expand All @@ -24,7 +24,7 @@ def upload_file(fh: io.IOBase, output_file_prefix: Optional[str] = None) -> str:
if output_file_prefix is not None:
name = getattr(fh, "name", "output")
url = output_file_prefix + os.path.basename(name)
resp = requests.put(url, files={"file": fh}, timeout=None)
resp = httpx.put(url, files={"file": fh}, timeout=None) # type: ignore
resp.raise_for_status()
return url

Expand Down