Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests for ASGI specification compliance #50

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
16 changes: 16 additions & 0 deletions async_asgi_testclient/exceptions.py
@@ -0,0 +1,16 @@
"""Exceptions for TestClient

Base Exception class and sub-classed exceptions to make it easy
(and in some cases, possible at all) to handle errors in different
ways.
"""
from async_asgi_testclient.utils import Message
from typing import Optional


class TestClientError(Exception):
"""An error in async_asgi_testclient"""

def __init__(self, *args, message: Optional[Message] = None):
super().__init__(*args)
self.message = message
47 changes: 32 additions & 15 deletions async_asgi_testclient/testing.py
Expand Up @@ -23,6 +23,7 @@
OTHER DEALINGS IN THE SOFTWARE.
"""
from async_asgi_testclient.compatibility import guarantee_single_callable
from async_asgi_testclient.exceptions import TestClientError
from async_asgi_testclient.multipart import encode_multipart_formdata
from async_asgi_testclient.response import BytesRW
from async_asgi_testclient.response import Response
Expand Down Expand Up @@ -78,33 +79,46 @@ def __init__(
self._lifespan_task = None # Must keep hard reference to prevent gc

async def __aenter__(self):
self._lifespan_task = create_monitored_task(
self.application(
{"type": "lifespan", "asgi": {"version": "3.0"}},
self._lifespan_input_queue.get,
self._lifespan_output_queue.put,
),
self._lifespan_output_queue.put_nowait,
)
try:
self._lifespan_task = create_monitored_task(
self.application(
{"type": "lifespan", "asgi": {"version": "3.0"}},
self._lifespan_input_queue.get,
self._lifespan_output_queue.put,
),
self._lifespan_output_queue.put_nowait,
)
# Make sure there is time for the output queue to be processed
await self.send_lifespan("startup")
except TestClientError:
# Pass these through directly, so that test clients can assert on them
raise
except: # noqa
# Any other exception is (almost) definitely passed through from the app under test
# So it means the lifespan protocol is not supported.
self._lifespan_task = None
LucidDan marked this conversation as resolved.
Show resolved Hide resolved

await self.send_lifespan("startup")
return self

async def __aexit__(self, exc_type, exc, tb):
await self.send_lifespan("shutdown")
self._lifespan_task = None
# If task is None, lifespan protocol is disabled (not supported by app)
if self._lifespan_task is not None:
await self.send_lifespan("shutdown")
self._lifespan_task = None

async def send_lifespan(self, action):
await self._lifespan_input_queue.put({"type": f"lifespan.{action}"})
message = await receive(self._lifespan_output_queue, timeout=self.timeout)

if isinstance(message, Message):
raise Exception(f"{message.event} - {message.reason} - {message.task}")
raise TestClientError(
f"{message.event} - {message.reason} - {message.task}", message=message
)

if message["type"] == f"lifespan.{action}.complete":
pass
elif message["type"] == f"lifespan.{action}.failed":
raise Exception(message)
raise TestClientError(message, message=message)

def websocket_connect(self, path, headers=None, cookies=None):
return WebSocketSession(self, path, headers, cookies)
Expand Down Expand Up @@ -260,12 +274,15 @@ async def open(
message = await self.wait_response(receive_or_fail, "http.response.start")
response.status_code = message["status"]
response.headers = CIMultiDict(
[(k.decode("utf8"), v.decode("utf8")) for k, v in message["headers"]]
[
(k.decode("utf8"), v.decode("utf8"))
for k, v in message.get("headers", [])
]
)

# Receive initial response body
message = await self.wait_response(receive_or_fail, "http.response.body")
response.raw.write(message["body"])
response.raw.write(message.get("body", b""))
response._more_body = message.get("more_body", False)

# Consume the remaining response if not in stream
Expand Down
175 changes: 175 additions & 0 deletions async_asgi_testclient/tests/asgi_spec/conftest.py
@@ -0,0 +1,175 @@
"""Test setup for ASGI spec tests

Mock application used for testing ASGI standard compliance.
"""
from enum import Enum
from functools import partial
from sys import version_info as PY_VER # noqa

import pytest


class AppState(Enum):
PREINIT = 0
INIT = 1
READY = 2
SHUTDOWN = 3


class BaseMockApp(object):
"""A mock application object passed to TestClient for the tests"""

# Make it easy to override these for lifespan related test scenarios
lifespan_startup_message = {"type": "lifespan.startup.complete", "message": "OK"}
lifespan_shutdown_message = {"type": "lifespan.shutdown.complete", "message": "OK"}
use_lifespan = True

def __init__(self, **kwargs):
for k, v in kwargs:
setattr(self, k, v)
self.state = AppState.PREINIT

async def lifespan_startup(self, scope, receive, send, msg):
if self.state == AppState.READY:
# Technically, this isn't explicitly forbidden in the spec.
# But I think it should not happen.
raise RuntimeError("Received more than one lifespan.startup")
self.state = AppState.READY
return await send(self.lifespan_startup_message)

async def lifespan_shutdown(self, scope, receive, send, msg):
if self.state == AppState.SHUTDOWN:
# Technically, this isn't explicitly forbidden in the spec.
# But I think it should not happen.
raise RuntimeError("Received more than one lifespan.shutdown")
self.state = AppState.SHUTDOWN
return await send(self.lifespan_shutdown_message)

async def lifespan(self, scope, receive, send):
if not self.use_lifespan:
raise RuntimeError(f"Type '{scope['type']}' is not supported.")
while True:
try:
msg = await receive()
except RuntimeError as e:
if e.args == ("Event loop is closed",):
return
else:
raise

if msg["type"] == "lifespan.startup":
await self.lifespan_startup(scope, receive, send, msg)
elif msg["type"] == "lifespan.shutdown":
await self.lifespan_shutdown(scope, receive, send, msg)
else:
raise RuntimeError(f"Received unknown message type '{msg['type']}")
if self.state == AppState.SHUTDOWN:
return

async def http_request(self, scope, receive, send, msg):
# Default behaviour, just send a minimal response with OK to any request
await send({"type": "http.response.start", "headers": [], "status": 200})
await send({"type": "http.response.body", "body": b"OK"})

async def http_disconnect(self, scope, receive, send, msg):
raise RuntimeError(f"Received http.disconnect message {msg}")

async def http(self, scope, receive, send):
msg = []
# Receive http.requests until http.disconnect or more_body = False
while True:
msg.append(await receive())
if msg[-1]["type"] == "http.disconnect" or not msg[-1].get(
"more_body", False
):
Comment on lines +82 to +84
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check that if have not received a http.disconnect, it is a http.request?

Like:

Suggested change
if msg[-1]["type"] == "http.disconnect" or not msg[-1].get(
"more_body", False
):
if msg[-1]["type"] == "http.disconnect" or (msg[-1] == "http.request" and not msg[-1].get(
"more_body", False
)):

So we can catch unexpected events sent by asgi-testclient.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this makes sense. That whole function could probably be rewritten to be more readable too, now that I look at it after a break. ;-)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great!

break
if msg[0]["type"] == "http.disconnect":
# Honestly this shouldn't really happen, but it's allowed in spec, so check.
return await self.http_disconnect(scope, receive, send, msg)
else:
return await self.http_request(scope, receive, send, msg)

async def websocket_connect(self, scope, receive, send, msg, msg_history):
await send({"type": "websocket.accept"})
return True

async def websocket_receive(self, scope, receive, send, msg, msg_history):
return True

async def websocket_disconnect(self, scope, receive, send, msg, msg_history):
return False

async def websocket(self, scope, receive, send):
msg_history = []
while True:
msg = await receive()

# Send websocket events to a handler
func = getattr(
self, msg["type"].replace(".", "_").replace("-", "__"), "handle_unknown"
)
res = await func(scope, receive, send, msg, msg_history)
msg_history.append(msg)

# If the event handler returns false, assume we closed the socket.
if msg["type"] == "websocket.disconnect" or not res:
return

async def handle_unknown(self, scope, receive, send):
if self.state != AppState.READY:
raise RuntimeError(
"Received another request before lifespan.startup.complete sent"
)
raise RuntimeError(f"Type '{scope['type']}' is not supported.")

async def handle_all(self, scope, receive, send):
# Do nothing unless something monkeypatches us
pass

async def asgi_call(self, scope, receive, send):
# Initial catch-all, for testing things like scope type itself
await self.handle_all(scope, receive, send)

if self.state == AppState.PREINIT:
if self.use_lifespan:
self.state = AppState.INIT
else:
self.state = AppState.READY
if self.state == AppState.SHUTDOWN:
raise RuntimeError(f"Got message after shutting down: {scope}")

# call hooks based on scope type, so we can monkeypatch them in tests
# the lifespan, http, and websocket protocol types all have simple methods already
# implemented.
func = getattr(
self, scope["type"].replace(".", "_").replace("-", "__"), "handle_unknown"
)
return await func(scope, receive, send)


class MockApp(BaseMockApp):
"""Modern ASGI single-callable app"""

async def __call__(self, scope, receive, send):
return await super().asgi_call(scope, receive, send)


class LegacyMockApp(BaseMockApp):
"""Legacy ASGI 'two-callable' app"""

def __call__(self, scope):
return partial(super().asgi_call, scope)


@pytest.fixture(scope="function")
def mock_app():
"""Create a mock ASGI App to test the TestClient against"""

return MockApp()


@pytest.fixture(scope="function")
def legacy_mock_app():
"""Create a mock legacy ASGI App to test the TestClient against"""

return LegacyMockApp()