Skip to content

Commit

Permalink
✨ Flama client and Lifespan refactor (#107)
Browse files Browse the repository at this point in the history
  • Loading branch information
perdy committed Sep 19, 2023
1 parent 1110d01 commit 4c00cc6
Show file tree
Hide file tree
Showing 40 changed files with 1,433 additions and 768 deletions.
6 changes: 5 additions & 1 deletion flama/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ def __init__(
modules: t.Optional[t.Set["Module"]] = None,
middleware: t.Optional[t.Sequence["Middleware"]] = None,
debug: bool = False,
events: t.Optional[t.Union[t.Dict[str, t.List[t.Callable]], Events]] = None,
events: t.Optional[
t.Union[t.Dict[str, t.List[t.Callable[..., t.Coroutine[t.Any, t.Any, None]]]], Events]
] = None,
lifespan: t.Optional[t.Callable[[t.Optional["Flama"]], t.AsyncContextManager]] = None,
title: str = "Flama",
version: str = "0.1.0",
Expand All @@ -53,6 +55,8 @@ def __init__(
:param schema_library: Schema library to use.
"""
self._debug = debug
self._status = types.AppStatus.NOT_INITIALIZED
self._shutdown = False

# Create Dependency Injector
self._injector = injection.Injector(
Expand Down
200 changes: 200 additions & 0 deletions flama/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
import asyncio
import contextlib
import functools
import importlib.metadata
import logging
import typing as t
from types import TracebackType

import httpx

from flama import types
from flama.applications import Flama

__all__ = ["Client", "AsyncClient", "LifespanContextManager"]

logger = logging.getLogger(__name__)


class LifespanContextManager:
def __init__(self, app: Flama, timeout: float = 60.0):
self.app = app
self.timeout = timeout
self._startup_complete = asyncio.Event()
self._shutdown_complete = asyncio.Event()
self._receive_queue = asyncio.Queue(maxsize=2)
self._exception: t.Optional[BaseException] = None
self._task: t.Optional[asyncio.Task] = None

async def _startup(self) -> None:
await self._receive_queue.put(types.Message({"type": "lifespan.startup"}))
await asyncio.wait_for(self._startup_complete.wait(), timeout=self.timeout)
if self._exception:
raise self._exception

async def _shutdown(self) -> None:
await self._receive_queue.put(types.Message({"type": "lifespan.shutdown"}))
await asyncio.wait_for(self._shutdown_complete.wait(), timeout=self.timeout)

async def _receive(self) -> types.Message:
return await self._receive_queue.get()

async def _send(self, message: types.Message) -> None:
if message["type"] == "lifespan.startup.complete":
self._startup_complete.set()
elif message["type"] == "lifespan.shutdown.complete":
self._shutdown_complete.set()

async def _app_task(self) -> None:
with contextlib.suppress(asyncio.CancelledError):
scope = types.Scope({"type": "lifespan"})

try:
await self.app(scope, self._receive, self._send)
except BaseException as exc:
self._exception = exc
self._startup_complete.set()
self._shutdown_complete.set()

raise

def _run_app(self) -> None:
self._task = asyncio.get_event_loop().create_task(self._app_task())

async def _stop_app(self) -> None:
assert self._task is not None

if not self._task.done():
self._task.cancel()

await self._task

async def __aenter__(self) -> "LifespanContextManager":
self._run_app()

try:
await self._startup()
except BaseException:
await self._stop_app()
raise

return self

async def __aexit__(
self,
exc_type: t.Optional[t.Type[BaseException]] = None,
exc_value: t.Optional[BaseException] = None,
traceback: t.Optional[TracebackType] = None,
):
await self._shutdown()
await self._stop_app()


class _BaseClient:
def __init__(
self,
/,
app: t.Optional[Flama] = None,
models: t.Optional[t.Sequence[t.Tuple[str, str, str]]] = None,
**kwargs,
):
self.models: t.Optional[t.Dict[str, str]] = None

if models:
app = Flama() if not app else app
for (name, url, path) in models:
app.models.add_model(url, path, name)

self.models = {m[0]: m[1] for m in models or {}}

self.lifespan = LifespanContextManager(app) if app else None

kwargs["app"] = app
kwargs.setdefault("base_url", "http://localapp")
kwargs["headers"] = {"user-agent": f"flama/{importlib.metadata.version('flama')}", **kwargs.get("headers", {})}

super().__init__(**kwargs)


class Client(_BaseClient, httpx.Client):
"""A client for interacting with a Flama application either remote or local.
This client can handle a local python object:
>>> client = Client(app=Flama())
Or connect to a remote API:
>>> client = Client(base_url="https://foo.bar")
Or generate a Flama application based on a set of flm model files:
>>> client = Client(models=[("foo", "/foo/", "model_foo.flm"), ("bar", "/bar/", "model_bar.flm")])
For initializing the application it's required to use it as an async context manager:
>>> with Client(app=Flama()) as client:
>>> client.post(...)
"""

def __enter__(self) -> "Client":
super().__enter__()
if self.lifespan:
asyncio.get_event_loop().run_until_complete(self.lifespan.__aenter__())

return self

def __exit__(
self,
exc_type: t.Optional[t.Type[BaseException]] = None,
exc_value: t.Optional[BaseException] = None,
traceback: t.Optional[TracebackType] = None,
):
if self.lifespan:
asyncio.get_event_loop().run_until_complete(self.lifespan.__aexit__(exc_type, exc_value, traceback))
super().__exit__(exc_type, exc_value, traceback)

def model_request(self, model: str, method: str, url: str, **kwargs) -> httpx.Response:
assert self.models, "No models found for request."
return self.request(method, f"{self.models[model].rstrip('/')}{url}", **kwargs)

model_inspect = functools.partialmethod(model_request, method="GET", url="/")
model_predict = functools.partialmethod(model_request, method="POST", url="/predict/")


class AsyncClient(_BaseClient, httpx.AsyncClient):
"""An async client for interacting with a Flama application either remote or local.
This client can handle a local python object:
>>> client = AsyncClient(app=Flama())
Or connect to a remote API:
>>> client = AsyncClient(base_url="https://foo.bar")
Or generate a Flama application based on a set of flm model files:
>>> client = AsyncClient(models=[("foo", "/foo/", "model_foo.flm"), ("bar", "/bar/", "model_bar.flm")])
For initializing the application it's required to use it as an async context manager:
>>> async with AsyncClient(app=Flama()) as client:
>>> client.post(...)
"""

async def __aenter__(self) -> "AsyncClient":
await super().__aenter__()
if self.lifespan:
await self.lifespan.__aenter__()

return self

async def __aexit__(
self,
exc_type: t.Optional[t.Type[BaseException]] = None,
exc_value: t.Optional[BaseException] = None,
traceback: t.Optional[TracebackType] = None,
):
if self.lifespan:
await self.lifespan.__aexit__(exc_type, exc_value, traceback)
await super().__aexit__(exc_type, exc_value, traceback)

async def model_request(self, model: str, method: str, url: str, **kwargs) -> t.Awaitable[httpx.Response]:
assert self.models, "No models found for request."
return self.request(method, f"{self.models[model].rstrip('/')}{url}", **kwargs)

model_inspect = functools.partialmethod(model_request, method="GET", url="/")
model_predict = functools.partialmethod(model_request, method="POST", url="/predict/")
26 changes: 25 additions & 1 deletion flama/concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ async def to_thread(func, /, *args, **kwargs):
t.TypeGuard = TypeGuard # type: ignore
t.ParamSpec = ParamSpec # type: ignore

__all__ = ["is_async", "run", "run"]
__all__ = ["is_async", "run", "run_task_group", "AsyncProcess"]

R = t.TypeVar("R", covariant=True)
P = t.ParamSpec("P") # type: ignore # PORT: Remove this comment when stop supporting 3.9
Expand Down Expand Up @@ -61,6 +61,30 @@ async def run(
return await asyncio.to_thread(func, *args, **kwargs) # type: ignore


if sys.version_info < (3, 11): # PORT: Remove when stop supporting 3.10 # pragma: no cover

async def run_task_group(*tasks: t.Coroutine[t.Any, t.Any, t.Any]) -> t.List[asyncio.Task]:
"""Run a group of tasks.
:param tasks: Tasks to run.
:result: Finished tasks.
"""
tasks_list = [asyncio.create_task(task) for task in tasks]
await asyncio.wait(tasks_list)
return tasks_list

else:

async def run_task_group(*tasks: t.Coroutine[t.Any, t.Any, t.Any]) -> t.List[asyncio.Task]:
"""Run a group of tasks.
:param tasks: Tasks to run.
:result: Finished tasks.
"""
async with asyncio.TaskGroup() as task_group:
return [task_group.create_task(task) for task in tasks]


class AsyncProcess(multiprocessing.Process):
"""Multiprocessing Process class whose target is an async function."""

Expand Down
6 changes: 3 additions & 3 deletions flama/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
class Events:
"""Application events register."""

startup: t.List[t.Callable] = dataclasses.field(default_factory=list)
shutdown: t.List[t.Callable] = dataclasses.field(default_factory=list)
startup: t.List[t.Callable[..., t.Coroutine[t.Any, t.Any, None]]] = dataclasses.field(default_factory=list)
shutdown: t.List[t.Callable[..., t.Coroutine[t.Any, t.Any, None]]] = dataclasses.field(default_factory=list)

def register(self, event: str, handler: t.Callable) -> None:
"""Register a new event.
Expand All @@ -19,7 +19,7 @@ def register(self, event: str, handler: t.Callable) -> None:
getattr(self, event).append(handler)

@classmethod
def build(cls, **events: t.List[t.Callable]) -> "Events":
def build(cls, **events: t.List[t.Callable[..., t.Coroutine[t.Any, t.Any, None]]]) -> "Events":
"""Build events register from dict.
:param events: Events to register.
Expand Down
4 changes: 4 additions & 0 deletions flama/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
]


class ApplicationError(Exception):
...


class DecodeError(Exception):
"""
Raised by a Codec when `decode` fails due to malformed syntax.
Expand Down
Loading

0 comments on commit 4c00cc6

Please sign in to comment.