Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

db.execute_isolated_fn() method #2220

Merged
merged 6 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
61 changes: 50 additions & 11 deletions datasette/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,25 +159,47 @@ def count_params(params):
kwargs["count"] = count
return results

async def execute_isolated_fn(self, fn):
# Open a new connection just for the duration of this function
# blocking the write queue to avoid any writes occurring during it
if self.ds.executor is None:
# non-threaded mode
isolated_connection = self.connect(write=True)
try:
result = fn(isolated_connection)
finally:
isolated_connection.close()
try:
self._all_file_connections.remove(isolated_connection)
except ValueError:
# Was probably a memory connection
pass
return result
else:
# Threaded mode - send to write thread
return await self._send_to_write_thread(fn, isolated_connection=True)

async def execute_write_fn(self, fn, block=True):
if self.ds.executor is None:
# non-threaded mode
if self._write_connection is None:
self._write_connection = self.connect(write=True)
self.ds._prepare_connection(self._write_connection, self.name)
return fn(self._write_connection)
else:
return await self._send_to_write_thread(fn, block)

# threaded mode
task_id = uuid.uuid5(uuid.NAMESPACE_DNS, "datasette.io")
async def _send_to_write_thread(self, fn, block=True, isolated_connection=False):
if self._write_queue is None:
self._write_queue = queue.Queue()
if self._write_thread is None:
self._write_thread = threading.Thread(
target=self._execute_writes, daemon=True
)
self._write_thread.start()
task_id = uuid.uuid5(uuid.NAMESPACE_DNS, "datasette.io")
reply_queue = janus.Queue()
self._write_queue.put(WriteTask(fn, task_id, reply_queue))
self._write_queue.put(WriteTask(fn, task_id, reply_queue, isolated_connection))
if block:
result = await reply_queue.async_q.get()
if isinstance(result, Exception):
Expand All @@ -202,12 +224,28 @@ def _execute_writes(self):
if conn_exception is not None:
result = conn_exception
else:
try:
result = task.fn(conn)
except Exception as e:
sys.stderr.write("{}\n".format(e))
sys.stderr.flush()
result = e
if task.isolated_connection:
isolated_connection = self.connect(write=True)
try:
result = task.fn(isolated_connection)
except Exception as e:
sys.stderr.write("{}\n".format(e))
sys.stderr.flush()
result = e
finally:
isolated_connection.close()
try:
self._all_file_connections.remove(isolated_connection)
except ValueError:
# Was probably a memory connection
pass
else:
try:
result = task.fn(conn)
except Exception as e:
sys.stderr.write("{}\n".format(e))
sys.stderr.flush()
result = e
task.reply_queue.sync_q.put(result)

async def execute_fn(self, fn):
Expand Down Expand Up @@ -515,12 +553,13 @@ def __repr__(self):


class WriteTask:
__slots__ = ("fn", "task_id", "reply_queue")
__slots__ = ("fn", "task_id", "reply_queue", "isolated_connection")

def __init__(self, fn, task_id, reply_queue):
def __init__(self, fn, task_id, reply_queue, isolated_connection):
self.fn = fn
self.task_id = task_id
self.reply_queue = reply_queue
self.isolated_connection = isolated_connection


class QueryInterrupted(Exception):
Expand Down
19 changes: 18 additions & 1 deletion docs/internals.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1017,7 +1017,7 @@ Like ``execute_write()`` but uses the ``sqlite3`` `conn.executemany() <https://d
.. _database_execute_write_fn:

await db.execute_write_fn(fn, block=True)
------------------------------------------
-----------------------------------------

This method works like ``.execute_write()``, but instead of a SQL statement you give it a callable Python function. Your function will be queued up and then called when the write connection is available, passing that connection as the argument to the function.

Expand Down Expand Up @@ -1054,6 +1054,23 @@ If you see ``OperationalError: database table is locked`` errors you should chec

If you specify ``block=False`` the method becomes fire-and-forget, queueing your function to be executed and then allowing your code after the call to ``.execute_write_fn()`` to continue running while the underlying thread waits for an opportunity to run your function. A UUID representing the queued task will be returned. Any exceptions in your code will be silently swallowed.

.. _database_execute_isolated_fn:

await db.execute_isolated_fn(fn)
--------------------------------

This method works is similar to :ref:`execute_write_fn() <database_execute_write_fn>` but executes the provided function in an entirely isolated SQLite connection, which is opened, used and then closed again in a single call to this method.

The :ref:`prepare_connection() <plugin_hook_prepare_connection>` plugin hook is not executed against this connection.

This allows plugins to execute database operations that might conflict with how database connections are usually configured. For example, running a ``VACUUM`` operation while bypassing any restrictions placed by the `datasette-sqlite-authorizer <https://github.com/datasette/datasette-sqlite-authorizer>`__ plugin.

Plugins can also use this method to load potentially dangerous SQLite extensions, use them to perform an operation and then have them safely unloaded at the end of the call, without risk of exposing them to other connections.

Functions run using ``execute_isolated_fn()`` share the same queue as ``execute_write_fn()``, which guarantees that no writes can be executed at the same time as the isolated function is executing.

The return value of the function will be returned by this method. Any exceptions raised by the function will be raised out of the ``await`` line as well.

.. _database_close:

db.close()
Expand Down
8 changes: 3 additions & 5 deletions docs/metadata_doc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import textwrap
from yaml import safe_dump
from ruamel.yaml import round_trip_load
from ruamel.yaml import YAML


def metadata_example(cog, data=None, yaml=None):
Expand All @@ -11,8 +11,7 @@ def metadata_example(cog, data=None, yaml=None):
if yaml:
# dedent it first
yaml = textwrap.dedent(yaml).strip()
# round_trip_load to preserve key order:
data = round_trip_load(yaml)
data = YAML().load(yaml)
output_yaml = yaml
else:
output_yaml = safe_dump(data, sort_keys=False)
Expand All @@ -27,8 +26,7 @@ def metadata_example(cog, data=None, yaml=None):

def config_example(cog, input):
if type(input) is str:
# round_trip_load to preserve key order:
data = round_trip_load(input)
data = YAML().load(input)
output_yaml = input
else:
data = input
Expand Down
65 changes: 65 additions & 0 deletions tests/test_internals_database.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Tests for the datasette.database.Database class
"""
from datasette.app import Datasette
from datasette.database import Database, Results, MultipleValues
from datasette.utils.sqlite import sqlite3
from datasette.utils import Column
Expand Down Expand Up @@ -519,6 +520,70 @@ def write_fn(conn):
app_client.ds.remove_database("immutable-db")


def table_exists(conn, name):
return bool(
conn.execute(
"""
with all_tables as (
select name from sqlite_master where type = 'table'
union all
select name from temp.sqlite_master where type = 'table'
)
select 1 from all_tables where name = ?
""",
(name,),
).fetchall(),
)


def table_exists_checker(name):
def inner(conn):
return table_exists(conn, name)

return inner


@pytest.mark.asyncio
@pytest.mark.parametrize("disable_threads", (False, True))
async def test_execute_isolated(db, disable_threads):
if disable_threads:
ds = Datasette(memory=True, settings={"num_sql_threads": 0})
db = ds.add_database(Database(ds, memory_name="test_num_sql_threads_zero"))

# Create temporary table in write
await db.execute_write(
"create temporary table created_by_write (id integer primary key)"
)
# Should stay visible to write connection
assert await db.execute_write_fn(table_exists_checker("created_by_write"))

def create_shared_table(conn):
conn.execute("create table shared (id integer primary key)")
# And a temporary table that should not continue to exist
conn.execute(
"create temporary table created_by_isolated (id integer primary key)"
)
assert table_exists(conn, "created_by_isolated")
# Also confirm that created_by_write does not exist
return table_exists(conn, "created_by_write")

# shared should not exist
assert not await db.execute_fn(table_exists_checker("shared"))

# Create it using isolated
created_by_write_exists = await db.execute_isolated_fn(create_shared_table)
assert not created_by_write_exists

# shared SHOULD exist now
assert await db.execute_fn(table_exists_checker("shared"))

# created_by_isolated should not exist, even in write connection
assert not await db.execute_write_fn(table_exists_checker("created_by_isolated"))

# ... and a second call to isolated should not see that connection either
assert not await db.execute_isolated_fn(table_exists_checker("created_by_isolated"))


@pytest.mark.asyncio
async def test_mtime_ns(db):
assert isinstance(db.mtime_ns, int)
Expand Down