Skip to content

Commit

Permalink
Refactored run_sanity_checks to check_connection(conn), refs #674
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Feb 15, 2020
1 parent f1442a8 commit d3f2fad
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 21 deletions.
20 changes: 0 additions & 20 deletions datasette/app.py
Expand Up @@ -216,26 +216,6 @@ def add_database(self, name, db):
def remove_database(self, name):
self.databases.pop(name)

async def run_sanity_checks(self):
# Only one check right now, for Spatialite
for database_name, database in self.databases.items():
# Run pragma_info on every table
for table in await database.table_names():
try:
await self.execute(
database_name,
"PRAGMA table_info({});".format(escape_sqlite(table)),
)
except sqlite3.OperationalError as e:
if e.args[0] == "no such module: VirtualSpatialIndex":
raise click.UsageError(
"It looks like you're trying to load a SpatiaLite"
" database without first loading the SpatiaLite module."
"\n\nRead more: https://datasette.readthedocs.io/en/latest/spatialite.html"
)
else:
raise

def config(self, key):
return self._config.get(key, None)

Expand Down
23 changes: 22 additions & 1 deletion datasette/cli.py
Expand Up @@ -10,6 +10,9 @@
import sys
from .app import Datasette, DEFAULT_CONFIG, CONFIG_OPTIONS, pm
from .utils import (
check_connection,
ConnectionProblem,
SpatialiteConnectionProblem,
temporary_docker_directory,
value_as_boolean,
StaticMount,
Expand Down Expand Up @@ -369,7 +372,25 @@ def serve(
version_note=version_note,
)
# Run async sanity checks - but only if we're not under pytest
asyncio.get_event_loop().run_until_complete(ds.run_sanity_checks())
asyncio.get_event_loop().run_until_complete(check_databases(ds))

# Start the server
uvicorn.run(ds.app(), host=host, port=port, log_level="info")


async def check_databases(ds):
# Run check_connection against every connected database
# to confirm they are all usable
for database in list(ds.databases.values()):
try:
await database.execute_against_connection_in_thread(check_connection)
except SpatialiteConnectionProblem:
raise click.UsageError(
"It looks like you're trying to load a SpatiaLite"
" database without first loading the SpatiaLite module."
"\n\nRead more: https://datasette.readthedocs.io/en/latest/spatialite.html"
)
except ConnectionProblem as e:
raise click.UsageError(
"Connection to {} failed check: {}".format(database.path, str(e.args[0]))
)
25 changes: 25 additions & 0 deletions datasette/utils/__init__.py
Expand Up @@ -790,3 +790,28 @@ def get(self, name, default=None):
def getlist(self, name, default=None):
"Return full list"
return super().get(name, default)


class ConnectionProblem(Exception):
pass


class SpatialiteConnectionProblem(ConnectionProblem):
pass


def check_connection(conn):
tables = [
r[0]
for r in conn.execute(
"select name from sqlite_master where type='table'"
).fetchall()
]
for table in tables:
try:
conn.execute("PRAGMA table_info({});".format(escape_sqlite(table)),)
except sqlite3.OperationalError as e:
if e.args[0] == "no such module: VirtualSpatialIndex":
raise SpatialiteConnectionProblem(e)
else:
raise ConnectionProblem(e)
13 changes: 13 additions & 0 deletions tests/test_utils.py
Expand Up @@ -7,6 +7,7 @@
from datasette.filters import Filters
import json
import os
import pathlib
import pytest
import sqlite3
import tempfile
Expand Down Expand Up @@ -410,3 +411,15 @@ def test_format_bytes(bytes, expected):
)
def test_escape_fts(query, expected):
assert expected == utils.escape_fts(query)


def test_check_connection_spatialite_raises():
path = str(pathlib.Path(__file__).parent / "spatialite.db")
conn = sqlite3.connect(path)
with pytest.raises(utils.SpatialiteConnectionProblem):
utils.check_connection(conn)


def test_check_connection_passes():
conn = sqlite3.connect(":memory:")
utils.check_connection(conn)

0 comments on commit d3f2fad

Please sign in to comment.