From d3f2fade88984dc6157b2ff69c24aa5a070f9716 Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Sat, 15 Feb 2020 09:56:48 -0800 Subject: [PATCH] Refactored run_sanity_checks to check_connection(conn), refs #674 --- datasette/app.py | 20 -------------------- datasette/cli.py | 23 ++++++++++++++++++++++- datasette/utils/__init__.py | 25 +++++++++++++++++++++++++ tests/test_utils.py | 13 +++++++++++++ 4 files changed, 60 insertions(+), 21 deletions(-) diff --git a/datasette/app.py b/datasette/app.py index 053646e61a..b643c102a0 100644 --- a/datasette/app.py +++ b/datasette/app.py @@ -216,26 +216,6 @@ def add_database(self, name, db): def remove_database(self, name): self.databases.pop(name) - async def run_sanity_checks(self): - # Only one check right now, for Spatialite - for database_name, database in self.databases.items(): - # Run pragma_info on every table - for table in await database.table_names(): - try: - await self.execute( - database_name, - "PRAGMA table_info({});".format(escape_sqlite(table)), - ) - except sqlite3.OperationalError as e: - if e.args[0] == "no such module: VirtualSpatialIndex": - raise click.UsageError( - "It looks like you're trying to load a SpatiaLite" - " database without first loading the SpatiaLite module." - "\n\nRead more: https://datasette.readthedocs.io/en/latest/spatialite.html" - ) - else: - raise - def config(self, key): return self._config.get(key, None) diff --git a/datasette/cli.py b/datasette/cli.py index 67d00fbb2a..8d724c42f6 100644 --- a/datasette/cli.py +++ b/datasette/cli.py @@ -10,6 +10,9 @@ import sys from .app import Datasette, DEFAULT_CONFIG, CONFIG_OPTIONS, pm from .utils import ( + check_connection, + ConnectionProblem, + SpatialiteConnectionProblem, temporary_docker_directory, value_as_boolean, StaticMount, @@ -369,7 +372,25 @@ def serve( version_note=version_note, ) # Run async sanity checks - but only if we're not under pytest - asyncio.get_event_loop().run_until_complete(ds.run_sanity_checks()) + asyncio.get_event_loop().run_until_complete(check_databases(ds)) # Start the server uvicorn.run(ds.app(), host=host, port=port, log_level="info") + + +async def check_databases(ds): + # Run check_connection against every connected database + # to confirm they are all usable + for database in list(ds.databases.values()): + try: + await database.execute_against_connection_in_thread(check_connection) + except SpatialiteConnectionProblem: + raise click.UsageError( + "It looks like you're trying to load a SpatiaLite" + " database without first loading the SpatiaLite module." + "\n\nRead more: https://datasette.readthedocs.io/en/latest/spatialite.html" + ) + except ConnectionProblem as e: + raise click.UsageError( + "Connection to {} failed check: {}".format(database.path, str(e.args[0])) + ) diff --git a/datasette/utils/__init__.py b/datasette/utils/__init__.py index facbc4de49..79ac8e0218 100644 --- a/datasette/utils/__init__.py +++ b/datasette/utils/__init__.py @@ -790,3 +790,28 @@ def get(self, name, default=None): def getlist(self, name, default=None): "Return full list" return super().get(name, default) + + +class ConnectionProblem(Exception): + pass + + +class SpatialiteConnectionProblem(ConnectionProblem): + pass + + +def check_connection(conn): + tables = [ + r[0] + for r in conn.execute( + "select name from sqlite_master where type='table'" + ).fetchall() + ] + for table in tables: + try: + conn.execute("PRAGMA table_info({});".format(escape_sqlite(table)),) + except sqlite3.OperationalError as e: + if e.args[0] == "no such module: VirtualSpatialIndex": + raise SpatialiteConnectionProblem(e) + else: + raise ConnectionProblem(e) diff --git a/tests/test_utils.py b/tests/test_utils.py index 8f00629190..df42ea5aea 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -7,6 +7,7 @@ from datasette.filters import Filters import json import os +import pathlib import pytest import sqlite3 import tempfile @@ -410,3 +411,15 @@ def test_format_bytes(bytes, expected): ) def test_escape_fts(query, expected): assert expected == utils.escape_fts(query) + + +def test_check_connection_spatialite_raises(): + path = str(pathlib.Path(__file__).parent / "spatialite.db") + conn = sqlite3.connect(path) + with pytest.raises(utils.SpatialiteConnectionProblem): + utils.check_connection(conn) + + +def test_check_connection_passes(): + conn = sqlite3.connect(":memory:") + utils.check_connection(conn)