From 51c39ac398901b7bcf958154f831d17240dfda73 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Sun, 23 Jun 2019 20:13:09 -0700 Subject: [PATCH] Port Datasette from Sanic to ASGI + Uvicorn (#518) Datasette now uses ASGI internally, and no longer depends on Sanic. It now uses Uvicorn as the underlying HTTP server. This was thirteen months in the making... for full details see the issue: https://github.com/simonw/datasette/issues/272 And for a full sequence of commits plus commentary, see the pull request: https://github.com/simonw/datasette/pull/518 --- datasette/app.py | 245 +++++++------- datasette/cli.py | 3 +- datasette/renderer.py | 2 +- datasette/tracer.py | 81 ++++- datasette/{utils.py => utils/__init__.py} | 17 +- datasette/utils/asgi.py | 377 ++++++++++++++++++++++ datasette/views/base.py | 68 ++-- datasette/views/database.py | 9 +- datasette/views/index.py | 7 +- datasette/views/special.py | 8 +- datasette/views/table.py | 7 +- pytest.ini | 2 - setup.py | 6 +- tests/fixtures.py | 86 ++++- tests/test_api.py | 12 +- tests/test_csv.py | 8 +- tests/test_html.py | 25 ++ tests/test_utils.py | 12 +- 18 files changed, 769 insertions(+), 206 deletions(-) rename datasette/{utils.py => utils/__init__.py} (98%) create mode 100644 datasette/utils/asgi.py diff --git a/datasette/app.py b/datasette/app.py index 2ef7da41b5..4a8ead1ddf 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -1,11 +1,9 @@ import asyncio import collections import hashlib -import json import os import sys import threading -import time import traceback import urllib.parse from concurrent import futures @@ -14,10 +12,8 @@ import click from markupsafe import Markup from jinja2 import ChoiceLoader, Environment, FileSystemLoader, PrefixLoader -from sanic import Sanic, response -from sanic.exceptions import InvalidUsage, NotFound -from .views.base import DatasetteError, ureg +from .views.base import DatasetteError, ureg, AsgiRouter from .views.database import DatabaseDownload, DatabaseView from .views.index import IndexView from .views.special import JsonDataView @@ -36,7 +32,16 @@ sqlite_timelimit, to_css_class, ) -from .tracer import capture_traces, trace +from .utils.asgi import ( + AsgiLifespan, + NotFound, + asgi_static, + asgi_send, + asgi_send_html, + asgi_send_json, + asgi_send_redirect, +) +from .tracer import trace, AsgiTracer from .plugins import pm, DEFAULT_PLUGINS from .version import __version__ @@ -126,8 +131,8 @@ DEFAULT_CONFIG = {option.name: option.default for option in CONFIG_OPTIONS} -async def favicon(request): - return response.text("") +async def favicon(scope, receive, send): + await asgi_send(send, "", 200) class Datasette: @@ -413,6 +418,7 @@ def versions(self): "full": sys.version, }, "datasette": datasette_version, + "asgi": "3.0", "sqlite": { "version": sqlite_version, "fts_versions": fts_versions, @@ -543,21 +549,7 @@ def register_renderers(self): self.renderers[renderer["extension"]] = renderer["callback"] def app(self): - class TracingSanic(Sanic): - async def handle_request(self, request, write_callback, stream_callback): - if request.args.get("_trace"): - request["traces"] = [] - request["trace_start"] = time.time() - with capture_traces(request["traces"]): - await super().handle_request( - request, write_callback, stream_callback - ) - else: - await super().handle_request( - request, write_callback, stream_callback - ) - - app = TracingSanic(__name__) + "Returns an ASGI app function that serves the whole of Datasette" default_templates = str(app_root / "datasette" / "templates") template_paths = [] if self.template_dir: @@ -588,134 +580,127 @@ async def handle_request(self, request, write_callback, stream_callback): pm.hook.prepare_jinja2_environment(env=self.jinja_env) self.register_renderers() + + routes = [] + + def add_route(view, regex): + routes.append((regex, view)) + # Generate a regex snippet to match all registered renderer file extensions renderer_regex = "|".join(r"\." + key for key in self.renderers.keys()) - app.add_route(IndexView.as_view(self), r"/") + add_route(IndexView.as_asgi(self), r"/(?P(\.jsono?)?$)") # TODO: /favicon.ico and /-/static/ deserve far-future cache expires - app.add_route(favicon, "/favicon.ico") - app.static("/-/static/", str(app_root / "datasette" / "static")) + add_route(favicon, "/favicon.ico") + + add_route( + asgi_static(app_root / "datasette" / "static"), r"/-/static/(?P.*)$" + ) for path, dirname in self.static_mounts: - app.static(path, dirname) + add_route(asgi_static(dirname), r"/" + path + "/(?P.*)$") + # Mount any plugin static/ directories for plugin in get_plugins(pm): if plugin["static_path"]: - modpath = "/-/static-plugins/{}/".format(plugin["name"]) - app.static(modpath, plugin["static_path"]) - app.add_route( - JsonDataView.as_view(self, "metadata.json", lambda: self._metadata), - r"/-/metadata", + modpath = "/-/static-plugins/{}/(?P.*)$".format(plugin["name"]) + add_route(asgi_static(plugin["static_path"]), modpath) + add_route( + JsonDataView.as_asgi(self, "metadata.json", lambda: self._metadata), + r"/-/metadata(?P(\.json)?)$", ) - app.add_route( - JsonDataView.as_view(self, "versions.json", self.versions), - r"/-/versions", + add_route( + JsonDataView.as_asgi(self, "versions.json", self.versions), + r"/-/versions(?P(\.json)?)$", ) - app.add_route( - JsonDataView.as_view(self, "plugins.json", self.plugins), - r"/-/plugins", + add_route( + JsonDataView.as_asgi(self, "plugins.json", self.plugins), + r"/-/plugins(?P(\.json)?)$", ) - app.add_route( - JsonDataView.as_view(self, "config.json", lambda: self._config), - r"/-/config", + add_route( + JsonDataView.as_asgi(self, "config.json", lambda: self._config), + r"/-/config(?P(\.json)?)$", ) - app.add_route( - JsonDataView.as_view(self, "databases.json", self.connected_databases), - r"/-/databases", + add_route( + JsonDataView.as_asgi(self, "databases.json", self.connected_databases), + r"/-/databases(?P(\.json)?)$", ) - app.add_route( - DatabaseDownload.as_view(self), r"/" + add_route( + DatabaseDownload.as_asgi(self), r"/(?P[^/]+?)(?P\.db)$" ) - app.add_route( - DatabaseView.as_view(self), - r"/", + add_route( + DatabaseView.as_asgi(self), + r"/(?P[^/]+?)(?P" + + renderer_regex + + r"|.jsono|\.csv)?$", ) - app.add_route( - TableView.as_view(self), r"//" + add_route( + TableView.as_asgi(self), + r"/(?P[^/]+)/(?P[^/]+?$)", ) - app.add_route( - RowView.as_view(self), - r"///[^/]+)/(?P[^/]+?)/(?P[^/]+?)(?P" + renderer_regex - + r")?$>", + + r")?$", ) self.register_custom_units() - # On 404 with a trailing slash redirect to path without that slash: - # pylint: disable=unused-variable - @app.middleware("response") - def redirect_on_404_with_trailing_slash(request, original_response): - if original_response.status == 404 and request.path.endswith("/"): - path = request.path.rstrip("/") - if request.query_string: - path = "{}?{}".format(path, request.query_string) - return response.redirect(path) - - @app.middleware("response") - async def add_traces_to_response(request, response): - if request.get("traces") is None: - return - traces = request["traces"] - trace_info = { - "request_duration_ms": 1000 * (time.time() - request["trace_start"]), - "sum_trace_duration_ms": sum(t["duration_ms"] for t in traces), - "num_traces": len(traces), - "traces": traces, - } - if "text/html" in response.content_type and b"" in response.body: - extra = json.dumps(trace_info, indent=2) - extra_html = "
{}
".format(extra).encode("utf8") - response.body = response.body.replace(b"", extra_html) - elif "json" in response.content_type and response.body.startswith(b"{"): - data = json.loads(response.body.decode("utf8")) - if "_trace" not in data: - data["_trace"] = trace_info - response.body = json.dumps(data).encode("utf8") - - @app.exception(Exception) - def on_exception(request, exception): - title = None - help = None - if isinstance(exception, NotFound): - status = 404 - info = {} - message = exception.args[0] - elif isinstance(exception, InvalidUsage): - status = 405 - info = {} - message = exception.args[0] - elif isinstance(exception, DatasetteError): - status = exception.status - info = exception.error_dict - message = exception.message - if exception.messagge_is_html: - message = Markup(message) - title = exception.title - else: - status = 500 - info = {} - message = str(exception) - traceback.print_exc() - templates = ["500.html"] - if status != 500: - templates = ["{}.html".format(status)] + templates - info.update( - {"ok": False, "error": message, "status": status, "title": title} - ) - if request is not None and request.path.split("?")[0].endswith(".json"): - r = response.json(info, status=status) - - else: - template = self.jinja_env.select_template(templates) - r = response.html(template.render(info), status=status) - if self.cors: - r.headers["Access-Control-Allow-Origin"] = "*" - return r - - # First time server starts up, calculate table counts for immutable databases - @app.listener("before_server_start") - async def setup_db(app, loop): + async def setup_db(): + # First time server starts up, calculate table counts for immutable databases for dbname, database in self.databases.items(): if not database.is_mutable: await database.table_counts(limit=60 * 60 * 1000) - return app + return AsgiLifespan( + AsgiTracer(DatasetteRouter(self, routes)), on_startup=setup_db + ) + + +class DatasetteRouter(AsgiRouter): + def __init__(self, datasette, routes): + self.ds = datasette + super().__init__(routes) + + async def handle_404(self, scope, receive, send): + # If URL has a trailing slash, redirect to URL without it + path = scope.get("raw_path", scope["path"].encode("utf8")) + if path.endswith(b"/"): + path = path.rstrip(b"/") + if scope["query_string"]: + path += b"?" + scope["query_string"] + await asgi_send_redirect(send, path.decode("latin1")) + else: + await super().handle_404(scope, receive, send) + + async def handle_500(self, scope, receive, send, exception): + title = None + if isinstance(exception, NotFound): + status = 404 + info = {} + message = exception.args[0] + elif isinstance(exception, DatasetteError): + status = exception.status + info = exception.error_dict + message = exception.message + if exception.messagge_is_html: + message = Markup(message) + title = exception.title + else: + status = 500 + info = {} + message = str(exception) + traceback.print_exc() + templates = ["500.html"] + if status != 500: + templates = ["{}.html".format(status)] + templates + info.update({"ok": False, "error": message, "status": status, "title": title}) + headers = {} + if self.ds.cors: + headers["Access-Control-Allow-Origin"] = "*" + if scope["path"].split("?")[0].endswith(".json"): + await asgi_send_json(send, info, status=status, headers=headers) + else: + template = self.ds.jinja_env.select_template(templates) + await asgi_send_html( + send, template.render(info), status=status, headers=headers + ) diff --git a/datasette/cli.py b/datasette/cli.py index 0d47f47a0c..181b281c7c 100644 --- a/datasette/cli.py +++ b/datasette/cli.py @@ -1,4 +1,5 @@ import asyncio +import uvicorn import click from click import formatting from click_default_group import DefaultGroup @@ -354,4 +355,4 @@ def serve( asyncio.get_event_loop().run_until_complete(ds.run_sanity_checks()) # Start the server - ds.app().run(host=host, port=port, debug=debug) + uvicorn.run(ds.app(), host=host, port=port, log_level="info") diff --git a/datasette/renderer.py b/datasette/renderer.py index 417fecb54b..349c29223c 100644 --- a/datasette/renderer.py +++ b/datasette/renderer.py @@ -88,5 +88,5 @@ def json_renderer(args, data, view_name): content_type = "text/plain" else: body = json.dumps(data, cls=CustomJSONEncoder) - content_type = "application/json" + content_type = "application/json; charset=utf-8" return {"body": body, "status_code": status_code, "content_type": content_type} diff --git a/datasette/tracer.py b/datasette/tracer.py index c6fe0a00ba..e46a6fda4a 100644 --- a/datasette/tracer.py +++ b/datasette/tracer.py @@ -1,6 +1,7 @@ import asyncio from contextlib import contextmanager import time +import json import traceback tracers = {} @@ -32,15 +33,15 @@ def trace(type, **kwargs): start = time.time() yield end = time.time() - trace = { + trace_info = { "type": type, "start": start, "end": end, "duration_ms": (end - start) * 1000, "traceback": traceback.format_list(traceback.extract_stack(limit=6)[:-3]), } - trace.update(kwargs) - tracer.append(trace) + trace_info.update(kwargs) + tracer.append(trace_info) @contextmanager @@ -53,3 +54,77 @@ def capture_traces(tracer): tracers[task_id] = tracer yield del tracers[task_id] + + +class AsgiTracer: + # If the body is larger than this we don't attempt to append the trace + max_body_bytes = 1024 * 256 # 256 KB + + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + if b"_trace=1" not in scope.get("query_string", b"").split(b"&"): + await self.app(scope, receive, send) + return + trace_start = time.time() + traces = [] + + accumulated_body = b"" + size_limit_exceeded = False + response_headers = [] + + async def wrapped_send(message): + nonlocal accumulated_body, size_limit_exceeded, response_headers + if message["type"] == "http.response.start": + response_headers = message["headers"] + await send(message) + return + + if message["type"] != "http.response.body" or size_limit_exceeded: + await send(message) + return + + # Accumulate body until the end or until size is exceeded + accumulated_body += message["body"] + if len(accumulated_body) > self.max_body_bytes: + await send( + { + "type": "http.response.body", + "body": accumulated_body, + "more_body": True, + } + ) + size_limit_exceeded = True + return + + if not message.get("more_body"): + # We have all the body - modify it and send the result + # TODO: What to do about Content-Type or other cases? + trace_info = { + "request_duration_ms": 1000 * (time.time() - trace_start), + "sum_trace_duration_ms": sum(t["duration_ms"] for t in traces), + "num_traces": len(traces), + "traces": traces, + } + try: + content_type = [ + v.decode("utf8") + for k, v in response_headers + if k.lower() == b"content-type" + ][0] + except IndexError: + content_type = "" + if "text/html" in content_type and b"" in accumulated_body: + extra = json.dumps(trace_info, indent=2) + extra_html = "
{}
".format(extra).encode("utf8") + accumulated_body = accumulated_body.replace(b"", extra_html) + elif "json" in content_type and accumulated_body.startswith(b"{"): + data = json.loads(accumulated_body.decode("utf8")) + if "_trace" not in data: + data["_trace"] = trace_info + accumulated_body = json.dumps(data).encode("utf8") + await send({"type": "http.response.body", "body": accumulated_body}) + + with capture_traces(traces): + await self.app(scope, receive, wrapped_send) diff --git a/datasette/utils.py b/datasette/utils/__init__.py similarity index 98% rename from datasette/utils.py rename to datasette/utils/__init__.py index 58746be424..94ccc23ed8 100644 --- a/datasette/utils.py +++ b/datasette/utils/__init__.py @@ -697,13 +697,13 @@ def __init__(self, writer, limit_mb): self.limit_bytes = limit_mb * 1024 * 1024 self.bytes_count = 0 - def write(self, bytes): + async def write(self, bytes): self.bytes_count += len(bytes) if self.limit_bytes and (self.bytes_count > self.limit_bytes): raise WriteLimitExceeded( "CSV contains more than {} bytes".format(self.limit_bytes) ) - self.writer.write(bytes) + await self.writer.write(bytes) _infinities = {float("inf"), float("-inf")} @@ -741,3 +741,16 @@ def format_bytes(bytes): return "{} {}".format(int(current), unit) else: return "{:.1f} {}".format(current, unit) + + +class RequestParameters(dict): + def get(self, name, default=None): + "Return first value in the list, if available" + try: + return super().get(name)[0] + except (KeyError, TypeError): + return default + + def getlist(self, name, default=None): + "Return full list" + return super().get(name, default) diff --git a/datasette/utils/asgi.py b/datasette/utils/asgi.py new file mode 100644 index 0000000000..fdf330ae79 --- /dev/null +++ b/datasette/utils/asgi.py @@ -0,0 +1,377 @@ +import json +from datasette.utils import RequestParameters +from mimetypes import guess_type +from urllib.parse import parse_qs, urlunparse +from pathlib import Path +from html import escape +import re +import aiofiles + + +class NotFound(Exception): + pass + + +class Request: + def __init__(self, scope): + self.scope = scope + + @property + def method(self): + return self.scope["method"] + + @property + def url(self): + return urlunparse( + (self.scheme, self.host, self.path, None, self.query_string, None) + ) + + @property + def scheme(self): + return self.scope.get("scheme") or "http" + + @property + def headers(self): + return dict( + [ + (k.decode("latin-1").lower(), v.decode("latin-1")) + for k, v in self.scope.get("headers") or [] + ] + ) + + @property + def host(self): + return self.headers.get("host") or "localhost" + + @property + def path(self): + return ( + self.scope.get("raw_path", self.scope["path"].encode("latin-1")) + ).decode("latin-1") + + @property + def query_string(self): + return (self.scope.get("query_string") or b"").decode("latin-1") + + @property + def args(self): + return RequestParameters(parse_qs(qs=self.query_string)) + + @property + def raw_args(self): + return {key: value[0] for key, value in self.args.items()} + + @classmethod + def fake(cls, path_with_query_string, method="GET", scheme="http"): + "Useful for constructing Request objects for tests" + path, _, query_string = path_with_query_string.partition("?") + scope = { + "http_version": "1.1", + "method": method, + "path": path, + "raw_path": path.encode("latin-1"), + "query_string": query_string.encode("latin-1"), + "scheme": scheme, + "type": "http", + } + return cls(scope) + + +class AsgiRouter: + def __init__(self, routes=None): + routes = routes or [] + self.routes = [ + # Compile any strings to regular expressions + ((re.compile(pattern) if isinstance(pattern, str) else pattern), view) + for pattern, view in routes + ] + + async def __call__(self, scope, receive, send): + # Because we care about "foo/bar" v.s. "foo%2Fbar" we decode raw_path ourselves + path = scope["raw_path"].decode("ascii") + for regex, view in self.routes: + match = regex.match(path) + if match is not None: + new_scope = dict(scope, url_route={"kwargs": match.groupdict()}) + try: + return await view(new_scope, receive, send) + except Exception as exception: + return await self.handle_500(scope, receive, send, exception) + return await self.handle_404(scope, receive, send) + + async def handle_404(self, scope, receive, send): + await send( + { + "type": "http.response.start", + "status": 404, + "headers": [[b"content-type", b"text/html"]], + } + ) + await send({"type": "http.response.body", "body": b"

404

"}) + + async def handle_500(self, scope, receive, send, exception): + await send( + { + "type": "http.response.start", + "status": 404, + "headers": [[b"content-type", b"text/html"]], + } + ) + html = "

500

".format(escape(repr(exception))) + await send({"type": "http.response.body", "body": html.encode("latin-1")}) + + +class AsgiLifespan: + def __init__(self, app, on_startup=None, on_shutdown=None): + self.app = app + on_startup = on_startup or [] + on_shutdown = on_shutdown or [] + if not isinstance(on_startup or [], list): + on_startup = [on_startup] + if not isinstance(on_shutdown or [], list): + on_shutdown = [on_shutdown] + self.on_startup = on_startup + self.on_shutdown = on_shutdown + + async def __call__(self, scope, receive, send): + if scope["type"] == "lifespan": + while True: + message = await receive() + if message["type"] == "lifespan.startup": + for fn in self.on_startup: + await fn() + await send({"type": "lifespan.startup.complete"}) + elif message["type"] == "lifespan.shutdown": + for fn in self.on_shutdown: + await fn() + await send({"type": "lifespan.shutdown.complete"}) + return + else: + await self.app(scope, receive, send) + + +class AsgiView: + def dispatch_request(self, request, *args, **kwargs): + handler = getattr(self, request.method.lower(), None) + return handler(request, *args, **kwargs) + + @classmethod + def as_asgi(cls, *class_args, **class_kwargs): + async def view(scope, receive, send): + # Uses scope to create a request object, then dispatches that to + # self.get(...) or self.options(...) along with keyword arguments + # that were already tucked into scope["url_route"]["kwargs"] by + # the router, similar to how Django Channels works: + # https://channels.readthedocs.io/en/latest/topics/routing.html#urlrouter + request = Request(scope) + self = view.view_class(*class_args, **class_kwargs) + response = await self.dispatch_request( + request, **scope["url_route"]["kwargs"] + ) + await response.asgi_send(send) + + view.view_class = cls + view.__doc__ = cls.__doc__ + view.__module__ = cls.__module__ + view.__name__ = cls.__name__ + return view + + +class AsgiStream: + def __init__(self, stream_fn, status=200, headers=None, content_type="text/plain"): + self.stream_fn = stream_fn + self.status = status + self.headers = headers or {} + self.content_type = content_type + + async def asgi_send(self, send): + # Remove any existing content-type header + headers = dict( + [(k, v) for k, v in self.headers.items() if k.lower() != "content-type"] + ) + headers["content-type"] = self.content_type + await send( + { + "type": "http.response.start", + "status": self.status, + "headers": [ + [key.encode("utf-8"), value.encode("utf-8")] + for key, value in headers.items() + ], + } + ) + w = AsgiWriter(send) + await self.stream_fn(w) + await send({"type": "http.response.body", "body": b""}) + + +class AsgiWriter: + def __init__(self, send): + self.send = send + + async def write(self, chunk): + await self.send( + { + "type": "http.response.body", + "body": chunk.encode("latin-1"), + "more_body": True, + } + ) + + +async def asgi_send_json(send, info, status=200, headers=None): + headers = headers or {} + await asgi_send( + send, + json.dumps(info), + status=status, + headers=headers, + content_type="application/json; charset=utf-8", + ) + + +async def asgi_send_html(send, html, status=200, headers=None): + headers = headers or {} + await asgi_send( + send, html, status=status, headers=headers, content_type="text/html" + ) + + +async def asgi_send_redirect(send, location, status=302): + await asgi_send( + send, + "", + status=status, + headers={"Location": location}, + content_type="text/html", + ) + + +async def asgi_send(send, content, status, headers=None, content_type="text/plain"): + await asgi_start(send, status, headers, content_type) + await send({"type": "http.response.body", "body": content.encode("latin-1")}) + + +async def asgi_start(send, status, headers=None, content_type="text/plain"): + headers = headers or {} + # Remove any existing content-type header + headers = dict([(k, v) for k, v in headers.items() if k.lower() != "content-type"]) + headers["content-type"] = content_type + await send( + { + "type": "http.response.start", + "status": status, + "headers": [ + [key.encode("latin1"), value.encode("latin1")] + for key, value in headers.items() + ], + } + ) + + +async def asgi_send_file( + send, filepath, filename=None, content_type=None, chunk_size=4096 +): + headers = {} + if filename: + headers["Content-Disposition"] = 'attachment; filename="{}"'.format(filename) + first = True + async with aiofiles.open(str(filepath), mode="rb") as fp: + if first: + await asgi_start( + send, + 200, + headers, + content_type or guess_type(str(filepath))[0] or "text/plain", + ) + first = False + more_body = True + while more_body: + chunk = await fp.read(chunk_size) + more_body = len(chunk) == chunk_size + await send( + {"type": "http.response.body", "body": chunk, "more_body": more_body} + ) + + +def asgi_static(root_path, chunk_size=4096, headers=None, content_type=None): + async def inner_static(scope, receive, send): + path = scope["url_route"]["kwargs"]["path"] + full_path = (Path(root_path) / path).absolute() + # Ensure full_path is within root_path to avoid weird "../" tricks + try: + full_path.relative_to(root_path) + except ValueError: + await asgi_send_html(send, "404", 404) + return + first = True + try: + await asgi_send_file(send, full_path, chunk_size=chunk_size) + except FileNotFoundError: + await asgi_send_html(send, "404", 404) + return + + return inner_static + + +class Response: + def __init__(self, body=None, status=200, headers=None, content_type="text/plain"): + self.body = body + self.status = status + self.headers = headers or {} + self.content_type = content_type + + async def asgi_send(self, send): + headers = {} + headers.update(self.headers) + headers["content-type"] = self.content_type + await send( + { + "type": "http.response.start", + "status": self.status, + "headers": [ + [key.encode("utf-8"), value.encode("utf-8")] + for key, value in headers.items() + ], + } + ) + body = self.body + if not isinstance(body, bytes): + body = body.encode("utf-8") + await send({"type": "http.response.body", "body": body}) + + @classmethod + def html(cls, body, status=200, headers=None): + return cls( + body, + status=status, + headers=headers, + content_type="text/html; charset=utf-8", + ) + + @classmethod + def text(cls, body, status=200, headers=None): + return cls( + body, + status=status, + headers=headers, + content_type="text/plain; charset=utf-8", + ) + + @classmethod + def redirect(cls, path, status=302, headers=None): + headers = headers or {} + headers["Location"] = path + return cls("", status=status, headers=headers) + + +class AsgiFileDownload: + def __init__( + self, filepath, filename=None, content_type="application/octet-stream" + ): + self.filepath = filepath + self.filename = filename + self.content_type = content_type + + async def asgi_send(self, send): + return await asgi_send_file(send, self.filepath, content_type=self.content_type) diff --git a/datasette/views/base.py b/datasette/views/base.py index 9db8cc762a..7acb7304cb 100644 --- a/datasette/views/base.py +++ b/datasette/views/base.py @@ -7,9 +7,8 @@ import jinja2 import pint -from sanic import response -from sanic.exceptions import NotFound -from sanic.views import HTTPMethodView + +from html import escape from datasette import __version__ from datasette.plugins import pm @@ -26,6 +25,14 @@ sqlite3, to_css_class, ) +from datasette.utils.asgi import ( + AsgiStream, + AsgiWriter, + AsgiRouter, + AsgiView, + NotFound, + Response, +) ureg = pint.UnitRegistry() @@ -49,7 +56,14 @@ def __init__( self.messagge_is_html = messagge_is_html -class BaseView(HTTPMethodView): +class BaseView(AsgiView): + ds = None + + async def head(self, *args, **kwargs): + response = await self.get(*args, **kwargs) + response.body = b"" + return response + def _asset_urls(self, key, template, context): # Flatten list-of-lists from plugins: seen_urls = set() @@ -104,7 +118,7 @@ def render(self, templates, **context): datasette=self.ds, ): body_scripts.append(jinja2.Markup(script)) - return response.html( + return Response.html( template.render( { **context, @@ -136,7 +150,7 @@ def __init__(self, datasette): self.ds = datasette def options(self, request, *args, **kwargs): - r = response.text("ok") + r = Response.text("ok") if self.ds.cors: r.headers["Access-Control-Allow-Origin"] = "*" return r @@ -146,7 +160,7 @@ def redirect(self, request, path, forward_querystring=True, remove_args=None): path = "{}?{}".format(path, request.query_string) if remove_args: path = path_with_removed_args(request, remove_args, path=path) - r = response.redirect(path) + r = Response.redirect(path) r.headers["Link"] = "<{}>; rel=preload".format(path) if self.ds.cors: r.headers["Access-Control-Allow-Origin"] = "*" @@ -195,17 +209,17 @@ async def async_table_exists(t): kwargs["table"] = table if _format: kwargs["as_format"] = ".{}".format(_format) - elif "table" in kwargs: + elif kwargs.get("table"): kwargs["table"] = urllib.parse.unquote_plus(kwargs["table"]) should_redirect = "/{}-{}".format(name, expected) - if "table" in kwargs: + if kwargs.get("table"): should_redirect += "/" + urllib.parse.quote_plus(kwargs["table"]) - if "pk_path" in kwargs: + if kwargs.get("pk_path"): should_redirect += "/" + kwargs["pk_path"] - if "as_format" in kwargs: + if kwargs.get("as_format"): should_redirect += kwargs["as_format"] - if "as_db" in kwargs: + if kwargs.get("as_db"): should_redirect += kwargs["as_db"] if ( @@ -246,7 +260,7 @@ async def as_csv(self, request, database, hash, **kwargs): response_or_template_contexts = await self.data( request, database, hash, **kwargs ) - if isinstance(response_or_template_contexts, response.HTTPResponse): + if isinstance(response_or_template_contexts, Response): return response_or_template_contexts else: data, _, _ = response_or_template_contexts @@ -282,13 +296,13 @@ async def stream_fn(r): if not first: data, _, _ = await self.data(request, database, hash, **kwargs) if first: - writer.writerow(headings) + await writer.writerow(headings) first = False next = data.get("next") for row in data["rows"]: if not expanded_columns: # Simple path - writer.writerow(row) + await writer.writerow(row) else: # Look for {"value": "label": } dicts and expand new_row = [] @@ -298,10 +312,10 @@ async def stream_fn(r): new_row.append(cell["label"]) else: new_row.append(cell) - writer.writerow(new_row) + await writer.writerow(new_row) except Exception as e: print("caught this", e) - r.write(str(e)) + await r.write(str(e)) return content_type = "text/plain; charset=utf-8" @@ -315,7 +329,7 @@ async def stream_fn(r): ) headers["Content-Disposition"] = disposition - return response.stream(stream_fn, headers=headers, content_type=content_type) + return AsgiStream(stream_fn, headers=headers, content_type=content_type) async def get_format(self, request, database, args): """ Determine the format of the response from the request, from URL @@ -363,7 +377,7 @@ async def view_get(self, request, database, hash, correct_hash_provided, **kwarg response_or_template_contexts = await self.data( request, database, hash, **kwargs ) - if isinstance(response_or_template_contexts, response.HTTPResponse): + if isinstance(response_or_template_contexts, Response): return response_or_template_contexts else: @@ -414,17 +428,11 @@ async def view_get(self, request, database, hash, correct_hash_provided, **kwarg if result is None: raise NotFound("No data") - response_args = { - "content_type": result.get("content_type", "text/plain"), - "status": result.get("status_code", 200), - } - - if type(result.get("body")) == bytes: - response_args["body_bytes"] = result.get("body") - else: - response_args["body"] = result.get("body") - - r = response.HTTPResponse(**response_args) + r = Response( + body=result.get("body"), + status=result.get("status_code", 200), + content_type=result.get("content_type", "text/plain"), + ) else: extras = {} if callable(extra_template_data): diff --git a/datasette/views/database.py b/datasette/views/database.py index a5b606f1c9..78af19c5c3 100644 --- a/datasette/views/database.py +++ b/datasette/views/database.py @@ -1,10 +1,9 @@ import os -from sanic import response - from datasette.utils import to_css_class, validate_sql_select +from datasette.utils.asgi import AsgiFileDownload -from .base import DataView, DatasetteError +from .base import DatasetteError, DataView class DatabaseView(DataView): @@ -79,8 +78,8 @@ async def view_get(self, request, database, hash, correct_hash_present, **kwargs if not db.path: raise DatasetteError("Cannot download database", status=404) filepath = db.path - return await response.file_stream( + return AsgiFileDownload( filepath, filename=os.path.basename(filepath), - mime_type="application/octet-stream", + content_type="application/octet-stream", ) diff --git a/datasette/views/index.py b/datasette/views/index.py index c9d15c3629..2c1c017a0a 100644 --- a/datasette/views/index.py +++ b/datasette/views/index.py @@ -1,9 +1,8 @@ import hashlib import json -from sanic import response - from datasette.utils import CustomJSONEncoder +from datasette.utils.asgi import Response from datasette.version import __version__ from .base import BaseView @@ -104,9 +103,9 @@ async def get(self, request, as_format): headers = {} if self.ds.cors: headers["Access-Control-Allow-Origin"] = "*" - return response.HTTPResponse( + return Response( json.dumps({db["name"]: db for db in databases}, cls=CustomJSONEncoder), - content_type="application/json", + content_type="application/json; charset=utf-8", headers=headers, ) else: diff --git a/datasette/views/special.py b/datasette/views/special.py index 91b577fc95..c4976bb225 100644 --- a/datasette/views/special.py +++ b/datasette/views/special.py @@ -1,5 +1,5 @@ import json -from sanic import response +from datasette.utils.asgi import Response from .base import BaseView @@ -17,8 +17,10 @@ async def get(self, request, as_format): headers = {} if self.ds.cors: headers["Access-Control-Allow-Origin"] = "*" - return response.HTTPResponse( - json.dumps(data), content_type="application/json", headers=headers + return Response( + json.dumps(data), + content_type="application/json; charset=utf-8", + headers=headers, ) else: diff --git a/datasette/views/table.py b/datasette/views/table.py index 14b8743ab9..06be5671de 100644 --- a/datasette/views/table.py +++ b/datasette/views/table.py @@ -3,13 +3,12 @@ import json import jinja2 -from sanic.exceptions import NotFound -from sanic.request import RequestParameters from datasette.plugins import pm from datasette.utils import ( CustomRow, QueryInterrupted, + RequestParameters, append_querystring, compound_keys_after_sql, escape_sqlite, @@ -24,6 +23,7 @@ urlsafe_components, value_as_boolean, ) +from datasette.utils.asgi import NotFound from datasette.filters import Filters from .base import DataView, DatasetteError, ureg @@ -219,8 +219,7 @@ async def data( if is_view: order_by = "" - # We roll our own query_string decoder because by default Sanic - # drops anything with an empty value e.g. ?name__exact= + # Ensure we don't drop anything with an empty value e.g. ?name__exact= args = RequestParameters( urllib.parse.parse_qs(request.query_string, keep_blank_values=True) ) diff --git a/pytest.ini b/pytest.ini index f2c8a6d297..aa292efc8d 100644 --- a/pytest.ini +++ b/pytest.ini @@ -4,7 +4,5 @@ filterwarnings= ignore:Using or importing the ABCs::jinja2 # https://bugs.launchpad.net/beautifulsoup/+bug/1778909 ignore:Using or importing the ABCs::bs4.element - # Sanic verify_ssl=True - ignore:verify_ssl is deprecated::sanic # Python 3.7 PendingDeprecationWarning: Task.current_task() ignore:.*current_task.*:PendingDeprecationWarning diff --git a/setup.py b/setup.py index 60c1bcc51b..f66d03da6e 100644 --- a/setup.py +++ b/setup.py @@ -37,17 +37,18 @@ def get_version(): author="Simon Willison", license="Apache License, Version 2.0", url="https://github.com/simonw/datasette", - packages=find_packages(exclude='tests'), + packages=find_packages(exclude="tests"), package_data={"datasette": ["templates/*.html"]}, include_package_data=True, install_requires=[ "click>=6.7", "click-default-group==1.2", - "Sanic==0.7.0", "Jinja2==2.10.1", "hupper==1.0", "pint==0.8.1", "pluggy>=0.12.0", + "uvicorn>=0.8.1", + "aiofiles==0.4.0", ], entry_points=""" [console_scripts] @@ -60,6 +61,7 @@ def get_version(): "pytest-asyncio==0.10.0", "aiohttp==3.5.3", "beautifulsoup4==4.6.1", + "asgiref==3.1.2", ] + maybe_black }, diff --git a/tests/fixtures.py b/tests/fixtures.py index 04ac3c6826..00140f50e5 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -1,5 +1,7 @@ from datasette.app import Datasette from datasette.utils import sqlite3 +from asgiref.testing import ApplicationCommunicator +from asgiref.sync import async_to_sync import itertools import json import os @@ -10,16 +12,82 @@ import string import tempfile import time +from urllib.parse import unquote -class TestClient: - def __init__(self, sanic_test_client): - self.sanic_test_client = sanic_test_client +class TestResponse: + def __init__(self, status, headers, body): + self.status = status + self.headers = headers + self.body = body + + @property + def json(self): + return json.loads(self.text) + + @property + def text(self): + return self.body.decode("utf8") - def get(self, path, allow_redirects=True): - return self.sanic_test_client.get( - path, allow_redirects=allow_redirects, gather_request=False + +class TestClient: + max_redirects = 5 + + def __init__(self, asgi_app): + self.asgi_app = asgi_app + + @async_to_sync + async def get(self, path, allow_redirects=True, redirect_count=0, method="GET"): + return await self._get(path, allow_redirects, redirect_count, method) + + async def _get(self, path, allow_redirects=True, redirect_count=0, method="GET"): + query_string = b"" + if "?" in path: + path, _, query_string = path.partition("?") + query_string = query_string.encode("utf8") + instance = ApplicationCommunicator( + self.asgi_app, + { + "type": "http", + "http_version": "1.0", + "method": method, + "path": unquote(path), + "raw_path": path.encode("ascii"), + "query_string": query_string, + "headers": [[b"host", b"localhost"]], + }, + ) + await instance.send_input({"type": "http.request"}) + # First message back should be response.start with headers and status + messages = [] + start = await instance.receive_output(2) + messages.append(start) + assert start["type"] == "http.response.start" + headers = dict( + [(k.decode("utf8"), v.decode("utf8")) for k, v in start["headers"]] ) + status = start["status"] + # Now loop until we run out of response.body + body = b"" + while True: + message = await instance.receive_output(2) + messages.append(message) + assert message["type"] == "http.response.body" + body += message["body"] + if not message.get("more_body"): + break + response = TestResponse(status, headers, body) + if allow_redirects and response.status in (301, 302): + assert ( + redirect_count < self.max_redirects + ), "Redirected {} times, max_redirects={}".format( + redirect_count, self.max_redirects + ) + location = response.headers["Location"] + return await self._get( + location, allow_redirects=True, redirect_count=redirect_count + 1 + ) + return response def make_app_client( @@ -32,6 +100,7 @@ def make_app_client( is_immutable=False, extra_databases=None, inspect_data=None, + static_mounts=None, ): with tempfile.TemporaryDirectory() as tmpdir: filepath = os.path.join(tmpdir, filename) @@ -73,9 +142,10 @@ def make_app_client( plugins_dir=plugins_dir, config=config, inspect_data=inspect_data, + static_mounts=static_mounts, ) ds.sqlite_functions.append(("sleep", 1, lambda n: time.sleep(float(n)))) - client = TestClient(ds.app().test_client) + client = TestClient(ds.app()) client.ds = ds yield client @@ -88,7 +158,7 @@ def app_client(): @pytest.fixture(scope="session") def app_client_no_files(): ds = Datasette([]) - client = TestClient(ds.app().test_client) + client = TestClient(ds.app()) client.ds = ds yield client diff --git a/tests/test_api.py b/tests/test_api.py index 5c1bff15da..a32ed5e3b6 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -22,6 +22,7 @@ def test_homepage(app_client): response = app_client.get("/.json") assert response.status == 200 + assert "application/json; charset=utf-8" == response.headers["content-type"] assert response.json.keys() == {"fixtures": 0}.keys() d = response.json["fixtures"] assert d["name"] == "fixtures" @@ -771,8 +772,8 @@ def test_paginate_tables_and_views(app_client, path, expected_rows, expected_pag fetched.extend(response.json["rows"]) path = response.json["next_url"] if path: - assert response.json["next"] assert urllib.parse.urlencode({"_next": response.json["next"]}) in path + path = path.replace("http://localhost", "") assert count < 30, "Possible infinite loop detected" assert expected_rows == len(fetched) @@ -812,6 +813,8 @@ def test_paginate_compound_keys(app_client): response = app_client.get(path) fetched.extend(response.json["rows"]) path = response.json["next_url"] + if path: + path = path.replace("http://localhost", "") assert page < 100 assert 1001 == len(fetched) assert 21 == page @@ -833,6 +836,8 @@ def test_paginate_compound_keys_with_extra_filters(app_client): response = app_client.get(path) fetched.extend(response.json["rows"]) path = response.json["next_url"] + if path: + path = path.replace("http://localhost", "") assert 2 == page expected = [r[3] for r in generate_compound_rows(1001) if "d" in r[3]] assert expected == [f["content"] for f in fetched] @@ -881,6 +886,8 @@ def test_sortable(app_client, query_string, sort_key, human_description_en): assert human_description_en == response.json["human_description_en"] fetched.extend(response.json["rows"]) path = response.json["next_url"] + if path: + path = path.replace("http://localhost", "") assert 5 == page expected = list(generate_sortable_rows(201)) expected.sort(key=sort_key) @@ -1191,6 +1198,7 @@ def test_plugins_json(app_client): def test_versions_json(app_client): response = app_client.get("/-/versions.json") assert "python" in response.json + assert "3.0" == response.json.get("asgi") assert "version" in response.json["python"] assert "full" in response.json["python"] assert "datasette" in response.json @@ -1236,6 +1244,8 @@ def test_page_size_matching_max_returned_rows( fetched.extend(response.json["rows"]) assert len(response.json["rows"]) in (1, 50) path = response.json["next_url"] + if path: + path = path.replace("http://localhost", "") assert 201 == len(fetched) diff --git a/tests/test_csv.py b/tests/test_csv.py index cf0e6732b0..c3cdc24156 100644 --- a/tests/test_csv.py +++ b/tests/test_csv.py @@ -46,7 +46,7 @@ def test_table_csv(app_client): response = app_client.get("/fixtures/simple_primary_key.csv") assert response.status == 200 assert not response.headers.get("Access-Control-Allow-Origin") - assert "text/plain; charset=utf-8" == response.headers["Content-Type"] + assert "text/plain; charset=utf-8" == response.headers["content-type"] assert EXPECTED_TABLE_CSV == response.text @@ -59,7 +59,7 @@ def test_table_csv_cors_headers(app_client_with_cors): def test_table_csv_with_labels(app_client): response = app_client.get("/fixtures/facetable.csv?_labels=1") assert response.status == 200 - assert "text/plain; charset=utf-8" == response.headers["Content-Type"] + assert "text/plain; charset=utf-8" == response.headers["content-type"] assert EXPECTED_TABLE_WITH_LABELS_CSV == response.text @@ -68,14 +68,14 @@ def test_custom_sql_csv(app_client): "/fixtures.csv?sql=select+content+from+simple_primary_key+limit+2" ) assert response.status == 200 - assert "text/plain; charset=utf-8" == response.headers["Content-Type"] + assert "text/plain; charset=utf-8" == response.headers["content-type"] assert EXPECTED_CUSTOM_CSV == response.text def test_table_csv_download(app_client): response = app_client.get("/fixtures/simple_primary_key.csv?_dl=1") assert response.status == 200 - assert "text/csv; charset=utf-8" == response.headers["Content-Type"] + assert "text/csv; charset=utf-8" == response.headers["content-type"] expected_disposition = 'attachment; filename="simple_primary_key.csv"' assert expected_disposition == response.headers["Content-Disposition"] diff --git a/tests/test_html.py b/tests/test_html.py index 6b673c1366..32fa2fe3dd 100644 --- a/tests/test_html.py +++ b/tests/test_html.py @@ -8,6 +8,7 @@ METADATA, ) import json +import pathlib import pytest import re import urllib.parse @@ -16,6 +17,7 @@ def test_homepage(app_client_two_attached_databases): response = app_client_two_attached_databases.get("/") assert response.status == 200 + assert "text/html; charset=utf-8" == response.headers["content-type"] soup = Soup(response.body, "html.parser") assert "Datasette Fixtures" == soup.find("h1").text assert ( @@ -44,6 +46,29 @@ def test_homepage(app_client_two_attached_databases): ] == table_links +def test_http_head(app_client): + response = app_client.get("/", method="HEAD") + assert response.status == 200 + + +def test_static(app_client): + response = app_client.get("/-/static/app2.css") + assert response.status == 404 + response = app_client.get("/-/static/app.css") + assert response.status == 200 + assert "text/css" == response.headers["content-type"] + + +def test_static_mounts(): + for client in make_app_client( + static_mounts=[("custom-static", str(pathlib.Path(__file__).parent))] + ): + response = client.get("/custom-static/test_html.py") + assert response.status == 200 + response = client.get("/custom-static/not_exists.py") + assert response.status == 404 + + def test_memory_database_page(): for client in make_app_client(memory=True): response = client.get("/:memory:") diff --git a/tests/test_utils.py b/tests/test_utils.py index a5f603e646..e9e722b817 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,11 +3,11 @@ """ from datasette import utils +from datasette.utils.asgi import Request from datasette.filters import Filters import json import os import pytest -from sanic.request import Request import sqlite3 import tempfile from unittest.mock import patch @@ -53,7 +53,7 @@ def test_urlsafe_components(path, expected): ], ) def test_path_with_added_args(path, added_args, expected): - request = Request(path.encode("utf8"), {}, "1.1", "GET", None) + request = Request.fake(path) actual = utils.path_with_added_args(request, added_args) assert expected == actual @@ -67,11 +67,11 @@ def test_path_with_added_args(path, added_args, expected): ], ) def test_path_with_removed_args(path, args, expected): - request = Request(path.encode("utf8"), {}, "1.1", "GET", None) + request = Request.fake(path) actual = utils.path_with_removed_args(request, args) assert expected == actual # Run the test again but this time use the path= argument - request = Request("/".encode("utf8"), {}, "1.1", "GET", None) + request = Request.fake("/") actual = utils.path_with_removed_args(request, args, path=path) assert expected == actual @@ -84,7 +84,7 @@ def test_path_with_removed_args(path, args, expected): ], ) def test_path_with_replaced_args(path, args, expected): - request = Request(path.encode("utf8"), {}, "1.1", "GET", None) + request = Request.fake(path) actual = utils.path_with_replaced_args(request, args) assert expected == actual @@ -363,7 +363,7 @@ def test_table_columns(): ], ) def test_path_with_format(path, format, extra_qs, expected): - request = Request(path.encode("utf8"), {}, "1.1", "GET", None) + request = Request.fake(path) actual = utils.path_with_format(request, format, extra_qs) assert expected == actual