diff --git a/sanic_testing/testing.py b/sanic_testing/testing.py index 0e4a410..b795a87 100644 --- a/sanic_testing/testing.py +++ b/sanic_testing/testing.py @@ -1,6 +1,7 @@ import asyncio import typing from asyncio.events import get_event_loop +from functools import partial from json import JSONDecodeError from socket import socket from types import SimpleNamespace @@ -31,14 +32,16 @@ def __init__( self.app = app self.port = port self.host = host + app.listener("after_server_start")(self._start_test_mode) + app.listener("before_server_stop")(self._end_test_mode) - @app.listener("after_server_start") - def _start_test_mode(sanic, *args, **kwargs): - sanic.test_mode = True + @classmethod + def _start_test_mode(cls, sanic, *args, **kwargs): + sanic.test_mode = True - @app.listener("before_server_end") - def _end_test_mode(sanic, *args, **kwargs): - sanic.test_mode = False + @classmethod + def _end_test_mode(cls, sanic, *args, **kwargs): + sanic.test_mode = False def get_new_session(self, **kwargs) -> httpx.AsyncClient: return httpx.AsyncClient(verify=False, **kwargs) @@ -93,6 +96,38 @@ async def _local_request(self, method: str, url: str, *args, **kwargs): return response + @classmethod + def _collect_request(cls, results, request): + if results[0] is None: + results[0] = request + + async def _collect_response( + self, + method, + url, + exceptions, + results, + sanic, + loop, + **request_kwargs, + ): + try: + response = await self._local_request(method, url, **request_kwargs) + results[-1] = response + if method == "websocket": + await response.ws.close() + except Exception as e: + logger.exception("Exception") + exceptions.append(e) + finally: + self.app.stop() + + async def _error_handler(self, request, exception): + if request.method in ["HEAD", "PATCH", "PUT", "DELETE"]: + return text("", exception.status_code, headers=exception.headers) + else: + return self.app.error_handler.default(request, exception) + def _sanic_endpoint_test( self, method: str = "get", @@ -103,28 +138,17 @@ def _sanic_endpoint_test( host: str = None, *request_args, **request_kwargs, - ) -> typing.Tuple[typing.Union[Request, HTTPResponse]]: + ) -> typing.Union[typing.Tuple[Request, HTTPResponse], HTTPResponse]: results = [None, None] exceptions = [] server_kwargs = server_kwargs or {"auto_reload": False} + _collect_request = partial(self._collect_request, results) if gather_request: - - def _collect_request(request): - if results[0] is None: - results[0] = request - self.app.request_middleware.appendleft(_collect_request) - @self.app.exception(MethodNotSupported) - async def error_handler(request, exception): - if request.method in ["HEAD", "PATCH", "PUT", "DELETE"]: - return text( - "", exception.status_code, headers=exception.headers - ) - else: - return self.app.error_handler.default(request, exception) + self.app.exception(MethodNotSupported)(self._error_handler) if self.port: server_kwargs = dict( @@ -140,9 +164,7 @@ async def error_handler(request, exception): host, port = sock.getsockname() self.port = port - if uri.startswith( - ("http:", "https:", "ftp:", "ftps://", "//", "ws:", "wss:") - ): + if uri.startswith(("http:", "https:", "ftp:", "ftps://", "//", "ws:", "wss:")): url = uri else: uri = uri if uri.startswith("/") else f"/{uri}" @@ -152,20 +174,16 @@ async def error_handler(request, exception): # known until this function is called, so fix that here url = url.replace(":None/", f":{port}/") - @self.app.listener("after_server_start") - async def _collect_response(sanic, loop): - try: - response = await self._local_request( - method, url, *request_args, **request_kwargs - ) - results[-1] = response - if method == "websocket": - await response.ws.close() - except Exception as e: - logger.exception("Exception") - exceptions.append(e) - finally: - self.app.stop() + self.app.listener("after_server_start")( + partial( + self._collect_response, + method, + url, + exceptions, + results, + **request_kwargs, + ) + ) self.app.run(debug=debug, **server_kwargs) self.app.listeners["after_server_start"].pop() @@ -174,6 +192,11 @@ async def _collect_response(sanic, loop): raise ValueError(f"Exception during request: {exceptions}") if gather_request: + try: + self.app.request_middleware.remove(_collect_request) + except BaseException: # noqa + pass + try: request, response = results return request, response @@ -233,6 +256,7 @@ def __init__( base_url: str = ASGI_BASE_URL, suppress_exceptions: bool = False, ) -> None: + app.__class__.__call__ = app_call_with_return app.asgi = True @@ -242,20 +266,25 @@ def __init__( super().__init__(transport=transport, base_url=base_url) + self.gather_request = True self.last_request = None - def _collect_request(request): - self.last_request = request + app.listener("after_server_start")(self._start_test_mode) + app.listener("before_server_stop")(self._end_test_mode) - @app.listener("after_server_start") - def _start_test_mode(sanic, *args, **kwargs): - sanic.test_mode = True + def _collect_request(self, request): + if self.gather_request: + self.last_request = request + else: + self.last_request = None - @app.listener("before_server_end") - def _end_test_mode(sanic, *args, **kwargs): - sanic.test_mode = False + @classmethod + def _start_test_mode(cls, sanic, *args, **kwargs): + sanic.test_mode = True - app.request_middleware.appendleft(_collect_request) + @classmethod + def _end_test_mode(cls, sanic, *args, **kwargs): + sanic.test_mode = False async def request(self, method, url, gather_request=True, *args, **kwargs): @@ -266,13 +295,25 @@ async def request(self, method, url, gather_request=True, *args, **kwargs): scheme = "ws" if method == "websocket" else "http" url = f"{scheme}://{ASGI_HOST}:{ASGI_PORT}{url}" + if self._collect_request not in self.sanic_app.request_middleware: + self.sanic_app.request_middleware.appendleft(self._collect_request) + self.gather_request = gather_request response = await super().request(method, url, *args, **kwargs) response.status = response.status_code response.body = response.content response.content_type = response.headers.get("content-type") + if gather_request: + return self.last_request, response + return response - return self.last_request, response + @classmethod + async def _ws_receive(cls): + return {} + + @classmethod + async def _ws_send(cls, message): + pass async def websocket(self, uri, subprotocols=None, *args, **kwargs): scheme = "ws" @@ -284,9 +325,7 @@ async def websocket(self, uri, subprotocols=None, *args, **kwargs): headers.setdefault("sec-websocket-key", "testserver==") headers.setdefault("sec-websocket-version", "13") if subprotocols is not None: - headers.setdefault( - "sec-websocket-protocol", ", ".join(subprotocols) - ) + headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols)) scope = { "type": "websocket", @@ -300,12 +339,23 @@ async def websocket(self, uri, subprotocols=None, *args, **kwargs): "subprotocols": subprotocols, } - async def receive(): - return {} + await self.sanic_app(scope, self._ws_receive, self._ws_send) - async def send(message): - pass + return None, {"opened": True} - await self.sanic_app(scope, receive, send) + def __getstate__(self): + # Cookies cannot be pickled, because they contain a ThreadLock + try: + del self._cookies + except AttributeError: + pass + return self.__dict__ - return None, {"opened": True} + def __setstate__(self, d): + try: + del d["_cookies"] + except LookupError: + pass + self.__dict__.update(d) + # Need to create a new CookieJar when unpickling, because it was killed on Pickle + self._cookies = httpx.Cookies() diff --git a/tests/conftest.py b/tests/conftest.py index 23f4dff..a7fb88e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,21 +3,21 @@ from sanic_testing import TestManager +def _basic_response(request): + return response.text("foo") + @pytest.fixture def app(): sanic_app = Sanic(__name__) TestManager(sanic_app) - @sanic_app.route( + sanic_app.route( "/", methods=["GET", "POST", "PATCH", "PUT", "DELETE", "OPTIONS"] - ) - def basic(request): - return response.text("foo") - + )(_basic_response) return sanic_app - @pytest.fixture def manager(): sanic_app = Sanic(__name__) return TestManager(sanic_app) + diff --git a/tests/test_basic.py b/tests/test_basic.py index 2dfe0d6..50865cd 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -1,3 +1,5 @@ +import pytest +import pickle from sanic import Sanic from sanic_testing import TestManager from sanic_testing.testing import SanicASGITestClient, SanicTestClient @@ -12,3 +14,16 @@ def test_manager_initialization(manager): assert isinstance(manager.test_client, SanicTestClient) assert isinstance(manager.asgi_client, SanicASGITestClient) assert isinstance(manager, TestManager) + +@pytest.mark.parametrize("protocol", [3, 4]) +def test_pickle_app(protocol): + app = Sanic("test_pickle_app") + manager = TestManager(app) + assert app._test_manager == manager + my_dict = {"app": app} + my_pickled = pickle.dumps(my_dict, protocol=protocol) + del my_dict + del app + del manager + my_new_dict = pickle.loads(my_pickled) + assert my_new_dict["app"]._test_manager