From a9909c29ccac771c23c2ef22b89d10697b5256b9 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Fri, 15 Nov 2019 14:49:45 -0800 Subject: [PATCH] Move .execute() from Datasette to Database Refs #569 - I split this change out from #579 --- datasette/app.py | 90 ++++++--------------------- datasette/database.py | 137 +++++++++++++++++++++++++++++++----------- 2 files changed, 121 insertions(+), 106 deletions(-) diff --git a/datasette/app.py b/datasette/app.py index 02fcf30303..119d0e1993 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -24,13 +24,11 @@ from .utils import ( QueryInterrupted, - Results, escape_css_string, escape_sqlite, get_plugins, module_from_path, sqlite3, - sqlite_timelimit, to_css_class, ) from .utils.asgi import ( @@ -42,13 +40,12 @@ asgi_send_json, asgi_send_redirect, ) -from .tracer import trace, AsgiTracer +from .tracer import AsgiTracer from .plugins import pm, DEFAULT_PLUGINS from .version import __version__ app_root = Path(__file__).parent.parent -connections = threading.local() MEMORY = object() ConfigOption = collections.namedtuple("ConfigOption", ("name", "default", "help")) @@ -336,6 +333,25 @@ def prepare_connection(self, conn): # pylint: disable=no-member pm.hook.prepare_connection(conn=conn) + async def execute( + self, + db_name, + sql, + params=None, + truncate=False, + custom_time_limit=None, + page_size=None, + log_sql_errors=True, + ): + return await self.databases[db_name].execute( + sql, + params=params, + truncate=truncate, + custom_time_limit=custom_time_limit, + page_size=page_size, + log_sql_errors=log_sql_errors, + ) + async def expand_foreign_keys(self, database, table, column, values): "Returns dict mapping (column, value) -> label" labeled_fks = {} @@ -477,72 +493,6 @@ def table_metadata(self, database, table): .get(table, {}) ) - async def execute_against_connection_in_thread(self, db_name, fn): - def in_thread(): - conn = getattr(connections, db_name, None) - if not conn: - conn = self.databases[db_name].connect() - self.prepare_connection(conn) - setattr(connections, db_name, conn) - return fn(conn) - - return await asyncio.get_event_loop().run_in_executor(self.executor, in_thread) - - async def execute( - self, - db_name, - sql, - params=None, - truncate=False, - custom_time_limit=None, - page_size=None, - log_sql_errors=True, - ): - """Executes sql against db_name in a thread""" - page_size = page_size or self.page_size - - def sql_operation_in_thread(conn): - time_limit_ms = self.sql_time_limit_ms - if custom_time_limit and custom_time_limit < time_limit_ms: - time_limit_ms = custom_time_limit - - with sqlite_timelimit(conn, time_limit_ms): - try: - cursor = conn.cursor() - cursor.execute(sql, params or {}) - max_returned_rows = self.max_returned_rows - if max_returned_rows == page_size: - max_returned_rows += 1 - if max_returned_rows and truncate: - rows = cursor.fetchmany(max_returned_rows + 1) - truncated = len(rows) > max_returned_rows - rows = rows[:max_returned_rows] - else: - rows = cursor.fetchall() - truncated = False - except sqlite3.OperationalError as e: - if e.args == ("interrupted",): - raise QueryInterrupted(e, sql, params) - if log_sql_errors: - print( - "ERROR: conn={}, sql = {}, params = {}: {}".format( - conn, repr(sql), params, e - ) - ) - raise - - if truncate: - return Results(rows, truncated, cursor.description) - - else: - return Results(rows, False, cursor.description) - - with trace("sql", database=db_name, sql=sql.strip(), params=params): - results = await self.execute_against_connection_in_thread( - db_name, sql_operation_in_thread - ) - return results - def register_renderers(self): """ Register output renderers which output data in custom formats. """ # Built-in renderers diff --git a/datasette/database.py b/datasette/database.py index 3a1cea9432..9a8ae4d434 100644 --- a/datasette/database.py +++ b/datasette/database.py @@ -1,17 +1,25 @@ +import asyncio +import contextlib from pathlib import Path +import threading +from .tracer import trace from .utils import ( QueryInterrupted, + Results, detect_fts, detect_primary_keys, detect_spatialite, get_all_foreign_keys, get_outbound_foreign_keys, + sqlite_timelimit, sqlite3, table_columns, ) from .inspect import inspect_hash +connections = threading.local() + class Database: def __init__(self, ds, path=None, is_mutable=False, is_memory=False): @@ -45,6 +53,73 @@ def connect(self): "file:{}?{}".format(self.path, qs), uri=True, check_same_thread=False ) + async def execute_against_connection_in_thread(self, fn): + def in_thread(): + conn = getattr(connections, self.name, None) + if not conn: + conn = self.connect() + self.ds.prepare_connection(conn) + setattr(connections, self.name, conn) + return fn(conn) + + return await asyncio.get_event_loop().run_in_executor( + self.ds.executor, in_thread + ) + + async def execute( + self, + sql, + params=None, + truncate=False, + custom_time_limit=None, + page_size=None, + log_sql_errors=True, + ): + """Executes sql against db_name in a thread""" + page_size = page_size or self.ds.page_size + + def sql_operation_in_thread(conn): + time_limit_ms = self.ds.sql_time_limit_ms + if custom_time_limit and custom_time_limit < time_limit_ms: + time_limit_ms = custom_time_limit + + with sqlite_timelimit(conn, time_limit_ms): + try: + cursor = conn.cursor() + cursor.execute(sql, params or {}) + max_returned_rows = self.ds.max_returned_rows + if max_returned_rows == page_size: + max_returned_rows += 1 + if max_returned_rows and truncate: + rows = cursor.fetchmany(max_returned_rows + 1) + truncated = len(rows) > max_returned_rows + rows = rows[:max_returned_rows] + else: + rows = cursor.fetchall() + truncated = False + except sqlite3.OperationalError as e: + if e.args == ("interrupted",): + raise QueryInterrupted(e, sql, params) + if log_sql_errors: + print( + "ERROR: conn={}, sql = {}, params = {}: {}".format( + conn, repr(sql), params, e + ) + ) + raise + + if truncate: + return Results(rows, truncated, cursor.description) + + else: + return Results(rows, False, cursor.description) + + with trace("sql", database=self.name, sql=sql.strip(), params=params): + results = await self.execute_against_connection_in_thread( + sql_operation_in_thread + ) + return results + @property def size(self): if self.is_memory: @@ -62,8 +137,7 @@ async def table_counts(self, limit=10): for table in await self.table_names(): try: table_count = ( - await self.ds.execute( - self.name, + await self.execute( "select count(*) from [{}]".format(table), custom_time_limit=limit, ) @@ -89,32 +163,30 @@ def name(self): return Path(self.path).stem async def table_exists(self, table): - results = await self.ds.execute( - self.name, - "select 1 from sqlite_master where type='table' and name=?", - params=(table,), + results = await self.execute( + "select 1 from sqlite_master where type='table' and name=?", params=(table,) ) return bool(results.rows) async def table_names(self): - results = await self.ds.execute( - self.name, "select name from sqlite_master where type='table'" + results = await self.execute( + "select name from sqlite_master where type='table'" ) return [r[0] for r in results.rows] async def table_columns(self, table): - return await self.ds.execute_against_connection_in_thread( - self.name, lambda conn: table_columns(conn, table) + return await self.execute_against_connection_in_thread( + lambda conn: table_columns(conn, table) ) async def primary_keys(self, table): - return await self.ds.execute_against_connection_in_thread( - self.name, lambda conn: detect_primary_keys(conn, table) + return await self.execute_against_connection_in_thread( + lambda conn: detect_primary_keys(conn, table) ) async def fts_table(self, table): - return await self.ds.execute_against_connection_in_thread( - self.name, lambda conn: detect_fts(conn, table) + return await self.execute_against_connection_in_thread( + lambda conn: detect_fts(conn, table) ) async def label_column_for_table(self, table): @@ -124,8 +196,8 @@ async def label_column_for_table(self, table): if explicit_label_column: return explicit_label_column # If a table has two columns, one of which is ID, then label_column is the other one - column_names = await self.ds.execute_against_connection_in_thread( - self.name, lambda conn: table_columns(conn, table) + column_names = await self.execute_against_connection_in_thread( + lambda conn: table_columns(conn, table) ) # Is there a name or title column? name_or_title = [c for c in column_names if c in ("name", "title")] @@ -141,8 +213,8 @@ async def label_column_for_table(self, table): return None async def foreign_keys_for_table(self, table): - return await self.ds.execute_against_connection_in_thread( - self.name, lambda conn: get_outbound_foreign_keys(conn, table) + return await self.execute_against_connection_in_thread( + lambda conn: get_outbound_foreign_keys(conn, table) ) async def hidden_table_names(self): @@ -150,18 +222,17 @@ async def hidden_table_names(self): hidden_tables = [ r[0] for r in ( - await self.ds.execute( - self.name, + await self.execute( """ select name from sqlite_master where rootpage = 0 and sql like '%VIRTUAL TABLE%USING FTS%' - """, + """ ) ).rows ] - has_spatialite = await self.ds.execute_against_connection_in_thread( - self.name, detect_spatialite + has_spatialite = await self.execute_against_connection_in_thread( + detect_spatialite ) if has_spatialite: # Also hide Spatialite internal tables @@ -178,13 +249,12 @@ async def hidden_table_names(self): ] + [ r[0] for r in ( - await self.ds.execute( - self.name, + await self.execute( """ select name from sqlite_master where name like "idx_%" and type = "table" - """, + """ ) ).rows ] @@ -207,25 +277,20 @@ async def hidden_table_names(self): return hidden_tables async def view_names(self): - results = await self.ds.execute( - self.name, "select name from sqlite_master where type='view'" - ) + results = await self.execute("select name from sqlite_master where type='view'") return [r[0] for r in results.rows] async def get_all_foreign_keys(self): - return await self.ds.execute_against_connection_in_thread( - self.name, get_all_foreign_keys - ) + return await self.execute_against_connection_in_thread(get_all_foreign_keys) async def get_outbound_foreign_keys(self, table): - return await self.ds.execute_against_connection_in_thread( - self.name, lambda conn: get_outbound_foreign_keys(conn, table) + return await self.execute_against_connection_in_thread( + lambda conn: get_outbound_foreign_keys(conn, table) ) async def get_table_definition(self, table, type_="table"): table_definition_rows = list( - await self.ds.execute( - self.name, + await self.execute( "select sql from sqlite_master where name = :n and type=:t", {"n": table, "t": type_}, )