Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add async support #76

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Virtualenv
.venv

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
8 changes: 8 additions & 0 deletions replicate/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,11 @@ def reload(self):
new_model = self._collection.get(self.id)
for k, v in new_model.dict().items():
setattr(self, k, v)

async def reload_async(self):
"""
Load this object from the server again.
"""
new_model = await self._collection.get_async(self.id)
for k, v in new_model.dict().items():
setattr(self, k, v)
30 changes: 30 additions & 0 deletions replicate/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from json import JSONDecodeError

import httpx
import requests
from requests.adapters import HTTPAdapter, Retry

Expand All @@ -21,6 +22,9 @@ def __init__(self, api_token=None) -> None:
)
self.poll_interval = float(os.environ.get("REPLICATE_POLL_INTERVAL", "0.5"))

max_retries: int = 5
self.httpx_transport = httpx.AsyncHTTPTransport(retries=max_retries)

# TODO: make thread safe
self.read_session = requests.Session()
read_retries = Retry(
Expand Down Expand Up @@ -78,6 +82,32 @@ def _request(self, method: str, path: str, **kwargs):
raise ReplicateError(f"HTTP error: {resp.status_code, resp.reason}")
return resp

async def _request_async(self, method: str, path: str, **kwargs):
# from requests.Session
if method in ["GET", "OPTIONS"]:
kwargs.setdefault("allow_redirects", True)
if method in ["HEAD"]:
kwargs.setdefault("allow_redirects", False)
kwargs.setdefault("headers", {})
kwargs["headers"].update(self._headers())

async with httpx.AsyncClient(
follow_redirects=True,
transport=self.httpx_transport,
) as client:
if "allow_redirects" in kwargs:
kwargs.pop("allow_redirects")

resp = await client.request(method, self.base_url + path, **kwargs)

if 400 <= resp.status_code < 600:
try:
raise ReplicateError(resp.json()["detail"])
except (JSONDecodeError, KeyError):
pass
raise ReplicateError(f"HTTP error: {resp.status_code, resp.reason}")
return resp

def _headers(self):
return {
"Authorization": f"Token {self._api_token()}",
Expand Down
9 changes: 9 additions & 0 deletions replicate/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@ def get(self, key):
def create(self, attrs=None):
raise NotImplementedError

async def list_async(self):
raise NotImplementedError

async def get_async(self, key):
raise NotImplementedError

async def create_async(self, attrs=None):
raise NotImplementedError

def prepare_model(self, attrs):
"""
Create a model from a set of attributes.
Expand Down
67 changes: 59 additions & 8 deletions replicate/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,16 @@
import mimetypes
import os

import httpx
import requests


def upload_file(fh: io.IOBase, output_file_prefix: str = None) -> str:
def to_data_url(fh: io.IOBase) -> str:
"""
Lifted straight from cog.files
"""
fh.seek(0)

if output_file_prefix is not None:
name = getattr(fh, "name", "output")
url = output_file_prefix + os.path.basename(name)
resp = requests.put(url, files={"file": fh})
resp.raise_for_status()
return url

b = fh.read()
# The file handle is strings, not bytes
if isinstance(b, str):
Expand All @@ -31,3 +25,60 @@ def upload_file(fh: io.IOBase, output_file_prefix: str = None) -> str:
mime_type = "application/octet-stream"
s = encoded_body.decode("utf-8")
return f"data:{mime_type};base64,{s}"


def upload_file_to_server(fh: io.IOBase, output_file_prefix: str) -> str:
"""
Lifted straight from cog.files
"""
fh.seek(0)

name = getattr(fh, "name", "output")
url = output_file_prefix + os.path.basename(name)
resp = requests.put(url, files={"file": fh})
resp.raise_for_status()
return url


def upload_file(fh: io.IOBase, output_file_prefix: str = None) -> str:
"""
Lifted straight from cog.files
"""
fh.seek(0)

if output_file_prefix is not None:
url = upload_file_to_server(fh, output_file_prefix)
return url

data_url: str = to_data_url(fh)
return data_url


async def upload_file_to_server_async(fh: io.IOBase, output_file_prefix: str) -> str:
"""
Lifted straight from cog.files
"""
fh.seek(0)

name = getattr(fh, "name", "output")
url = output_file_prefix + os.path.basename(name)

# httpx does not follow redirects by default
async with httpx.AsyncClient(follow_redirects=True) as client:
resp = await client.put(url, files={"file": fh})

return url


async def upload_file_async(fh: io.IOBase, output_file_prefix: str = None) -> str:
"""
Lifted straight from cog.files
"""
fh.seek(0)

if output_file_prefix is not None:
url = await upload_file_to_server_async(fh, output_file_prefix)
return url

data_url: str = to_data_url(fh)
return data_url
73 changes: 73 additions & 0 deletions replicate/prediction.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import time
from typing import Any, Dict, Iterator, List, Optional

Expand Down Expand Up @@ -27,6 +28,34 @@ def wait(self):
time.sleep(self._client.poll_interval)
self.reload()

async def wait_async(self):
"""Wait for prediction to finish."""
while self.status not in ["succeeded", "failed", "canceled"]:
await asyncio.sleep(0.5)
await self.reload_async()

async def output_iterator_async(self) -> Iterator[Any]:
# TODO: check output is list
previous_output = self.output or []
while self.status not in ["succeeded", "failed", "canceled"]:
output = self.output or []
new_output = output[len(previous_output) :]
for output in new_output:
yield output
previous_output = output

await asyncio.sleep(0.5)
await self.reload_async()

if self.status == "failed":
raise ModelError(self.error)

output = self.output or []
new_output = output[len(previous_output) :]
for output in new_output:
yield output


def output_iterator(self) -> Iterator[Any]:
# TODO: check output is list
previous_output = self.output or []
Expand All @@ -51,6 +80,10 @@ def cancel(self):
"""Cancel a currently running prediction"""
self._client._request("POST", f"/v1/predictions/{self.id}/cancel")

async def cancel_async(self):
"""Cancel a currently running prediction"""
await self._client._request_async("POST", f"/v1/predictions/{self.id}/cancel")


class PredictionCollection(Collection):
model = Prediction
Expand Down Expand Up @@ -93,3 +126,43 @@ def list(self) -> List[Prediction]:
# HACK: resolve this? make it lazy somehow?
del prediction["version"]
return [self.prepare_model(obj) for obj in predictions]

async def create_async(
self,
version: Version,
input: Dict[str, Any],
webhook_completed: Optional[str] = None,
) -> Prediction:
input = encode_json(input, upload_file=upload_file)
body = {
"version": version.id,
"input": input,
}
if webhook_completed is not None:
body["webhook_completed"] = webhook_completed

resp = await self._client._request_async(
"POST",
"/v1/predictions",
json=body,
)

obj = resp.json()
obj["version"] = version
return self.prepare_model(obj)

async def get_async(self, id: str) -> Prediction:
resp = await self._client._request_async("GET", f"/v1/predictions/{id}")
obj = resp.json()
# HACK: resolve this? make it lazy somehow?
del obj["version"]
return self.prepare_model(obj)

async def list_async(self) -> List[Prediction]:
resp = await self._client._request_async("GET", f"/v1/predictions")
# TODO: paginate
predictions = resp.json()["results"]
for prediction in predictions:
# HACK: resolve this? make it lazy somehow?
del prediction["version"]
return [self.prepare_model(obj) for obj in predictions]
38 changes: 38 additions & 0 deletions replicate/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,25 @@ def predict(self, **kwargs) -> Union[Any, Iterator[Any]]:
raise ModelError(prediction.error)
return prediction.output


async def predict_async(self, **kwargs) -> Union[Any, Iterator[Any]]:
# TODO: support args
prediction = await self._client.predictions.create_async(version=self, input=kwargs)
# Return an iterator of the output
# FIXME: might just be a list, not an iterator. I wonder if we should differentiate?
schema = self.get_transformed_schema()
output = schema["components"]["schemas"]["Output"]
if (
output.get("type") == "array"
and output.get("x-cog-array-type") == "iterator"
):
return prediction.output_iterator_async()

await prediction.wait_async()
if prediction.status == "failed":
raise ModelError(prediction.error)
return prediction.output

def get_transformed_schema(self):
schema = self.openapi_schema
schema = make_schema_backwards_compatible(schema, self.cog_version)
Expand All @@ -44,6 +63,25 @@ def __init__(self, client, model):
super().__init__(client=client)
self._model = model

# doesn't exist yet
async def get_async(self, id: str) -> Version:
"""
Get a specific version.
"""
resp = await self._client._request_async(
"GET", f"/v1/models/{self._model.username}/{self._model.name}/versions/{id}"
)
return self.prepare_model(resp.json())

async def list_async(self) -> List[Version]:
"""
Return a list of all versions for a model.
"""
resp = await self._client._request_async(
"GET", f"/v1/models/{self._model.username}/{self._model.name}/versions"
)
return [self.prepare_model(obj) for obj in resp.json()["results"]]

# doesn't exist yet
def get(self, id: str) -> Version:
"""
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
packaging==21.3
pytest==7.1.2
pytest-asyncio==0.21.0
responses==0.21.0
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@
license="BSD",
url="https://github.com/replicate/replicate-python",
python_requires=">=3.6",
install_requires=["requests", "pydantic", "packaging"],
install_requires=["requests", "pydantic", "packaging", "httpx"],
classifiers=[],
)
20 changes: 20 additions & 0 deletions tests/test_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pytest

import replicate

@pytest.mark.asyncio
async def test_async_client():
model = replicate.models.get("creatorrr/instructor-large")
version = await model.versions.get_async("bd2701dac1aea9d598bda71e6ae56b204287c0a79e2cadf96b1393127d044495")

inputs = {
# Text to embed
'text': "Hello world! How are you doing?",

# Embedding instruction
'instruction': "Represent the following text",
}

output = await version.predict_async(**inputs)

assert output["result"]