diff --git a/datasette/app.py b/datasette/app.py index afe20258ca..eb5bd54101 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -3,7 +3,7 @@ from sanic.exceptions import NotFound from sanic.views import HTTPMethodView from sanic.request import RequestParameters -from jinja2 import Environment, FileSystemLoader +from jinja2 import Environment, FileSystemLoader, ChoiceLoader, PrefixLoader import re import sqlite3 from pathlib import Path @@ -43,15 +43,13 @@ class RenderMixin(HTTPMethodView): - def render(self, template, **context): + def render(self, templates, **context): return response.html( - self.jinja_env.get_template(template).render(**context) + self.jinja_env.select_template(templates).render(**context) ) class BaseView(RenderMixin): - template = None - def __init__(self, datasette): self.ds = datasette self.files = datasette.files @@ -166,6 +164,9 @@ def sql_operation_in_thread(): self.executor, sql_operation_in_thread ) + def get_templates(self, database, table=None): + assert NotImplemented + async def get(self, request, db_name, **kwargs): name, hash, should_redirect = self.resolve_db_name(db_name, **kwargs) if should_redirect: @@ -179,8 +180,8 @@ async def view_get(self, request, name, hash, **kwargs): as_json = False extra_template_data = {} start = time.time() - template = self.template status_code = 200 + templates = [] try: response_or_template_contexts = await self.data( request, name, hash, **kwargs @@ -188,7 +189,7 @@ async def view_get(self, request, name, hash, **kwargs): if isinstance(response_or_template_contexts, response.HTTPResponse): return response_or_template_contexts else: - data, extra_template_data = response_or_template_contexts + data, extra_template_data, templates = response_or_template_contexts except (sqlite3.OperationalError, InvalidSql) as e: data = { 'ok': False, @@ -196,8 +197,8 @@ async def view_get(self, request, name, hash, **kwargs): 'database': name, 'database_hash': hash, } - template = 'error.html' status_code = 400 + templates = ['error.html'] end = time.time() data['query_ms'] = (end - start) * 1000 for key in ('source', 'source_url', 'license', 'license_url'): @@ -246,7 +247,7 @@ async def view_get(self, request, name, hash, **kwargs): } } r = self.render( - template, + templates, **context, ) r.status = status_code @@ -300,7 +301,7 @@ async def get(self, request, as_json): ) else: return self.render( - 'index.html', + ['index.html'], databases=databases, metadata=self.ds.metadata, datasette_version=__version__, @@ -314,7 +315,6 @@ async def favicon(request): class DatabaseView(BaseView): - template = 'database.html' re_named_parameter = re.compile(':([a-zA-Z0-9_]+)') async def data(self, request, name, hash): @@ -331,7 +331,7 @@ async def data(self, request, name, hash): }, { 'database_hash': hash, 'show_hidden': request.args.get('_show_hidden'), - } + }, ('database-{}.html'.format(to_css_class(name)), 'database.html') async def custom_sql(self, request, name, hash): params = request.raw_args @@ -370,7 +370,7 @@ async def custom_sql(self, request, name, hash): 'database_hash': hash, 'custom_sql': True, 'named_parameter_values': named_parameter_values, - } + }, ('database-{}.html'.format(to_css_class(name)), 'database.html') class DatabaseDownload(BaseView): @@ -464,8 +464,6 @@ async def make_display_rows(self, database, database_hash, table, rows, display_ class TableView(RowTableShared): - template = 'table.html' - async def data(self, request, name, hash, table): table = urllib.parse.unquote_plus(table) pks = await self.pks_for_table(name, table) @@ -681,12 +679,13 @@ async def extra_template(): }, 'next': next_value and str(next_value) or None, 'next_url': next_url, - }, extra_template + }, extra_template, ( + 'table-{}-{}.html'.format(to_css_class(name), to_css_class(table)), + 'table.html' + ) class RowView(RowTableShared): - template = 'row.html' - async def data(self, request, name, hash, table, pk_path): table = urllib.parse.unquote_plus(table) pk_values = compound_pks_from_path(pk_path) @@ -733,7 +732,10 @@ async def template_data(): if 'foreign_key_tables' in (request.raw_args.get('_extras') or '').split(','): data['foreign_key_tables'] = await self.foreign_key_tables(name, table, pk_values) - return data, template_data + return data, template_data, ( + 'row-{}-{}.html'.format(to_css_class(name), to_css_class(table)), + 'row.html' + ) async def foreign_key_tables(self, name, table, pk_values): if len(pk_values) != 1: @@ -893,12 +895,19 @@ def inspect(self): def app(self): app = Sanic(__name__) - template_paths = [] + default_templates = str(app_root / 'datasette' / 'templates') if self.template_dir: - template_paths.append(self.template_dir) - template_paths.append(str(app_root / 'datasette' / 'templates')) + template_loader = ChoiceLoader([ + FileSystemLoader([self.template_dir, default_templates]), + # Support {% extends "default:table.html" %}: + PrefixLoader({ + 'default': FileSystemLoader(default_templates), + }, delimiter=':') + ]) + else: + template_loader = FileSystemLoader(default_templates) self.jinja_env = Environment( - loader=FileSystemLoader(template_paths), + loader=template_loader, autoescape=True, ) self.jinja_env.filters['escape_css_string'] = escape_css_string