diff --git a/qiita_db/sql_connection.py b/qiita_db/sql_connection.py index 49362f70b..97231fb08 100644 --- a/qiita_db/sql_connection.py +++ b/qiita_db/sql_connection.py @@ -14,60 +14,7 @@ :toctree: generated/ SQLConnectionHandler - -Examples --------- -Transaction blocks are created by first creating a queue of SQL commands, then -adding commands to it. Finally, the execute command is called to execute the -entire queue of SQL commands. A single command is made up of SQL and sql_args. -SQL is the sql string in psycopg2 format with \%s markup, and sql_args is the -list or tuple of replacement items. -An example of a basic queue with two SQL commands in a single transaction: - -from qiita_db.sql_connection import SQLConnectionHandler -conn_handler = SQLConnectionHandler # doctest: +SKIP -conn_handler.create_queue("example_queue") # doctest: +SKIP -conn_handler.add_to_queue( - "example_queue", "INSERT INTO qiita.qiita_user (email, name, password," - "phone) VALUES (%s, %s, %s, %s)", - ['insert@foo.bar', 'Toy', 'pass', '111-111-11112']) # doctest: +SKIP -conn_handler.add_to_queue( - "example_queue", "UPDATE qiita.qiita_user SET user_level_id = 1, " - "phone = '222-222-2221' WHERE email = %s", - ['insert@foo.bar']) # doctest: +SKIP -conn_handler.execute_queue("example_queue") # doctest: +SKIP -conn_handler.execute_fetchall( - "SELECT * from qiita.qiita_user WHERE email = %s", - ['insert@foo.bar']) # doctest: +SKIP -[['insert@foo.bar', 1, 'pass', 'Toy', None, None, '222-222-2221', None, None, - None]] # doctest: +SKIP - -You can also use results from a previous command in the queue in a later -command. If an item in the queue depends on a previous sql command's output, -use {#} notation as a placeholder for the value. The \# must be the -position of the result, e.g. if you return two things you can use \{0\} -to reference the first and \{1\} to referece the second. The results list -will continue to grow until one of the references is reached, then it -will be cleaned out. -Modifying the previous example to show this ability (Note the RETURNING added -to the first SQL command): - -from qiita_db.sql_connection import SQLConnectionHandler -conn_handler = SQLConnectionHandler # doctest: +SKIP -conn_handler.create_queue("example_queue") # doctest: +SKIP -conn_handler.add_to_queue( - "example_queue", "INSERT INTO qiita.qiita_user (email, name, password," - "phone) VALUES (%s, %s, %s, %s) RETURNING email, password", - ['insert@foo.bar', 'Toy', 'pass', '111-111-11112']) # doctest: +SKIP -conn_handler.add_to_queue( - "example_queue", "UPDATE qiita.qiita_user SET user_level_id = 1, " - "phone = '222-222-2221' WHERE email = %s AND password = %s", - ['{0}', '{1}']) # doctest: +SKIP -conn_handler.execute_queue("example_queue") # doctest: +SKIP -conn_handler.execute_fetchall( - "SELECT * from qiita.qiita_user WHERE email = %s", ['insert@foo.bar']) -[['insert@foo.bar', 1, 'pass', 'Toy', None, None, '222-222-2221', None, None, - None]] # doctest: +SKIP + Transaction """ # ----------------------------------------------------------------------------- # Copyright (c) 2014--, The Qiita Development Team. @@ -77,12 +24,10 @@ # The full license is in the file LICENSE, distributed with this software. # ----------------------------------------------------------------------------- -# DUMB mod so we can do the PR! - from __future__ import division from contextlib import contextmanager from itertools import chain -from functools import partial +from functools import partial, wraps from tempfile import mktemp from datetime import date, time, datetime import re @@ -91,7 +36,8 @@ OperationalError) from psycopg2.extras import DictCursor from psycopg2.extensions import ( - ISOLATION_LEVEL_AUTOCOMMIT, ISOLATION_LEVEL_READ_COMMITTED) + ISOLATION_LEVEL_AUTOCOMMIT, ISOLATION_LEVEL_READ_COMMITTED, + TRANSACTION_STATUS_IDLE) from qiita_core.qiita_settings import qiita_config @@ -460,6 +406,7 @@ def execute_fetchall(self, sql, sql_args=None): return result + # ---- Queue calls def _check_queue_exists(self, queue_name): """Checks if queue `queue_name` exists in the handler @@ -623,6 +570,7 @@ def get_temp_queue(self): Returns ------- + str The name of the queue """ @@ -633,3 +581,296 @@ def get_temp_queue(self): self.create_queue(temp_queue_name) return temp_queue_name + + +def _checker(func): + """Decorator to check that methods are executed inside the context""" + @wraps(func) + def wrapper(self, *args, **kwargs): + if not self._is_inside_context: + raise RuntimeError( + "Operation not permitted. Transaction methods can only be " + "invoked within the context manager.") + return func(self, *args, **kwargs) + return wrapper + + +class Transaction(object): + """A context manager that encapsulates a DB transaction + + A transaction is defined by a series of consecutive queries that need to + be applied to the database as a single block. + + Parameters + ---------- + name : str + Name of the transaction. + + Notes + ----- + When the execution leaves the context manager, any remaining queries in + the transaction will be executed and committed. + The Transaction methods can only be executed inside a context, if they are + invoked outside a context, a RuntimeError is raised. + """ + + _regex = re.compile("^{(\d+):(\d+):(\d+)}$") + + def __init__(self, name): + # The name is useful for debugging, since we can identify the + # failed queue in errors + self._name = name + self._queries = [] + self._results = [] + self.index = 0 + self._conn_handler = SQLConnectionHandler() + self._is_inside_context = False + + def __enter__(self): + self._is_inside_context = True + return self + + def __exit__(self, exc_type, exc_value, traceback): + # We need to wrap the entire function in a try/finally because + # at the end of the function we need to set _is_inside_context to false + try: + status = self._conn_handler._connection.get_transaction_status() + if exc_type is not None: + # An exception occurred during the execution of the transaction + # Make sure that we leave the DB w/o any modification + self.rollback() + elif self._queries: + # There are still queries to be executed, execute them + # It is safe to use the execute method here, as internally is + # wrapped in a try/except and rollbacks in case of failure + self.execute() + elif status != TRANSACTION_STATUS_IDLE: + # There are no queries to be executed, however, the transaction + # is still not committed. Commit it so the changes are not lost + self.commit() + finally: + self._is_inside_context = False + + def _raise_execution_error(self, sql, sql_args, error): + """Rollbacks the current transaction and raises a useful error + + The error message contains the name of the transaction, the failed + query, the arguments of the failed query and the error generated. + + Raises + ------ + ValueError + """ + self.rollback() + raise ValueError( + "Error running SQL query in transaction %s:\n" + "Query: %s\nArguments: %s\nError: %s\n" + % (self._name, sql, str(sql_args), str(error))) + + def _replace_placeholders(self, sql, sql_args): + """Replaces the placeholder in `sql_args` with the actual value + + Parameters + ---------- + sql : str + The SQL query + sql_args : list + The arguments of the SQL query + + Returns + ------- + tuple of (str, list of objects) + The input SQL query (unmodified) and the SQL arguments with the + placeholder (if any) substituted with the actual value of the + previous query + + Raises + ------ + ValueError + If a placeholder does not match any previous result + If a placeholder points to a query that do not produce any result + """ + for pos, arg in enumerate(sql_args): + # Check if we have a placeholder + if isinstance(arg, str): + placeholder = self._regex.search(arg) + if placeholder: + # We do have a placeholder, get the indexes + # Query index + q_idx = int(placeholder.group(1)) + # Row index + r_idx = int(placeholder.group(2)) + # Value index + v_idx = int(placeholder.group(3)) + try: + sql_args[pos] = self._results[q_idx][r_idx][v_idx] + except IndexError: + # A previous query that was expected to retrieve + # some data from the DB did not return as many + # values as expected + self._raise_execution_error( + sql, sql_args, + "The placeholder {%d:%d:%d} does not match to " + "any previous result" + % (q_idx, r_idx, v_idx)) + except TypeError: + # The query that the placeholder is pointing to + # is not expected to retrieve any value + # (e.g. an INSERT w/o RETURNING clause) + self._raise_execution_error( + sql, sql_args, + "The placeholder {%d:%d:%d} is referring to " + "a SQL query that does not retrieve data" + % (q_idx, r_idx, v_idx)) + return sql, sql_args + + @_checker + def add(self, sql, sql_args=None, many=False): + """Add an sql query to the transaction + + If the current query needs a result of a previous query in the + transaction, a placeholder of the form '{#:#:#}' can be used. The first + number is the index of the previous SQL query in the transaction, the + second number is the row from that query result and the third number is + the index of the value within the query result row. + The placeholder will be replaced by the actual value at execution time. + + Parameters + ---------- + sql : str + The sql query + sql_args : list of objects, optional + The arguments to the sql query + many : bool, optional + Whether or not we should add the query multiple times to the + transaction + + Raises + ------ + TypeError + If `sql_args` is provided and is not a list + RuntimeError + If invoked outside a context + + Notes + ----- + If `many` is true, `sql_args` should be a list of lists, in which each + list of the list contains the parameters for one SQL query of the many. + Each element on the list is all the parameters for a single one of the + many queries added. The amount of SQL queries added to the list is + len(sql_args). + """ + if not many: + sql_args = [sql_args] + + for args in sql_args: + if args: + if not isinstance(args, list): + raise TypeError("sql_args should be a list. Found %s" + % type(args)) + else: + args = [] + self._queries.append((sql, args)) + self.index += 1 + + def _execute(self, commit=True): + """Internal function that actually executes the transaction + + The `execute` function exposed in the API wraps this one to make sure + that we catch any exception that happens in here and we rollback the + transaction + """ + with self._conn_handler.get_postgres_cursor() as cur: + for sql, sql_args in self._queries: + sql, sql_args = self._replace_placeholders(sql, sql_args) + + # Execute the current SQL command + try: + cur.execute(sql, sql_args) + except Exception as e: + # We catch any exception as we want to make sure that we + # rollback every time that something went wrong + self._raise_execution_error(sql, sql_args, e) + + try: + res = cur.fetchall() + except ProgrammingError as e: + # At this execution point, we don't know if the sql query + # that we executed should retrieve values from the database + # If the query was not supposed to retrieve any value + # (e.g. an INSERT without a RETURNING clause), it will + # raise a ProgrammingError. Otherwise it will just return + # an empty list + res = None + except PostgresError as e: + # Some other error happened during the execution of the + # query, so we need to rollback + self._raise_execution_error(sql, sql_args, e) + + # Store the results of the current query + self._results.append(res) + + # wipe out the already executed queries + self._queries = [] + + if commit: + self.commit() + + return self._results + + @_checker + def execute(self, commit=True): + """Executes the transaction + + Parameters + ---------- + commit : bool, optional + Whether the transaction should be committed or not. Defaults + to true. + + Returns + ------- + list of DictCursor + The results of all the SQL queries in the transaction + + Raises + ------ + RuntimeError + If invoked outside a context + + Notes + ----- + If any exception occurs during the execution transaction, a rollback + is executed an no changes are reflected in the database + """ + try: + return self._execute(commit=commit) + except Exception: + self.rollback() + raise + + @_checker + def commit(self): + """Commits the transaction and reset the queries + + Raises + ------ + RuntimeError + If invoked outside a context + """ + self._conn_handler._connection.commit() + # Reset the queries + self._queries = [] + + @_checker + def rollback(self): + """Rollbacks the transaction and reset the queries + + Raises + ------ + RuntimeError + If invoked outside a context + """ + self._conn_handler._connection.rollback() + # Reset the queries + self._queries = [] diff --git a/qiita_db/test/test_sql_connection.py b/qiita_db/test/test_sql_connection.py index eedb0f08c..4cd551fb8 100644 --- a/qiita_db/test/test_sql_connection.py +++ b/qiita_db/test/test_sql_connection.py @@ -4,9 +4,10 @@ from psycopg2.extras import DictCursor from psycopg2 import connect from psycopg2.extensions import (ISOLATION_LEVEL_AUTOCOMMIT, - ISOLATION_LEVEL_READ_COMMITTED) + ISOLATION_LEVEL_READ_COMMITTED, + TRANSACTION_STATUS_IDLE) -from qiita_db.sql_connection import SQLConnectionHandler +from qiita_db.sql_connection import SQLConnectionHandler, Transaction from qiita_core.util import qiita_test_checker from qiita_core.qiita_settings import qiita_config @@ -18,7 +19,7 @@ @qiita_test_checker() -class TestConnHandler(TestCase): +class TestBase(TestCase): def setUp(self): # Add the test table to the database, so we can use it in the tests with connect(user=qiita_config.user, password=qiita_config.password, @@ -53,6 +54,8 @@ def _assert_sql_equal(self, exp): self.assertEqual(obs, exp) + +class TestConnHandler(TestBase): def test_init(self): obs = SQLConnectionHandler() self.assertEqual(obs.admin, 'no_admin') @@ -193,173 +196,383 @@ def test_execute_fetchall_with_sql_args(self): obs = self.conn_handler.execute_fetchall(sql, (True,)) self.assertEqual(obs, [['test1', True, 1], ['test2', True, 2]]) - def test_check_queue_exists(self): - self.assertFalse(self.conn_handler._check_queue_exists('foo')) - self.conn_handler.create_queue('foo') - self.assertTrue(self.conn_handler._check_queue_exists('foo')) - - def test_create_queue(self): - self.assertEqual(self.conn_handler.queues, {}) - self.conn_handler.create_queue("toy_queue") - self.assertEqual(self.conn_handler.queues, {'toy_queue': []}) - - def test_create_queue_error(self): - self.conn_handler.create_queue("test_queue") - with self.assertRaises(KeyError): - self.conn_handler.create_queue("test_queue") - - def test_list_queues(self): - self.assertEqual(self.conn_handler.list_queues(), []) - self.conn_handler.create_queue("test_queue") - self.assertEqual(self.conn_handler.list_queues(), ["test_queue"]) - - def test_add_to_queue(self): - self.conn_handler.create_queue("test_queue") - - sql1 = "INSERT INTO qiita.test_table (bool_column) VALUES (%s)" - sql_args1 = (True,) - self.conn_handler.add_to_queue("test_queue", sql1, sql_args1) - self.assertEqual(self.conn_handler.queues, - {"test_queue": [(sql1, sql_args1)]}) - - sql2 = "INSERT INTO qiita.test_table (int_column) VALUES (1)" - self.conn_handler.add_to_queue("test_queue", sql2) - self.assertEqual(self.conn_handler.queues, - {"test_queue": [(sql1, sql_args1), (sql2, None)]}) - - def test_add_to_queue_many(self): - self.conn_handler.create_queue("test_queue") - - sql = "INSERT INTO qiita.test_table (int_column) VALUES (%s)" - sql_args = [(1,), (2,), (3,)] - self.conn_handler.add_to_queue("test_queue", sql, sql_args, many=True) - self.assertEqual(self.conn_handler.queues, - {"test_queue": [(sql, (1,)), (sql, (2,)), - (sql, (3,))]}) - - def test_add_to_queue_error(self): - with self.assertRaises(KeyError): - self.conn_handler.add_to_queue("foo", "SELECT 42") - - def test_execute_queue(self): - self.conn_handler.create_queue("test_queue") - sql = """INSERT INTO qiita.test_table (str_column, int_column) - VALUES (%s, %s)""" - self.conn_handler.add_to_queue("test_queue", sql, ['test_insert', '2']) - sql = """UPDATE qiita.test_table - SET int_column = 20, bool_column = FALSE - WHERE str_column = %s""" - self.conn_handler.add_to_queue("test_queue", sql, ['test_insert']) - obs = self.conn_handler.execute_queue("test_queue") - self.assertEqual(obs, []) - self._assert_sql_equal([("test_insert", False, 20)]) - - def test_execute_queue_many(self): - sql = """INSERT INTO qiita.test_table (str_column, int_column) - VALUES (%s, %s)""" - sql_args = [('insert1', 1), ('insert2', 2), ('insert3', 3)] - - self.conn_handler.create_queue("test_queue") - self.conn_handler.add_to_queue("test_queue", sql, sql_args, many=True) - sql = """UPDATE qiita.test_table - SET int_column = 20, bool_column = FALSE - WHERE str_column = %s""" - self.conn_handler.add_to_queue("test_queue", sql, ['insert2']) - obs = self.conn_handler.execute_queue('test_queue') - self.assertEqual(obs, []) - - self._assert_sql_equal([('insert1', True, 1), ('insert3', True, 3), - ('insert2', False, 20)]) - - def test_execute_queue_last_return(self): - self.conn_handler.create_queue("test_queue") - sql = """INSERT INTO qiita.test_table (str_column, int_column) - VALUES (%s, %s)""" - self.conn_handler.add_to_queue("test_queue", sql, ['test_insert', '2']) - sql = """UPDATE qiita.test_table SET bool_column = FALSE - WHERE str_column = %s RETURNING int_column""" - self.conn_handler.add_to_queue("test_queue", sql, ['test_insert']) - obs = self.conn_handler.execute_queue("test_queue") - self.assertEqual(obs, [2]) - - def test_execute_queue_placeholders(self): - self.conn_handler.create_queue("test_queue") - sql = """INSERT INTO qiita.test_table (int_column) VALUES (%s) - RETURNING str_column""" - self.conn_handler.add_to_queue("test_queue", sql, (2,)) - sql = """UPDATE qiita.test_table SET bool_column = FALSE - WHERE str_column = %s""" - self.conn_handler.add_to_queue("test_queue", sql, ('{0}',)) - obs = self.conn_handler.execute_queue("test_queue") - self.assertEqual(obs, []) - self._assert_sql_equal([('foo', False, 2)]) - - def test_execute_queue_placeholders_regex(self): - self.conn_handler.create_queue("test_queue") - sql = """INSERT INTO qiita.test_table (int_column) - VALUES (%s) RETURNING str_column""" - self.conn_handler.add_to_queue("test_queue", sql, (1,)) - sql = """UPDATE qiita.test_table SET str_column = %s - WHERE str_column = %s""" - self.conn_handler.add_to_queue("test_queue", sql, ("", "{0}")) - obs = self.conn_handler.execute_queue("test_queue") - self.assertEqual(obs, []) - self._assert_sql_equal([('', True, 1)]) - - def test_execute_queue_fail(self): - self.conn_handler.create_queue("test_queue") - sql = """INSERT INTO qiita.test_table (int_column) VALUES (%s)""" - self.conn_handler.add_to_queue("test_queue", sql, (2,)) - sql = """UPDATE qiita.test_table SET bool_column = False - WHERE str_column = %s""" - self.conn_handler.add_to_queue("test_queue", sql, ('{0}',)) - - with self.assertRaises(ValueError): - self.conn_handler.execute_queue("test_queue") - # make sure rollback correctly - self._assert_sql_equal([]) - - def test_execute_queue_error(self): - self.conn_handler.create_queue("test_queue") - sql = """INSERT INTO qiita.test_table (str_column, int_column) - VALUES (%s, %s)""" - self.conn_handler.add_to_queue("test_queue", sql, ['test_insert', '2']) - sql = """UPDATE qiita.test_table - SET int_column = 20, bool_column = FALSE - WHERE str_column = %s""" - self.conn_handler.add_to_queue("test_queue", sql, ['test_insert']) - with self.assertRaises(KeyError): - self.conn_handler.execute_queue("oops!") - - def test_huge_queue(self): - self.conn_handler.create_queue("test_queue") - # Add a lof of inserts to the queue - sql = "INSERT INTO qiita.test_table (int_column) VALUES (%s)" - for x in range(1000): - self.conn_handler.add_to_queue("test_queue", sql, (x,)) - - # Make the queue fail with the last insert - sql = "INSERT INTO qiita.table_to_make (the_queue_to_fail) VALUES (1)" - self.conn_handler.add_to_queue("test_queue", sql) - - with self.assertRaises(ValueError): - self.conn_handler.execute_queue("test_queue") - - # make sure rollback correctly +class TestTransaction(TestBase): + def test_init(self): + with Transaction("test_init") as obs: + obs = Transaction("test_init") + self.assertEqual(obs._name, "test_init") + self.assertEqual(obs._queries, []) + self.assertEqual(obs._results, []) + self.assertEqual(obs.index, 0) + self.assertTrue( + isinstance(obs._conn_handler, SQLConnectionHandler)) + self.assertFalse(obs._is_inside_context) + + def test_replace_placeholders(self): + with Transaction("test_replace_placeholders") as trans: + trans._results = [[["res1", 1]], [["res2a", 2], ["res2b", 3]], + None, None, [["res5", 5]]] + sql = "SELECT 42" + obs_sql, obs_args = trans._replace_placeholders(sql, ["{0:0:0}"]) + self.assertEqual(obs_sql, sql) + self.assertEqual(obs_args, ["res1"]) + + obs_sql, obs_args = trans._replace_placeholders(sql, ["{1:0:0}"]) + self.assertEqual(obs_sql, sql) + self.assertEqual(obs_args, ["res2a"]) + + obs_sql, obs_args = trans._replace_placeholders(sql, ["{1:1:1}"]) + self.assertEqual(obs_sql, sql) + self.assertEqual(obs_args, [3]) + + obs_sql, obs_args = trans._replace_placeholders(sql, ["{4:0:0}"]) + self.assertEqual(obs_sql, sql) + self.assertEqual(obs_args, ["res5"]) + + obs_sql, obs_args = trans._replace_placeholders( + sql, ["foo", "{0:0:1}", "bar", "{1:0:1}"]) + self.assertEqual(obs_sql, sql) + self.assertEqual(obs_args, ["foo", 1, "bar", 2]) + + def test_replace_placeholders_index_error(self): + with Transaction("test_replace_placeholders_index_error") as trans: + trans._results = [[["res1", 1]], [["res2a", 2], ["res2b", 2]]] + + error_regex = ('The placeholder {0:0:3} does not match to any ' + 'previous result') + with self.assertRaisesRegexp(ValueError, error_regex): + trans._replace_placeholders("SELECT 42", ["{0:0:3}"]) + + error_regex = ('The placeholder {0:2:0} does not match to any ' + 'previous result') + with self.assertRaisesRegexp(ValueError, error_regex): + trans._replace_placeholders("SELECT 42", ["{0:2:0}"]) + + error_regex = ('The placeholder {2:0:0} does not match to any ' + 'previous result') + with self.assertRaisesRegexp(ValueError, error_regex): + trans._replace_placeholders("SELECT 42", ["{2:0:0}"]) + + def test_replace_placeholders_type_error(self): + with Transaction("test_replace_placeholders_type_error") as trans: + trans._results = [None] + + error_regex = ("The placeholder {0:0:0} is referring to a SQL " + "query that does not retrieve data") + with self.assertRaisesRegexp(ValueError, error_regex): + trans._replace_placeholders("SELECT 42", ["{0:0:0}"]) + + def test_add(self): + with Transaction("test_add") as trans: + self.assertEqual(trans._queries, []) + + sql1 = "INSERT INTO qiita.test_table (bool_column) VALUES (%s)" + args1 = [True] + trans.add(sql1, args1) + sql2 = "INSERT INTO qiita.test_table (int_column) VALUES (1)" + trans.add(sql2) + + exp = [(sql1, args1), (sql2, [])] + self.assertEqual(trans._queries, exp) + + # Remove queries so __exit__ doesn't try to execute it + trans._queries = [] + + def test_add_many(self): + with Transaction("test_add_many") as trans: + self.assertEqual(trans._queries, []) + + sql = "INSERT INTO qiita.test_table (int_column) VALUES (%s)" + args = [[1], [2], [3]] + trans.add(sql, args, many=True) + + exp = [(sql, [1]), (sql, [2]), (sql, [3])] + self.assertEqual(trans._queries, exp) + + def test_add_error(self): + with Transaction("test_add_error") as trans: + + with self.assertRaises(TypeError): + trans.add("SELECT 42", (1,)) + + with self.assertRaises(TypeError): + trans.add("SELECT 42", {'foo': 'bar'}) + + with self.assertRaises(TypeError): + trans.add("SELECT 42", [(1,), (1,)], many=True) + + def test_execute(self): + with Transaction("test_execute") as trans: + sql = """INSERT INTO qiita.test_table (str_column, int_column) + VALUES (%s, %s)""" + trans.add(sql, ["test_insert", 2]) + sql = """UPDATE qiita.test_table + SET int_column = %s, bool_column = %s + WHERE str_column = %s""" + trans.add(sql, [20, False, "test_insert"]) + obs = trans.execute() + self.assertEqual(obs, [None, None]) + self._assert_sql_equal([("test_insert", False, 20)]) + + def test_execute_many(self): + with Transaction("test_execute_many") as trans: + sql = """INSERT INTO qiita.test_table (str_column, int_column) + VALUES (%s, %s)""" + args = [['insert1', 1], ['insert2', 2], ['insert3', 3]] + trans.add(sql, args, many=True) + sql = """UPDATE qiita.test_table + SET int_column = %s, bool_column = %s + WHERE str_column = %s""" + trans.add(sql, [20, False, 'insert2']) + obs = trans.execute() + self.assertEqual(obs, [None, None, None, None]) + + self._assert_sql_equal([('insert1', True, 1), + ('insert3', True, 3), + ('insert2', False, 20)]) + + def test_execute_return(self): + with Transaction("test_execute_return") as trans: + sql = """INSERT INTO qiita.test_table (str_column, int_column) + VALUES (%s, %s) RETURNING str_column, int_column""" + trans.add(sql, ['test_insert', 2]) + sql = """UPDATE qiita.test_table SET bool_column = %s + WHERE str_column = %s RETURNING int_column""" + trans.add(sql, [False, 'test_insert']) + obs = trans.execute() + self.assertEqual(obs, [[['test_insert', 2]], [[2]]]) + + def test_execute_return_many(self): + with Transaction("test_execute_return_many") as trans: + sql = """INSERT INTO qiita.test_table (str_column, int_column) + VALUES (%s, %s) RETURNING str_column, int_column""" + args = [['insert1', 1], ['insert2', 2], ['insert3', 3]] + trans.add(sql, args, many=True) + sql = """UPDATE qiita.test_table SET bool_column = %s + WHERE str_column = %s""" + trans.add(sql, [False, 'insert2']) + sql = "SELECT * FROM qiita.test_table" + trans.add(sql) + obs = trans.execute() + exp = [[['insert1', 1]], # First query of the many query + [['insert2', 2]], # Second query of the many query + [['insert3', 3]], # Third query of the many query + None, # Update query + [['insert1', True, 1], # First result select + ['insert3', True, 3], # Second result select + ['insert2', False, 2]]] # Third result select + self.assertEqual(obs, exp) + + def test_execute_placeholders(self): + with Transaction("test_execute_placeholders") as trans: + sql = """INSERT INTO qiita.test_table (int_column) VALUES (%s) + RETURNING str_column""" + trans.add(sql, [2]) + sql = """UPDATE qiita.test_table SET str_column = %s + WHERE str_column = %s""" + trans.add(sql, ["", "{0:0:0}"]) + obs = trans.execute() + self.assertEqual(obs, [[['foo']], None]) + self._assert_sql_equal([('', True, 2)]) + + def test_execute_error_bad_placeholder(self): + with Transaction("test_execute_error_bad_placeholder") as trans: + sql = "INSERT INTO qiita.test_table (int_column) VALUES (%s)" + trans.add(sql, [2]) + sql = """UPDATE qiita.test_table SET bool_column = %s + WHERE str_column = %s""" + trans.add(sql, [False, "{0:0:0}"]) + + with self.assertRaises(ValueError): + trans.execute() + + # make sure rollback correctly + self._assert_sql_equal([]) + + def test_execute_error_no_result_placeholder(self): + with Transaction("test_execute_error_no_result_placeholder") as trans: + sql = "INSERT INTO qiita.test_table (int_column) VALUES (%s)" + trans.add(sql, [[1], [2], [3]], many=True) + sql = """SELECT str_column FROM qiita.test_table + WHERE int_column = %s""" + trans.add(sql, [4]) + sql = """UPDATE qiita.test_table SET bool_column = %s + WHERE str_column = %s""" + trans.add(sql, [False, "{3:0:0}"]) + + with self.assertRaises(ValueError): + trans.execute() + + # make sure rollback correctly + self._assert_sql_equal([]) + + def test_execute_huge_transaction(self): + with Transaction("test_execute_huge_transaction") as trans: + # Add a lot of inserts to the transaction + sql = "INSERT INTO qiita.test_table (int_column) VALUES (%s)" + for i in range(1000): + trans.add(sql, [i]) + # Add some updates to the transaction + sql = """UPDATE qiita.test_table SET bool_column = %s + WHERE int_column = %s""" + for i in range(500): + trans.add(sql, [False, i]) + # Make the transaction fail with the last insert + sql = """INSERT INTO qiita.table_to_make (the_trans_to_fail) + VALUES (1)""" + trans.add(sql) + + with self.assertRaises(ValueError): + trans.execute() + + # make sure rollback correctly + self._assert_sql_equal([]) + + def test_execute_commit_false(self): + with Transaction("test_execute_commit_false") as trans: + sql = """INSERT INTO qiita.test_table (str_column, int_column) + VALUES (%s, %s) RETURNING str_column, int_column""" + args = [['insert1', 1], ['insert2', 2], ['insert3', 3]] + trans.add(sql, args, many=True) + + obs = trans.execute(commit=False) + exp = [[['insert1', 1]], [['insert2', 2]], [['insert3', 3]]] + self.assertEqual(obs, exp) + + self._assert_sql_equal([]) + + trans.commit() + + self._assert_sql_equal([('insert1', True, 1), ('insert2', True, 2), + ('insert3', True, 3)]) + + def test_execute_commit_false_rollback(self): + with Transaction("test_execute_commit_false_rollback") as trans: + sql = """INSERT INTO qiita.test_table (str_column, int_column) + VALUES (%s, %s) RETURNING str_column, int_column""" + args = [['insert1', 1], ['insert2', 2], ['insert3', 3]] + trans.add(sql, args, many=True) + + obs = trans.execute(commit=False) + exp = [[['insert1', 1]], [['insert2', 2]], [['insert3', 3]]] + self.assertEqual(obs, exp) + + self._assert_sql_equal([]) + + trans.rollback() + + self._assert_sql_equal([]) + + def test_execute_commit_false_wipe_queries(self): + with Transaction("test_execute_commit_false_wipe_queries") as trans: + sql = """INSERT INTO qiita.test_table (str_column, int_column) + VALUES (%s, %s) RETURNING str_column, int_column""" + args = [['insert1', 1], ['insert2', 2], ['insert3', 3]] + trans.add(sql, args, many=True) + + obs = trans.execute(commit=False) + exp = [[['insert1', 1]], [['insert2', 2]], [['insert3', 3]]] + self.assertEqual(obs, exp) + + self._assert_sql_equal([]) + + sql = """UPDATE qiita.test_table SET bool_column = %s + WHERE str_column = %s""" + args = [False, 'insert2'] + trans.add(sql, args) + self.assertEqual(trans._queries, [(sql, args)]) + + trans.execute() + + self._assert_sql_equal([('insert1', True, 1), ('insert3', True, 3), + ('insert2', False, 2)]) + + def test_context_manager_rollback(self): + try: + with Transaction("test_context_manager_rollback") as trans: + sql = """INSERT INTO qiita.test_table (str_column, int_column) + VALUES (%s, %s) RETURNING str_column, int_column""" + args = [['insert1', 1], ['insert2', 2], ['insert3', 3]] + trans.add(sql, args, many=True) + + trans.execute(commit=False) + raise ValueError("Force exiting the context manager") + except ValueError: + pass self._assert_sql_equal([]) - - def test_get_temp_queue(self): - my_queue = self.conn_handler.get_temp_queue() - self.assertTrue(my_queue in self.conn_handler.list_queues()) - - self.conn_handler.add_to_queue(my_queue, - "SELECT * from qiita.qiita_user") - self.conn_handler.add_to_queue(my_queue, - "SELECT * from qiita.user_level") - self.conn_handler.execute_queue(my_queue) - - self.assertTrue(my_queue not in self.conn_handler.list_queues()) + self.assertEqual( + trans._conn_handler._connection.get_transaction_status(), + TRANSACTION_STATUS_IDLE) + + def test_context_manager_execute(self): + with Transaction("test_context_manager_no_commit") as trans: + sql = """INSERT INTO qiita.test_table (str_column, int_column) + VALUES (%s, %s) RETURNING str_column, int_column""" + args = [['insert1', 1], ['insert2', 2], ['insert3', 3]] + trans.add(sql, args, many=True) + self._assert_sql_equal([]) + + self._assert_sql_equal([('insert1', True, 1), ('insert2', True, 2), + ('insert3', True, 3)]) + self.assertEqual( + trans._conn_handler._connection.get_transaction_status(), + TRANSACTION_STATUS_IDLE) + + def test_context_manager_no_commit(self): + with Transaction("test_context_manager_no_commit") as trans: + sql = """INSERT INTO qiita.test_table (str_column, int_column) + VALUES (%s, %s) RETURNING str_column, int_column""" + args = [['insert1', 1], ['insert2', 2], ['insert3', 3]] + trans.add(sql, args, many=True) + + trans.execute(commit=False) + self._assert_sql_equal([]) + + self._assert_sql_equal([('insert1', True, 1), ('insert2', True, 2), + ('insert3', True, 3)]) + self.assertEqual( + trans._conn_handler._connection.get_transaction_status(), + TRANSACTION_STATUS_IDLE) + + def test_context_managet_checker(self): + t = Transaction("test_context_managet_checker") + + with self.assertRaises(RuntimeError): + t.add("SELECT 42") + + with self.assertRaises(RuntimeError): + t.execute() + + with self.assertRaises(RuntimeError): + t.commit() + + with self.assertRaises(RuntimeError): + t.rollback() + + with t: + t.add("SELECT 42") + + with self.assertRaises(RuntimeError): + t.execute() + + def test_index(self): + with Transaction("test_index") as trans: + self.assertEqual(trans.index, 0) + + trans.add("SELECT 42") + self.assertEqual(trans.index, 1) + + sql = "INSERT INTO qiita.test_table (int_column) VALUES (%s)" + args = [[1], [2], [3]] + trans.add(sql, args, many=True) + self.assertEqual(trans.index, 4) + + trans.execute(commit=False) + self.assertEqual(trans.index, 4) + + trans.add(sql, args, many=True) + self.assertEqual(trans.index, 7) if __name__ == "__main__": main()