From 3cfed766c525f96267dde8bd8cba81b6cbda1e37 Mon Sep 17 00:00:00 2001 From: Ashley Sommer Date: Thu, 28 Jan 2021 14:50:17 +1000 Subject: [PATCH] Move nested subfuctions out of TestClient class methods. This fixes a weird App pickling error seen when using TestManager (specifically on the ASGITestClient) Add logic to remove the CookieJar from ASGITestClient before pickling, then add it back when unpickled. The cookie jar contains a ThreadLock which cannot be pickled. Fixes #3 Improve gather_request logic in TestClient and ASGITestClient, Fixes #4 Change "before_server_end" to "before_server_stop". Fixes #5 Fix return type on _sanic_endpoint_test, Fixes #6 Fixes #6 --- sanic_testing/testing.py | 162 +++++++++++++++++++++++++-------------- tests/conftest.py | 12 +-- tests/test_basic.py | 15 ++++ 3 files changed, 127 insertions(+), 62 deletions(-) diff --git a/sanic_testing/testing.py b/sanic_testing/testing.py index 0e4a410..b795a87 100644 --- a/sanic_testing/testing.py +++ b/sanic_testing/testing.py @@ -1,6 +1,7 @@ import asyncio import typing from asyncio.events import get_event_loop +from functools import partial from json import JSONDecodeError from socket import socket from types import SimpleNamespace @@ -31,14 +32,16 @@ def __init__( self.app = app self.port = port self.host = host + app.listener("after_server_start")(self._start_test_mode) + app.listener("before_server_stop")(self._end_test_mode) - @app.listener("after_server_start") - def _start_test_mode(sanic, *args, **kwargs): - sanic.test_mode = True + @classmethod + def _start_test_mode(cls, sanic, *args, **kwargs): + sanic.test_mode = True - @app.listener("before_server_end") - def _end_test_mode(sanic, *args, **kwargs): - sanic.test_mode = False + @classmethod + def _end_test_mode(cls, sanic, *args, **kwargs): + sanic.test_mode = False def get_new_session(self, **kwargs) -> httpx.AsyncClient: return httpx.AsyncClient(verify=False, **kwargs) @@ -93,6 +96,38 @@ async def _local_request(self, method: str, url: str, *args, **kwargs): return response + @classmethod + def _collect_request(cls, results, request): + if results[0] is None: + results[0] = request + + async def _collect_response( + self, + method, + url, + exceptions, + results, + sanic, + loop, + **request_kwargs, + ): + try: + response = await self._local_request(method, url, **request_kwargs) + results[-1] = response + if method == "websocket": + await response.ws.close() + except Exception as e: + logger.exception("Exception") + exceptions.append(e) + finally: + self.app.stop() + + async def _error_handler(self, request, exception): + if request.method in ["HEAD", "PATCH", "PUT", "DELETE"]: + return text("", exception.status_code, headers=exception.headers) + else: + return self.app.error_handler.default(request, exception) + def _sanic_endpoint_test( self, method: str = "get", @@ -103,28 +138,17 @@ def _sanic_endpoint_test( host: str = None, *request_args, **request_kwargs, - ) -> typing.Tuple[typing.Union[Request, HTTPResponse]]: + ) -> typing.Union[typing.Tuple[Request, HTTPResponse], HTTPResponse]: results = [None, None] exceptions = [] server_kwargs = server_kwargs or {"auto_reload": False} + _collect_request = partial(self._collect_request, results) if gather_request: - - def _collect_request(request): - if results[0] is None: - results[0] = request - self.app.request_middleware.appendleft(_collect_request) - @self.app.exception(MethodNotSupported) - async def error_handler(request, exception): - if request.method in ["HEAD", "PATCH", "PUT", "DELETE"]: - return text( - "", exception.status_code, headers=exception.headers - ) - else: - return self.app.error_handler.default(request, exception) + self.app.exception(MethodNotSupported)(self._error_handler) if self.port: server_kwargs = dict( @@ -140,9 +164,7 @@ async def error_handler(request, exception): host, port = sock.getsockname() self.port = port - if uri.startswith( - ("http:", "https:", "ftp:", "ftps://", "//", "ws:", "wss:") - ): + if uri.startswith(("http:", "https:", "ftp:", "ftps://", "//", "ws:", "wss:")): url = uri else: uri = uri if uri.startswith("/") else f"/{uri}" @@ -152,20 +174,16 @@ async def error_handler(request, exception): # known until this function is called, so fix that here url = url.replace(":None/", f":{port}/") - @self.app.listener("after_server_start") - async def _collect_response(sanic, loop): - try: - response = await self._local_request( - method, url, *request_args, **request_kwargs - ) - results[-1] = response - if method == "websocket": - await response.ws.close() - except Exception as e: - logger.exception("Exception") - exceptions.append(e) - finally: - self.app.stop() + self.app.listener("after_server_start")( + partial( + self._collect_response, + method, + url, + exceptions, + results, + **request_kwargs, + ) + ) self.app.run(debug=debug, **server_kwargs) self.app.listeners["after_server_start"].pop() @@ -174,6 +192,11 @@ async def _collect_response(sanic, loop): raise ValueError(f"Exception during request: {exceptions}") if gather_request: + try: + self.app.request_middleware.remove(_collect_request) + except BaseException: # noqa + pass + try: request, response = results return request, response @@ -233,6 +256,7 @@ def __init__( base_url: str = ASGI_BASE_URL, suppress_exceptions: bool = False, ) -> None: + app.__class__.__call__ = app_call_with_return app.asgi = True @@ -242,20 +266,25 @@ def __init__( super().__init__(transport=transport, base_url=base_url) + self.gather_request = True self.last_request = None - def _collect_request(request): - self.last_request = request + app.listener("after_server_start")(self._start_test_mode) + app.listener("before_server_stop")(self._end_test_mode) - @app.listener("after_server_start") - def _start_test_mode(sanic, *args, **kwargs): - sanic.test_mode = True + def _collect_request(self, request): + if self.gather_request: + self.last_request = request + else: + self.last_request = None - @app.listener("before_server_end") - def _end_test_mode(sanic, *args, **kwargs): - sanic.test_mode = False + @classmethod + def _start_test_mode(cls, sanic, *args, **kwargs): + sanic.test_mode = True - app.request_middleware.appendleft(_collect_request) + @classmethod + def _end_test_mode(cls, sanic, *args, **kwargs): + sanic.test_mode = False async def request(self, method, url, gather_request=True, *args, **kwargs): @@ -266,13 +295,25 @@ async def request(self, method, url, gather_request=True, *args, **kwargs): scheme = "ws" if method == "websocket" else "http" url = f"{scheme}://{ASGI_HOST}:{ASGI_PORT}{url}" + if self._collect_request not in self.sanic_app.request_middleware: + self.sanic_app.request_middleware.appendleft(self._collect_request) + self.gather_request = gather_request response = await super().request(method, url, *args, **kwargs) response.status = response.status_code response.body = response.content response.content_type = response.headers.get("content-type") + if gather_request: + return self.last_request, response + return response - return self.last_request, response + @classmethod + async def _ws_receive(cls): + return {} + + @classmethod + async def _ws_send(cls, message): + pass async def websocket(self, uri, subprotocols=None, *args, **kwargs): scheme = "ws" @@ -284,9 +325,7 @@ async def websocket(self, uri, subprotocols=None, *args, **kwargs): headers.setdefault("sec-websocket-key", "testserver==") headers.setdefault("sec-websocket-version", "13") if subprotocols is not None: - headers.setdefault( - "sec-websocket-protocol", ", ".join(subprotocols) - ) + headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols)) scope = { "type": "websocket", @@ -300,12 +339,23 @@ async def websocket(self, uri, subprotocols=None, *args, **kwargs): "subprotocols": subprotocols, } - async def receive(): - return {} + await self.sanic_app(scope, self._ws_receive, self._ws_send) - async def send(message): - pass + return None, {"opened": True} - await self.sanic_app(scope, receive, send) + def __getstate__(self): + # Cookies cannot be pickled, because they contain a ThreadLock + try: + del self._cookies + except AttributeError: + pass + return self.__dict__ - return None, {"opened": True} + def __setstate__(self, d): + try: + del d["_cookies"] + except LookupError: + pass + self.__dict__.update(d) + # Need to create a new CookieJar when unpickling, because it was killed on Pickle + self._cookies = httpx.Cookies() diff --git a/tests/conftest.py b/tests/conftest.py index 23f4dff..a7fb88e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,21 +3,21 @@ from sanic_testing import TestManager +def _basic_response(request): + return response.text("foo") + @pytest.fixture def app(): sanic_app = Sanic(__name__) TestManager(sanic_app) - @sanic_app.route( + sanic_app.route( "/", methods=["GET", "POST", "PATCH", "PUT", "DELETE", "OPTIONS"] - ) - def basic(request): - return response.text("foo") - + )(_basic_response) return sanic_app - @pytest.fixture def manager(): sanic_app = Sanic(__name__) return TestManager(sanic_app) + diff --git a/tests/test_basic.py b/tests/test_basic.py index 2dfe0d6..50865cd 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -1,3 +1,5 @@ +import pytest +import pickle from sanic import Sanic from sanic_testing import TestManager from sanic_testing.testing import SanicASGITestClient, SanicTestClient @@ -12,3 +14,16 @@ def test_manager_initialization(manager): assert isinstance(manager.test_client, SanicTestClient) assert isinstance(manager.asgi_client, SanicASGITestClient) assert isinstance(manager, TestManager) + +@pytest.mark.parametrize("protocol", [3, 4]) +def test_pickle_app(protocol): + app = Sanic("test_pickle_app") + manager = TestManager(app) + assert app._test_manager == manager + my_dict = {"app": app} + my_pickled = pickle.dumps(my_dict, protocol=protocol) + del my_dict + del app + del manager + my_new_dict = pickle.loads(my_pickled) + assert my_new_dict["app"]._test_manager