Skip to content

Commit

Permalink
Separate ASGI request and lifespan callables (#2646)
Browse files Browse the repository at this point in the history
  • Loading branch information
ahopkins committed Mar 17, 2023
1 parent 08a81c8 commit 5ee36fd
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 120 deletions.
15 changes: 10 additions & 5 deletions sanic/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@

from sanic.application.ext import setup_ext
from sanic.application.state import ApplicationState, ServerStage
from sanic.asgi import ASGIApp
from sanic.asgi import ASGIApp, Lifespan
from sanic.base.root import BaseSanic
from sanic.blueprint_group import BlueprintGroup
from sanic.blueprints import Blueprint
Expand Down Expand Up @@ -119,6 +119,7 @@ class Sanic(StaticHandleMixin, BaseSanic, StartupMixin, metaclass=TouchUpMeta):
)
__slots__ = (
"_asgi_app",
"_asgi_lifespan",
"_asgi_client",
"_blueprint_order",
"_delayed_tasks",
Expand Down Expand Up @@ -198,6 +199,8 @@ def __init__(
self.config.INSPECTOR = inspector

# Then we can do the rest
self._asgi_app: Optional[ASGIApp] = None
self._asgi_lifespan: Optional[Lifespan] = None
self._asgi_client: Any = None
self._blueprint_order: List[Blueprint] = []
self._delayed_tasks: List[str] = []
Expand Down Expand Up @@ -1349,12 +1352,14 @@ async def __call__(self, scope, receive, send):
three arguments: scope, receive, send. See the ASGI reference for more
details: https://asgi.readthedocs.io/en/latest
"""
self.asgi = True
if scope["type"] == "lifespan":
self.asgi = True
self.motd("")
self._asgi_app = await ASGIApp.create(self, scope, receive, send)
asgi_app = self._asgi_app
await asgi_app()
self._asgi_lifespan = Lifespan(self, scope, receive, send)
await self._asgi_lifespan()
else:
self._asgi_app = await ASGIApp.create(self, scope, receive, send)
await self._asgi_app()

_asgi_single_callable = True # We conform to ASGI 3.0 single-callable

Expand Down
179 changes: 87 additions & 92 deletions sanic/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,23 @@


class Lifespan:
def __init__(self, asgi_app: ASGIApp) -> None:
self.asgi_app = asgi_app
def __init__(
self, sanic_app, scope: ASGIScope, receive: ASGIReceive, send: ASGISend
) -> None:
self.sanic_app = sanic_app
self.scope = scope
self.receive = receive
self.send = send

if (
"server.init.before"
in self.asgi_app.sanic_app.signal_router.name_index
):
if "server.init.before" in self.sanic_app.signal_router.name_index:
logger.debug(
'You have set a listener for "before_server_start" '
"in ASGI mode. "
"It will be executed as early as possible, but not before "
"the ASGI server is started.",
extra={"verbosity": 1},
)
if (
"server.shutdown.after"
in self.asgi_app.sanic_app.signal_router.name_index
):
if "server.shutdown.after" in self.sanic_app.signal_router.name_index:
logger.debug(
'You have set a listener for "after_server_stop" '
"in ASGI mode. "
Expand All @@ -57,11 +56,11 @@ async def startup(self) -> None:
in sequence since the ASGI lifespan protocol only supports a single
startup event.
"""
await self.asgi_app.sanic_app._startup()
await self.asgi_app.sanic_app._server_event("init", "before")
await self.asgi_app.sanic_app._server_event("init", "after")
await self.sanic_app._startup()
await self.sanic_app._server_event("init", "before")
await self.sanic_app._server_event("init", "after")

if not isinstance(self.asgi_app.sanic_app.config.USE_UVLOOP, Default):
if not isinstance(self.sanic_app.config.USE_UVLOOP, Default):
warnings.warn(
"You have set the USE_UVLOOP configuration option, but Sanic "
"cannot control the event loop when running in ASGI mode."
Expand All @@ -77,35 +76,33 @@ async def shutdown(self) -> None:
in sequence since the ASGI lifespan protocol only supports a single
shutdown event.
"""
await self.asgi_app.sanic_app._server_event("shutdown", "before")
await self.asgi_app.sanic_app._server_event("shutdown", "after")
await self.sanic_app._server_event("shutdown", "before")
await self.sanic_app._server_event("shutdown", "after")

async def __call__(
self, scope: ASGIScope, receive: ASGIReceive, send: ASGISend
) -> None:
message = await receive()
if message["type"] == "lifespan.startup":
try:
await self.startup()
except Exception as e:
error_logger.exception(e)
await send(
{"type": "lifespan.startup.failed", "message": str(e)}
)
else:
await send({"type": "lifespan.startup.complete"})

message = await receive()
if message["type"] == "lifespan.shutdown":
try:
await self.shutdown()
except Exception as e:
error_logger.exception(e)
await send(
{"type": "lifespan.shutdown.failed", "message": str(e)}
)
else:
await send({"type": "lifespan.shutdown.complete"})
async def __call__(self) -> None:
while True:
message = await self.receive()
if message["type"] == "lifespan.startup":
try:
await self.startup()
except Exception as e:
error_logger.exception(e)
await self.send(
{"type": "lifespan.startup.failed", "message": str(e)}
)
else:
await self.send({"type": "lifespan.startup.complete"})
elif message["type"] == "lifespan.shutdown":
try:
await self.shutdown()
except Exception as e:
error_logger.exception(e)
await self.send(
{"type": "lifespan.shutdown.failed", "message": str(e)}
)
else:
await self.send({"type": "lifespan.shutdown.complete"})
return


class ASGIApp:
Expand All @@ -117,19 +114,22 @@ class ASGIApp:
stage: Stage
response: Optional[BaseHTTPResponse]

def __init__(self) -> None:
self.ws = None

@classmethod
async def create(
cls, sanic_app, scope: ASGIScope, receive: ASGIReceive, send: ASGISend
) -> "ASGIApp":
cls,
sanic_app: Sanic,
scope: ASGIScope,
receive: ASGIReceive,
send: ASGISend,
) -> ASGIApp:
instance = cls()
instance.ws = None
instance.sanic_app = sanic_app
instance.transport = MockTransport(scope, receive, send)
instance.transport.loop = sanic_app.loop
instance.stage = Stage.IDLE
instance.response = None
instance.sanic_app.state.is_started = True
setattr(instance.transport, "add_task", sanic_app.loop.create_task)

headers = Header(
Expand All @@ -138,52 +138,47 @@ async def create(
for key, value in scope.get("headers", [])
]
)
instance.lifespan = Lifespan(instance)

if scope["type"] == "lifespan":
await instance.lifespan(scope, receive, send)
else:
path = (
scope["path"][1:]
if scope["path"].startswith("/")
else scope["path"]
)
url = "/".join([scope.get("root_path", ""), quote(path)])
url_bytes = url.encode("latin-1")
url_bytes += b"?" + scope["query_string"]

if scope["type"] == "http":
version = scope["http_version"]
method = scope["method"]
elif scope["type"] == "websocket":
version = "1.1"
method = "GET"

instance.ws = instance.transport.create_websocket_connection(
send, receive
)
else:
raise ServerError("Received unknown ASGI scope")

request_class = sanic_app.request_class or Request
instance.request = request_class(
url_bytes,
headers,
version,
method,
instance.transport,
sanic_app,
)
instance.request.stream = instance
instance.request_body = True
instance.request.conn_info = ConnInfo(instance.transport)

await sanic_app.dispatch(
"http.lifecycle.request",
inline=True,
context={"request": instance.request},
fail_not_found=False,
path = (
scope["path"][1:]
if scope["path"].startswith("/")
else scope["path"]
)
url = "/".join([scope.get("root_path", ""), quote(path)])
url_bytes = url.encode("latin-1")
url_bytes += b"?" + scope["query_string"]

if scope["type"] == "http":
version = scope["http_version"]
method = scope["method"]
elif scope["type"] == "websocket":
version = "1.1"
method = "GET"

instance.ws = instance.transport.create_websocket_connection(
send, receive
)
else:
raise ServerError("Received unknown ASGI scope")

request_class = sanic_app.request_class or Request
instance.request = request_class(
url_bytes,
headers,
version,
method,
instance.transport,
sanic_app,
)
instance.request.stream = instance # type: ignore
instance.request_body = True
instance.request.conn_info = ConnInfo(instance.transport)

await instance.sanic_app.dispatch(
"http.lifecycle.request",
inline=True,
context={"request": instance.request},
fail_not_found=False,
)

return instance

Expand Down
61 changes: 38 additions & 23 deletions tests/test_asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
import logging

from collections import deque, namedtuple
from unittest.mock import call

import pytest
import uvicorn

from sanic import Sanic
from sanic.application.state import Mode
from sanic.asgi import ASGIApp, MockTransport
from sanic.asgi import ASGIApp, Lifespan, MockTransport
from sanic.exceptions import BadRequest, Forbidden, ServiceUnavailable
from sanic.request import Request
from sanic.response import json, text
Expand Down Expand Up @@ -116,10 +117,6 @@ def install_signal_handlers(self):
stop_message,
) not in caplog.record_tuples

all_tasks = asyncio.all_tasks(asyncio.get_event_loop())
for task in all_tasks:
task.cancel()

assert before_server_start
assert after_server_start
assert before_server_stop
Expand Down Expand Up @@ -218,10 +215,6 @@ def install_signal_handlers(self):
stop_message,
) not in caplog.record_tuples

all_tasks = asyncio.all_tasks(asyncio.get_event_loop())
for task in all_tasks:
task.cancel()

assert before_server_start
assert after_server_start
assert before_server_stop
Expand Down Expand Up @@ -272,10 +265,6 @@ def install_signal_handlers(self):
with pytest.warns(UserWarning) as records:
server.run()

all_tasks = asyncio.all_tasks(asyncio.get_event_loop())
for task in all_tasks:
task.cancel()

msg = ""
for record in records:
_msg = str(record.message)
Expand Down Expand Up @@ -583,15 +572,28 @@ async def test_error_on_lifespan_exception_start(app, caplog):
async def before_server_start(_):
1 / 0

recv = AsyncMock(return_value={"type": "lifespan.startup"})
recv = AsyncMock(
side_effect=[
{"type": "lifespan.startup"},
{"type": "lifespan.shutdown"},
]
)
send = AsyncMock()
app.asgi = True

lifespan = Lifespan(app, {"type": "lifespan"}, recv, send)
with caplog.at_level(logging.ERROR):
await ASGIApp.create(app, {"type": "lifespan"}, recv, send)

send.assert_awaited_once_with(
{"type": "lifespan.startup.failed", "message": "division by zero"}
await lifespan()

send.assert_has_calls(
[
call(
{
"type": "lifespan.startup.failed",
"message": "division by zero",
}
)
]
)


Expand All @@ -601,13 +603,26 @@ async def test_error_on_lifespan_exception_stop(app: Sanic):
async def before_server_stop(_):
1 / 0

recv = AsyncMock(return_value={"type": "lifespan.shutdown"})
recv = AsyncMock(
side_effect=[
{"type": "lifespan.startup"},
{"type": "lifespan.shutdown"},
]
)
send = AsyncMock()
app.asgi = True
await app._startup()

await ASGIApp.create(app, {"type": "lifespan"}, recv, send)

send.assert_awaited_once_with(
{"type": "lifespan.shutdown.failed", "message": "division by zero"}
lifespan = Lifespan(app, {"type": "lifespan"}, recv, send)
await lifespan()

send.assert_has_calls(
[
call(
{
"type": "lifespan.shutdown.failed",
"message": "division by zero",
}
)
]
)

0 comments on commit 5ee36fd

Please sign in to comment.