Skip to content

Commit

Permalink
Merge pull request #56 from ejlangev/feature/allow-non-async-generators
Browse files Browse the repository at this point in the history
Allow Usage without Async Generators
  • Loading branch information
sysid committed May 14, 2023
2 parents 2d90414 + 15f5fb9 commit 97eb623
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 0 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,15 @@ return EventSourceResponse(
See example: `examples/error_handling.py`


### Sending Responses without Async Generators
Async generators can expose tricky error and cleanup behavior especially when they are interrupted.

[Background: Cleanup in async generators](https://vorpus.org/blog/some-thoughts-on-asynchronous-api-design-in-a-post-asyncawait-world/#cleanup-in-generators-and-async-generators).

Example [`no_async_generators.py`](https://github.com/sysid/sse-starlette/pull/56#issue-1704495339) shows an alternative implementation
that does not rely on async generators but instead uses memory channels (`examples/no_async_generators.py`).


## Development, Contributing
1. install pipenv: `pip install pipenv`
2. install dependencies using pipenv: `pipenv install --dev -e .`
Expand Down
76 changes: 76 additions & 0 deletions examples/no_async_generators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import anyio
import logging
from anyio.streams.memory import MemoryObjectSendStream
from functools import partial

import trio
import uvicorn
from fastapi import FastAPI
from sse_starlette.sse import EventSourceResponse
from starlette.requests import Request

_log = logging.getLogger(__name__)
log_fmt = r"%(asctime)-15s %(levelname)s %(name)s %(funcName)s:%(lineno)d %(message)s"
datefmt = "%Y-%m-%d %H:%M:%S"
logging.basicConfig(format=log_fmt, level=logging.DEBUG, datefmt=datefmt)

app = FastAPI()


@app.get("/endless")
async def endless(req: Request):
"""Simulates an endless stream
In case of server shutdown the running task has to be stopped via signal handler in order
to enable proper server shutdown. Otherwise, there will be dangling tasks preventing proper shutdown.
"""
send_chan, recv_chan = anyio.create_memory_object_stream(10)
async def event_publisher(inner_send_chan: MemoryObjectSendStream):
async with inner_send_chan:
try:
i = 0
while True:
i += 1
await inner_send_chan.send(dict(data=i))
await anyio.sleep(1.0)
except anyio.get_cancelled_exc_class() as e:
_log.info(f"Disconnected from client (via refresh/close) {req.client}")
with anyio.move_on_after(1, shield=True):
await inner_send_chan.send(dict(closing=True))
raise e

return EventSourceResponse(recv_chan, data_sender_callable=partial(event_publisher, send_chan))



@app.get("/endless-trio")
async def endless_trio(req: Request):
"""Simulates an endless stream
In case of server shutdown the running task has to be stopped via signal handler in order
to enable proper server shutdown. Otherwise, there will be dangling tasks preventing proper shutdown.
"""
raise Exception("Trio is not compatible with uvicorn, this code is for example purposes")

send_chan, recv_chan = trio.open_memory_channel(10)
async def event_publisher(inner_send_chan: trio.MemorySendChannel):
async with inner_send_chan:
try:
i = 0
while True:
i += 1
await inner_send_chan.send(dict(data=i))
await trio.sleep(1.0)
except trio.Cancelled as e:
_log.info(f"Disconnected from client (via refresh/close) {req.client}")
with anyio.move_on_after(1, shield=True):
# This may not make it
await inner_send_chan.send(dict(closing=True))
raise e

return EventSourceResponse(recv_chan, data_sender_callable=partial(event_publisher, send_chan))



if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8080, log_level="trace", log_config=None) # type: ignore
6 changes: 6 additions & 0 deletions sse_starlette/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def __init__(
ping: Optional[int] = None,
sep: Optional[str] = None,
ping_message_factory: Optional[Callable[[], ServerSentEvent]] = None,
data_sender_callable: Optional[Callable[[], Coroutine[None, None, None]]] = None
) -> None:
self.sep = sep
self.ping_message_factory = ping_message_factory
Expand All @@ -165,6 +166,7 @@ def __init__(
self.status_code = status_code
self.media_type = self.media_type if media_type is None else media_type
self.background = background # type: ignore # follows https://github.com/encode/starlette/blob/master/starlette/responses.py
self.data_sender_callable = data_sender_callable

_headers = {}
if headers is not None: # pragma: no cover
Expand Down Expand Up @@ -239,6 +241,10 @@ async def wrap(func: Callable[[], Coroutine[None, None, None]]) -> None:
task_group.start_soon(wrap, partial(self.stream_response, safe_send))
task_group.start_soon(wrap, partial(self._ping, safe_send))
task_group.start_soon(wrap, self.listen_for_exit_signal)

if self.data_sender_callable:
task_group.start_soon(self.data_sender_callable)

await wrap(partial(self.listen_for_disconnect, receive))

if self.background is not None: # pragma: no cover, tested in StreamResponse
Expand Down
35 changes: 35 additions & 0 deletions tests/test_event_source_response.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import logging
import math
from functools import partial

import anyio
import anyio.lowlevel
Expand Down Expand Up @@ -62,6 +64,39 @@ async def numbers(minimum, maximum):
print(response.content)


@pytest.mark.parametrize(
"input,expected",
[
("integer", b"data: 1\r\n\r\n"),
("dict1", b"data: 1\r\n\r\n"),
("dict2", b"event: message\r\ndata: 1\r\n\r\n"),
],
)
def test_sync_memory_channel_event_source_response(input, expected):
async def app(scope, receive, send):
send_chan, recv_chan = anyio.create_memory_object_stream(math.inf)
async def numbers(inner_send_chan, minimum, maximum):
async with send_chan:
for i in range(minimum, maximum + 1):
await anyio.sleep(0.1)

if input == "integer":
await inner_send_chan.send(i)
elif input == "dict1":
await inner_send_chan.send(dict(data=i))
elif input == "dict2":
await inner_send_chan.send(dict(data=i, event="message"))

response = EventSourceResponse(recv_chan, data_sender_callable=partial(numbers, send_chan, 1, 5), ping=0.2) # type: ignore
await response(scope, receive, send)

client = TestClient(app)
response = client.get("/")
assert response.content.decode().count("ping") == 2
assert expected in response.content
print(response.content)


@pytest.mark.anyio
async def test_endless():
async def app(scope, receive, send):
Expand Down

0 comments on commit 97eb623

Please sign in to comment.