Skip to content

Commit

Permalink
Merge pull request #7 from ashleysommer/pickle_fixes
Browse files Browse the repository at this point in the history
Pickle fixes, and other changes
  • Loading branch information
ahopkins committed Jan 28, 2021
2 parents c22b4dd + 3cfed76 commit e4d33f9
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 62 deletions.
162 changes: 106 additions & 56 deletions sanic_testing/testing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import typing
from asyncio.events import get_event_loop
from functools import partial
from json import JSONDecodeError
from socket import socket
from types import SimpleNamespace
Expand Down Expand Up @@ -31,14 +32,16 @@ def __init__(
self.app = app
self.port = port
self.host = host
app.listener("after_server_start")(self._start_test_mode)
app.listener("before_server_stop")(self._end_test_mode)

@app.listener("after_server_start")
def _start_test_mode(sanic, *args, **kwargs):
sanic.test_mode = True
@classmethod
def _start_test_mode(cls, sanic, *args, **kwargs):
sanic.test_mode = True

@app.listener("before_server_end")
def _end_test_mode(sanic, *args, **kwargs):
sanic.test_mode = False
@classmethod
def _end_test_mode(cls, sanic, *args, **kwargs):
sanic.test_mode = False

def get_new_session(self, **kwargs) -> httpx.AsyncClient:
return httpx.AsyncClient(verify=False, **kwargs)
Expand Down Expand Up @@ -93,6 +96,38 @@ async def _local_request(self, method: str, url: str, *args, **kwargs):

return response

@classmethod
def _collect_request(cls, results, request):
if results[0] is None:
results[0] = request

async def _collect_response(
self,
method,
url,
exceptions,
results,
sanic,
loop,
**request_kwargs,
):
try:
response = await self._local_request(method, url, **request_kwargs)
results[-1] = response
if method == "websocket":
await response.ws.close()
except Exception as e:
logger.exception("Exception")
exceptions.append(e)
finally:
self.app.stop()

async def _error_handler(self, request, exception):
if request.method in ["HEAD", "PATCH", "PUT", "DELETE"]:
return text("", exception.status_code, headers=exception.headers)
else:
return self.app.error_handler.default(request, exception)

def _sanic_endpoint_test(
self,
method: str = "get",
Expand All @@ -103,28 +138,17 @@ def _sanic_endpoint_test(
host: str = None,
*request_args,
**request_kwargs,
) -> typing.Tuple[typing.Union[Request, HTTPResponse]]:
) -> typing.Union[typing.Tuple[Request, HTTPResponse], HTTPResponse]:
results = [None, None]
exceptions = []

server_kwargs = server_kwargs or {"auto_reload": False}
_collect_request = partial(self._collect_request, results)

if gather_request:

def _collect_request(request):
if results[0] is None:
results[0] = request

self.app.request_middleware.appendleft(_collect_request)

@self.app.exception(MethodNotSupported)
async def error_handler(request, exception):
if request.method in ["HEAD", "PATCH", "PUT", "DELETE"]:
return text(
"", exception.status_code, headers=exception.headers
)
else:
return self.app.error_handler.default(request, exception)
self.app.exception(MethodNotSupported)(self._error_handler)

if self.port:
server_kwargs = dict(
Expand All @@ -140,9 +164,7 @@ async def error_handler(request, exception):
host, port = sock.getsockname()
self.port = port

if uri.startswith(
("http:", "https:", "ftp:", "ftps://", "//", "ws:", "wss:")
):
if uri.startswith(("http:", "https:", "ftp:", "ftps://", "//", "ws:", "wss:")):
url = uri
else:
uri = uri if uri.startswith("/") else f"/{uri}"
Expand All @@ -152,20 +174,16 @@ async def error_handler(request, exception):
# known until this function is called, so fix that here
url = url.replace(":None/", f":{port}/")

@self.app.listener("after_server_start")
async def _collect_response(sanic, loop):
try:
response = await self._local_request(
method, url, *request_args, **request_kwargs
)
results[-1] = response
if method == "websocket":
await response.ws.close()
except Exception as e:
logger.exception("Exception")
exceptions.append(e)
finally:
self.app.stop()
self.app.listener("after_server_start")(
partial(
self._collect_response,
method,
url,
exceptions,
results,
**request_kwargs,
)
)

self.app.run(debug=debug, **server_kwargs)
self.app.listeners["after_server_start"].pop()
Expand All @@ -174,6 +192,11 @@ async def _collect_response(sanic, loop):
raise ValueError(f"Exception during request: {exceptions}")

if gather_request:
try:
self.app.request_middleware.remove(_collect_request)
except BaseException: # noqa
pass

try:
request, response = results
return request, response
Expand Down Expand Up @@ -233,6 +256,7 @@ def __init__(
base_url: str = ASGI_BASE_URL,
suppress_exceptions: bool = False,
) -> None:

app.__class__.__call__ = app_call_with_return
app.asgi = True

Expand All @@ -242,20 +266,25 @@ def __init__(

super().__init__(transport=transport, base_url=base_url)

self.gather_request = True
self.last_request = None

def _collect_request(request):
self.last_request = request
app.listener("after_server_start")(self._start_test_mode)
app.listener("before_server_stop")(self._end_test_mode)

@app.listener("after_server_start")
def _start_test_mode(sanic, *args, **kwargs):
sanic.test_mode = True
def _collect_request(self, request):
if self.gather_request:
self.last_request = request
else:
self.last_request = None

@app.listener("before_server_end")
def _end_test_mode(sanic, *args, **kwargs):
sanic.test_mode = False
@classmethod
def _start_test_mode(cls, sanic, *args, **kwargs):
sanic.test_mode = True

app.request_middleware.appendleft(_collect_request)
@classmethod
def _end_test_mode(cls, sanic, *args, **kwargs):
sanic.test_mode = False

async def request(self, method, url, gather_request=True, *args, **kwargs):

Expand All @@ -266,13 +295,25 @@ async def request(self, method, url, gather_request=True, *args, **kwargs):
scheme = "ws" if method == "websocket" else "http"
url = f"{scheme}://{ASGI_HOST}:{ASGI_PORT}{url}"

if self._collect_request not in self.sanic_app.request_middleware:
self.sanic_app.request_middleware.appendleft(self._collect_request)

self.gather_request = gather_request
response = await super().request(method, url, *args, **kwargs)
response.status = response.status_code
response.body = response.content
response.content_type = response.headers.get("content-type")
if gather_request:
return self.last_request, response
return response

return self.last_request, response
@classmethod
async def _ws_receive(cls):
return {}

@classmethod
async def _ws_send(cls, message):
pass

async def websocket(self, uri, subprotocols=None, *args, **kwargs):
scheme = "ws"
Expand All @@ -284,9 +325,7 @@ async def websocket(self, uri, subprotocols=None, *args, **kwargs):
headers.setdefault("sec-websocket-key", "testserver==")
headers.setdefault("sec-websocket-version", "13")
if subprotocols is not None:
headers.setdefault(
"sec-websocket-protocol", ", ".join(subprotocols)
)
headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols))

scope = {
"type": "websocket",
Expand All @@ -300,12 +339,23 @@ async def websocket(self, uri, subprotocols=None, *args, **kwargs):
"subprotocols": subprotocols,
}

async def receive():
return {}
await self.sanic_app(scope, self._ws_receive, self._ws_send)

async def send(message):
pass
return None, {"opened": True}

await self.sanic_app(scope, receive, send)
def __getstate__(self):
# Cookies cannot be pickled, because they contain a ThreadLock
try:
del self._cookies
except AttributeError:
pass
return self.__dict__

return None, {"opened": True}
def __setstate__(self, d):
try:
del d["_cookies"]
except LookupError:
pass
self.__dict__.update(d)
# Need to create a new CookieJar when unpickling, because it was killed on Pickle
self._cookies = httpx.Cookies()
12 changes: 6 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,21 @@
from sanic_testing import TestManager


def _basic_response(request):
return response.text("foo")

@pytest.fixture
def app():
sanic_app = Sanic(__name__)
TestManager(sanic_app)

@sanic_app.route(
sanic_app.route(
"/", methods=["GET", "POST", "PATCH", "PUT", "DELETE", "OPTIONS"]
)
def basic(request):
return response.text("foo")

)(_basic_response)
return sanic_app


@pytest.fixture
def manager():
sanic_app = Sanic(__name__)
return TestManager(sanic_app)

15 changes: 15 additions & 0 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest
import pickle
from sanic import Sanic
from sanic_testing import TestManager
from sanic_testing.testing import SanicASGITestClient, SanicTestClient
Expand All @@ -12,3 +14,16 @@ def test_manager_initialization(manager):
assert isinstance(manager.test_client, SanicTestClient)
assert isinstance(manager.asgi_client, SanicASGITestClient)
assert isinstance(manager, TestManager)

@pytest.mark.parametrize("protocol", [3, 4])
def test_pickle_app(protocol):
app = Sanic("test_pickle_app")
manager = TestManager(app)
assert app._test_manager == manager
my_dict = {"app": app}
my_pickled = pickle.dumps(my_dict, protocol=protocol)
del my_dict
del app
del manager
my_new_dict = pickle.loads(my_pickled)
assert my_new_dict["app"]._test_manager

0 comments on commit e4d33f9

Please sign in to comment.