-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
✨ Flama client and Lifespan refactor (#107)
- Loading branch information
Showing
40 changed files
with
1,433 additions
and
768 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
import asyncio | ||
import contextlib | ||
import functools | ||
import importlib.metadata | ||
import logging | ||
import typing as t | ||
from types import TracebackType | ||
|
||
import httpx | ||
|
||
from flama import types | ||
from flama.applications import Flama | ||
|
||
__all__ = ["Client", "AsyncClient", "LifespanContextManager"] | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class LifespanContextManager: | ||
def __init__(self, app: Flama, timeout: float = 60.0): | ||
self.app = app | ||
self.timeout = timeout | ||
self._startup_complete = asyncio.Event() | ||
self._shutdown_complete = asyncio.Event() | ||
self._receive_queue = asyncio.Queue(maxsize=2) | ||
self._exception: t.Optional[BaseException] = None | ||
self._task: t.Optional[asyncio.Task] = None | ||
|
||
async def _startup(self) -> None: | ||
await self._receive_queue.put(types.Message({"type": "lifespan.startup"})) | ||
await asyncio.wait_for(self._startup_complete.wait(), timeout=self.timeout) | ||
if self._exception: | ||
raise self._exception | ||
|
||
async def _shutdown(self) -> None: | ||
await self._receive_queue.put(types.Message({"type": "lifespan.shutdown"})) | ||
await asyncio.wait_for(self._shutdown_complete.wait(), timeout=self.timeout) | ||
|
||
async def _receive(self) -> types.Message: | ||
return await self._receive_queue.get() | ||
|
||
async def _send(self, message: types.Message) -> None: | ||
if message["type"] == "lifespan.startup.complete": | ||
self._startup_complete.set() | ||
elif message["type"] == "lifespan.shutdown.complete": | ||
self._shutdown_complete.set() | ||
|
||
async def _app_task(self) -> None: | ||
with contextlib.suppress(asyncio.CancelledError): | ||
scope = types.Scope({"type": "lifespan"}) | ||
|
||
try: | ||
await self.app(scope, self._receive, self._send) | ||
except BaseException as exc: | ||
self._exception = exc | ||
self._startup_complete.set() | ||
self._shutdown_complete.set() | ||
|
||
raise | ||
|
||
def _run_app(self) -> None: | ||
self._task = asyncio.get_event_loop().create_task(self._app_task()) | ||
|
||
async def _stop_app(self) -> None: | ||
assert self._task is not None | ||
|
||
if not self._task.done(): | ||
self._task.cancel() | ||
|
||
await self._task | ||
|
||
async def __aenter__(self) -> "LifespanContextManager": | ||
self._run_app() | ||
|
||
try: | ||
await self._startup() | ||
except BaseException: | ||
await self._stop_app() | ||
raise | ||
|
||
return self | ||
|
||
async def __aexit__( | ||
self, | ||
exc_type: t.Optional[t.Type[BaseException]] = None, | ||
exc_value: t.Optional[BaseException] = None, | ||
traceback: t.Optional[TracebackType] = None, | ||
): | ||
await self._shutdown() | ||
await self._stop_app() | ||
|
||
|
||
class _BaseClient: | ||
def __init__( | ||
self, | ||
/, | ||
app: t.Optional[Flama] = None, | ||
models: t.Optional[t.Sequence[t.Tuple[str, str, str]]] = None, | ||
**kwargs, | ||
): | ||
self.models: t.Optional[t.Dict[str, str]] = None | ||
|
||
if models: | ||
app = Flama() if not app else app | ||
for (name, url, path) in models: | ||
app.models.add_model(url, path, name) | ||
|
||
self.models = {m[0]: m[1] for m in models or {}} | ||
|
||
self.lifespan = LifespanContextManager(app) if app else None | ||
|
||
kwargs["app"] = app | ||
kwargs.setdefault("base_url", "http://localapp") | ||
kwargs["headers"] = {"user-agent": f"flama/{importlib.metadata.version('flama')}", **kwargs.get("headers", {})} | ||
|
||
super().__init__(**kwargs) | ||
|
||
|
||
class Client(_BaseClient, httpx.Client): | ||
"""A client for interacting with a Flama application either remote or local. | ||
This client can handle a local python object: | ||
>>> client = Client(app=Flama()) | ||
Or connect to a remote API: | ||
>>> client = Client(base_url="https://foo.bar") | ||
Or generate a Flama application based on a set of flm model files: | ||
>>> client = Client(models=[("foo", "/foo/", "model_foo.flm"), ("bar", "/bar/", "model_bar.flm")]) | ||
For initializing the application it's required to use it as an async context manager: | ||
>>> with Client(app=Flama()) as client: | ||
>>> client.post(...) | ||
""" | ||
|
||
def __enter__(self) -> "Client": | ||
super().__enter__() | ||
if self.lifespan: | ||
asyncio.get_event_loop().run_until_complete(self.lifespan.__aenter__()) | ||
|
||
return self | ||
|
||
def __exit__( | ||
self, | ||
exc_type: t.Optional[t.Type[BaseException]] = None, | ||
exc_value: t.Optional[BaseException] = None, | ||
traceback: t.Optional[TracebackType] = None, | ||
): | ||
if self.lifespan: | ||
asyncio.get_event_loop().run_until_complete(self.lifespan.__aexit__(exc_type, exc_value, traceback)) | ||
super().__exit__(exc_type, exc_value, traceback) | ||
|
||
def model_request(self, model: str, method: str, url: str, **kwargs) -> httpx.Response: | ||
assert self.models, "No models found for request." | ||
return self.request(method, f"{self.models[model].rstrip('/')}{url}", **kwargs) | ||
|
||
model_inspect = functools.partialmethod(model_request, method="GET", url="/") | ||
model_predict = functools.partialmethod(model_request, method="POST", url="/predict/") | ||
|
||
|
||
class AsyncClient(_BaseClient, httpx.AsyncClient): | ||
"""An async client for interacting with a Flama application either remote or local. | ||
This client can handle a local python object: | ||
>>> client = AsyncClient(app=Flama()) | ||
Or connect to a remote API: | ||
>>> client = AsyncClient(base_url="https://foo.bar") | ||
Or generate a Flama application based on a set of flm model files: | ||
>>> client = AsyncClient(models=[("foo", "/foo/", "model_foo.flm"), ("bar", "/bar/", "model_bar.flm")]) | ||
For initializing the application it's required to use it as an async context manager: | ||
>>> async with AsyncClient(app=Flama()) as client: | ||
>>> client.post(...) | ||
""" | ||
|
||
async def __aenter__(self) -> "AsyncClient": | ||
await super().__aenter__() | ||
if self.lifespan: | ||
await self.lifespan.__aenter__() | ||
|
||
return self | ||
|
||
async def __aexit__( | ||
self, | ||
exc_type: t.Optional[t.Type[BaseException]] = None, | ||
exc_value: t.Optional[BaseException] = None, | ||
traceback: t.Optional[TracebackType] = None, | ||
): | ||
if self.lifespan: | ||
await self.lifespan.__aexit__(exc_type, exc_value, traceback) | ||
await super().__aexit__(exc_type, exc_value, traceback) | ||
|
||
async def model_request(self, model: str, method: str, url: str, **kwargs) -> t.Awaitable[httpx.Response]: | ||
assert self.models, "No models found for request." | ||
return self.request(method, f"{self.models[model].rstrip('/')}{url}", **kwargs) | ||
|
||
model_inspect = functools.partialmethod(model_request, method="GET", url="/") | ||
model_predict = functools.partialmethod(model_request, method="POST", url="/predict/") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.