Skip to content

Commit

Permalink
await_me_maybe utility function
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Sep 2, 2020
1 parent f65c456 commit 26b2922
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 34 deletions.
41 changes: 9 additions & 32 deletions datasette/app.py
Expand Up @@ -45,6 +45,7 @@

from .utils import (
async_call_with_supported_arguments,
await_me_maybe,
call_with_supported_arguments,
display_actor,
escape_css_string,
Expand Down Expand Up @@ -312,10 +313,7 @@ def __init__(

async def invoke_startup(self):
for hook in pm.hook.startup(datasette=self):
if callable(hook):
hook = hook()
if asyncio.iscoroutine(hook):
hook = await hook
await await_me_maybe(hook)

def sign(self, value, namespace="default"):
return URLSafeSerializer(self._secret, namespace).dumps(value)
Expand Down Expand Up @@ -400,10 +398,7 @@ async def get_canned_queries(self, database_name, actor):
for more_queries in pm.hook.canned_queries(
datasette=self, database=database_name, actor=actor,
):
if callable(more_queries):
more_queries = more_queries()
if asyncio.iscoroutine(more_queries):
more_queries = await more_queries
more_queries = await await_me_maybe(more_queries)
queries.update(more_queries or {})
# Fix any {"name": "select ..."} queries to be {"name": {"sql": "select ..."}}
for key in queries:
Expand Down Expand Up @@ -475,10 +470,7 @@ async def permission_allowed(self, actor, action, resource=None, default=False):
for check in pm.hook.permission_allowed(
datasette=self, actor=actor, action=action, resource=resource,
):
if callable(check):
check = check()
if asyncio.iscoroutine(check):
check = await check
check = await await_me_maybe(check)
if check is not None:
result = check
used_default = False
Expand Down Expand Up @@ -718,10 +710,7 @@ async def render_template(
request=request,
datasette=self,
):
if callable(extra_script):
extra_script = extra_script()
if asyncio.iscoroutine(extra_script):
extra_script = await extra_script
extra_script = await await_me_maybe(extra_script)
body_scripts.append(Markup(extra_script))

extra_template_vars = {}
Expand All @@ -735,10 +724,7 @@ async def render_template(
request=request,
datasette=self,
):
if callable(extra_vars):
extra_vars = extra_vars()
if asyncio.iscoroutine(extra_vars):
extra_vars = await extra_vars
extra_vars = await await_me_maybe(extra_vars)
assert isinstance(extra_vars, dict), "extra_vars is of type {}".format(
type(extra_vars)
)
Expand Down Expand Up @@ -786,10 +772,7 @@ async def _asset_urls(self, key, template, context, request, view_name):
request=request,
datasette=self,
):
if callable(hook):
hook = hook()
if asyncio.iscoroutine(hook):
hook = await hook
hook = await await_me_maybe(hook)
collected.extend(hook)
collected.extend(self.metadata(key) or [])
output = []
Expand Down Expand Up @@ -981,10 +964,7 @@ async def route_path(self, scope, receive, send, path):
default_actor = scope.get("actor") or None
actor = None
for actor in pm.hook.actor_from_request(datasette=self.ds, request=request):
if callable(actor):
actor = actor()
if asyncio.iscoroutine(actor):
actor = await actor
actor = await await_me_maybe(actor)
if actor:
break
scope_modifications["actor"] = actor or default_actor
Expand Down Expand Up @@ -1079,10 +1059,7 @@ async def handle_500(self, request, send, exception):
for custom_response in pm.hook.forbidden(
datasette=self.ds, request=request, message=message
):
if callable(custom_response):
custom_response = custom_response()
if asyncio.iscoroutine(custom_response):
custom_response = await custom_response
custom_response = await await_me_maybe(custom_response)
if custom_response is not None:
await custom_response.asgi_send(send)
return
Expand Down
9 changes: 9 additions & 0 deletions datasette/utils/__init__.py
@@ -1,3 +1,4 @@
import asyncio
from contextlib import contextmanager
from collections import OrderedDict
import base64
Expand Down Expand Up @@ -51,6 +52,14 @@
"""


async def await_me_maybe(value):
if callable(value):
value = value()
if asyncio.iscoroutine(value):
value = await value
return value


def urlsafe_components(token):
"Splits token on commas and URL decodes each component"
return [urllib.parse.unquote_plus(b) for b in token.split(",")]
Expand Down
4 changes: 2 additions & 2 deletions datasette/views/base.py
Expand Up @@ -12,6 +12,7 @@
from datasette.plugins import pm
from datasette.database import QueryInterrupted
from datasette.utils import (
await_me_maybe,
InvalidSql,
LimitedWriter,
call_with_supported_arguments,
Expand Down Expand Up @@ -492,8 +493,7 @@ async def view_get(self, request, database, hash, correct_hash_provided, **kwarg
request=request,
view_name=self.name,
)
if asyncio.iscoroutine(it_can_render):
it_can_render = await it_can_render
it_can_render = await await_me_maybe(it_can_render)
if it_can_render:
renderers[key] = path_with_format(
request, key, {**url_labels_extra}
Expand Down

0 comments on commit 26b2922

Please sign in to comment.