Skip to content

Commit

Permalink
Merge pull request #1276 from josenavas/improve-sql-queues-system
Browse files Browse the repository at this point in the history
Add `Transaction` object
  • Loading branch information
squirrelo committed Jun 24, 2015
2 parents ff9f304 + af1cf92 commit 5f1c1b1
Show file tree
Hide file tree
Showing 2 changed files with 680 additions and 226 deletions.
357 changes: 299 additions & 58 deletions qiita_db/sql_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,60 +14,7 @@
:toctree: generated/
SQLConnectionHandler
Examples
--------
Transaction blocks are created by first creating a queue of SQL commands, then
adding commands to it. Finally, the execute command is called to execute the
entire queue of SQL commands. A single command is made up of SQL and sql_args.
SQL is the sql string in psycopg2 format with \%s markup, and sql_args is the
list or tuple of replacement items.
An example of a basic queue with two SQL commands in a single transaction:
from qiita_db.sql_connection import SQLConnectionHandler
conn_handler = SQLConnectionHandler # doctest: +SKIP
conn_handler.create_queue("example_queue") # doctest: +SKIP
conn_handler.add_to_queue(
"example_queue", "INSERT INTO qiita.qiita_user (email, name, password,"
"phone) VALUES (%s, %s, %s, %s)",
['insert@foo.bar', 'Toy', 'pass', '111-111-11112']) # doctest: +SKIP
conn_handler.add_to_queue(
"example_queue", "UPDATE qiita.qiita_user SET user_level_id = 1, "
"phone = '222-222-2221' WHERE email = %s",
['insert@foo.bar']) # doctest: +SKIP
conn_handler.execute_queue("example_queue") # doctest: +SKIP
conn_handler.execute_fetchall(
"SELECT * from qiita.qiita_user WHERE email = %s",
['insert@foo.bar']) # doctest: +SKIP
[['insert@foo.bar', 1, 'pass', 'Toy', None, None, '222-222-2221', None, None,
None]] # doctest: +SKIP
You can also use results from a previous command in the queue in a later
command. If an item in the queue depends on a previous sql command's output,
use {#} notation as a placeholder for the value. The \# must be the
position of the result, e.g. if you return two things you can use \{0\}
to reference the first and \{1\} to referece the second. The results list
will continue to grow until one of the references is reached, then it
will be cleaned out.
Modifying the previous example to show this ability (Note the RETURNING added
to the first SQL command):
from qiita_db.sql_connection import SQLConnectionHandler
conn_handler = SQLConnectionHandler # doctest: +SKIP
conn_handler.create_queue("example_queue") # doctest: +SKIP
conn_handler.add_to_queue(
"example_queue", "INSERT INTO qiita.qiita_user (email, name, password,"
"phone) VALUES (%s, %s, %s, %s) RETURNING email, password",
['insert@foo.bar', 'Toy', 'pass', '111-111-11112']) # doctest: +SKIP
conn_handler.add_to_queue(
"example_queue", "UPDATE qiita.qiita_user SET user_level_id = 1, "
"phone = '222-222-2221' WHERE email = %s AND password = %s",
['{0}', '{1}']) # doctest: +SKIP
conn_handler.execute_queue("example_queue") # doctest: +SKIP
conn_handler.execute_fetchall(
"SELECT * from qiita.qiita_user WHERE email = %s", ['insert@foo.bar'])
[['insert@foo.bar', 1, 'pass', 'Toy', None, None, '222-222-2221', None, None,
None]] # doctest: +SKIP
Transaction
"""
# -----------------------------------------------------------------------------
# Copyright (c) 2014--, The Qiita Development Team.
Expand All @@ -77,12 +24,10 @@
# The full license is in the file LICENSE, distributed with this software.
# -----------------------------------------------------------------------------

# DUMB mod so we can do the PR!

from __future__ import division
from contextlib import contextmanager
from itertools import chain
from functools import partial
from functools import partial, wraps
from tempfile import mktemp
from datetime import date, time, datetime
import re
Expand All @@ -91,7 +36,8 @@
OperationalError)
from psycopg2.extras import DictCursor
from psycopg2.extensions import (
ISOLATION_LEVEL_AUTOCOMMIT, ISOLATION_LEVEL_READ_COMMITTED)
ISOLATION_LEVEL_AUTOCOMMIT, ISOLATION_LEVEL_READ_COMMITTED,
TRANSACTION_STATUS_IDLE)

from qiita_core.qiita_settings import qiita_config

Expand Down Expand Up @@ -460,6 +406,7 @@ def execute_fetchall(self, sql, sql_args=None):

return result

# ---- Queue calls
def _check_queue_exists(self, queue_name):
"""Checks if queue `queue_name` exists in the handler
Expand Down Expand Up @@ -623,6 +570,7 @@ def get_temp_queue(self):
Returns
-------
str
The name of the queue
"""
Expand All @@ -633,3 +581,296 @@ def get_temp_queue(self):
self.create_queue(temp_queue_name)

return temp_queue_name


def _checker(func):
"""Decorator to check that methods are executed inside the context"""
@wraps(func)
def wrapper(self, *args, **kwargs):
if not self._is_inside_context:
raise RuntimeError(
"Operation not permitted. Transaction methods can only be "
"invoked within the context manager.")
return func(self, *args, **kwargs)
return wrapper


class Transaction(object):
"""A context manager that encapsulates a DB transaction
A transaction is defined by a series of consecutive queries that need to
be applied to the database as a single block.
Parameters
----------
name : str
Name of the transaction.
Notes
-----
When the execution leaves the context manager, any remaining queries in
the transaction will be executed and committed.
The Transaction methods can only be executed inside a context, if they are
invoked outside a context, a RuntimeError is raised.
"""

_regex = re.compile("^{(\d+):(\d+):(\d+)}$")

def __init__(self, name):
# The name is useful for debugging, since we can identify the
# failed queue in errors
self._name = name
self._queries = []
self._results = []
self.index = 0
self._conn_handler = SQLConnectionHandler()
self._is_inside_context = False

def __enter__(self):
self._is_inside_context = True
return self

def __exit__(self, exc_type, exc_value, traceback):
# We need to wrap the entire function in a try/finally because
# at the end of the function we need to set _is_inside_context to false
try:
status = self._conn_handler._connection.get_transaction_status()
if exc_type is not None:
# An exception occurred during the execution of the transaction
# Make sure that we leave the DB w/o any modification
self.rollback()
elif self._queries:
# There are still queries to be executed, execute them
# It is safe to use the execute method here, as internally is
# wrapped in a try/except and rollbacks in case of failure
self.execute()
elif status != TRANSACTION_STATUS_IDLE:
# There are no queries to be executed, however, the transaction
# is still not committed. Commit it so the changes are not lost
self.commit()
finally:
self._is_inside_context = False

def _raise_execution_error(self, sql, sql_args, error):
"""Rollbacks the current transaction and raises a useful error
The error message contains the name of the transaction, the failed
query, the arguments of the failed query and the error generated.
Raises
------
ValueError
"""
self.rollback()
raise ValueError(
"Error running SQL query in transaction %s:\n"
"Query: %s\nArguments: %s\nError: %s\n"
% (self._name, sql, str(sql_args), str(error)))

def _replace_placeholders(self, sql, sql_args):
"""Replaces the placeholder in `sql_args` with the actual value
Parameters
----------
sql : str
The SQL query
sql_args : list
The arguments of the SQL query
Returns
-------
tuple of (str, list of objects)
The input SQL query (unmodified) and the SQL arguments with the
placeholder (if any) substituted with the actual value of the
previous query
Raises
------
ValueError
If a placeholder does not match any previous result
If a placeholder points to a query that do not produce any result
"""
for pos, arg in enumerate(sql_args):
# Check if we have a placeholder
if isinstance(arg, str):
placeholder = self._regex.search(arg)
if placeholder:
# We do have a placeholder, get the indexes
# Query index
q_idx = int(placeholder.group(1))
# Row index
r_idx = int(placeholder.group(2))
# Value index
v_idx = int(placeholder.group(3))
try:
sql_args[pos] = self._results[q_idx][r_idx][v_idx]
except IndexError:
# A previous query that was expected to retrieve
# some data from the DB did not return as many
# values as expected
self._raise_execution_error(
sql, sql_args,
"The placeholder {%d:%d:%d} does not match to "
"any previous result"
% (q_idx, r_idx, v_idx))
except TypeError:
# The query that the placeholder is pointing to
# is not expected to retrieve any value
# (e.g. an INSERT w/o RETURNING clause)
self._raise_execution_error(
sql, sql_args,
"The placeholder {%d:%d:%d} is referring to "
"a SQL query that does not retrieve data"
% (q_idx, r_idx, v_idx))
return sql, sql_args

@_checker
def add(self, sql, sql_args=None, many=False):
"""Add an sql query to the transaction
If the current query needs a result of a previous query in the
transaction, a placeholder of the form '{#:#:#}' can be used. The first
number is the index of the previous SQL query in the transaction, the
second number is the row from that query result and the third number is
the index of the value within the query result row.
The placeholder will be replaced by the actual value at execution time.
Parameters
----------
sql : str
The sql query
sql_args : list of objects, optional
The arguments to the sql query
many : bool, optional
Whether or not we should add the query multiple times to the
transaction
Raises
------
TypeError
If `sql_args` is provided and is not a list
RuntimeError
If invoked outside a context
Notes
-----
If `many` is true, `sql_args` should be a list of lists, in which each
list of the list contains the parameters for one SQL query of the many.
Each element on the list is all the parameters for a single one of the
many queries added. The amount of SQL queries added to the list is
len(sql_args).
"""
if not many:
sql_args = [sql_args]

for args in sql_args:
if args:
if not isinstance(args, list):
raise TypeError("sql_args should be a list. Found %s"
% type(args))
else:
args = []
self._queries.append((sql, args))
self.index += 1

def _execute(self, commit=True):
"""Internal function that actually executes the transaction
The `execute` function exposed in the API wraps this one to make sure
that we catch any exception that happens in here and we rollback the
transaction
"""
with self._conn_handler.get_postgres_cursor() as cur:
for sql, sql_args in self._queries:
sql, sql_args = self._replace_placeholders(sql, sql_args)

# Execute the current SQL command
try:
cur.execute(sql, sql_args)
except Exception as e:
# We catch any exception as we want to make sure that we
# rollback every time that something went wrong
self._raise_execution_error(sql, sql_args, e)

try:
res = cur.fetchall()
except ProgrammingError as e:
# At this execution point, we don't know if the sql query
# that we executed should retrieve values from the database
# If the query was not supposed to retrieve any value
# (e.g. an INSERT without a RETURNING clause), it will
# raise a ProgrammingError. Otherwise it will just return
# an empty list
res = None
except PostgresError as e:
# Some other error happened during the execution of the
# query, so we need to rollback
self._raise_execution_error(sql, sql_args, e)

# Store the results of the current query
self._results.append(res)

# wipe out the already executed queries
self._queries = []

if commit:
self.commit()

return self._results

@_checker
def execute(self, commit=True):
"""Executes the transaction
Parameters
----------
commit : bool, optional
Whether the transaction should be committed or not. Defaults
to true.
Returns
-------
list of DictCursor
The results of all the SQL queries in the transaction
Raises
------
RuntimeError
If invoked outside a context
Notes
-----
If any exception occurs during the execution transaction, a rollback
is executed an no changes are reflected in the database
"""
try:
return self._execute(commit=commit)
except Exception:
self.rollback()
raise

@_checker
def commit(self):
"""Commits the transaction and reset the queries
Raises
------
RuntimeError
If invoked outside a context
"""
self._conn_handler._connection.commit()
# Reset the queries
self._queries = []

@_checker
def rollback(self):
"""Rollbacks the transaction and reset the queries
Raises
------
RuntimeError
If invoked outside a context
"""
self._conn_handler._connection.rollback()
# Reset the queries
self._queries = []

0 comments on commit 5f1c1b1

Please sign in to comment.