diff --git a/qiita_core/util.py b/qiita_core/util.py index a887be093..735e4190e 100644 --- a/qiita_core/util.py +++ b/qiita_core/util.py @@ -7,6 +7,7 @@ # ----------------------------------------------------------------------------- from smtplib import SMTP, SMTP_SSL, SMTPException from future import standard_library +from functools import wraps from qiita_core.qiita_settings import qiita_config from qiita_db.sql_connection import SQLConnectionHandler @@ -83,3 +84,13 @@ def tearDown(self): return DecoratedClass return class_modifier + + +def execute_as_transaction(func): + """Decorator to make a method execute inside a transaction""" + @wraps(func) + def wrapper(*args, **kwargs): + from qiita_db.sql_connection import TRN + with TRN: + return func(*args, **kwargs) + return wrapper diff --git a/qiita_db/analysis.py b/qiita_db/analysis.py index a62cf73c7..a8a698910 100644 --- a/qiita_db/analysis.py +++ b/qiita_db/analysis.py @@ -28,7 +28,7 @@ from skbio.util import find_duplicates from qiita_core.exceptions import IncompetentQiitaDeveloperError -from .sql_connection import SQLConnectionHandler +from .sql_connection import TRN from .base import QiitaStatusObject from .data import ProcessedData from .study import Study @@ -76,13 +76,13 @@ class Analysis(QiitaStatusObject): _portal_table = "analysis_portal" _analysis_id_column = 'analysis_id' - def _lock_check(self, conn_handler): + def _lock_check(self): """Raises QiitaDBStatusError if analysis is not in_progress""" if self.check_status({"queued", "running", "public", "completed", "error"}): raise QiitaDBStatusError("Analysis is locked!") - def _status_setter_checks(self, conn_handler): + def _status_setter_checks(self): r"""Perform a check to make sure not setting status away from public """ if self.check_status({"public"}): @@ -102,14 +102,15 @@ def get_by_status(cls, status): set of int All analyses in the database with the given status """ - conn_handler = SQLConnectionHandler() - sql = """SELECT analysis_id FROM qiita.{0} - JOIN qiita.{0}_status USING (analysis_status_id) - JOIN qiita.analysis_portal USING (analysis_id) - JOIN qiita.portal_type USING (portal_type_id) - WHERE status = %s AND portal = %s""".format(cls._table) - return {x[0] for x in conn_handler.execute_fetchall( - sql, (status, qiita_config.portal))} + with TRN: + sql = """SELECT analysis_id + FROM qiita.{0} + JOIN qiita.{0}_status USING (analysis_status_id) + JOIN qiita.analysis_portal USING (analysis_id) + JOIN qiita.portal_type USING (portal_type_id) + WHERE status = %s AND portal = %s""".format(cls._table) + TRN.add(sql, [status, qiita_config.portal]) + return set(TRN.execute_fetchflatten()) @classmethod def create(cls, owner, name, description, parent=None, from_default=False): @@ -129,61 +130,66 @@ def create(cls, owner, name, description, parent=None, from_default=False): If True, use the default analysis to populate selected samples. Default False. """ - queue = "create_analysis" - conn_handler = SQLConnectionHandler() - conn_handler.create_queue(queue) - # TODO after demo: if exists() - # Needed since issue #292 exists - status_id = convert_to_id('in_construction', 'analysis_status', - 'status') - portal_id = convert_to_id(qiita_config.portal, 'portal_type', - 'portal') - if from_default: - # insert analysis and move samples into that new analysis - dflt_id = owner.default_analysis - sql = """INSERT INTO qiita.{0} - (email, name, description, analysis_status_id) - VALUES (%s, %s, %s, %s) - RETURNING analysis_id""".format(cls._table) - conn_handler.add_to_queue(queue, sql, ( - owner.id, name, description, status_id)) - # MAGIC NUMBER 3: command selection step - # needed so we skip the sample selection step - sql = """INSERT INTO qiita.analysis_workflow - (analysis_id, step) VALUES (%s, %s) - RETURNING %s""" - conn_handler.add_to_queue(queue, sql, ['{0}', 3, '{0}']) - sql = """UPDATE qiita.analysis_sample - SET analysis_id = %s - WHERE analysis_id = %s RETURNING %s""" - conn_handler.add_to_queue(queue, sql, ['{0}', dflt_id, '{0}']) - else: - # insert analysis information into table as "in construction" - sql = """INSERT INTO qiita.{0} - (email, name, description, analysis_status_id) - VALUES (%s, %s, %s, %s) - RETURNING analysis_id""".format(cls._table) - conn_handler.add_to_queue( - queue, sql, (owner.id, name, description, status_id)) - - # Add to both QIITA and given portal (if not QIITA) - sql = """INSERT INTO qiita.analysis_portal - (analysis_id, portal_type_id) - VALUES (%s, %s) RETURNING %s""" - args = [['{0}', portal_id, '{0}']] - if qiita_config.portal != 'QIITA': - qp_id = convert_to_id('QIITA', 'portal_type', 'portal') - args.append(['{0}', qp_id, '{0}']) - conn_handler.add_to_queue(queue, sql, args, many=True) - - # add parent if necessary - if parent: - sql = ("INSERT INTO qiita.analysis_chain (parent_id, child_id) " - "VALUES (%s, %s) RETURNING child_id") - conn_handler.add_to_queue(queue, sql, [parent.id, '{0}']) - - a_id = conn_handler.execute_queue(queue)[0] - return cls(a_id) + with TRN: + status_id = convert_to_id('in_construction', + 'analysis_status', 'status') + portal_id = convert_to_id(qiita_config.portal, 'portal_type', + 'portal') + + if from_default: + # insert analysis and move samples into that new analysis + dflt_id = owner.default_analysis + + # Get the analysis id placeholder + a_id_idx = TRN.index + a_id_ph = "{%s:0:0}" % a_id_idx + sql = """INSERT INTO qiita.{0} + (email, name, description, analysis_status_id) + VALUES (%s, %s, %s, %s) + RETURNING analysis_id""".format(cls._table) + TRN.add(sql, [owner.id, name, description, status_id]) + # MAGIC NUMBER 3: command selection step + # needed so we skip the sample selection step + sql = """INSERT INTO qiita.analysis_workflow + (analysis_id, step) + VALUES (%s, %s)""" + TRN.add(sql, [a_id_ph, 3]) + + sql = """UPDATE qiita.analysis_sample + SET analysis_id = %s + WHERE analysis_id = %s""" + TRN.add(sql, [a_id_ph, dflt_id]) + else: + # Get the analysis id placeholder + a_id_idx = TRN.index + a_id_ph = "{%s:0:0}" % a_id_idx + # insert analysis information into table as "in construction" + sql = """INSERT INTO qiita.{0} + (email, name, description, analysis_status_id) + VALUES (%s, %s, %s, %s) + RETURNING analysis_id""".format(cls._table) + TRN.add(sql, [owner.id, name, description, status_id]) + + # Add to both QIITA and given portal (if not QIITA) + sql = """INSERT INTO qiita.analysis_portal + (analysis_id, portal_type_id) + VALUES (%s, %s)""" + args = [[a_id_ph, portal_id]] + + if qiita_config.portal != 'QIITA': + qp_id = convert_to_id('QIITA', 'portal_type', 'portal') + args.append([a_id_ph, qp_id]) + TRN.add(sql, args, many=True) + + # add parent if necessary + if parent: + sql = """INSERT INTO qiita.analysis_chain + (parent_id, child_id) + VALUES (%s, %s)""" + TRN.add(sql, [parent.id, a_id_ph]) + + # The analysis id is in the `a_id_idx` query, first row, first elem + return cls(TRN.execute()[a_id_idx][0][0]) @classmethod def delete(cls, _id): @@ -199,41 +205,39 @@ def delete(cls, _id): QiitaDBUnknownIDError If the analysis id doesn't exist """ - # check if the analysis exist - if not cls.exists(_id): - raise QiitaDBUnknownIDError(_id, "analysis") - - queue = "delete_analysis_%d" % _id - conn_handler = SQLConnectionHandler() - conn_handler.create_queue(queue) + with TRN: + # check if the analysis exist + if not cls.exists(_id): + raise QiitaDBUnknownIDError(_id, "analysis") - sql = ("DELETE FROM qiita.analysis_filepath WHERE " - "{0} = {1}".format(cls._analysis_id_column, _id)) - conn_handler.add_to_queue(queue, sql) + sql = "DELETE FROM qiita.analysis_filepath WHERE {0} = %s".format( + cls._analysis_id_column) + args = [_id] + TRN.add(sql, args) - sql = ("DELETE FROM qiita.analysis_workflow WHERE " - "{0} = {1}".format(cls._analysis_id_column, _id)) - conn_handler.add_to_queue(queue, sql) + sql = "DELETE FROM qiita.analysis_workflow WHERE {0} = %s".format( + cls._analysis_id_column) + TRN.add(sql, args) - sql = ("DELETE FROM qiita.analysis_portal WHERE " - "{0} = {1}".format(cls._analysis_id_column, _id)) - conn_handler.add_to_queue(queue, sql) + sql = "DELETE FROM qiita.analysis_portal WHERE {0} = %s".format( + cls._analysis_id_column) + TRN.add(sql, args) - sql = ("DELETE FROM qiita.analysis_sample WHERE " - "{0} = {1}".format(cls._analysis_id_column, _id)) - conn_handler.add_to_queue(queue, sql) + sql = "DELETE FROM qiita.analysis_sample WHERE {0} = %s".format( + cls._analysis_id_column) + TRN.add(sql, args) - sql = ("DELETE FROM qiita.collection_analysis WHERE " - "{0} = {1}".format(cls._analysis_id_column, _id)) - conn_handler.add_to_queue(queue, sql) + sql = """DELETE FROM qiita.collection_analysis + WHERE {0} = %s""".format(cls._analysis_id_column) + TRN.add(sql, args) - # TODO: issue #1176 + # TODO: issue #1176 - sql = ("DELETE FROM qiita.{0} WHERE " - "{1} = {2}".format(cls._table, cls._analysis_id_column, _id)) - conn_handler.add_to_queue(queue, sql) + sql = """DELETE FROM qiita.{0} WHERE {1} = %s""".format( + cls._table, cls._analysis_id_column) + TRN.add(sql, args) - conn_handler.execute_queue(queue) + TRN.execute() @classmethod def exists(cls, analysis_id): @@ -249,15 +253,17 @@ def exists(cls, analysis_id): bool True if exists, false otherwise. """ - conn_handler = SQLConnectionHandler() - - return conn_handler.execute_fetchone( - """SELECT EXISTS(SELECT * FROM qiita.{0} - JOIN qiita.analysis_portal USING (analysis_id) - JOIN qiita.portal_type USING (portal_type_id) - WHERE {1}=%s AND portal=%s)""".format( - cls._table, cls._analysis_id_column), - (analysis_id, qiita_config.portal))[0] + with TRN: + sql = """SELECT EXISTS( + SELECT * + FROM qiita.{0} + JOIN qiita.analysis_portal USING (analysis_id) + JOIN qiita.portal_type USING (portal_type_id) + WHERE {1}=%s + AND portal=%s)""".format(cls._table, + cls._analysis_id_column) + TRN.add(sql, [analysis_id, qiita_config.portal]) + return TRN.execute_fetchlast() # ---- Properties ---- @property @@ -269,10 +275,11 @@ def owner(self): str Name of the Analysis """ - conn_handler = SQLConnectionHandler() - sql = ("SELECT email FROM qiita.{0} WHERE " - "analysis_id = %s".format(self._table)) - return conn_handler.execute_fetchone(sql, (self._id, ))[0] + with TRN: + sql = "SELECT email FROM qiita.{0} WHERE analysis_id = %s".format( + self._table) + TRN.add(sql, [self._id]) + return TRN.execute_fetchlast() @property def name(self): @@ -283,10 +290,11 @@ def name(self): str Name of the Analysis """ - conn_handler = SQLConnectionHandler() - sql = ("SELECT name FROM qiita.{0} WHERE " - "analysis_id = %s".format(self._table)) - return conn_handler.execute_fetchone(sql, (self._id, ))[0] + with TRN: + sql = "SELECT name FROM qiita.{0} WHERE analysis_id = %s".format( + self._table) + TRN.add(sql, [self._id]) + return TRN.execute_fetchlast() @property def _portals(self): @@ -297,11 +305,13 @@ def _portals(self): str Name of the portal """ - conn_handler = SQLConnectionHandler() - sql = """SELECT portal FROM qiita.analysis_portal - JOIN qiita.portal_type USING (portal_type_id) - WHERE analysis_id = %s""".format(self._table) - return [x[0] for x in conn_handler.execute_fetchall(sql, [self._id])] + with TRN: + sql = """SELECT portal + FROM qiita.analysis_portal + JOIN qiita.portal_type USING (portal_type_id) + WHERE analysis_id = %s""".format(self._table) + TRN.add(sql, [self._id]) + return TRN.execute_fetchflatten() @property def timestamp(self): @@ -312,18 +322,20 @@ def timestamp(self): datetime Timestamp of the Analysis """ - conn_handler = SQLConnectionHandler() - sql = ("SELECT timestamp FROM qiita.{0} WHERE " - "analysis_id = %s".format(self._table)) - return conn_handler.execute_fetchone(sql, (self._id, ))[0] + with TRN: + sql = """SELECT timestamp FROM qiita.{0} + WHERE analysis_id = %s""".format(self._table) + TRN.add(sql, [self._id]) + return TRN.execute_fetchlast() @property def description(self): """Returns the description of the analysis""" - conn_handler = SQLConnectionHandler() - sql = ("SELECT description FROM qiita.{0} WHERE " - "analysis_id = %s".format(self._table)) - return conn_handler.execute_fetchone(sql, (self._id, ))[0] + with TRN: + sql = """SELECT description FROM qiita.{0} + WHERE analysis_id = %s""".format(self._table) + TRN.add(sql, [self._id]) + return TRN.execute_fetchlast() @description.setter def description(self, description): @@ -339,11 +351,12 @@ def description(self, description): QiitaDBStatusError Analysis is public """ - conn_handler = SQLConnectionHandler() - self._lock_check(conn_handler) - sql = ("UPDATE qiita.{0} SET description = %s WHERE " - "analysis_id = %s".format(self._table)) - conn_handler.execute(sql, (description, self._id)) + with TRN: + self._lock_check() + sql = """UPDATE qiita.{0} SET description = %s + WHERE analysis_id = %s""".format(self._table) + TRN.add(sql, [description, self._id]) + TRN.execute() @property def samples(self): @@ -354,14 +367,17 @@ def samples(self): dict Format is {processed_data_id: [sample_id, sample_id, ...]} """ - conn_handler = SQLConnectionHandler() - sql = ("SELECT processed_data_id, sample_id FROM qiita.analysis_sample" - " WHERE analysis_id = %s ORDER BY processed_data_id") - ret_samples = defaultdict(list) - # turn into dict of samples keyed to processed_data_id - for pid, sample in conn_handler.execute_fetchall(sql, (self._id, )): - ret_samples[pid].append(sample) - return ret_samples + with TRN: + sql = """SELECT processed_data_id, sample_id + FROM qiita.analysis_sample + WHERE analysis_id = %s + ORDER BY processed_data_id""" + ret_samples = defaultdict(list) + TRN.add(sql, [self._id]) + # turn into dict of samples keyed to processed_data_id + for pid, sample in TRN.execute_fetchindex(): + ret_samples[pid].append(sample) + return ret_samples @property def dropped_samples(self): @@ -372,25 +388,28 @@ def dropped_samples(self): dict of sets Format is {processed_data_id: {sample_id, sample_id, ...}, ...} """ - bioms = self.biom_tables - if not bioms: - return {} - - # get all samples selected for the analysis, converting lists to - # sets for fast searching. Overhead less this way for large analyses - all_samples = {k: set(v) for k, v in viewitems(self.samples)} - - for biom, filepath in viewitems(bioms): - table = load_table(filepath) - # remove the samples from the sets as they are found in the table - proc_data_ids = set(sample['Processed_id'] - for sample in table.metadata()) - ids = set(table.ids()) - for proc_data_id in proc_data_ids: - all_samples[proc_data_id] = all_samples[proc_data_id] - ids - - # what's left are unprocessed samples, so return - return all_samples + with TRN: + bioms = self.biom_tables + if not bioms: + return {} + + # get all samples selected for the analysis, converting lists to + # sets for fast searching. Overhead less this way + # for large analyses + all_samples = {k: set(v) for k, v in viewitems(self.samples)} + + for biom, filepath in viewitems(bioms): + table = load_table(filepath) + # remove the samples from the sets as they + # are found in the table + proc_data_ids = set(sample['Processed_id'] + for sample in table.metadata()) + ids = set(table.ids()) + for proc_data_id in proc_data_ids: + all_samples[proc_data_id] = all_samples[proc_data_id] - ids + + # what's left are unprocessed samples, so return + return all_samples @property def data_types(self): @@ -401,13 +420,15 @@ def data_types(self): list of str Data types in the analysis """ - sql = ("SELECT DISTINCT data_type from qiita.data_type d JOIN " - "qiita.processed_data p ON p.data_type_id = d.data_type_id " - "JOIN qiita.analysis_sample a ON p.processed_data_id = " - "a.processed_data_id WHERE a.analysis_id = %s ORDER BY " - "data_type") - conn_handler = SQLConnectionHandler() - return [x[0] for x in conn_handler.execute_fetchall(sql, (self._id, ))] + with TRN: + sql = """SELECT DISTINCT data_type + FROM qiita.data_type + JOIN qiita.processed_data USING (data_type_id) + JOIN qiita.analysis_sample USING (processed_data_id) + WHERE analysis_id = %s + ORDER BY data_type""" + TRN.add(sql, [self._id]) + return TRN.execute_fetchflatten() @property def shared_with(self): @@ -418,10 +439,11 @@ def shared_with(self): list of int User ids analysis is shared with """ - conn_handler = SQLConnectionHandler() - sql = ("SELECT email FROM qiita.analysis_users WHERE " - "analysis_id = %s") - return [u[0] for u in conn_handler.execute_fetchall(sql, (self._id, ))] + with TRN: + sql = """SELECT email FROM qiita.analysis_users + WHERE analysis_id = %s""" + TRN.add(sql, [self._id]) + return TRN.execute_fetchflatten() @property def all_associated_filepath_ids(self): @@ -431,28 +453,22 @@ def all_associated_filepath_ids(self): ------- list """ - conn_handler = SQLConnectionHandler() - sql = """SELECT f.filepath_id - FROM qiita.filepath f JOIN - qiita.analysis_filepath af ON f.filepath_id = af.filepath_id - WHERE af.analysis_id = %s""" - filepaths = {row[0] - for row in conn_handler.execute_fetchall(sql, [self._id])} - - sql = """SELECT fp.filepath_id - FROM qiita.analysis_job aj - JOIN qiita.job j ON aj.job_id = j.job_id - JOIN qiita.job_results_filepath jrfp ON aj.job_id = jrfp.job_id - JOIN qiita.filepath fp ON jrfp.filepath_id = fp.filepath_id - WHERE aj.analysis_id = %s""" - - job_filepaths = {row[0] - for row in conn_handler.execute_fetchall(sql, - [self._id])} - - filepaths = filepaths.union(job_filepaths) - - return filepaths + with TRN: + sql = """SELECT filepath_id + FROM qiita.filepath + JOIN qiita.analysis_filepath USING (filepath_id) + WHERE analysis_id = %s""" + TRN.add(sql, [self._id]) + filepaths = set(TRN.execute_fetchflatten()) + + sql = """SELECT filepath_id + FROM qiita.analysis_job + JOIN qiita.job USING (job_id) + JOIN qiita.job_results_filepath USING (job_id) + JOIN qiita.filepath USING (filepath_id) + WHERE analysis_id = %s""" + TRN.add(sql, [self._id]) + return filepaths.union(TRN.execute_fetchflatten()) @property def biom_tables(self): @@ -463,20 +479,22 @@ def biom_tables(self): dict Dictonary in the form {data_type: full BIOM filepath} """ - conn_handler = SQLConnectionHandler() - fptypeid = convert_to_id("biom", "filepath_type") - sql = ("SELECT dt.data_type, f.filepath FROM qiita.filepath f JOIN " - "qiita.analysis_filepath af ON f.filepath_id = af.filepath_id " - "JOIN qiita.data_type dt ON dt.data_type_id = af.data_type_id " - "WHERE af.analysis_id = %s AND f.filepath_type_id = %s") - tables = conn_handler.execute_fetchall(sql, (self._id, fptypeid)) - if not tables: - return {} - ret_tables = {} - _, base_fp = get_mountpoint(self._table)[0] - for fp in tables: - ret_tables[fp[0]] = join(base_fp, fp[1]) - return ret_tables + with TRN: + fptypeid = convert_to_id("biom", "filepath_type") + sql = """SELECT data_type, filepath + FROM qiita.filepath + JOIN qiita.analysis_filepath USING (filepath_id) + JOIN qiita.data_type USING (data_type_id) + WHERE analysis_id = %s AND filepath_type_id = %s""" + TRN.add(sql, [self._id, fptypeid]) + tables = TRN.execute_fetchindex() + if not tables: + return {} + ret_tables = {} + _, base_fp = get_mountpoint(self._table)[0] + for fp in tables: + ret_tables[fp[0]] = join(base_fp, fp[1]) + return ret_tables @property def mapping_file(self): @@ -487,43 +505,64 @@ def mapping_file(self): str or None full filepath to the mapping file or None if not generated """ - conn_handler = SQLConnectionHandler() - fptypeid = convert_to_id("plain_text", "filepath_type") - sql = ("SELECT f.filepath FROM qiita.filepath f JOIN " - "qiita.analysis_filepath af ON f.filepath_id = af.filepath_id " - "WHERE af.analysis_id = %s AND f.filepath_type_id = %s") - mapping_fp = conn_handler.execute_fetchone(sql, (self._id, fptypeid)) - if not mapping_fp: - return None - - _, base_fp = get_mountpoint(self._table)[0] - return join(base_fp, mapping_fp[0]) + with TRN: + fptypeid = convert_to_id("plain_text", "filepath_type") + sql = """SELECT filepath + FROM qiita.filepath + JOIN qiita.analysis_filepath USING (filepath_id) + WHERE analysis_id = %s AND filepath_type_id = %s""" + TRN.add(sql, [self._id, fptypeid]) + mapping_fp = TRN.execute_fetchindex() + if not mapping_fp: + return None + + _, base_fp = get_mountpoint(self._table)[0] + return join(base_fp, mapping_fp[0][0]) @property def step(self): - conn_handler = SQLConnectionHandler() - self._lock_check(conn_handler) - sql = "SELECT step from qiita.analysis_workflow WHERE analysis_id = %s" - try: - return conn_handler.execute_fetchone(sql, (self._id,))[0] - except TypeError: - raise ValueError("Step not set yet!") + """Returns the current step of the analysis + + Returns + ------- + str + The current step of the analysis + + Raises + ------ + ValueError + If the step is not set up + """ + with TRN: + self._lock_check() + sql = """SELECT step FROM qiita.analysis_workflow + WHERE analysis_id = %s""" + TRN.add(sql, [self._id]) + try: + return TRN.execute_fetchlast() + except IndexError: + raise ValueError("Step not set yet!") @step.setter def step(self, value): - conn_handler = SQLConnectionHandler() - self._lock_check(conn_handler) - sql = ("SELECT EXISTS(SELECT analysis_id from qiita.analysis_workflow " - "WHERE analysis_id = %s)") - step_exists = conn_handler.execute_fetchone(sql, (self._id,))[0] - - if step_exists: - sql = ("UPDATE qiita.analysis_workflow SET step = %s WHERE " - "analysis_id = %s") - else: - sql = ("INSERT INTO qiita.analysis_workflow (step, analysis_id) " - "VALUES (%s, %s)") - conn_handler.execute(sql, (value, self._id)) + with TRN: + self._lock_check() + sql = """SELECT EXISTS( + SELECT analysis_id + FROM qiita.analysis_workflow + WHERE analysis_id = %s)""" + TRN.add(sql, [self._id]) + step_exists = TRN.execute_fetchlast() + + if step_exists: + sql = """UPDATE qiita.analysis_workflow SET step = %s + WHERE analysis_id = %s""" + else: + sql = """INSERT INTO qiita.analysis_workflow + (step, analysis_id) + VALUES (%s, %s)""" + TRN.add(sql, [value, self._id]) + TRN.execute() @property def jobs(self): @@ -534,11 +573,11 @@ def jobs(self): list of ints Job ids for jobs in analysis. Empty list if no jobs attached. """ - conn_handler = SQLConnectionHandler() - sql = ("SELECT job_id FROM qiita.analysis_job WHERE " - "analysis_id = %s".format(self._table)) - job_ids = conn_handler.execute_fetchall(sql, (self._id, )) - return [job_id[0] for job_id in job_ids] + with TRN: + sql = """SELECT job_id FROM qiita.analysis_job + WHERE analysis_id = %s""".format(self._table) + TRN.add(sql, [self._id]) + return TRN.execute_fetchflatten() @property def pmid(self): @@ -549,11 +588,11 @@ def pmid(self): str or None returns the PMID or None if none is attached """ - conn_handler = SQLConnectionHandler() - sql = ("SELECT pmid FROM qiita.{0} WHERE " - "analysis_id = %s".format(self._table)) - pmid = conn_handler.execute_fetchone(sql, (self._id, ))[0] - return pmid + with TRN: + sql = "SELECT pmid FROM qiita.{0} WHERE analysis_id = %s".format( + self._table) + TRN.add(sql, [self._id]) + return TRN.execute_fetchlast() @pmid.setter def pmid(self, pmid): @@ -573,11 +612,12 @@ def pmid(self, pmid): ----- An analysis should only ever have one PMID attached to it. """ - conn_handler = SQLConnectionHandler() - self._lock_check(conn_handler) - sql = ("UPDATE qiita.{0} SET pmid = %s WHERE " - "analysis_id = %s".format(self._table)) - conn_handler.execute(sql, (pmid, self._id)) + with TRN: + self._lock_check() + sql = """UPDATE qiita.{0} SET pmid = %s + WHERE analysis_id = %s""".format(self._table) + TRN.add(sql, [pmid, self._id]) + TRN.execute() # @property # def parent(self): @@ -602,12 +642,13 @@ def has_access(self, user): bool Whether user has access to analysis or not """ - # if admin or superuser, just return true - if user.level in {'superuser', 'admin'}: - return True + with TRN: + # if admin or superuser, just return true + if user.level in {'superuser', 'admin'}: + return True - return self._id in Analysis.get_by_status('public') | \ - user.private_analyses | user.shared_analyses + return self._id in Analysis.get_by_status('public') | \ + user.private_analyses | user.shared_analyses def summary_data(self): """Return number of studies, processed data, and samples selected @@ -617,14 +658,16 @@ def summary_data(self): dict counts keyed to their relevant type """ - sql = """SELECT COUNT(DISTINCT study_id) as studies, - COUNT(DISTINCT processed_data_id) as processed_data, - COUNT(DISTINCT sample_id) as samples - FROM qiita.study_processed_data - JOIN qiita.analysis_sample USING (processed_data_id) - WHERE analysis_id = %s""" - conn_handler = SQLConnectionHandler() - return dict(conn_handler.execute_fetchone(sql, [self._id])) + with TRN: + sql = """SELECT + COUNT(DISTINCT study_id) as studies, + COUNT(DISTINCT processed_data_id) as processed_data, + COUNT(DISTINCT sample_id) as samples + FROM qiita.study_processed_data + JOIN qiita.analysis_sample USING (processed_data_id) + WHERE analysis_id = %s""" + TRN.add(sql, [self._id]) + return dict(TRN.execute_fetchindex()[0]) def share(self, user): """Share the analysis with another user @@ -634,17 +677,17 @@ def share(self, user): user: User object The user to share the analysis with """ - conn_handler = SQLConnectionHandler() - self._lock_check(conn_handler) - - # Make sure the analysis is not already shared with the given user - if user.id in self.shared_with: - return + with TRN: + self._lock_check() - sql = ("INSERT INTO qiita.analysis_users (analysis_id, email) VALUES " - "(%s, %s)") + # Make sure the analysis is not already shared with the given user + if user.id in self.shared_with: + return - conn_handler.execute(sql, (self._id, user.id)) + sql = """INSERT INTO qiita.analysis_users (analysis_id, email) + VALUES (%s, %s)""" + TRN.add(sql, [self._id, user.id]) + TRN.execute() def unshare(self, user): """Unshare the analysis with another user @@ -654,13 +697,13 @@ def unshare(self, user): user: User object The user to unshare the analysis with """ - conn_handler = SQLConnectionHandler() - self._lock_check(conn_handler) + with TRN: + self._lock_check() - sql = ("DELETE FROM qiita.analysis_users WHERE analysis_id = %s AND " - "email = %s") - - conn_handler.execute(sql, (self._id, user.id)) + sql = """DELETE FROM qiita.analysis_users + WHERE analysis_id = %s AND email = %s""" + TRN.add(sql, [self._id, user.id]) + TRN.execute() def add_samples(self, samples): """Adds samples to the analysis @@ -671,22 +714,24 @@ def add_samples(self, samples): samples and the processed data id they come from in form {processed_data_id: [sample1, sample2, ...], ...} """ - conn_handler = SQLConnectionHandler() - self._lock_check(conn_handler) - - for pid, samps in viewitems(samples): - # get previously selected samples for pid and filter them out - sql = """SELECT sample_id FROM qiita.analysis_sample - WHERE processed_data_id = %s and analysis_id = %s""" - prev_selected = [x[0] for x in - conn_handler.execute_fetchall(sql, - (pid, self._id))] - - select = set(samps).difference(prev_selected) - sql = ("INSERT INTO qiita.analysis_sample " - "(analysis_id, processed_data_id, sample_id) VALUES " - "({}, %s, %s)".format(self._id)) - conn_handler.executemany(sql, [x for x in product([pid], select)]) + with TRN: + self._lock_check() + + for pid, samps in viewitems(samples): + # get previously selected samples for pid and filter them out + sql = """SELECT sample_id + FROM qiita.analysis_sample + WHERE processed_data_id = %s AND analysis_id = %s""" + TRN.add(sql, [pid, self._id]) + prev_selected = TRN.execute_fetchflatten() + + select = set(samps).difference(prev_selected) + sql = """INSERT INTO qiita.analysis_sample + (analysis_id, processed_data_id, sample_id) + VALUES (%s, %s, %s)""" + args = [[self._id, pid, s] for s in select] + TRN.add(sql, args, many=True) + TRN.execute() def remove_samples(self, proc_data=None, samples=None): """Removes samples from the analysis @@ -709,29 +754,32 @@ def remove_samples(self, proc_data=None, samples=None): If both are passed, the given samples are removed from the given processed data ids """ - conn_handler = SQLConnectionHandler() - self._lock_check(conn_handler) - if proc_data and samples: - sql = ("DELETE FROM qiita.analysis_sample WHERE analysis_id = %s " - "AND processed_data_id = %s AND sample_id = %s") - remove = [] - # build tuples for what samples to remove from what processed data - for proc_id in proc_data: - for sample_id in samples: - remove.append((self._id, proc_id, sample_id)) - elif proc_data: - sql = ("DELETE FROM qiita.analysis_sample WHERE analysis_id = %s " - "AND processed_data_id = %s") - remove = [(self._id, p) for p in proc_data] - elif samples: - sql = ("DELETE FROM qiita.analysis_sample WHERE analysis_id = %s " - "AND sample_id = %s") - remove = [(self._id, s) for s in samples] - else: - raise IncompetentQiitaDeveloperError( - "Must provide list of samples and/or proc_data for removal!") - - conn_handler.executemany(sql, remove) + with TRN: + self._lock_check() + if proc_data and samples: + sql = """DELETE FROM qiita.analysis_sample + WHERE analysis_id = %s + AND processed_data_id = %s + AND sample_id = %s""" + # build tuples for what samples to remove from what + # processed data + args = [[self._id, p, s] + for p, s in product(proc_data, samples)] + elif proc_data: + sql = """DELETE FROM qiita.analysis_sample + WHERE analysis_id = %s AND processed_data_id = %s""" + args = [[self._id, p] for p in proc_data] + elif samples: + sql = """DELETE FROM qiita.analysis_sample + WHERE analysis_id = %s AND sample_id = %s""" + args = [[self._id, s] for s in samples] + else: + raise IncompetentQiitaDeveloperError( + "Must provide list of samples and/or proc_data for " + "removal") + + TRN.add(sql, args, many=True) + TRN.execute() def build_files(self, rarefaction_depth=None): """Builds biom and mapping files needed for analysis @@ -754,127 +802,136 @@ def build_files(self, rarefaction_depth=None): Creates biom tables for each requested data type Creates mapping file for requested samples """ - if rarefaction_depth is not None: - if type(rarefaction_depth) is not int: - raise TypeError("rarefaction_depth must be in integer") - if rarefaction_depth <= 0: - raise ValueError("rarefaction_depth must be greater than 0") + with TRN: + if rarefaction_depth is not None: + if type(rarefaction_depth) is not int: + raise TypeError("rarefaction_depth must be in integer") + if rarefaction_depth <= 0: + raise ValueError( + "rarefaction_depth must be greater than 0") - samples = self._get_samples() - self._build_mapping_file(samples) - self._build_biom_tables(samples, rarefaction_depth) + samples = self._get_samples() + self._build_mapping_file(samples) + self._build_biom_tables(samples, rarefaction_depth) def _get_samples(self): """Retrieves dict of samples to proc_data_id for the analysis""" - conn_handler = SQLConnectionHandler() - sql = ("SELECT processed_data_id, array_agg(sample_id ORDER BY " - "sample_id) FROM qiita.analysis_sample WHERE analysis_id = %s " - "GROUP BY processed_data_id") - return dict(conn_handler.execute_fetchall(sql, [self._id])) + with TRN: + sql = """SELECT processed_data_id, array_agg( + sample_id ORDER BY sample_id) + FROM qiita.analysis_sample + WHERE analysis_id = %s + GROUP BY processed_data_id""" + TRN.add(sql, [self._id]) + return dict(TRN.execute_fetchindex()) def _build_biom_tables(self, samples, rarefaction_depth): """Build tables and add them to the analysis""" - # filter and combine all study BIOM tables needed for each data type - new_tables = {dt: None for dt in self.data_types} - base_fp = get_work_base_dir() - for pid, samps in viewitems(samples): - # one biom table attached to each processed data object - proc_data = ProcessedData(pid) - proc_data_fp = proc_data.get_filepaths()[0][1] - table_fp = join(base_fp, proc_data_fp) - table = load_table(table_fp) - # HACKY WORKAROUND FOR DEMO. Issue # 246 - # make sure samples not in biom table are not filtered for - table_samps = set(table.ids()) - filter_samps = table_samps.intersection(samps) - # add the metadata column for study the samples come from - study_meta = {'Study': Study(proc_data.study).title, - 'Processed_id': proc_data.id} - samples_meta = {sid: study_meta for sid in filter_samps} - # filter for just the wanted samples and merge into new table - # this if/else setup avoids needing a blank table to start merges - table.filter(filter_samps, axis='sample', inplace=True) - table.add_metadata(samples_meta, axis='sample') - data_type = proc_data.data_type() - if new_tables[data_type] is None: - new_tables[data_type] = table - else: - new_tables[data_type] = new_tables[data_type].merge(table) - - # add the new tables to the analysis - _, base_fp = get_mountpoint(self._table)[0] - for dt, biom_table in viewitems(new_tables): - # rarefy, if specified - if rarefaction_depth is not None: - biom_table = biom_table.subsample(rarefaction_depth) - # write out the file - biom_fp = join(base_fp, "%d_analysis_%s.biom" % (self._id, dt)) - with biom_open(biom_fp, 'w') as f: - biom_table.to_hdf5(f, "Analysis %s Datatype %s" % - (self._id, dt)) - self._add_file("%d_analysis_%s.biom" % (self._id, dt), - "biom", data_type=dt) + with TRN: + # filter and combine all study BIOM tables needed for + # each data type + new_tables = {dt: None for dt in self.data_types} + base_fp = get_work_base_dir() + for pid, samps in viewitems(samples): + # one biom table attached to each processed data object + proc_data = ProcessedData(pid) + proc_data_fp = proc_data.get_filepaths()[0][1] + table_fp = join(base_fp, proc_data_fp) + table = load_table(table_fp) + # HACKY WORKAROUND FOR DEMO. Issue # 246 + # make sure samples not in biom table are not filtered for + table_samps = set(table.ids()) + filter_samps = table_samps.intersection(samps) + # add the metadata column for study the samples come from + study_meta = {'Study': Study(proc_data.study).title, + 'Processed_id': proc_data.id} + samples_meta = {sid: study_meta for sid in filter_samps} + # filter for just the wanted samples and merge into new table + # this if/else setup avoids needing a blank table to + # start merges + table.filter(filter_samps, axis='sample', inplace=True) + table.add_metadata(samples_meta, axis='sample') + data_type = proc_data.data_type() + if new_tables[data_type] is None: + new_tables[data_type] = table + else: + new_tables[data_type] = new_tables[data_type].merge(table) + + # add the new tables to the analysis + _, base_fp = get_mountpoint(self._table)[0] + for dt, biom_table in viewitems(new_tables): + # rarefy, if specified + if rarefaction_depth is not None: + biom_table = biom_table.subsample(rarefaction_depth) + # write out the file + biom_fp = join(base_fp, "%d_analysis_%s.biom" % (self._id, dt)) + with biom_open(biom_fp, 'w') as f: + biom_table.to_hdf5(f, "Analysis %s Datatype %s" % + (self._id, dt)) + self._add_file("%d_analysis_%s.biom" % (self._id, dt), + "biom", data_type=dt) def _build_mapping_file(self, samples): """Builds the combined mapping file for all samples Code modified slightly from qiime.util.MetadataMap.__add__""" - conn_handler = SQLConnectionHandler() - all_sample_ids = set() - sql = """SELECT filepath_id, filepath - FROM qiita.filepath - JOIN qiita.prep_template_filepath USING (filepath_id) - JOIN qiita.prep_template_preprocessed_data - USING (prep_template_id) - JOIN qiita.preprocessed_processed_data - USING (preprocessed_data_id) - JOIN qiita.filepath_type USING (filepath_type_id) - WHERE processed_data_id = %s - AND filepath_type = 'qiime_map' - ORDER BY filepath_id DESC""" - _id, fp = get_mountpoint('templates')[0] - to_concat = [] - - for pid, samples in viewitems(samples): - if len(samples) != len(set(samples)): - duplicates = find_duplicates(samples) - raise QiitaDBError("Duplicate sample ids found: %s" - % ', '.join(duplicates)) - # Get the QIIME mapping file - qiime_map_fp = conn_handler.execute_fetchall(sql, (pid,))[0][1] - # Parse the mapping file - qiime_map = pd.read_csv( - join(fp, qiime_map_fp), sep='\t', keep_default_na=False, - na_values=['unknown'], index_col=False, - converters=defaultdict(lambda: str)) - qiime_map.set_index('#SampleID', inplace=True, drop=True) - qiime_map = qiime_map.loc[samples] - - duplicates = all_sample_ids.intersection(qiime_map.index) - if duplicates or len(samples) != len(set(samples)): - # Duplicate samples so raise error - raise QiitaDBError("Duplicate sample ids found: %s" - % ', '.join(duplicates)) - all_sample_ids.update(qiime_map.index) - to_concat.append(qiime_map) - - merged_map = pd.concat(to_concat) - - cols = merged_map.columns.values.tolist() - cols.remove('BarcodeSequence') - cols.remove('LinkerPrimerSequence') - cols.remove('Description') - new_cols = ['BarcodeSequence', 'LinkerPrimerSequence'] - new_cols.extend(cols) - new_cols.append('Description') - merged_map = merged_map[new_cols] - - # Save the mapping file - _, base_fp = get_mountpoint(self._table)[0] - mapping_fp = join(base_fp, "%d_analysis_mapping.txt" % self._id) - merged_map.to_csv(mapping_fp, index_label='#SampleID', - na_rep='unknown', sep='\t') - - self._add_file("%d_analysis_mapping.txt" % self._id, "plain_text") + with TRN: + all_sample_ids = set() + sql = """SELECT filepath_id, filepath + FROM qiita.filepath + JOIN qiita.prep_template_filepath USING (filepath_id) + JOIN qiita.prep_template_preprocessed_data + USING (prep_template_id) + JOIN qiita.preprocessed_processed_data + USING (preprocessed_data_id) + JOIN qiita.filepath_type USING (filepath_type_id) + WHERE processed_data_id = %s + AND filepath_type = 'qiime_map' + ORDER BY filepath_id DESC""" + _id, fp = get_mountpoint('templates')[0] + to_concat = [] + + for pid, samples in viewitems(samples): + if len(samples) != len(set(samples)): + duplicates = find_duplicates(samples) + raise QiitaDBError("Duplicate sample ids found: %s" + % ', '.join(duplicates)) + # Get the QIIME mapping file + TRN.add(sql, [pid]) + qiime_map_fp = TRN.execute_fetchindex()[0][1] + # Parse the mapping file + qiime_map = pd.read_csv( + join(fp, qiime_map_fp), sep='\t', keep_default_na=False, + na_values=['unknown'], index_col=False, + converters=defaultdict(lambda: str)) + qiime_map.set_index('#SampleID', inplace=True, drop=True) + qiime_map = qiime_map.loc[samples] + + duplicates = all_sample_ids.intersection(qiime_map.index) + if duplicates or len(samples) != len(set(samples)): + # Duplicate samples so raise error + raise QiitaDBError("Duplicate sample ids found: %s" + % ', '.join(duplicates)) + all_sample_ids.update(qiime_map.index) + to_concat.append(qiime_map) + + merged_map = pd.concat(to_concat) + + cols = merged_map.columns.values.tolist() + cols.remove('BarcodeSequence') + cols.remove('LinkerPrimerSequence') + cols.remove('Description') + new_cols = ['BarcodeSequence', 'LinkerPrimerSequence'] + new_cols.extend(cols) + new_cols.append('Description') + merged_map = merged_map[new_cols] + + # Save the mapping file + _, base_fp = get_mountpoint(self._table)[0] + mapping_fp = join(base_fp, "%d_analysis_mapping.txt" % self._id) + merged_map.to_csv(mapping_fp, index_label='#SampleID', + na_rep='unknown', sep='\t') + + self._add_file("%d_analysis_mapping.txt" % self._id, "plain_text") def _add_file(self, filename, filetype, data_type=None): """adds analysis item to database @@ -886,23 +943,24 @@ def _add_file(self, filename, filetype, data_type=None): filetype : {plain_text, biom} data_type : str, optional """ - conn_handler = SQLConnectionHandler() - - filetype_id = convert_to_id(filetype, 'filepath_type') - _, mp = get_mountpoint('analysis')[0] - fpid = insert_filepaths([ - (join(mp, filename), filetype_id)], -1, 'analysis', 'filepath', - conn_handler, move_files=False)[0] - - col = "" - dtid = "" - if data_type: - col = ",data_type_id" - dtid = ",%d" % convert_to_id(data_type, "data_type") - - sql = ("INSERT INTO qiita.analysis_filepath (analysis_id, filepath_id" - "{0}) VALUES (%s, %s{1})".format(col, dtid)) - conn_handler.execute(sql, (self._id, fpid)) + with TRN: + filetype_id = convert_to_id(filetype, 'filepath_type') + _, mp = get_mountpoint('analysis')[0] + fpid = insert_filepaths([ + (join(mp, filename), filetype_id)], -1, 'analysis', 'filepath', + move_files=False)[0] + + col = "" + dtid = "" + if data_type: + col = ", data_type_id" + dtid = ", %d" % convert_to_id(data_type, "data_type") + + sql = """INSERT INTO qiita.analysis_filepath + (analysis_id, filepath_id{0}) + VALUES (%s, %s{1})""".format(col, dtid) + TRN.add(sql, [self._id, fpid]) + TRN.execute() class Collection(QiitaStatusObject): @@ -936,7 +994,7 @@ class Collection(QiitaStatusObject): _highlight_table = "collection_job" _share_table = "collection_users" - def _status_setter_checks(self, conn_handler): + def _status_setter_checks(self): r"""Perform a check to make sure not setting status away from public """ if self.check_status(("public", )): @@ -955,11 +1013,11 @@ def create(cls, owner, name, description=None): description : str, optional Brief description of the collecton's overarching goal """ - conn_handler = SQLConnectionHandler() - - sql = ("INSERT INTO qiita.{0} (email, name, description) " - "VALUES (%s, %s, %s)".format(cls._table)) - conn_handler.execute(sql, [owner.id, name, description]) + with TRN: + sql = """INSERT INTO qiita.{0} (email, name, description) + VALUES (%s, %s, %s)""".format(cls._table) + TRN.add(sql, [owner.id, name, description]) + TRN.execute() @classmethod def delete(cls, id_): @@ -975,81 +1033,85 @@ def delete(cls, id_): QiitaDBStatusError Trying to delete a public collection """ - conn_handler = SQLConnectionHandler() - if cls(id_).status == "public": - raise QiitaDBStatusError("Can't delete public collection!") - - queue = "remove_collection_%d" % id_ - conn_handler.create_queue(queue) + with TRN: + if cls(id_).status == "public": + raise QiitaDBStatusError("Can't delete public collection!") - for table in (cls._analysis_table, cls._highlight_table, - cls._share_table, cls._table): - conn_handler.add_to_queue( - queue, "DELETE FROM qiita.{0} WHERE " - "collection_id = %s".format(table), [id_]) + sql = "DELETE FROM qiita.{0} WHERE collection_id = %s" + for table in (cls._analysis_table, cls._highlight_table, + cls._share_table, cls._table): + TRN.add(sql.format(table), [id_]) - conn_handler.execute_queue(queue) + TRN.execute() # --- Properties --- @property def name(self): - sql = ("SELECT name FROM qiita.{0} WHERE " - "collection_id = %s".format(self._table)) - conn_handler = SQLConnectionHandler() - return conn_handler.execute_fetchone(sql, [self._id])[0] + with TRN: + sql = "SELECT name FROM qiita.{0} WHERE collection_id = %s".format( + self._table) + TRN.add(sql, [self._id]) + return TRN.execute_fetchlast() @name.setter def name(self, value): - conn_handler = SQLConnectionHandler() - self._status_setter_checks(conn_handler) + with TRN: + self._status_setter_checks() - sql = ("UPDATE qiita.{0} SET name = %s WHERE " - "collection_id = %s".format(self._table)) - conn_handler.execute(sql, [value, self._id]) + sql = """UPDATE qiita.{0} SET name = %s + WHERE collection_id = %s""".format(self._table) + TRN.add(sql, [value, self._id]) + TRN.execute() @property def description(self): - sql = ("SELECT description FROM qiita.{0} WHERE " - "collection_id = %s".format(self._table)) - conn_handler = SQLConnectionHandler() - return conn_handler.execute_fetchone(sql, [self._id])[0] + with TRN: + sql = """SELECT description FROM qiita.{0} + WHERE collection_id = %s""".format(self._table) + TRN.add(sql, [self._id]) + return TRN.execute_fetchlast() @description.setter def description(self, value): - conn_handler = SQLConnectionHandler() - self._status_setter_checks(conn_handler) + with TRN: + self._status_setter_checks() - sql = ("UPDATE qiita.{0} SET description = %s WHERE " - "collection_id = %s".format(self._table)) - conn_handler.execute(sql, [value, self._id]) + sql = """UPDATE qiita.{0} SET description = %s + WHERE collection_id = %s""".format(self._table) + TRN.add(sql, [value, self._id]) + TRN.execute() @property def owner(self): - sql = ("SELECT email FROM qiita.{0} WHERE " - "collection_id = %s".format(self._table)) - conn_handler = SQLConnectionHandler() - return conn_handler.execute_fetchone(sql, [self._id])[0] + with TRN: + sql = """SELECT email FROM qiita.{0} + WHERE collection_id = %s""".format(self._table) + TRN.add(sql, [self._id]) + return TRN.execute_fetchlast() @property def analyses(self): - sql = ("SELECT analysis_id FROM qiita.{0} WHERE " - "collection_id = %s".format(self._analysis_table)) - conn_handler = SQLConnectionHandler() - return [x[0] for x in conn_handler.execute_fetchall(sql, [self._id])] + with TRN: + sql = """SELECT analysis_id FROM qiita.{0} + WHERE collection_id = %s""".format(self._analysis_table) + TRN.add(sql, [self._id]) + return TRN.execute_fetchflatten() @property def highlights(self): - sql = ("SELECT job_id FROM qiita.{0} WHERE " - "collection_id = %s".format(self._highlight_table)) - conn_handler = SQLConnectionHandler() - return [x[0] for x in conn_handler.execute_fetchall(sql, [self._id])] + with TRN: + sql = """SELECT job_id FROM qiita.{0} + WHERE collection_id = %s""".format(self._highlight_table) + TRN.add(sql, [self._id]) + return TRN.execute_fetchflatten() @property def shared_with(self): - sql = ("SELECT email FROM qiita.{0} WHERE " - "collection_id = %s".format(self._share_table)) - conn_handler = SQLConnectionHandler() - return [x[0] for x in conn_handler.execute_fetchall(sql, [self._id])] + with TRN: + sql = """SELECT email FROM qiita.{0} + WHERE collection_id = %s""".format(self._share_table) + TRN.add(sql, [self._id]) + return TRN.execute_fetchflatten() # --- Functions --- def add_analysis(self, analysis): @@ -1059,12 +1121,13 @@ def add_analysis(self, analysis): ---------- analysis : Analysis object """ - conn_handler = SQLConnectionHandler() - self._status_setter_checks(conn_handler) + with TRN: + self._status_setter_checks() - sql = ("INSERT INTO qiita.{0} (analysis_id, collection_id) " - "VALUES (%s, %s)".format(self._analysis_table)) - conn_handler.execute(sql, [analysis.id, self._id]) + sql = """INSERT INTO qiita.{0} (analysis_id, collection_id) + VALUES (%s, %s)""".format(self._analysis_table) + TRN.add(sql, [analysis.id, self._id]) + TRN.execute() def remove_analysis(self, analysis): """Remove an analysis from the collection object @@ -1073,13 +1136,14 @@ def remove_analysis(self, analysis): ---------- analysis : Analysis object """ - conn_handler = SQLConnectionHandler() - self._status_setter_checks(conn_handler) + with TRN: + self._status_setter_checks() - sql = ("DELETE FROM qiita.{0} WHERE analysis_id = %s AND " - "collection_id = %s".format(self._analysis_table)) - conn_handler = SQLConnectionHandler() - conn_handler.execute(sql, [analysis.id, self._id]) + sql = """DELETE FROM qiita.{0} + WHERE analysis_id = %s + AND collection_id = %s""".format(self._analysis_table) + TRN.add(sql, [analysis.id, self._id]) + TRN.execute() def highlight_job(self, job): """Marks a job as important to the collection @@ -1088,13 +1152,13 @@ def highlight_job(self, job): ---------- job : Job object """ - conn_handler = SQLConnectionHandler() - self._status_setter_checks(conn_handler) + with TRN: + self._status_setter_checks() - sql = ("INSERT INTO qiita.{0} (job_id, collection_id) " - "VALUES (%s, %s)".format(self._highlight_table)) - conn_handler = SQLConnectionHandler() - conn_handler.execute(sql, [job.id, self._id]) + sql = """INSERT INTO qiita.{0} (job_id, collection_id) + VALUES (%s, %s)""".format(self._highlight_table) + TRN.add(sql, [job.id, self._id]) + TRN.execute() def remove_highlight(self, job): """Removes job importance from the collection @@ -1103,13 +1167,14 @@ def remove_highlight(self, job): ---------- job : Job object """ - conn_handler = SQLConnectionHandler() - self._status_setter_checks(conn_handler) + with TRN: + self._status_setter_checks() - sql = ("DELETE FROM qiita.{0} WHERE job_id = %s AND " - "collection_id = %s".format(self._highlight_table)) - conn_handler = SQLConnectionHandler() - conn_handler.execute(sql, [job.id, self._id]) + sql = """DELETE FROM qiita.{0} + WHERE job_id = %s + AND collection_id = %s""".format(self._highlight_table) + TRN.add(sql, [job.id, self._id]) + TRN.execute() def share(self, user): """Shares the collection with another user @@ -1118,13 +1183,13 @@ def share(self, user): ---------- user : User object """ - conn_handler = SQLConnectionHandler() - self._status_setter_checks(conn_handler) + with TRN: + self._status_setter_checks() - sql = ("INSERT INTO qiita.{0} (email, collection_id) " - "VALUES (%s, %s)".format(self._share_table)) - conn_handler = SQLConnectionHandler() - conn_handler.execute(sql, [user.id, self._id]) + sql = """INSERT INTO qiita.{0} (email, collection_id) + VALUES (%s, %s)""".format(self._share_table) + TRN.add(sql, [user.id, self._id]) + TRN.execute() def unshare(self, user): """Unshares the collection with another user @@ -1133,10 +1198,11 @@ def unshare(self, user): ---------- user : User object """ - conn_handler = SQLConnectionHandler() - self._status_setter_checks(conn_handler) - - sql = ("DELETE FROM qiita.{0} WHERE " - "email = %s AND collection_id = %s".format(self._share_table)) - conn_handler = SQLConnectionHandler() - conn_handler.execute(sql, [user.id, self._id]) + with TRN: + self._status_setter_checks() + + sql = """DELETE FROM qiita.{0} + WHERE email = %s + AND collection_id = %s""".format(self._share_table) + TRN.add(sql, [user.id, self._id]) + TRN.execute() diff --git a/qiita_db/base.py b/qiita_db/base.py index 35cb0208d..54d4fe65e 100644 --- a/qiita_db/base.py +++ b/qiita_db/base.py @@ -26,9 +26,10 @@ # ----------------------------------------------------------------------------- from __future__ import division + from qiita_core.exceptions import IncompetentQiitaDeveloperError from qiita_core.qiita_settings import qiita_config -from .sql_connection import SQLConnectionHandler +from .sql_connection import TRN from .exceptions import (QiitaDBNotImplementedError, QiitaDBUnknownIDError, QiitaDBError) @@ -131,11 +132,12 @@ def _check_id(self, id_): the other classes. However, still defining here as there is only one subclass that doesn't follow this convention and it can override this. """ - conn_handler = SQLConnectionHandler() - - return conn_handler.execute_fetchone( - "SELECT EXISTS(SELECT * FROM qiita.{0} WHERE " - "{0}_id=%s)".format(self._table), (id_, ))[0] + with TRN: + sql = """SELECT EXISTS( + SELECT * FROM qiita.{0} + WHERE {0}_id=%s)""".format(self._table) + TRN.add(sql, [id_]) + return TRN.execute_fetchlast() def _check_portal(self, id_): """Checks that object is accessible in current portal @@ -149,15 +151,15 @@ def _check_portal(self, id_): # assume not portal limited object return True - conn_handler = SQLConnectionHandler() - - return conn_handler.execute_fetchone( - """SELECT EXISTS( - SELECT * from qiita.{0} - JOIN qiita.portal_type using (portal_type_id) - WHERE {1}_id = %s AND portal = %s)""".format( - self._portal_table, self._table), - [id_, qiita_config.portal])[0] + with TRN: + sql = """SELECT EXISTS( + SELECT * + FROM qiita.{0} + JOIN qiita.portal_type USING (portal_type_id) + WHERE {1}_id = %s AND portal = %s + )""".format(self._portal_table, self._table) + TRN.add(sql, [id_, qiita_config.portal]) + return TRN.execute_fetchlast() def __init__(self, id_): r"""Initializes the object @@ -171,14 +173,15 @@ def __init__(self, id_): QiitaDBUnknownIDError If `id_` does not correspond to any object """ - self._check_subclass() - if not self._check_id(id_): - raise QiitaDBUnknownIDError(id_, self._table) + with TRN: + self._check_subclass() + if not self._check_id(id_): + raise QiitaDBUnknownIDError(id_, self._table) - if not self._check_portal(id_): - raise QiitaDBError("%s with id %d inaccessible in current portal: " - "%s" % (self.__class__.__name__, id_, - qiita_config.portal)) + if not self._check_portal(id_): + raise QiitaDBError( + "%s with id %d inaccessible in current portal: %s" + % (self.__class__.__name__, id_, qiita_config.portal)) self._id = id_ @@ -216,18 +219,16 @@ class QiitaStatusObject(QiitaObject): @property def status(self): r"""String with the current status of the analysis""" - # Check that self._table is actually defined - self._check_subclass() - # Get the DB status of the object - conn_handler = SQLConnectionHandler() - return conn_handler.execute_fetchone( - "SELECT status FROM qiita.{0}_status WHERE {0}_status_id = " - "(SELECT {0}_status_id FROM qiita.{0} WHERE " - "{0}_id = %s)".format(self._table), - (self._id, ))[0] - - def _status_setter_checks(self, conn_handler): + with TRN: + sql = """SELECT status FROM qiita.{0}_status + WHERE {0}_status_id = ( + SELECT {0}_status_id FROM qiita.{0} + WHERE {0}_id = %s)""".format(self._table) + TRN.add(sql, [self._id]) + return TRN.execute_fetchlast() + + def _status_setter_checks(self): r"""Perform any extra checks that needed to be done before setting the object status on the database. Should be overwritten by the subclasses """ @@ -242,18 +243,18 @@ def status(self, status): status: str The new object status """ - # Check that self._table is actually defined - self._check_subclass() - - # Perform any extra checks needed before we update the status in the DB - conn_handler = SQLConnectionHandler() - self._status_setter_checks(conn_handler) - - # Update the status of the object - conn_handler.execute( - "UPDATE qiita.{0} SET {0}_status_id = " - "(SELECT {0}_status_id FROM qiita.{0}_status WHERE status = %s) " - "WHERE {0}_id = %s".format(self._table), (status, self._id)) + with TRN: + # Perform any extra checks needed before + # we update the status in the DB + self._status_setter_checks() + + # Update the status of the object + sql = """UPDATE qiita.{0} SET {0}_status_id = ( + SELECT {0}_status_id FROM qiita.{0}_status + WHERE status = %s) + WHERE {0}_id = %s""".format(self._table) + TRN.add(sql, [status, self._id]) + TRN.execute() def check_status(self, status, exclude=False): r"""Checks status of object. @@ -285,21 +286,20 @@ def check_status(self, status, exclude=False): Table setup: foo: foo_status_id ----> foo_status: foo_status_id, status """ - # Check that self._table is actually defined - self._check_subclass() - - # Get all available statuses - conn_handler = SQLConnectionHandler() - - statuses = [x[0] for x in conn_handler.execute_fetchall( - "SELECT DISTINCT status FROM qiita.{0}_status".format(self._table), - (self._id, ))] - - # Check that all the provided statuses are valid statuses - if set(status).difference(statuses): - raise ValueError("%s are not valid status values" - % set(status).difference(statuses)) - - # Get the DB status of the object - dbstatus = self.status - return dbstatus not in status if exclude else dbstatus in status + with TRN: + # Get all available statuses + sql = "SELECT DISTINCT status FROM qiita.{0}_status".format( + self._table) + TRN.add(sql) + # We need to access to the results of the last SQL query, + # hence indexing using -1 + avail_status = [x[0] for x in TRN.execute_fetchindex()] + + # Check that all the provided status are valid status + if set(status).difference(avail_status): + raise ValueError("%s are not valid status values" + % set(status).difference(avail_status)) + + # Get the DB status of the object + dbstatus = self.status + return dbstatus not in status if exclude else dbstatus in status diff --git a/qiita_db/commands.py b/qiita_db/commands.py index a3237e407..4004ae12d 100644 --- a/qiita_db/commands.py +++ b/qiita_db/commands.py @@ -24,7 +24,7 @@ load_template_to_dataframe) from .parameters import (PreprocessedIlluminaParams, Preprocessed454Params, ProcessedSortmernaParams) -from .sql_connection import SQLConnectionHandler +from .sql_connection import TRN with standard_library.hooks(): from configparser import ConfigParser @@ -80,25 +80,25 @@ def get_optional(name): if optvalue is not None: infodict[value] = optvalue - emp_person_name_email = get_optional('emp_person_name') - if emp_person_name_email is not None: - emp_name, emp_email, emp_affiliation = emp_person_name_email.split(',') - infodict['emp_person_id'] = StudyPerson.create(emp_name.strip(), - emp_email.strip(), - emp_affiliation.strip()) - lab_name_email = get_optional('lab_person') - if lab_name_email is not None: - lab_name, lab_email, lab_affiliation = lab_name_email.split(',') - infodict['lab_person_id'] = StudyPerson.create(lab_name.strip(), - lab_email.strip(), - lab_affiliation.strip()) + with TRN: + emp_person_name_email = get_optional('emp_person_name') + if emp_person_name_email is not None: + emp_name, emp_email, emp_affiliation = \ + emp_person_name_email.split(',') + infodict['emp_person_id'] = StudyPerson.create( + emp_name.strip(), emp_email.strip(), emp_affiliation.strip()) + lab_name_email = get_optional('lab_person') + if lab_name_email is not None: + lab_name, lab_email, lab_affiliation = lab_name_email.split(',') + infodict['lab_person_id'] = StudyPerson.create( + lab_name.strip(), lab_email.strip(), lab_affiliation.strip()) - pi_name_email = infodict.pop('principal_investigator') - pi_name, pi_email, pi_affiliation = pi_name_email.split(',', 2) - infodict['principal_investigator_id'] = StudyPerson.create( - pi_name.strip(), pi_email.strip(), pi_affiliation.strip()) + pi_name_email = infodict.pop('principal_investigator') + pi_name, pi_email, pi_affiliation = pi_name_email.split(',', 2) + infodict['principal_investigator_id'] = StudyPerson.create( + pi_name.strip(), pi_email.strip(), pi_affiliation.strip()) - return Study.create(User(owner), title, efo_ids, infodict) + return Study.create(User(owner), title, efo_ids, infodict) def load_preprocessed_data_from_cmd(study_id, params_table, filedir, @@ -128,14 +128,17 @@ def load_preprocessed_data_from_cmd(study_id, params_table, filedir, data_type : str The data type of the template """ - fp_types_dict = get_filepath_types() - fp_type = fp_types_dict[filepathtype] - filepaths = [(join(filedir, fp), fp_type) for fp in listdir(filedir)] - pt = None if prep_template_id is None else PrepTemplate(prep_template_id) - return PreprocessedData.create( - Study(study_id), params_table, params_id, filepaths, prep_template=pt, - submitted_to_insdc_status=submitted_to_insdc_status, - data_type=data_type) + with TRN: + fp_types_dict = get_filepath_types() + fp_type = fp_types_dict[filepathtype] + filepaths = [(join(filedir, fp), fp_type) for fp in listdir(filedir)] + pt = (None if prep_template_id is None + else PrepTemplate(prep_template_id)) + return PreprocessedData.create( + Study(study_id), params_table, params_id, filepaths, + prep_template=pt, + submitted_to_insdc_status=submitted_to_insdc_status, + data_type=data_type) def load_sample_template_from_cmd(sample_temp_path, study_id): @@ -149,7 +152,6 @@ def load_sample_template_from_cmd(sample_temp_path, study_id): The study id to which the sample template belongs """ sample_temp = load_template_to_dataframe(sample_temp_path) - return SampleTemplate.create(sample_temp, Study(study_id)) @@ -192,16 +194,17 @@ def load_raw_data_cmd(filepaths, filepath_types, filetype, prep_template_ids): raise ValueError("Please pass exactly one filepath_type for each " "and every filepath") - filetypes_dict = get_filetypes() - filetype_id = filetypes_dict[filetype] + with TRN: + filetypes_dict = get_filetypes() + filetype_id = filetypes_dict[filetype] - filepath_types_dict = get_filepath_types() - filepath_types = [filepath_types_dict[x] for x in filepath_types] + filepath_types_dict = get_filepath_types() + filepath_types = [filepath_types_dict[x] for x in filepath_types] - prep_templates = [PrepTemplate(x) for x in prep_template_ids] + prep_templates = [PrepTemplate(x) for x in prep_template_ids] - return RawData.create(filetype_id, prep_templates, - filepaths=list(zip(filepaths, filepath_types))) + return RawData.create(filetype_id, prep_templates, + filepaths=list(zip(filepaths, filepath_types))) def load_processed_data_cmd(fps, fp_types, processed_params_table_name, @@ -235,25 +238,21 @@ def load_processed_data_cmd(fps, fp_types, processed_params_table_name, raise ValueError("Please pass exactly one fp_type for each " "and every fp") - fp_types_dict = get_filepath_types() - fp_types = [fp_types_dict[x] for x in fp_types] + with TRN: + fp_types_dict = get_filepath_types() + fp_types = [fp_types_dict[x] for x in fp_types] - if preprocessed_data_id is not None: - preprocessed_data = PreprocessedData(preprocessed_data_id) - else: - preprocessed_data = None + preprocessed_data = (None if preprocessed_data_id is None + else PreprocessedData(preprocessed_data_id)) - if study_id is not None: - study = Study(study_id) - else: - study = None + study = None if study_id is None else Study(study_id) - if processed_date is not None: - processed_date = parse(processed_date) + if processed_date is not None: + processed_date = parse(processed_date) - return ProcessedData.create(processed_params_table_name, - processed_params_id, list(zip(fps, fp_types)), - preprocessed_data, study, processed_date) + return ProcessedData.create( + processed_params_table_name, processed_params_id, + list(zip(fps, fp_types)), preprocessed_data, study, processed_date) def load_parameters_from_cmd(name, fp, table): @@ -348,77 +347,63 @@ def update_preprocessed_data_from_cmd(sl_out_dir, study_id, ppd_id=None): "%s" % (sl_out_dir, ', '.join(missing_files))) # Get the preprocessed data to be updated - study = Study(study_id) - ppds = study.preprocessed_data() - if not ppds: - raise ValueError("Study %s does not have any preprocessed data") - - if ppd_id: - if ppd_id not in ppds: - raise ValueError("The preprocessed data %d does not exist in " - "study %d. Available preprocessed data: %s" - % (ppd_id, study_id, ', '.join(map(str, ppds)))) - ppd = PreprocessedData(ppd_id) - else: - ppd = PreprocessedData(sorted(ppds)[0]) - - # We need to loop through the fps list to get the db filepaths that we - # need to modify - fps = defaultdict(list) - for fp_id, fp, fp_type in sorted(ppd.get_filepaths()): - fps[fp_type].append((fp_id, fp)) - - fps_to_add = [] - fps_to_modify = [] - keys = ['preprocessed_fasta', 'preprocessed_fastq', 'preprocessed_demux', - 'log'] - - for key in keys: - if key in fps: - db_id, db_fp = fps[key][0] - fp_checksum = compute_checksum(new_fps[key]) - fps_to_modify.append((db_id, db_fp, new_fps[key], fp_checksum)) + with TRN: + study = Study(study_id) + ppds = study.preprocessed_data() + if not ppds: + raise ValueError("Study %s does not have any preprocessed data") + + if ppd_id: + if ppd_id not in ppds: + raise ValueError( + "The preprocessed data %d does not exist in " + "study %d. Available preprocessed data: %s" + % (ppd_id, study_id, ', '.join(map(str, ppds)))) + ppd = PreprocessedData(ppd_id) else: - fps_to_add.append( - (new_fps[key], convert_to_id(key, 'filepath_type'))) - - # Insert the new files in the database, if any - if fps_to_add: - ppd.add_filepaths(fps_to_add) - - # Update the files and the database - conn_handler = SQLConnectionHandler() - # Create a queue so we can execute all the modifications on the DB in - # a transaction block - queue_name = "update_ppd_%d" % ppd.id - conn_handler.create_queue(queue_name) - sql = "UPDATE qiita.filepath SET checksum=%s WHERE filepath_id=%s" - bkp_files = [] - for db_id, db_fp, new_fp, checksum in fps_to_modify: - # Move the db_file in case something goes wrong - bkp_fp = "%s.bkp" % db_fp - move(db_fp, bkp_fp) - bkp_files.append((bkp_fp, db_fp)) - - # Start the update for the current file - # Move the file to the database location - move(new_fp, db_fp) - # Add the SQL instruction to the DB - conn_handler.add_to_queue(queue_name, sql, (checksum, db_id)) - - # Execute the queue - try: - conn_handler.execute_queue(queue_name) - except Exception: - # We need to catch any exception so we can restore the db files - for bkp_fp, db_fp in bkp_files: - move(bkp_fp, db_fp) - # Using just raise so the original traceback is shown - raise - - # Since the files and the database have been updated correctly, - # remove the backup files - for bkp_fp, _ in bkp_files: - remove(bkp_fp) - - return ppd + ppd = PreprocessedData(sorted(ppds)[0]) + + # We need to loop through the fps list to get the db filepaths that we + # need to modify + fps = defaultdict(list) + for fp_id, fp, fp_type in sorted(ppd.get_filepaths()): + fps[fp_type].append((fp_id, fp)) + + fps_to_add = [] + fps_to_modify = [] + keys = ['preprocessed_fasta', 'preprocessed_fastq', + 'preprocessed_demux', 'log'] + + for key in keys: + if key in fps: + db_id, db_fp = fps[key][0] + fp_checksum = compute_checksum(new_fps[key]) + fps_to_modify.append((db_id, db_fp, new_fps[key], fp_checksum)) + else: + fps_to_add.append( + (new_fps[key], convert_to_id(key, 'filepath_type'))) + + # Insert the new files in the database, if any + if fps_to_add: + ppd.add_filepaths(fps_to_add) + + sql = "UPDATE qiita.filepath SET checksum=%s WHERE filepath_id=%s" + for db_id, db_fp, new_fp, checksum in fps_to_modify: + # Move the db_file in case something goes wrong + bkp_fp = "%s.bkp" % db_fp + move(db_fp, bkp_fp) + + # Start the update for the current file + # Move the file to the database location + move(new_fp, db_fp) + # Add the SQL instruction to the DB + TRN.add(sql, [checksum, db_id]) + + # In case that a rollback occurs, we need to restore the files + TRN.add_post_rollback_func(move, bkp_fp, db_fp) + # In case of commit, we can remove the backup files + TRN.add_post_commit_func(remove, bkp_fp) + + TRN.execute() + + return ppd diff --git a/qiita_db/data.py b/qiita_db/data.py index 9901c9e25..aed308fc5 100644 --- a/qiita_db/data.py +++ b/qiita_db/data.py @@ -84,7 +84,7 @@ from qiita_core.exceptions import IncompetentQiitaDeveloperError from .base import QiitaObject from .logger import LogEntry -from .sql_connection import SQLConnectionHandler +from .sql_connection import TRN from .exceptions import QiitaDBError, QiitaDBUnknownIDError, QiitaDBStatusError from .util import (exists_dynamic_table, insert_filepaths, convert_to_id, convert_from_id, get_filepath_id, get_mountpoint, @@ -112,16 +112,13 @@ class BaseData(QiitaObject): _data_filepath_table = None _data_filepath_column = None - def _link_data_filepaths(self, fp_ids, conn_handler): + def _link_data_filepaths(self, fp_ids): r"""Links the data `data_id` with its filepaths `fp_ids` in the DB - connected with `conn_handler` Parameters ---------- fp_ids : list of ints The filepaths ids to connect the data - conn_handler : SQLConnectionHandler - The connection handler object connected to the DB Raises ------ @@ -130,43 +127,39 @@ def _link_data_filepaths(self, fp_ids, conn_handler): not define the class attributes _data_filepath_table and _data_filepath_column """ - # Create the list of SQL values to add - values = [(self.id, fp_id) for fp_id in fp_ids] - # Add all rows at once - conn_handler.executemany( - "INSERT INTO qiita.{0} ({1}, filepath_id) " - "VALUES (%s, %s)".format(self._data_filepath_table, - self._data_filepath_column), values) + with TRN: + # Create the list of SQL values to add + values = [[self.id, fp_id] for fp_id in fp_ids] + # Add all rows at once + sql = """INSERT INTO qiita.{0} ({1}, filepath_id) + VALUES (%s, %s)""".format(self._data_filepath_table, + self._data_filepath_column) + TRN.add(sql, values, many=True) + TRN.execute() def add_filepaths(self, filepaths): r"""Populates the DB tables for storing the filepaths and connects the `self` objects with these filepaths""" - # Check that this function has been called from a subclass - self._check_subclass() - - # Check if the connection handler has been provided. Create a new - # one if not. - conn_handler = SQLConnectionHandler() - - # Update the status of the current object - self._set_link_filepaths_status("linking") - - try: - # Add the filepaths to the database - fp_ids = insert_filepaths(filepaths, self._id, self._table, - self._filepath_table, conn_handler) - - # Connect the raw data with its filepaths - self._link_data_filepaths(fp_ids, conn_handler) - except Exception as e: - # Something went wrong, update the status - self._set_link_filepaths_status("failed: %s" % e) - LogEntry.create('Runtime', str(e), - info={self.__class__.__name__: self.id}) - raise e - - # Filepaths successfully added, update the status - self._set_link_filepaths_status("idle") + with TRN: + # Update the status of the current object + self._set_link_filepaths_status("linking") + + try: + # Add the filepaths to the database + fp_ids = insert_filepaths(filepaths, self._id, self._table, + self._filepath_table) + + # Connect the raw data with its filepaths + self._link_data_filepaths(fp_ids) + except Exception as e: + # Something went wrong, update the status + self._set_link_filepaths_status("failed: %s" % e) + LogEntry.create('Runtime', str(e), + info={self.__class__.__name__: self.id}) + raise e + + # Filepaths successfully added, update the status + self._set_link_filepaths_status("idle") def get_filepaths(self): r"""Returns the filepaths and filetypes associated with the data object @@ -177,44 +170,44 @@ def get_filepaths(self): A list of (filepath_id, path, filetype) with all the paths associated with the current data """ - self._check_subclass() - # We need a connection handler to the database - conn_handler = SQLConnectionHandler() - # Retrieve all the (path, id) tuples related with the current data - # object. We need to first check the _data_filepath_table to get the - # filepath ids of the filepath associated with the current data object. - # We then can query the filepath table to get those paths. - db_paths = conn_handler.execute_fetchall( - "SELECT filepath_id, filepath, filepath_type_id " - "FROM qiita.{0} WHERE " - "filepath_id IN (SELECT filepath_id FROM qiita.{1} WHERE " - "{2}=%(id)s)".format(self._filepath_table, - self._data_filepath_table, - self._data_filepath_column), {'id': self.id}) - - _, fb = get_mountpoint(self._table)[0] - base_fp = partial(join, fb) - - return [(fpid, base_fp(fp), convert_from_id(fid, "filepath_type")) - for fpid, fp, fid in db_paths] + with TRN: + # Retrieve all the (path, id) tuples related with the current data + # object. We need to first check the _data_filepath_table to get + # the filepath ids of the filepath associated with the current data + # object. We then can query the filepath table to get those paths. + sql = """SELECT filepath_id, filepath, filepath_type_id + FROM qiita.{0} + WHERE filepath_id IN ( + SELECT filepath_id + FROM qiita.{1} + WHERE {2}=%s)""".format(self._filepath_table, + self._data_filepath_table, + self._data_filepath_column) + TRN.add(sql, [self.id]) + db_paths = TRN.execute_fetchindex() + + _, fb = get_mountpoint(self._table)[0] + base_fp = partial(join, fb) + + return [(fpid, base_fp(fp), convert_from_id(fid, "filepath_type")) + for fpid, fp, fid in db_paths] def get_filepath_ids(self): - self._check_subclass() - conn_handler = SQLConnectionHandler() - db_ids = conn_handler.execute_fetchall( - "SELECT filepath_id FROM qiita.{0} WHERE " - "{1}=%(id)s".format(self._data_filepath_table, - self._data_filepath_column), {'id': self.id}) - return [fp_id[0] for fp_id in db_ids] + with TRN: + sql = "SELECT filepath_id FROM qiita.{0} WHERE {1}=%s".format( + self._data_filepath_table, self._data_filepath_column) + TRN.add(sql, [self.id]) + return TRN.execute_fetchflatten() @property def link_filepaths_status(self): - self._check_subclass() - conn_handler = SQLConnectionHandler() - return conn_handler.execute_fetchone( - "SELECT link_filepaths_status FROM qiita.{0} " - "WHERE {1}=%s".format(self._table, self._data_filepath_column), - (self._id,))[0] + with TRN: + sql = """SELECT link_filepaths_status + FROM qiita.{0} + WHERE {1}=%s""".format(self._table, + self._data_filepath_column) + TRN.add(sql, [self.id]) + return TRN.execute_fetchlast() def _set_link_filepaths_status(self, status): """Updates the link_filepaths_status of the object @@ -229,19 +222,19 @@ def _set_link_filepaths_status(self, status): ValueError If the status is unknown """ - self._check_subclass() - if (status not in ('idle', 'linking', 'unlinking') and - not status.startswith('failed')): - msg = 'Unknown status: %s' % status - LogEntry.create('Runtime', msg, - info={self.__class__.__name__: self.id}) - raise ValueError(msg) - - conn_handler = SQLConnectionHandler() - conn_handler.execute( - "UPDATE qiita.{0} SET link_filepaths_status = %s " - "WHERE {1} = %s".format(self._table, self._data_filepath_column), - (status, self._id)) + with TRN: + if (status not in ('idle', 'linking', 'unlinking') and + not status.startswith('failed')): + msg = 'Unknown status: %s' % status + LogEntry.create('Runtime', msg, + info={self.__class__.__name__: self.id}) + raise ValueError(msg) + + sql = """UPDATE qiita.{0} SET link_filepaths_status = %s + WHERE {1} = %s""".format(self._table, + self._data_filepath_column) + TRN.add(sql, [status, self._id]) + TRN.execute() @classmethod def exists(cls, object_id): @@ -257,12 +250,12 @@ def exists(cls, object_id): bool True if exists, false otherwise. """ - conn_handler = SQLConnectionHandler() - - return conn_handler.execute_fetchone( - "SELECT EXISTS(SELECT * FROM qiita.{0} WHERE " - "{1}=%s)".format(cls._table, cls._data_filepath_column), - (object_id, ))[0] + with TRN: + cls._check_subclass() + sql = "SELECT EXISTS(SELECT * FROM qiita.{0} WHERE {1}=%s)".format( + cls._table, cls._data_filepath_column) + TRN.add(sql, [object_id]) + return TRN.execute_fetchlast() class RawData(BaseData): @@ -310,43 +303,44 @@ def create(cls, filetype, prep_templates, filepaths): QiitaDBError If any of the passed prep templates already have a raw data id """ - conn_handler = SQLConnectionHandler() - # We first need to check if the passed prep templates don't have - # a raw data already attached to them - sql = """SELECT EXISTS( - SELECT * - FROM qiita.prep_template - WHERE prep_template_id IN ({}) - AND raw_data_id IS NOT NULL)""".format( - ', '.join(['%s'] * len(prep_templates))) - exists = conn_handler.execute_fetchone( - sql, [pt.id for pt in prep_templates])[0] - if exists: - raise QiitaDBError( - "Cannot create raw data because the passed prep templates " - "already have a raw data associated with it. " - "Prep templates: %s" - % ', '.join([str(pt.id) for pt in prep_templates])) - - # Add the raw data to the database, and get the raw data id back - rd_id = conn_handler.execute_fetchone( - "INSERT INTO qiita.{0} (filetype_id) VALUES (%s) " - "RETURNING raw_data_id".format(cls._table), (filetype,))[0] - - # Instantiate the object with the new id - rd = cls(rd_id) - - # Connect the raw data with its prep templates - values = [(rd_id, pt.id) for pt in prep_templates] - sql = """UPDATE qiita.prep_template - SET raw_data_id = %s WHERE prep_template_id = %s""" - conn_handler.executemany(sql, values) - - # If file paths have been provided, add them to the raw data object - if filepaths: + with TRN: + # We first need to check if the passed prep templates don't have + # a raw data already attached to them + sql = """SELECT EXISTS( + SELECT * + FROM qiita.prep_template + WHERE prep_template_id IN %s + AND raw_data_id IS NOT NULL)""".format( + ', '.join(['%s'] * len(prep_templates))) + TRN.add(sql, [tuple(pt.id for pt in prep_templates)]) + exists = TRN.execute_fetchlast() + if exists: + raise QiitaDBError( + "Cannot create raw data because the passed prep templates " + "already have a raw data associated with it. " + "Prep templates: %s" + % ', '.join([str(pt.id) for pt in prep_templates])) + + # Add the raw data to the database, and get the raw data id back + sql = """INSERT INTO qiita.{0} (filetype_id) VALUES (%s) + RETURNING raw_data_id""".format(cls._table) + TRN.add(sql, [filetype]) + rd_id = TRN.execute_fetchlast() + + # Instantiate the object with the new id + rd = cls(rd_id) + + # Connect the raw data with its prep templates + values = [[rd_id, pt.id] for pt in prep_templates] + sql = """UPDATE qiita.prep_template + SET raw_data_id = %s WHERE prep_template_id = %s""" + TRN.add(sql, values, many=True) + TRN.execute() + + # Link the files with the raw data object rd.add_filepaths(filepaths) - return rd + return rd @classmethod def delete(cls, raw_data_id, prep_template_id): @@ -367,48 +361,47 @@ def delete(cls, raw_data_id, prep_template_id): If the raw data is not linked to that prep_template_id If the raw data has files linked """ - conn_handler = SQLConnectionHandler() - - # check if the raw data exist - if not cls.exists(raw_data_id): - raise QiitaDBUnknownIDError(raw_data_id, "raw data") - - # Check if the raw data is linked to the prep template - sql = """SELECT EXISTS( - SELECT * FROM qiita.prep_template - WHERE prep_template_id = %s AND raw_data_id = %s)""" - pt_rd_exists = conn_handler.execute_fetchone( - sql, (prep_template_id, raw_data_id)) - if not pt_rd_exists: - raise QiitaDBError( - "Raw data %d is not linked to prep template %d or the prep " - "template doesn't exist" % (raw_data_id, prep_template_id)) - - # Check to how many prep templates the raw data is still linked. - # If last one, check that are no linked files - raw_data_count = conn_handler.execute_fetchone( - "SELECT COUNT(*) FROM qiita.prep_template WHERE " - "raw_data_id = %s", (raw_data_id,))[0] - if raw_data_count == 1 and RawData(raw_data_id).get_filepath_ids(): - raise QiitaDBError( - "Raw data (%d) can't be remove because it has linked files. " - "To remove it, first unlink files." % raw_data_id) - - # delete - queue = "DELETE_%d_%d" % (raw_data_id, prep_template_id) - conn_handler.create_queue(queue) - sql = """UPDATE qiita.prep_template - SET raw_data_id = %s - WHERE prep_template_id = %s""" - conn_handler.add_to_queue(queue, sql, (None, prep_template_id)) - - # If there is no other prep template pointing to the raw data, it can - # be removed - if raw_data_count == 1: - sql = "DELETE FROM qiita.raw_data WHERE raw_data_id = %s" - conn_handler.add_to_queue(queue, sql, (raw_data_id,)) - - conn_handler.execute_queue(queue) + with TRN: + # check if the raw data exist + if not cls.exists(raw_data_id): + raise QiitaDBUnknownIDError(raw_data_id, "raw data") + + # Check if the raw data is linked to the prep template + sql = """SELECT EXISTS( + SELECT * FROM qiita.prep_template + WHERE prep_template_id = %s AND raw_data_id = %s)""" + TRN.add(sql, [prep_template_id, raw_data_id]) + pt_rd_exists = TRN.execute_fetchlast() + if not pt_rd_exists: + raise QiitaDBError( + "Raw data %d is not linked to prep template %d or the " + "prep template doesn't exist" + % (raw_data_id, prep_template_id)) + + # Check to how many prep templates the raw data is still linked. + # If last one, check that are no linked files + sql = """SELECT COUNT(*) FROM qiita.prep_template + WHERE raw_data_id = %s""" + TRN.add(sql, [raw_data_id]) + raw_data_count = TRN.execute_fetchlast() + if raw_data_count == 1 and RawData(raw_data_id).get_filepath_ids(): + raise QiitaDBError( + "Raw data (%d) can't be removed because it has linked " + "files. To remove it, first unlink files." % raw_data_id) + + # delete + sql = """UPDATE qiita.prep_template + SET raw_data_id = %s + WHERE prep_template_id = %s""" + TRN.add(sql, [None, prep_template_id]) + + # If there is no other prep template pointing to the raw data, it + # can be removed + if raw_data_count == 1: + sql = "DELETE FROM qiita.raw_data WHERE raw_data_id = %s" + TRN.add(sql, [raw_data_id]) + + TRN.execute() @property def studies(self): @@ -419,13 +412,13 @@ def studies(self): list of int The list of study ids to which the raw data belongs to """ - conn_handler = SQLConnectionHandler() - sql = """SELECT study_id - FROM qiita.study_prep_template - JOIN qiita.prep_template USING (prep_template_id) - WHERE raw_data_id = %s""" - ids = conn_handler.execute_fetchall(sql, (self.id,)) - return [id[0] for id in ids] + with TRN: + sql = """SELECT study_id + FROM qiita.study_prep_template + JOIN qiita.prep_template USING (prep_template_id) + WHERE raw_data_id = %s""" + TRN.add(sql, [self.id]) + return TRN.execute_fetchflatten() @property def filetype(self): @@ -436,12 +429,13 @@ def filetype(self): str The raw data's filetype """ - conn_handler = SQLConnectionHandler() - return conn_handler.execute_fetchone( - "SELECT f.type FROM qiita.filetype f JOIN qiita.{0} r ON " - "f.filetype_id = r.filetype_id WHERE " - "r.raw_data_id=%s".format(self._table), - (self._id,))[0] + with TRN: + sql = """SELECT f.type + FROM qiita.filetype f + JOIN qiita.{0} r ON f.filetype_id = r.filetype_id + WHERE r.raw_data_id=%s""".format(self._table) + TRN.add(sql, [self._id]) + return TRN.execute_fetchlast() def data_types(self, ret_id=False): """Returns the list of data_types or data_type_ids @@ -456,20 +450,23 @@ def data_types(self, ret_id=False): list of str or int string values of data_type or ints if data_type_id """ - ret = "_id" if ret_id else "" - conn_handler = SQLConnectionHandler() - data_types = conn_handler.execute_fetchall( - "SELECT d.data_type{0} FROM qiita.data_type d JOIN " - "qiita.prep_template p ON p.data_type_id = d.data_type_id " - "WHERE p.raw_data_id = %s".format(ret), (self._id, )) - return [dt[0] for dt in data_types] + with TRN: + ret = "_id" if ret_id else "" + sql = """SELECT d.data_type{0} + FROM qiita.data_type d + JOIN qiita.prep_template p + ON p.data_type_id = d.data_type_id + WHERE p.raw_data_id = %s""".format(ret) + TRN.add(sql, [self._id]) + return TRN.execute_fetchflatten() @property def prep_templates(self): - conn_handler = SQLConnectionHandler() - sql = ("SELECT prep_template_id FROM qiita.prep_template " - "WHERE raw_data_id = %s ORDER BY prep_template_id") - return [x[0] for x in conn_handler.execute_fetchall(sql, (self._id,))] + with TRN: + sql = """SELECT prep_template_id FROM qiita.prep_template + WHERE raw_data_id = %s ORDER BY prep_template_id""" + TRN.add(sql, [self._id]) + return TRN.execute_fetchflatten() def _is_preprocessed(self): """Returns whether the RawData has been preprocessed or not @@ -479,23 +476,23 @@ def _is_preprocessed(self): bool whether the RawData has been preprocessed or not """ - conn_handler = SQLConnectionHandler() - return conn_handler.execute_fetchone( - "SELECT EXISTS(SELECT * FROM qiita.prep_template_preprocessed_data" - " PTPD JOIN qiita.prep_template PT ON PT.prep_template_id = " - "PTPD.prep_template_id WHERE PT.raw_data_id = %s)", (self._id,))[0] - - def _remove_filepath(self, fp, conn_handler, queue): + with TRN: + sql = """SELECT EXISTS( + SELECT * + FROM qiita.prep_template_preprocessed_data PTPD + JOIN qiita.prep_template PT + ON PT.prep_template_id = PTPD.prep_template_id + WHERE PT.raw_data_id = %s)""" + TRN.add(sql, [self._id]) + return TRN.execute_fetchlast() + + def _remove_filepath(self, fp): """Removes the filepath from the RawData Parameters ---------- fp : str The filepath to remove - conn_handler : SQLConnectionHandler - The connection handler object connected to the DB - queue : str - The queue to use in the conn_handler Raises ------ @@ -506,44 +503,47 @@ def _remove_filepath(self, fp, conn_handler, queue): ValueError If fp does not belong to the raw data """ - # If the RawData has been already preprocessed, we cannot remove any - # file - raise an error - if self._is_preprocessed(): - msg = ("Cannot clear all the filepaths from raw data %s, it has " - "been already preprocessed" % self._id) - self._set_link_filepaths_status("failed: %s" % msg) - raise QiitaDBError(msg) - - # The filepath belongs to one or more prep templates - prep_templates = self.prep_templates - if len(prep_templates) > 1: - msg = ("Can't clear all the filepaths from raw data %s because " - "it has been used with other prep templates: %s. If you " - "want to remove it, first remove the raw data from the " - "other prep templates." - % (self._id, ', '.join(map(str, prep_templates)))) - self._set_link_filepaths_status("failed: %s" % msg) - raise QiitaDBError(msg) - - # Get the filpeath id - fp_id = get_filepath_id(self._table, fp) - fp_is_mine = conn_handler.execute_fetchone( - "SELECT EXISTS(SELECT * FROM qiita.{0} WHERE filepath_id=%s AND " - "{1}=%s)".format(self._data_filepath_table, - self._data_filepath_column), - (fp_id, self._id))[0] - - if not fp_is_mine: - msg = ("The filepath %s does not belong to raw data %s" - % (fp, self._id)) - self._set_link_filepaths_status("failed: %s" % msg) - raise ValueError(msg) - - # We can remove the file - sql = "DELETE FROM qiita.{0} WHERE filepath_id=%s".format( - self._data_filepath_table) - sql_args = (fp_id,) - conn_handler.add_to_queue(queue, sql, sql_args) + with TRN: + # If the RawData has been already preprocessed, we cannot remove + # any file - raise an error + if self._is_preprocessed(): + msg = ("Cannot clear all the filepaths from raw data %s, it " + "has been already preprocessed" % self._id) + self._set_link_filepaths_status("failed: %s" % msg) + raise QiitaDBError(msg) + + # The filepath belongs to one or more prep templates + prep_templates = self.prep_templates + if len(prep_templates) > 1: + msg = ("Can't clear all the filepaths from raw data %s " + "because it has been used with other prep templates: " + "%s. If you want to remove it, first remove the raw " + "data from the other prep templates." + % (self._id, ', '.join(map(str, prep_templates)))) + self._set_link_filepaths_status("failed: %s" % msg) + raise QiitaDBError(msg) + + # Get the filpeath id + fp_id = get_filepath_id(self._table, fp) + sql = """SELECT EXISTS( + SELECT * + FROM qiita.{0} + WHERE filepath_id=%s AND {1}=%s + )""".format(self._data_filepath_table, + self._data_filepath_column) + TRN.add(sql, [fp_id, self._id]) + fp_is_mine = TRN.execute_fetchlast() + + if not fp_is_mine: + msg = ("The filepath %s does not belong to raw data %s" + % (fp, self._id)) + self._set_link_filepaths_status("failed: %s" % msg) + raise ValueError(msg) + + # We can remove the file + sql = "DELETE FROM qiita.{0} WHERE filepath_id=%s".format( + self._data_filepath_table) + TRN.add(sql, [fp_id]) def clear_filepaths(self): """Removes all the filepaths attached to the RawData @@ -553,34 +553,29 @@ def clear_filepaths(self): QiitaDBError If the RawData has been already preprocessed """ - conn_handler = SQLConnectionHandler() - - queue = "%s_clear_fps" % self.id - conn_handler.create_queue(queue) - - self._set_link_filepaths_status("unlinking") - - filepaths = self.get_filepaths() - for _, fp, _ in filepaths: - self._remove_filepath(fp, conn_handler, queue) - - try: - # Execute all the queue - conn_handler.execute_queue(queue) - except Exception as e: - self._set_link_filepaths_status("failed: %s" % e) - LogEntry.create('Runtime', str(e), - info={self.__class__.__name__: self.id}) - raise e - - # We can already update the status to done, as the files have been - # unlinked, the move_filepaths_to_upload_folder call will not change - # the status of the raw data object - self._set_link_filepaths_status("idle") - - # Move the files, if they are not used, if you get to this point - # self.studies should only have one element, thus self.studies[0] - move_filepaths_to_upload_folder(self.studies[0], filepaths) + with TRN: + self._set_link_filepaths_status("unlinking") + + filepaths = self.get_filepaths() + for _, fp, _ in filepaths: + self._remove_filepath(fp) + + try: + TRN.execute() + except Exception as e: + self._set_link_filepaths_status("failed: %s" % e) + LogEntry.create('Runtime', str(e), + info={self.__class__.__name__: self.id}) + raise e + + # We can already update the status to done, as the files have been + # unlinked, the move_filepaths_to_upload_folder call will not + # change the status of the raw data object + self._set_link_filepaths_status("idle") + + # Move the files, if they are not used, if you get to this point + # self.studies should only have one element, thus self.studies[0] + move_filepaths_to_upload_folder(self.studies[0], filepaths) def status(self, study): """The status of the raw data within the given study @@ -609,28 +604,28 @@ def status(self, study): We then check the processed data generated to infer the status of the raw data. """ - if self._id not in study.raw_data(): - raise QiitaDBStatusError( - "The study %s does not have access to the raw data %s" - % (study.id, self.id)) - - conn_handler = SQLConnectionHandler() - sql = """SELECT processed_data_status - FROM qiita.processed_data_status pds - JOIN qiita.processed_data pd - USING (processed_data_status_id) - JOIN qiita.preprocessed_processed_data ppd_pd - USING (processed_data_id) - JOIN qiita.prep_template_preprocessed_data pt_ppd - USING (preprocessed_data_id) - JOIN qiita.prep_template pt - USING (prep_template_id) - JOIN qiita.study_processed_data spd - USING (processed_data_id) - WHERE pt.raw_data_id=%s AND spd.study_id=%s""" - pd_statuses = conn_handler.execute_fetchall(sql, (self._id, study.id)) - - return infer_status(pd_statuses) + with TRN: + if self._id not in study.raw_data(): + raise QiitaDBStatusError( + "The study %s does not have access to the raw data %s" + % (study.id, self.id)) + + sql = """SELECT processed_data_status + FROM qiita.processed_data_status pds + JOIN qiita.processed_data pd + USING (processed_data_status_id) + JOIN qiita.preprocessed_processed_data ppd_pd + USING (processed_data_id) + JOIN qiita.prep_template_preprocessed_data pt_ppd + USING (preprocessed_data_id) + JOIN qiita.prep_template pt + USING (prep_template_id) + JOIN qiita.study_processed_data spd + USING (processed_data_id) + WHERE pt.raw_data_id=%s AND spd.study_id=%s""" + TRN.add(sql, [self._id, study.id]) + + return infer_status(TRN.execute_fetchindex()) class PreprocessedData(BaseData): @@ -701,75 +696,67 @@ def create(cls, study, preprocessed_params_table, preprocessed_params_id, IncompetentQiitaDeveloperError If data_type does not match that of prep_template passed """ - conn_handler = SQLConnectionHandler() - - # Sanity checks for the preprocesses_data data_type - if ((data_type and prep_template) and - data_type != prep_template.data_type): - raise IncompetentQiitaDeveloperError( - "data_type passed does not match prep_template data_type!") - elif data_type is None and prep_template is None: - raise IncompetentQiitaDeveloperError("Neither data_type nor " - "prep_template passed!") - elif prep_template: - # prep_template passed but no data_type, - # so set to prep_template data_type - data_type = prep_template.data_type(ret_id=True) - else: - # only data_type, so need id from the text - data_type = convert_to_id(data_type, "data_type") - - # Check that the preprocessed_params_table exists - if not exists_dynamic_table(preprocessed_params_table, "preprocessed_", - "_params", conn_handler): - raise IncompetentQiitaDeveloperError( - "Preprocessed params table '%s' does not exists!" - % preprocessed_params_table) - - # Add the preprocessed data to the database, - # and get the preprocessed data id back - ppd_id = conn_handler.execute_fetchone( - "INSERT INTO qiita.{0} (preprocessed_params_table, " - "preprocessed_params_id, submitted_to_insdc_status, data_type_id, " - "ebi_submission_accession, ebi_study_accession) VALUES " - "(%(param_table)s, %(param_id)s, %(insdc)s, %(data_type)s, " - "%(ebi_submission_accession)s, %(ebi_study_accession)s) " - "RETURNING preprocessed_data_id".format(cls._table), - {'param_table': preprocessed_params_table, - 'param_id': preprocessed_params_id, - 'insdc': submitted_to_insdc_status, - 'data_type': data_type, - 'ebi_submission_accession': ebi_submission_accession, - 'ebi_study_accession': ebi_study_accession})[0] - ppd = cls(ppd_id) - - # Connect the preprocessed data with its study - conn_handler.execute( - "INSERT INTO qiita.{0} (study_id, preprocessed_data_id) " - "VALUES (%s, %s)".format(ppd._study_preprocessed_table), - (study.id, ppd.id)) - - # If the prep template was provided, connect the preprocessed data - # with the prep_template - if prep_template is not None: - q = conn_handler.get_temp_queue() - conn_handler.add_to_queue( - q, - "INSERT INTO qiita.{0} (prep_template_id, " - "preprocessed_data_id) VALUES " - "(%s, %s)".format(cls._template_preprocessed_table), - (prep_template.id, ppd_id)) - conn_handler.add_to_queue( - q, - """UPDATE qiita.prep_template - SET preprocessing_status = 'success' - WHERE prep_template_id = %s""", - [prep_template.id]) - conn_handler.execute_queue(q) - - # Add the filepaths to the database and connect them - ppd.add_filepaths(filepaths) - return ppd + with TRN: + # Sanity checks for the preprocesses_data data_type + if ((data_type and prep_template) and + data_type != prep_template.data_type): + raise IncompetentQiitaDeveloperError( + "data_type passed does not match prep_template data_type!") + elif data_type is None and prep_template is None: + raise IncompetentQiitaDeveloperError( + "Neither data_type nor prep_template passed!") + elif prep_template: + # prep_template passed but no data_type, + # so set to prep_template data_type + data_type = prep_template.data_type(ret_id=True) + else: + # only data_type, so need id from the text + data_type = convert_to_id(data_type, "data_type") + + # Check that the preprocessed_params_table exists + if not exists_dynamic_table(preprocessed_params_table, + "preprocessed_", "_params"): + raise IncompetentQiitaDeveloperError( + "Preprocessed params table '%s' does not exists!" + % preprocessed_params_table) + + # Add the preprocessed data to the database, + # and get the preprocessed data id back + sql = """INSERT INTO qiita.{0} ( + preprocessed_params_table, preprocessed_params_id, + submitted_to_insdc_status, data_type_id, + ebi_submission_accession, ebi_study_accession) + VALUES (%s, %s, %s, %s, %s, %s) + RETURNING preprocessed_data_id""".format(cls._table) + TRN.add(sql, [preprocessed_params_table, preprocessed_params_id, + submitted_to_insdc_status, data_type, + ebi_submission_accession, ebi_study_accession]) + ppd_id = TRN.execute_fetchlast() + ppd = cls(ppd_id) + + # Connect the preprocessed data with its study + sql = """INSERT INTO qiita.{0} (study_id, preprocessed_data_id) + VALUES (%s, %s)""".format(ppd._study_preprocessed_table) + TRN.add(sql, [study.id, ppd.id]) + + # If the prep template was provided, connect the preprocessed data + # with the prep_template + if prep_template is not None: + sql = """INSERT INTO qiita.{0} + (prep_template_id, preprocessed_data_id) + VALUES (%s, %s)""".format( + cls._template_preprocessed_table) + TRN.add(sql, [prep_template.id, ppd_id]) + + sql = """UPDATE qiita.prep_template + SET preprocessing_status = 'success' + WHERE prep_template_id = %s""" + TRN.add(sql, [prep_template.id]) + + TRN.execute() + # Add the filepaths to the database and connect them + ppd.add_filepaths(filepaths) + return ppd @classmethod def delete(cls, ppd_id): @@ -789,72 +776,77 @@ def delete(cls, ppd_id): QiitaDBError If the preprocessed data has been processed """ - valid_submission_states = ['not submitted', 'failed'] - ppd = cls(ppd_id) - if ppd.status != 'sandbox': - raise QiitaDBStatusError( - "Illegal operation on non sandboxed preprocessed data") - elif ppd.submitted_to_vamps_status() not in valid_submission_states: - raise QiitaDBStatusError( - "Illegal operation. This preprocessed data has or is being " - "added to VAMPS.") - elif ppd.submitted_to_insdc_status() not in valid_submission_states: - raise QiitaDBStatusError( - "Illegal operation. This preprocessed data has or is being " - "added to EBI.") - - conn_handler = SQLConnectionHandler() - - processed_data = [str(n[0]) for n in conn_handler.execute_fetchall( - "SELECT processed_data_id FROM qiita.preprocessed_processed_data " - "WHERE preprocessed_data_id = {0} ORDER BY " - "processed_data_id".format(ppd_id))] - - if processed_data: - raise QiitaDBError( - "Preprocessed data %d cannot be removed because it was used " - "to generate the following processed data: %s" % ( - ppd_id, ', '.join(processed_data))) - - # delete - queue = "delete_preprocessed_data_%d" % ppd_id - conn_handler.create_queue(queue) - - sql = ("DELETE FROM qiita.prep_template_preprocessed_data WHERE " - "preprocessed_data_id = {0}".format(ppd_id)) - conn_handler.add_to_queue(queue, sql) - - sql = ("DELETE FROM qiita.preprocessed_filepath WHERE " - "preprocessed_data_id = {0}".format(ppd_id)) - conn_handler.add_to_queue(queue, sql) - - sql = ("DELETE FROM qiita.study_preprocessed_data WHERE " - "preprocessed_data_id = {0}".format(ppd_id)) - conn_handler.add_to_queue(queue, sql) - - sql = ("DELETE FROM qiita.preprocessed_data WHERE " - "preprocessed_data_id = {0}".format(ppd_id)) - conn_handler.add_to_queue(queue, sql) - - conn_handler.execute_queue(queue) + with TRN: + valid_submission_states = ['not submitted', 'failed'] + ppd = cls(ppd_id) + + if ppd.status != 'sandbox': + raise QiitaDBStatusError( + "Illegal operation on non sandboxed preprocessed data") + elif ppd.submitted_to_vamps_status() not in \ + valid_submission_states: + raise QiitaDBStatusError( + "Illegal operation. This preprocessed data has or is " + "being added to VAMPS.") + elif ppd.submitted_to_insdc_status() not in \ + valid_submission_states: + raise QiitaDBStatusError( + "Illegal operation. This preprocessed data has or is " + "being added to EBI.") + + sql = """SELECT processed_data_id + FROM qiita.preprocessed_processed_data + WHERE preprocessed_data_id = %s + ORDER BY processed_data_id""".format() + TRN.add(sql, [ppd_id]) + processed_data = TRN.execute_fetchflatten() + + if processed_data: + raise QiitaDBError( + "Preprocessed data %d cannot be removed because it was " + "used to generate the following processed data: %s" % ( + ppd_id, ', '.join(map(str, processed_data)))) + + # delete + sql = """DELETE FROM qiita.prep_template_preprocessed_data + WHERE preprocessed_data_id = %s""" + args = [ppd_id] + TRN.add(sql, args) + + sql = """DELETE FROM qiita.preprocessed_filepath + WHERE preprocessed_data_id = %s""" + TRN.add(sql, args) + + sql = """DELETE FROM qiita.study_preprocessed_data + WHERE preprocessed_data_id = %s""" + TRN.add(sql, args) + + sql = """DELETE FROM qiita.preprocessed_data + WHERE preprocessed_data_id = %s""" + TRN.add(sql, args) + + TRN.execute() @property def processed_data(self): r"""The processed data list generated from this preprocessed data""" - conn_handler = SQLConnectionHandler() - processed_ids = conn_handler.execute_fetchall( - "SELECT processed_data_id FROM qiita.preprocessed_processed_data " - "WHERE preprocessed_data_id = %s", (self.id,)) - return [pid[0] for pid in processed_ids] + with TRN: + sql = """SELECT processed_data_id + FROM qiita.preprocessed_processed_data + WHERE preprocessed_data_id = %s""" + TRN.add(sql, [self._id]) + return TRN.execute_fetchflatten() @property def prep_template(self): r"""The prep template used to generate the preprocessed data""" - conn_handler = SQLConnectionHandler() - return conn_handler.execute_fetchone( - "SELECT prep_template_id FROM qiita.{0} WHERE " - "preprocessed_data_id=%s".format( - self._template_preprocessed_table), (self._id,))[0] + with TRN: + sql = """SELECT prep_template_id + FROM qiita.{0} + WHERE preprocessed_data_id=%s""".format( + self._template_preprocessed_table) + TRN.add(sql, [self._id]) + return TRN.execute_fetchlast() @property def study(self): @@ -863,12 +855,14 @@ def study(self): Returns ------- int - The study id to which this preprocessed data belongs to""" - conn_handler = SQLConnectionHandler() - return conn_handler.execute_fetchone( - "SELECT study_id FROM qiita.{0} WHERE " - "preprocessed_data_id=%s".format(self._study_preprocessed_table), - [self._id])[0] + The study id to which this preprocessed data belongs to + """ + with TRN: + sql = """SELECT study_id FROM qiita.{0} + WHERE preprocessed_data_id=%s""".format( + self._study_preprocessed_table) + TRN.add(sql, [self._id]) + return TRN.execute_fetchlast() @property def ebi_submission_accession(self): @@ -879,10 +873,12 @@ def ebi_submission_accession(self): str The ebi submission accession of this preprocessed data """ - conn_handler = SQLConnectionHandler() - return conn_handler.execute_fetchone( - "SELECT ebi_submission_accession FROM qiita.{0} " - "WHERE preprocessed_data_id=%s".format(self._table), (self.id,))[0] + with TRN: + sql = """SELECT ebi_submission_accession + FROM qiita.{0} + WHERE preprocessed_data_id=%s""".format(self._table) + TRN.add(sql, [self.id]) + return TRN.execute_fetchlast() @property def ebi_study_accession(self): @@ -893,10 +889,12 @@ def ebi_study_accession(self): str The ebi study accession of this preprocessed data """ - conn_handler = SQLConnectionHandler() - return conn_handler.execute_fetchone( - "SELECT ebi_study_accession FROM qiita.{0} " - "WHERE preprocessed_data_id=%s".format(self._table), (self.id,))[0] + with TRN: + sql = """SELECT ebi_study_accession + FROM qiita.{0} + WHERE preprocessed_data_id=%s""".format(self._table) + TRN.add(sql, [self.id]) + return TRN.execute_fetchlast() @ebi_submission_accession.setter def ebi_submission_accession(self, new_ebi_submission_accession): @@ -907,11 +905,12 @@ def ebi_submission_accession(self, new_ebi_submission_accession): new_ebi_submission_accession: str The new ebi submission accession """ - conn_handler = SQLConnectionHandler() - - sql = ("UPDATE qiita.{0} SET ebi_submission_accession = %s WHERE " - "preprocessed_data_id = %s").format(self._table) - conn_handler.execute(sql, (new_ebi_submission_accession, self._id)) + with TRN: + sql = """UPDATE qiita.{0} + SET ebi_submission_accession = %s + WHERE preprocessed_data_id = %s""".format(self._table) + TRN.add(sql, [new_ebi_submission_accession, self._id]) + TRN.execute() @ebi_study_accession.setter def ebi_study_accession(self, new_ebi_study_accession): @@ -922,11 +921,12 @@ def ebi_study_accession(self, new_ebi_study_accession): new_ebi_study_accession: str The new ebi study accession """ - conn_handler = SQLConnectionHandler() - - sql = ("UPDATE qiita.{0} SET ebi_study_accession = %s WHERE " - "preprocessed_data_id = %s").format(self._table) - conn_handler.execute(sql, (new_ebi_study_accession, self._id)) + with TRN: + sql = """UPDATE qiita.{0} + SET ebi_study_accession = %s + WHERE preprocessed_data_id = %s""".format(self._table) + TRN.add(sql, [new_ebi_study_accession, self._id]) + TRN.execute() def data_type(self, ret_id=False): """Returns the data_type or data_type_id @@ -941,14 +941,15 @@ def data_type(self, ret_id=False): str or int string value of data_type or data_type_id """ - conn_handler = SQLConnectionHandler() - ret = "_id" if ret_id else "" - data_type = conn_handler.execute_fetchone( - "SELECT d.data_type{0} FROM qiita.data_type d JOIN " - "qiita.{1} p ON p.data_type_id = d.data_type_id WHERE" - " p.preprocessed_data_id = %s".format(ret, self._table), - (self._id, )) - return data_type[0] + with TRN: + ret = "_id" if ret_id else "" + sql = """SELECT d.data_type{0} + FROM qiita.data_type d + JOIN qiita.{1} p ON p.data_type_id = d.data_type_id + WHERE p.preprocessed_data_id = %s""".format( + ret, self._table) + TRN.add(sql, [self._id]) + return TRN.execute_fetchlast() def submitted_to_insdc_status(self): r"""Tells if the raw data has been submitted to INSDC @@ -958,10 +959,12 @@ def submitted_to_insdc_status(self): str One of {'not submitted', 'submitting', 'success', 'failed'} """ - conn_handler = SQLConnectionHandler() - return conn_handler.execute_fetchone( - "SELECT submitted_to_insdc_status FROM qiita.{0} " - "WHERE preprocessed_data_id=%s".format(self._table), (self.id,))[0] + with TRN: + sql = """SELECT submitted_to_insdc_status + FROM qiita.{0} + WHERE preprocessed_data_id=%s""".format(self._table) + TRN.add(sql, [self.id]) + return TRN.execute_fetchlast() def update_insdc_status(self, state, study_acc=None, submission_acc=None): r"""Update the INSDC submission status @@ -985,28 +988,28 @@ def update_insdc_status(self, state, study_acc=None, submission_acc=None): If ``state`` is ``success`` and either ``study_acc`` or ``submission_acc`` are ``None``. """ - if state not in ('not submitted', 'submitting', 'success', 'failed'): - raise ValueError("Unknown state: %s" % state) - - conn_handler = SQLConnectionHandler() - - if state == 'success': - if study_acc is None or submission_acc is None: - raise ValueError("study_acc or submission_acc is None!") - - conn_handler.execute(""" - UPDATE qiita.{0} - SET (submitted_to_insdc_status, - ebi_study_accession, - ebi_submission_accession) = (%s, %s, %s) - WHERE preprocessed_data_id=%s""".format(self._table), - (state, study_acc, submission_acc, self.id)) - else: - conn_handler.execute(""" - UPDATE qiita.{0} - SET submitted_to_insdc_status = %s - WHERE preprocessed_data_id=%s""".format(self._table), - (state, self.id)) + with TRN: + if state not in ('not submitted', 'submitting', 'success', + 'failed'): + raise ValueError("Unknown state: %s" % state) + + if state == 'success': + if study_acc is None or submission_acc is None: + raise ValueError("study_acc or submission_acc is None!") + + sql = """UPDATE qiita.{0} + SET (submitted_to_insdc_status, + ebi_study_accession, + ebi_submission_accession) = (%s, %s, %s) + WHERE preprocessed_data_id=%s""".format(self._table) + TRN.add(sql, [state, study_acc, submission_acc, self.id]) + else: + sql = """UPDATE qiita.{0} + SET submitted_to_insdc_status = %s + WHERE preprocessed_data_id=%s""".format(self._table) + TRN.add(sql, [state, self.id]) + + TRN.execute() def submitted_to_vamps_status(self): r"""Tells if the raw data has been submitted to VAMPS @@ -1016,10 +1019,12 @@ def submitted_to_vamps_status(self): str One of {'not submitted', 'submitting', 'success', 'failed'} """ - conn_handler = SQLConnectionHandler() - return conn_handler.execute_fetchone( - "SELECT submitted_to_vamps_status FROM qiita.{0} " - "WHERE preprocessed_data_id=%s".format(self._table), (self.id,))[0] + with TRN: + sql = """SELECT submitted_to_vamps_status + FROM qiita.{0} + WHERE preprocessed_data_id=%s""".format(self._table) + TRN.add(sql, [self.id]) + return TRN.execute_fetchlast() def update_vamps_status(self, status): r"""Update the VAMPS submission status @@ -1037,10 +1042,12 @@ def update_vamps_status(self, status): if status not in ('not submitted', 'submitting', 'success', 'failed'): raise ValueError("Unknown status: %s" % status) - conn_handler = SQLConnectionHandler() - conn_handler.execute( - """UPDATE qiita.{0} SET submitted_to_vamps_status = %s WHERE - preprocessed_data_id=%s""".format(self._table), (status, self.id)) + with TRN: + sql = """UPDATE qiita.{0} + SET submitted_to_vamps_status = %s + WHERE preprocessed_data_id=%s""".format(self._table) + TRN.add(sql, [status, self.id]) + TRN.execute() @property def processing_status(self): @@ -1051,10 +1058,12 @@ def processing_status(self): str One of {'not_processed', 'processing', 'processed', 'failed'} """ - conn_handler = SQLConnectionHandler() - return conn_handler.execute_fetchone( - "SELECT processing_status FROM qiita.{0} WHERE " - "preprocessed_data_id=%s".format(self._table), (self.id,))[0] + with TRN: + sql = """SELECT processing_status + FROM qiita.{0} + WHERE preprocessed_data_id=%s""".format(self._table) + TRN.add(sql, [self.id]) + return TRN.execute_fetchlast() @processing_status.setter def processing_status(self, state): @@ -1073,11 +1082,11 @@ def processing_status(self, state): if (state not in ('not_processed', 'processing', 'processed') and not state.startswith('failed')): raise ValueError('Unknown state: %s' % state) - - conn_handler = SQLConnectionHandler() - conn_handler.execute( - "UPDATE qiita.{0} SET processing_status=%s WHERE " - "preprocessed_data_id=%s".format(self._table), (state, self.id)) + with TRN: + sql = """UPDATE qiita.{0} SET processing_status=%s + WHERE preprocessed_data_id=%s""".format(self._table) + TRN.add(sql, [state, self.id]) + TRN.execute() @property def status(self): @@ -1095,17 +1104,17 @@ def status(self): data has been generated with this preprocessed data; then the status is 'sandbox'. """ - conn_handler = SQLConnectionHandler() - sql = """SELECT processed_data_status - FROM qiita.processed_data_status pds - JOIN qiita.processed_data pd - USING (processed_data_status_id) - JOIN qiita.preprocessed_processed_data ppd_pd - USING (processed_data_id) - WHERE ppd_pd.preprocessed_data_id=%s""" - pd_statuses = conn_handler.execute_fetchall(sql, (self._id,)) + with TRN: + sql = """SELECT processed_data_status + FROM qiita.processed_data_status pds + JOIN qiita.processed_data pd + USING (processed_data_status_id) + JOIN qiita.preprocessed_processed_data ppd_pd + USING (processed_data_id) + WHERE ppd_pd.preprocessed_data_id=%s""" + TRN.add(sql, [self._id]) - return infer_status(pd_statuses) + return infer_status(TRN.execute_fetchindex()) class ProcessedData(BaseData): @@ -1146,18 +1155,13 @@ def get_by_status(cls, status): list of int All the processed data ids that match the given status """ - conn_handler = SQLConnectionHandler() - sql = """SELECT processed_data_id FROM qiita.processed_data pd - JOIN qiita.processed_data_status pds - USING (processed_data_status_id) - WHERE pds.processed_data_status=%s""" - result = conn_handler.execute_fetchall(sql, (status,)) - if result: - pds = set(x[0] for x in result) - else: - pds = set() - - return pds + with TRN: + sql = """SELECT processed_data_id FROM qiita.processed_data pd + JOIN qiita.processed_data_status pds + USING (processed_data_status_id) + WHERE pds.processed_data_status=%s""" + TRN.add(sql, [status]) + return set(TRN.execute_fetchflatten()) @classmethod def get_by_status_grouped_by_study(cls, status): @@ -1175,17 +1179,18 @@ def get_by_status_grouped_by_study(cls, status): processed data ids that belong to that study and match the given status """ - conn_handler = SQLConnectionHandler() - sql = """SELECT spd.study_id, - array_agg(pd.processed_data_id ORDER BY pd.processed_data_id) - FROM qiita.processed_data pd - JOIN qiita.processed_data_status pds - USING (processed_data_status_id) - JOIN qiita.study_processed_data spd - USING (processed_data_id) - WHERE pds.processed_data_status = %s - GROUP BY spd.study_id;""" - return dict(conn_handler.execute_fetchall(sql, (status,))) + with TRN: + sql = """SELECT spd.study_id, + array_agg(pd.processed_data_id ORDER BY pd.processed_data_id) + FROM qiita.processed_data pd + JOIN qiita.processed_data_status pds + USING (processed_data_status_id) + JOIN qiita.study_processed_data spd + USING (processed_data_id) + WHERE pds.processed_data_status = %s + GROUP BY spd.study_id;""" + TRN.add(sql, [status]) + return dict(TRN.execute_fetchindex()) @classmethod def create(cls, processed_params_table, processed_params_id, filepaths, @@ -1220,73 +1225,74 @@ def create(cls, processed_params_table, processed_params_id, filepaths, If `preprocessed_data` and `study` are provided at the same time If `preprocessed_data` and `study` are not provided """ - conn_handler = SQLConnectionHandler() - if preprocessed_data is not None: - if study is not None: - raise IncompetentQiitaDeveloperError( - "You should provide either preprocessed_data or study, " - "but not both") - elif data_type is not None and \ - data_type != preprocessed_data.data_type(): - raise IncompetentQiitaDeveloperError( - "data_type passed does not match preprocessed_data " - "data_type!") + with TRN: + if preprocessed_data is not None: + if study is not None: + raise IncompetentQiitaDeveloperError( + "You should provide either preprocessed_data or " + "study, but not both") + elif data_type is not None and \ + data_type != preprocessed_data.data_type(): + raise IncompetentQiitaDeveloperError( + "data_type passed does not match preprocessed_data " + "data_type!") + else: + data_type = preprocessed_data.data_type(ret_id=True) else: - data_type = preprocessed_data.data_type(ret_id=True) - else: - if study is None: - raise IncompetentQiitaDeveloperError( - "You should provide either a preprocessed_data or a study") - if data_type is None: + if study is None: + raise IncompetentQiitaDeveloperError( + "You should provide either a preprocessed_data or " + "a study") + if data_type is None: + raise IncompetentQiitaDeveloperError( + "You must provide either a preprocessed_data, a " + "data_type, or both") + else: + data_type = convert_to_id(data_type, "data_type") + + # We first check that the processed_params_table exists + if not exists_dynamic_table(processed_params_table, + "processed_params_", ""): raise IncompetentQiitaDeveloperError( - "You must provide either a preprocessed_data, a " - "data_type, or both") + "Processed params table %s does not exists!" + % processed_params_table) + + # Check if we have received a date: + if processed_date is None: + processed_date = datetime.now() + + # Add the processed data to the database, + # and get the processed data id back + sql = """INSERT INTO qiita.{0} + (processed_params_table, processed_params_id, + processed_date, data_type_id) + VALUES (%s, %s, %s, %s) + RETURNING processed_data_id""".format(cls._table) + TRN.add(sql, [processed_params_table, processed_params_id, + processed_date, data_type]) + pd_id = TRN.execute_fetchlast() + + pd = cls(pd_id) + + if preprocessed_data is not None: + sql = """INSERT INTO qiita.{0} + (preprocessed_data_id, processed_data_id) + VALUES (%s, %s)""".format( + cls._preprocessed_processed_table) + TRN.add(sql, [preprocessed_data.id, pd_id]) + TRN.execute() + study_id = preprocessed_data.study else: - data_type = convert_to_id(data_type, "data_type") + study_id = study.id + + # Connect the processed data with the study + sql = """INSERT INTO qiita.{0} (study_id, processed_data_id) + VALUES (%s, %s)""".format(cls._study_processed_table) + TRN.add(sql, [study_id, pd_id]) + TRN.execute() - # We first check that the processed_params_table exists - if not exists_dynamic_table(processed_params_table, - "processed_params_", "", conn_handler): - raise IncompetentQiitaDeveloperError( - "Processed params table %s does not exists!" - % processed_params_table) - - # Check if we have received a date: - if processed_date is None: - processed_date = datetime.now() - - # Add the processed data to the database, - # and get the processed data id back - pd_id = conn_handler.execute_fetchone( - "INSERT INTO qiita.{0} (processed_params_table, " - "processed_params_id, processed_date, data_type_id) VALUES (" - "%(param_table)s, %(param_id)s, %(date)s, %(data_type)s) RETURNING" - " processed_data_id".format(cls._table), - {'param_table': processed_params_table, - 'param_id': processed_params_id, - 'date': processed_date, - 'data_type': data_type})[0] - - pd = cls(pd_id) - - if preprocessed_data is not None: - conn_handler.execute( - "INSERT INTO qiita.{0} (preprocessed_data_id, " - "processed_data_id) VALUES " - "(%s, %s)".format(cls._preprocessed_processed_table), - (preprocessed_data.id, pd_id)) - study_id = preprocessed_data.study - else: - study_id = study.id - - # Connect the processed data with the study - conn_handler.execute( - "INSERT INTO qiita.{0} (study_id, processed_data_id) VALUES " - "(%s, %s)".format(cls._study_processed_table), - (study_id, pd_id)) - - pd.add_filepaths(filepaths) - return cls(pd_id) + pd.add_filepaths(filepaths) + return cls(pd_id) @classmethod def delete(cls, processed_data_id): @@ -1304,53 +1310,55 @@ def delete(cls, processed_data_id): QiitaDBError If the processed data has analyses """ - if cls(processed_data_id).status != 'sandbox': - raise QiitaDBStatusError( - "Illegal operation on non sandboxed processed data") + with TRN: + if cls(processed_data_id).status != 'sandbox': + raise QiitaDBStatusError( + "Illegal operation on non sandboxed processed data") - conn_handler = SQLConnectionHandler() + sql = """SELECT DISTINCT name + FROM qiita.analysis + JOIN qiita.analysis_sample USING (analysis_id) + WHERE processed_data_id = %s ORDER BY name""" + TRN.add(sql, [processed_data_id]) - analyses = [str(n[0]) for n in conn_handler.execute_fetchall( - "SELECT DISTINCT name FROM qiita.analysis JOIN " - "qiita.analysis_sample USING (analysis_id) WHERE " - "processed_data_id = {0} ORDER BY name".format(processed_data_id))] + analyses = TRN.execute_fetchflatten() - if analyses: - raise QiitaDBError( - "Processed data %d cannot be removed because it is linked to " - "the following analysis: %s" % (processed_data_id, - ', '.join(analyses))) + if analyses: + raise QiitaDBError( + "Processed data %d cannot be removed because it is linked " + "to the following analysis: %s" + % (processed_data_id, ', '.join(analyses))) - # delete - queue = "delete_processed_data_%d" % processed_data_id - conn_handler.create_queue(queue) + # delete + sql = """DELETE FROM qiita.preprocessed_processed_data + WHERE processed_data_id = %s""" + args = [processed_data_id] + TRN.add(sql, args) - sql = ("DELETE FROM qiita.preprocessed_processed_data WHERE " - "processed_data_id = {0}".format(processed_data_id)) - conn_handler.add_to_queue(queue, sql) + sql = """DELETE FROM qiita.processed_filepath + WHERE processed_data_id = %s""" + TRN.add(sql, args) - sql = ("DELETE FROM qiita.processed_filepath WHERE " - "processed_data_id = {0}".format(processed_data_id)) - conn_handler.add_to_queue(queue, sql) + sql = """DELETE FROM qiita.study_processed_data + WHERE processed_data_id = %s""" + TRN.add(sql, args) - sql = ("DELETE FROM qiita.study_processed_data WHERE " - "processed_data_id = {0}".format(processed_data_id)) - conn_handler.add_to_queue(queue, sql) + sql = """DELETE FROM qiita.processed_data + WHERE processed_data_id = %s""" + TRN.add(sql, args) - sql = ("DELETE FROM qiita.processed_data WHERE " - "processed_data_id = {0}".format(processed_data_id)) - conn_handler.add_to_queue(queue, sql) - - conn_handler.execute_queue(queue) + TRN.execute() @property def preprocessed_data(self): r"""The preprocessed data id used to generate the processed data""" - conn_handler = SQLConnectionHandler() - return conn_handler.execute_fetchone( - "SELECT preprocessed_data_id FROM qiita.{0} WHERE " - "processed_data_id=%s".format(self._preprocessed_processed_table), - [self._id])[0] + with TRN: + sql = """SELECT preprocessed_data_id + FROM qiita.{0} + WHERE processed_data_id=%s""".format( + self._preprocessed_processed_table) + TRN.add(sql, [self._id]) + return TRN.execute_fetchlast() @property def study(self): @@ -1360,11 +1368,13 @@ def study(self): ------- int The study id to which this processed data belongs""" - conn_handler = SQLConnectionHandler() - return conn_handler.execute_fetchone( - "SELECT study_id FROM qiita.{0} WHERE " - "processed_data_id=%s".format(self._study_processed_table), - [self._id])[0] + with TRN: + sql = """SELECT study_id + FROM qiita.{0} + WHERE processed_data_id=%s""".format( + self._study_processed_table) + TRN.add(sql, [self._id]) + return TRN.execute_fetchlast() def data_type(self, ret_id=False): """Returns the data_type or data_type_id @@ -1379,14 +1389,14 @@ def data_type(self, ret_id=False): str or int string value of data_type or data_type_id """ - conn_handler = SQLConnectionHandler() - ret = "_id" if ret_id else "" - data_type = conn_handler.execute_fetchone( - "SELECT d.data_type{0} FROM qiita.data_type d JOIN " - "qiita.{1} p ON p.data_type_id = d.data_type_id WHERE" - " p.processed_data_id = %s".format(ret, self._table), - (self._id, )) - return data_type[0] + with TRN: + ret = "_id" if ret_id else "" + sql = """SELECT d.data_type{0} + FROM qiita.data_type d + JOIN qiita.{1} p ON p.data_type_id = d.data_type_id + WHERE p.processed_data_id = %s""".format(ret, self._table) + TRN.add(sql, [self._id]) + return TRN.execute_fetchlast() @property def processing_info(self): @@ -1398,44 +1408,47 @@ def processing_info(self): Parameter settings keyed to the parameter, along with date and algorithm used """ - # Get processed date and the info for the dynamic table - conn_handler = SQLConnectionHandler() - sql = """SELECT processed_date, processed_params_table, - processed_params_id FROM qiita.{0} - WHERE processed_data_id=%s""".format(self._table) - static_info = conn_handler.execute_fetchone(sql, (self.id,)) - - # Get the info from the dynamic table, including reference used - sql = """SELECT * from qiita.{0} - JOIN qiita.reference USING (reference_id) - WHERE processed_params_id = {1} - """.format(static_info['processed_params_table'], - static_info['processed_params_id']) - dynamic_info = dict(conn_handler.execute_fetchone(sql)) - - # replace reference filepath_ids with full filepaths - # figure out what columns have filepaths and what don't - ref_fp_cols = {'sequence_filepath', 'taxonomy_filepath', - 'tree_filepath'} - fp_ids = [str(dynamic_info[col]) for col in ref_fp_cols - if dynamic_info[col] is not None] - # Get the filepaths and create dict of fpid to filepath - sql = ("SELECT filepath_id, filepath FROM qiita.filepath WHERE " - "filepath_id IN ({})").format(','.join(fp_ids)) - lookup = {fp[0]: fp[1] for fp in conn_handler.execute_fetchall(sql)} - # Loop through and replace ids - for key in ref_fp_cols: - if dynamic_info[key] is not None: - dynamic_info[key] = lookup[dynamic_info[key]] - - # add missing info to the dictionary and remove id column info - dynamic_info['processed_date'] = static_info['processed_date'] - dynamic_info['algorithm'] = static_info[ - 'processed_params_table'].split('_')[-1] - del dynamic_info['processed_params_id'] - del dynamic_info['reference_id'] - - return dynamic_info + with TRN: + # Get processed date and the info for the dynamic table + sql = """SELECT processed_date, processed_params_table, + processed_params_id FROM qiita.{0} + WHERE processed_data_id=%s""".format(self._table) + TRN.add(sql, [self.id]) + static_info = TRN.execute_fetchindex()[0] + + # Get the info from the dynamic table, including reference used + sql = """SELECT * FROM qiita.{0} + JOIN qiita.reference USING (reference_id) + WHERE processed_params_id = %s""".format( + static_info['processed_params_table']) + TRN.add(sql, [static_info['processed_params_id']]) + dynamic_info = dict(TRN.execute_fetchindex()[0]) + + # replace reference filepath_ids with full filepaths + # figure out what columns have filepaths and what don't + ref_fp_cols = {'sequence_filepath', 'taxonomy_filepath', + 'tree_filepath'} + fp_ids = tuple(dynamic_info[col] for col in ref_fp_cols + if dynamic_info[col] is not None) + # Get the filepaths and create dict of fpid to filepath + sql = """SELECT filepath_id, filepath + FROM qiita.filepath + WHERE filepath_id IN %s""" + TRN.add(sql, [fp_ids]) + lookup = {fp[0]: fp[1] for fp in TRN.execute_fetchindex()} + # Loop through and replace ids + for key in ref_fp_cols: + if dynamic_info[key] is not None: + dynamic_info[key] = lookup[dynamic_info[key]] + + # add missing info to the dictionary and remove id column info + dynamic_info['processed_date'] = static_info['processed_date'] + dynamic_info['algorithm'] = static_info[ + 'processed_params_table'].split('_')[-1] + del dynamic_info['processed_params_id'] + del dynamic_info['reference_id'] + + return dynamic_info @property def samples(self): @@ -1446,27 +1459,33 @@ def samples(self): set all sample_ids available for the processed data """ - conn_handler = SQLConnectionHandler() - # Get the prep template id for teh dynamic table lookup - sql = """SELECT ptp.prep_template_id FROM - qiita.prep_template_preprocessed_data ptp JOIN - qiita.preprocessed_processed_data ppd USING (preprocessed_data_id) - WHERE ppd.processed_data_id = %s""" - prep_id = conn_handler.execute_fetchone(sql, [self._id])[0] - - # Get samples from dynamic table - sql = "SELECT sample_id FROM qiita.prep_%d" % prep_id - return set(s[0] for s in conn_handler.execute_fetchall(sql)) + with TRN: + # Get the prep template id for teh dynamic table lookup + sql = """SELECT ptp.prep_template_id + FROM qiita.prep_template_preprocessed_data ptp + JOIN qiita.preprocessed_processed_data ppd + USING (preprocessed_data_id) + WHERE ppd.processed_data_id = %s""" + TRN.add(sql, [self._id]) + prep_id = TRN.execute_fetchlast() + + # Get samples from dynamic table + sql = """SELECT sample_id + FROM qiita.prep_template_sample + WHERE prep_template_id=%s""" + TRN.add(sql, [prep_id]) + return set(TRN.execute_fetchflatten()) @property def status(self): - conn_handler = SQLConnectionHandler() - sql = """SELECT pds.processed_data_status - FROM qiita.processed_data_status pds - JOIN qiita.processed_data pd - USING (processed_data_status_id) - WHERE pd.processed_data_id=%s""" - return conn_handler.execute_fetchone(sql, (self._id,))[0] + with TRN: + sql = """SELECT pds.processed_data_status + FROM qiita.processed_data_status pds + JOIN qiita.processed_data pd + USING (processed_data_status_id) + WHERE pd.processed_data_id=%s""" + TRN.add(sql, [self._id]) + return TRN.execute_fetchlast() @status.setter def status(self, status): @@ -1482,14 +1501,14 @@ def status(self, status): QiitaDBStatusError If the processed data status is public """ - if self.status == 'public': - raise QiitaDBStatusError( - "Illegal operation on public processed data") - - conn_handler = SQLConnectionHandler() + with TRN: + if self.status == 'public': + raise QiitaDBStatusError( + "Illegal operation on public processed data") - status_id = convert_to_id(status, 'processed_data_status') + status_id = convert_to_id(status, 'processed_data_status') - sql = """UPDATE qiita.{0} SET processed_data_status_id = %s - WHERE processed_data_id=%s""".format(self._table) - conn_handler.execute(sql, (status_id, self._id)) + sql = """UPDATE qiita.{0} SET processed_data_status_id = %s + WHERE processed_data_id=%s""".format(self._table) + TRN.add(sql, [status_id, self._id]) + TRN.execute() diff --git a/qiita_db/environment_manager.py b/qiita_db/environment_manager.py index f13ebf9dc..fb969cf94 100644 --- a/qiita_db/environment_manager.py +++ b/qiita_db/environment_manager.py @@ -19,7 +19,7 @@ from qiita_core.exceptions import QiitaEnvironmentError from qiita_core.qiita_settings import qiita_config -from .sql_connection import SQLConnectionHandler +from .sql_connection import SQLConnectionHandler, TRN from .reference import Reference from natsort import natsorted @@ -55,34 +55,36 @@ def _check_db_exists(db, conn_handler): return (db,) in dbs -def create_layout_and_patch(conn, verbose=False): +def create_layout_and_patch(verbose=False): r"""Builds the SQL layout and applies all the patches Parameters ---------- - conn : SQLConnectionHandler - The handler connected to the DB verbose : bool, optional If true, print the current step. Default: False. """ - if verbose: - print('Building SQL layout') - # Create the schema - with open(LAYOUT_FP, 'U') as f: - conn.execute(f.read()) + with TRN: + if verbose: + print('Building SQL layout') + # Create the schema + with open(LAYOUT_FP, 'U') as f: + TRN.add(f.read()) + TRN.execute() - if verbose: - print('Patching Database...') - patch(verbose=verbose) + if verbose: + print('Patching Database...') + patch(verbose=verbose) -def _populate_test_db(conn): +def _populate_test_db(): print('Populating database with demo data') - with open(POPULATE_FP, 'U') as f: - conn.execute(f.read()) + with TRN: + with open(POPULATE_FP, 'U') as f: + TRN.add(f.read()) + TRN.execute() -def _add_ontology_data(conn): +def _add_ontology_data(): print ('Loading Ontology Data') if not exists(reference_base_dir): mkdir(reference_base_dir) @@ -101,18 +103,21 @@ def _add_ontology_data(conn): raise IOError("Error: Could not fetch ontologies file from %s" % url) - with gzip.open(fp, 'rb') as f: - conn.execute(f.read()) + with TRN: + with gzip.open(fp, 'rb') as f: + TRN.add(f.read()) + TRN.execute() -def _insert_processed_params(conn, ref): - sortmerna_sql = """INSERT INTO qiita.processed_params_sortmerna - (reference_id, sortmerna_e_value, sortmerna_max_pos, - similarity, sortmerna_coverage, threads) - VALUES - (%s, 1, 10000, 0.97, 0.97, 1)""" - - conn.execute(sortmerna_sql, [ref._id]) +def _insert_processed_params(ref): + with TRN: + sortmerna_sql = """INSERT INTO qiita.processed_params_sortmerna + (reference_id, sortmerna_e_value, sortmerna_max_pos, + similarity, sortmerna_coverage, threads) + VALUES + (%s, 1, 10000, 0.97, 0.97, 1)""" + TRN.add(sortmerna_sql, [ref._id]) + TRN.execute() def _download_reference_files(conn): @@ -142,11 +147,11 @@ def _download_reference_files(conn): except: raise IOError("Error: Could not fetch %s file from %s" % (file_type, url)) + with TRN: + ref = Reference.create('Greengenes', '13_8', files['sequence'][0], + files['taxonomy'][0], files['tree'][0]) - ref = Reference.create('Greengenes', '13_8', files['sequence'][0], - files['taxonomy'][0], files['tree'][0]) - - _insert_processed_params(conn, ref) + _insert_processed_params(ref) def make_environment(load_ontologies, download_reference, add_demo_user): @@ -192,76 +197,80 @@ def make_environment(load_ontologies, download_reference, add_demo_user): admin_conn.autocommit = False del admin_conn - - # Connect to the postgres server, but this time to the just created db - conn = SQLConnectionHandler() - - print('Inserting database metadata') - # Build the SQL layout into the database - with open(SETTINGS_FP, 'U') as f: - conn.execute(f.read()) - - # Insert the settings values to the database - conn.execute("INSERT INTO settings (test, base_data_dir, base_work_dir) " - "VALUES (%s, %s, %s)", - (qiita_config.test_environment, qiita_config.base_data_dir, - qiita_config.working_dir)) - - create_layout_and_patch(conn, verbose=True) - - if load_ontologies: - _add_ontology_data(conn) - - # these values can only be added if the environment is being loaded - # with the ontologies, thus this cannot exist inside intialize.sql - # because otherwise loading the ontologies would be a requirement - ontology = Ontology(convert_to_id('ENA', 'ontology')) - ontology.add_user_defined_term('Amplicon Sequencing') - - if download_reference: - _download_reference_files(conn) - - # we don't do this if it's a test environment because populate.sql - # already adds this user... - if add_demo_user and not qiita_config.test_environment: - conn.execute(""" - INSERT INTO qiita.qiita_user (email, user_level_id, password, - name, affiliation, address, phone) - VALUES - ('demo@microbio.me', 4, - '$2a$12$gnUi8Qg.0tvW243v889BhOBhWLIHyIJjjgaG6dxuRJkUM8nXG9Efe', - 'Demo', 'Qitta Dev', '1345 Colorado Avenue', '303-492-1984')""") - analysis_id = conn.execute_fetchone(""" - INSERT INTO qiita.analysis (email, name, description, dflt, - analysis_status_id) - VALUES - ('demo@microbio.me', 'demo@microbio.me-dflt', 'dflt', 't', 1) - RETURNING analysis_id - """)[0] - - # Add default analysis to all portals - args = [] - sql = "SELECT portal_type_id FROM qiita.portal_type" - for portal_id in conn.execute_fetchall(sql): - args.append([analysis_id, portal_id[0]]) - - sql = """INSERT INTO qiita.analysis_portal - (analysis_id, portal_type_id) - VALUES (%s, %s)""" - conn.executemany(sql, args) - - print('Demo user successfully created') - - if qiita_config.test_environment: - _populate_test_db(conn) - print('Test environment successfully created') - else: - print('Production environment successfully created') + SQLConnectionHandler.close() + + with TRN: + print('Inserting database metadata') + # Build the SQL layout into the database + with open(SETTINGS_FP, 'U') as f: + TRN.add(f.read()) + TRN.execute() + + # Insert the settings values to the database + sql = """INSERT INTO settings (test, base_data_dir, base_work_dir) + VALUES (%s, %s, %s)""" + TRN.add(sql, [qiita_config.test_environment, + qiita_config.base_data_dir, + qiita_config.working_dir]) + TRN.execute() + + create_layout_and_patch(verbose=True) + + if load_ontologies: + _add_ontology_data() + + # these values can only be added if the environment is being loaded + # with the ontologies, thus this cannot exist inside intialize.sql + # because otherwise loading the ontologies would be a requirement + ontology = Ontology(convert_to_id('ENA', 'ontology')) + ontology.add_user_defined_term('Amplicon Sequencing') + + if download_reference: + _download_reference_files() + + # we don't do this if it's a test environment because populate.sql + # already adds this user... + if add_demo_user and not qiita_config.test_environment: + sql = """INSERT INTO qiita.qiita_user (email, user_level_id, + password, name, affiliation, + address, phone) + VALUES + ('demo@microbio.me', 4, + '$2a$12$gnUi8Qg.0tvW243v889BhOBhWLIHyIJjjgaG6dxuRJkUM8nXG9Efe', + 'Demo', 'Qitta Dev', '1345 Colorado Avenue', '303-492-1984')""" + TRN.add(sql) + sql = """INSERT INTO qiita.analysis (email, name, description, + dflt, analysis_status_id) + VALUES ('demo@microbio.me', 'demo@microbio.me-dflt', + 'dflt', 't', 1)""" + TRN.add(sql) + analysis_id = TRN.execute_fetchlast() + + # Add default analysis to all portals + sql = "SELECT portal_type_id FROM qiita.portal_type" + TRN.add(sql) + args = [[analysis_id, p_id] for p_id in TRN.execute_fetchflatten()] + sql = """INSERT INTO qiita.analysis_portal + (analysis_id, portal_type_id) + VALUES (%s, %s)""" + TRN.add(sql, args, many=True) + TRN.execute() + + print('Demo user successfully created') + + if qiita_config.test_environment: + _populate_test_db() + print('Test environment successfully created') + else: + print('Production environment successfully created') def drop_environment(ask_for_confirmation): """Drops the database specified in the configuration """ + # The transaction has an open connection to the database, so we need + # to close it in order to drop the environment + TRN.close() # Connect to the postgres server conn = SQLConnectionHandler() settings_sql = "SELECT test FROM settings" @@ -292,34 +301,31 @@ def drop_environment(ask_for_confirmation): print('ABORTING') -def drop_and_rebuild_tst_database(conn_handler): +def drop_and_rebuild_tst_database(): """Drops the qiita schema and rebuilds the test database - - Parameters - ---------- - conn_handler : SQLConnectionHandler - The handler connected to the database """ - # Drop the schema - conn_handler.execute("DROP SCHEMA IF EXISTS qiita CASCADE") - # Set the database to unpatched - conn_handler.execute("UPDATE settings SET current_patch = 'unpatched'") - # Create the database and apply patches - create_layout_and_patch(conn_handler) - # Populate the database - with open(POPULATE_FP, 'U') as f: - conn_handler.execute(f.read()) + with TRN: + # Drop the schema + TRN.add("DROP SCHEMA IF EXISTS qiita CASCADE") + # Set the database to unpatched + TRN.add("UPDATE settings SET current_patch = 'unpatched'") + # Create the database and apply patches + create_layout_and_patch() + # Populate the database + with open(POPULATE_FP, 'U') as f: + TRN.add(f.read()) + + TRN.execute() def reset_test_database(wrapped_fn): """Decorator that drops the qiita schema, rebuilds and repopulates the schema with test data, then executes wrapped_fn """ - conn_handler = SQLConnectionHandler() def decorated_wrapped_fn(*args, **kwargs): # Reset the test database - drop_and_rebuild_tst_database(conn_handler) + drop_and_rebuild_tst_database() # Execute the wrapped function return wrapped_fn(*args, **kwargs) @@ -356,41 +362,40 @@ def patch(patches_dir=PATCHES_DIR, verbose=False): Pulls the current patch from the settings table and applies all subsequent patches found in the patches directory. """ - conn = SQLConnectionHandler() + with TRN: + TRN.add("SELECT current_patch FROM settings") + current_patch = TRN.execute_fetchlast() + current_sql_patch_fp = join(patches_dir, current_patch) + corresponding_py_patch = partial(join, patches_dir, 'python_patches') + + sql_glob = join(patches_dir, '*.sql') + sql_patch_files = natsorted(glob(sql_glob)) + + if current_patch == 'unpatched': + next_patch_index = 0 + elif current_sql_patch_fp not in sql_patch_files: + raise RuntimeError("Cannot find patch file %s" % current_patch) + else: + next_patch_index = sql_patch_files.index(current_sql_patch_fp) + 1 - current_patch = conn.execute_fetchone( - "select current_patch from settings")[0] - current_sql_patch_fp = join(patches_dir, current_patch) - corresponding_py_patch = partial(join, patches_dir, 'python_patches') + patch_update_sql = "UPDATE settings SET current_patch = %s" - sql_glob = join(patches_dir, '*.sql') - sql_patch_files = natsorted(glob(sql_glob)) + for sql_patch_fp in sql_patch_files[next_patch_index:]: + sql_patch_filename = basename(sql_patch_fp) + py_patch_fp = corresponding_py_patch( + splitext(basename(sql_patch_fp))[0] + '.py') + py_patch_filename = basename(py_patch_fp) - if current_patch == 'unpatched': - next_patch_index = 0 - elif current_sql_patch_fp not in sql_patch_files: - raise RuntimeError("Cannot find patch file %s" % current_patch) - else: - next_patch_index = sql_patch_files.index(current_sql_patch_fp) + 1 - - patch_update_sql = "update settings set current_patch = %s" - - for sql_patch_fp in sql_patch_files[next_patch_index:]: - sql_patch_filename = basename(sql_patch_fp) - py_patch_fp = corresponding_py_patch( - splitext(basename(sql_patch_fp))[0] + '.py') - py_patch_filename = basename(py_patch_fp) - conn.create_queue(sql_patch_filename) - with open(sql_patch_fp, 'U') as patch_file: - if verbose: - print('\tApplying patch %s...' % sql_patch_filename) - conn.add_to_queue(sql_patch_filename, patch_file.read()) - conn.add_to_queue(sql_patch_filename, patch_update_sql, - [sql_patch_filename]) - - conn.execute_queue(sql_patch_filename) - - if exists(py_patch_fp): - if verbose: - print('\t\tApplying python patch %s...' % py_patch_filename) - execfile(py_patch_fp) + with open(sql_patch_fp, 'U') as patch_file: + if verbose: + print('\tApplying patch %s...' % sql_patch_filename) + TRN.add(patch_file.read()) + TRN.add(patch_update_sql, [sql_patch_filename]) + + TRN.execute() + + if exists(py_patch_fp): + if verbose: + print('\t\tApplying python patch %s...' + % py_patch_filename) + execfile(py_patch_fp) diff --git a/qiita_db/job.py b/qiita_db/job.py index ec040579a..7f7036d97 100644 --- a/qiita_db/job.py +++ b/qiita_db/job.py @@ -26,7 +26,7 @@ # ----------------------------------------------------------------------------- from __future__ import division from json import loads -from os.path import join, relpath +from os.path import join, relpath, isdir from os import remove from glob import glob from shutil import rmtree @@ -36,7 +36,7 @@ from .base import QiitaStatusObject from .util import (insert_filepaths, convert_to_id, get_db_files_base_dir, params_dict_to_json, get_mountpoint) -from .sql_connection import SQLConnectionHandler +from .sql_connection import TRN from .logger import LogEntry from .exceptions import QiitaDBStatusError, QiitaDBDuplicateError @@ -60,16 +60,16 @@ class Job(QiitaStatusObject): """ _table = "job" - def _lock_job(self, conn_handler): + def _lock_job(self): """Raises QiitaDBStatusError if study is public""" if self.check_status(("completed", "error")): raise QiitaDBStatusError("Can't change status of finished job!") - def _status_setter_checks(self, conn_handler): + def _status_setter_checks(self): r"""Perform a check to make sure not setting status away from completed or errored """ - self._lock_job(conn_handler) + self._lock_job() @staticmethod def get_commands(): @@ -108,43 +108,56 @@ def exists(cls, datatype, command, options, analysis, If return_existing is True, the Job object of the matching job or None if none exists """ - conn_handler = SQLConnectionHandler() - # check passed arguments and grab analyses for matching jobs - datatype_id = convert_to_id(datatype, "data_type") - sql = "SELECT command_id FROM qiita.command WHERE name = %s" - command_id = conn_handler.execute_fetchone(sql, (command, ))[0] - opts_json = params_dict_to_json(options) - sql = ("SELECT DISTINCT aj.analysis_id, aj.job_id FROM " - "qiita.analysis_job aj JOIN qiita.{0} j ON aj.job_id = j.job_id" - " WHERE j.data_type_id = %s AND j.command_id = %s " - "AND j.options = %s".format(cls._table)) - analyses = conn_handler.execute_fetchall( - sql, (datatype_id, command_id, opts_json)) - if not analyses and return_existing: - # stop looking since we have no possible matches - return False, None - elif not analyses: - return False + with TRN: + # check passed arguments and grab analyses for matching jobs + datatype_id = convert_to_id(datatype, "data_type") + sql = "SELECT command_id FROM qiita.command WHERE name = %s" + TRN.add(sql, [command]) + command_id = TRN.execute_fetchlast() + + opts_json = params_dict_to_json(options) + sql = """SELECT DISTINCT analysis_id, job_id + FROM qiita.analysis_job + JOIN qiita.{0} USING (job_id) + WHERE data_type_id = %s + AND command_id = %s + AND options = %s""".format(cls._table) + TRN.add(sql, [datatype_id, command_id, opts_json]) + analyses = TRN.execute_fetchindex() + + if not analyses and return_existing: + # stop looking since we have no possible matches + return False, None + elif not analyses: + return False + + # build the samples dict as list of samples keyed to + # their proc_data_id + sql = """SELECT processed_data_id, array_agg( + sample_id ORDER BY sample_id) + FROM qiita.analysis_sample + WHERE analysis_id = %s GROUP BY processed_data_id""" + TRN.add(sql, [analysis.id]) + samples = dict(TRN.execute_fetchindex()) + + # check passed analyses' samples dict against all found analyses + matched_job = None + for aid, jid in analyses: + # build the samples dict for a found analysis + TRN.add(sql, [aid]) + comp_samples = dict(TRN.execute_fetchindex()) + + # compare samples and stop checking if a match is found + matched_samples = samples == comp_samples + if matched_samples: + matched_job = jid + break + + if return_existing: + return matched_samples, (cls(matched_job) if matched_job + else None) - # build the samples dict as list of samples keyed to their proc_data_id - sql = ("SELECT processed_data_id, array_agg(sample_id ORDER BY " - "sample_id) FROM qiita.analysis_sample WHERE analysis_id = %s " - "GROUP BY processed_data_id") - samples = dict(conn_handler.execute_fetchall(sql, [analysis.id])) - # check passed analyses' samples dict against all found analyses - matched_job = None - for aid, jid in analyses: - # build the samples dict for a found analysis - comp_samples = dict(conn_handler.execute_fetchall(sql, [aid])) - # compare samples and stop checking if a match is found - matched_samples = True if samples == comp_samples else False - if matched_samples: - matched_job = jid - break - - if return_existing: - return matched_samples, (cls(matched_job) if matched_job else None) - return matched_samples + return matched_samples @classmethod def delete(cls, jobid): @@ -163,36 +176,42 @@ def delete(cls, jobid): filepath and job_results_filepath tables. All the job's files on the filesystem will also be removed. """ - conn_handler = SQLConnectionHandler() - # store filepath info for later use - sql = ("SELECT f.filepath, f.filepath_id FROM qiita.filepath f JOIN " - "qiita.job_results_filepath jf ON jf.filepath_id = " - "f.filepath_id WHERE jf.job_id = %s") - filepaths = conn_handler.execute_fetchall(sql, [jobid]) - - # remove fiepath links in DB - conn_handler.execute("DELETE FROM qiita.job_results_filepath WHERE " - "job_id = %s", [jobid]) - sql = "DELETE FROM qiita.filepath WHERE" - for x in range(len(filepaths)): - sql = ' '.join((sql, "filepath_id = %s")) - conn_handler.execute(sql, [fp[1] for fp in filepaths]) - - # remove job - conn_handler.execute("DELETE FROM qiita.analysis_job WHERE " - "job_id = %s", [jobid]) - conn_handler.execute("DELETE FROM qiita.collection_job WHERE " - "job_id = %s", [jobid]) - conn_handler.execute("DELETE FROM qiita.job WHERE job_id = %s", - [jobid]) - - # remove files/folders attached to job - _, basedir = get_mountpoint("job")[0] - for fp in filepaths: - try: - rmtree(join(basedir, fp[0])) - except OSError: - remove(join(basedir, fp[0])) + with TRN: + # store filepath info for later use + sql = """SELECT filepath, filepath_id + FROM qiita.filepath + JOIN qiita.job_results_filepath USING (filepath_id) + WHERE job_id = %s""" + args = [jobid] + TRN.add(sql, args) + filepaths = TRN.execute_fetchindex() + + # remove fiepath links in DB + sql = "DELETE FROM qiita.job_results_filepath WHERE job_id = %s" + TRN.add(sql, args) + + sql = "DELETE FROM qiita.filepath WHERE filepath_id IN %s" + TRN.add(sql, [tuple(fp[1] for fp in filepaths)]) + + # remove job + sql = "DELETE FROM qiita.analysis_job WHERE job_id = %s" + TRN.add(sql, args) + sql = "DELETE FROM qiita.collection_job WHERE job_id = %s" + TRN.add(sql, args) + sql = "DELETE FROM qiita.job WHERE job_id = %s" + TRN.add(sql, args) + + TRN.execute() + + # remove files/folders attached to job + _, basedir = get_mountpoint("job")[0] + path_builder = partial(join, basedir) + for fp, _ in filepaths: + fp = path_builder(fp) + if isdir(fp): + TRN.add_post_commit_func(rmtree, fp) + else: + TRN.add_post_commit_func(remove, fp) @classmethod def create(cls, datatype, command, options, analysis, @@ -222,46 +241,56 @@ def create(cls, datatype, command, options, analysis, return_existing is False and an exact duplicate of the job already exists in the DB. """ - analysis_sql = ("INSERT INTO qiita.analysis_job (analysis_id, job_id) " - "VALUES (%s, %s)") - exists, job = cls.exists(datatype, command, options, analysis, - return_existing=True) - conn_handler = SQLConnectionHandler() - if exists: - if return_existing: - # add job to analysis - conn_handler.execute(analysis_sql, (analysis.id, job.id)) - return job - else: - raise QiitaDBDuplicateError( - "Job", "datatype: %s, command: %s, options: %s, " - "analysis: %s" % (datatype, command, options, analysis.id)) - - # Get the datatype and command ids from the strings - datatype_id = convert_to_id(datatype, "data_type") - sql = "SELECT command_id FROM qiita.command WHERE name = %s" - command_id = conn_handler.execute_fetchone(sql, (command, ))[0] - opts_json = params_dict_to_json(options) - - # Create the job and return it - sql = ("INSERT INTO qiita.{0} (data_type_id, job_status_id, " - "command_id, options) VALUES " - "(%s, %s, %s, %s) RETURNING job_id").format(cls._table) - job_id = conn_handler.execute_fetchone(sql, (datatype_id, 1, - command_id, opts_json))[0] - - # add job to analysis - conn_handler.execute(analysis_sql, (analysis.id, job_id)) - - return cls(job_id) + with TRN: + analysis_sql = """INSERT INTO qiita.analysis_job + (analysis_id, job_id) VALUES (%s, %s)""" + exists, job = cls.exists(datatype, command, options, analysis, + return_existing=True) + + if exists: + if return_existing: + # add job to analysis + TRN.add(analysis_sql, [analysis.id, job.id]) + TRN.execute() + return job + else: + raise QiitaDBDuplicateError( + "Job", "datatype: %s, command: %s, options: %s, " + "analysis: %s" + % (datatype, command, options, analysis.id)) + + # Get the datatype and command ids from the strings + datatype_id = convert_to_id(datatype, "data_type") + sql = "SELECT command_id FROM qiita.command WHERE name = %s" + TRN.add(sql, [command]) + command_id = TRN.execute_fetchlast() + opts_json = params_dict_to_json(options) + + # Create the job and return it + sql = """INSERT INTO qiita.{0} (data_type_id, job_status_id, + command_id, options) + VALUES (%s, %s, %s, %s) + RETURNING job_id""".format(cls._table) + TRN.add(sql, [datatype_id, 1, command_id, opts_json]) + job_id = TRN.execute_fetchlast() + + # add job to analysis + TRN.add(analysis_sql, [analysis.id, job_id]) + TRN.execute() + + return cls(job_id) @property def datatype(self): - sql = ("SELECT data_type from qiita.data_type WHERE data_type_id = " - "(SELECT data_type_id from qiita.{0} WHERE " - "job_id = %s)".format(self._table)) - conn_handler = SQLConnectionHandler() - return conn_handler.execute_fetchone(sql, (self._id, ))[0] + with TRN: + sql = """SELECT data_type + FROM qiita.data_type + WHERE data_type_id = ( + SELECT data_type_id + FROM qiita.{0} + WHERE job_id = %s)""".format(self._table) + TRN.add(sql, [self._id]) + return TRN.execute_fetchlast() @property def command(self): @@ -272,11 +301,16 @@ def command(self): str command run by the job """ - sql = ("SELECT name, command from qiita.command WHERE command_id = " - "(SELECT command_id from qiita.{0} WHERE " - "job_id = %s)".format(self._table)) - conn_handler = SQLConnectionHandler() - return conn_handler.execute_fetchone(sql, (self._id, )) + with TRN: + sql = """SELECT name, command + FROM qiita.command + WHERE command_id = ( + SELECT command_id + FROM qiita.{0} + WHERE job_id = %s)""".format(self._table) + TRN.add(sql, [self._id]) + # We only want the first row (the only one present) + return TRN.execute_fetchindex()[0] @property def options(self): @@ -287,21 +321,29 @@ def options(self): dict options in the format {option: setting} """ - sql = ("SELECT options FROM qiita.{0} WHERE " - "job_id = %s".format(self._table)) - conn_handler = SQLConnectionHandler() - db_opts = conn_handler.execute_fetchone(sql, (self._id, ))[0] - opts = loads(db_opts) if db_opts else {} - sql = ("SELECT command, output from qiita.command WHERE command_id = (" - "SELECT command_id from qiita.{0} WHERE " - "job_id = %s)".format(self._table)) - db_comm = conn_handler.execute_fetchone(sql, (self._id, )) - out_opt = loads(db_comm[1]) - basedir = get_db_files_base_dir() - join_f = partial(join, join(basedir, "job")) - for k in out_opt: - opts[k] = join_f("%s_%s_%s" % (self._id, db_comm[0], k.strip("-"))) - return opts + with TRN: + sql = """SELECT options FROM qiita.{0} + WHERE job_id = %s""".format(self._table) + TRN.add(sql, [self._id]) + db_opts = TRN.execute_fetchlast() + opts = loads(db_opts) if db_opts else {} + + sql = """SELECT command, output + FROM qiita.command + WHERE command_id = ( + SELECT command_id + FROM qiita.{0} + WHERE job_id = %s)""".format(self._table) + TRN.add(sql, [self._id]) + db_comm = TRN.execute_fetchindex()[0] + + out_opt = loads(db_comm[1]) + basedir = get_db_files_base_dir() + join_f = partial(join, join(basedir, "job")) + for k in out_opt: + opts[k] = join_f("%s_%s_%s" % (self._id, db_comm[0], + k.strip("-"))) + return opts @options.setter def options(self, opts): @@ -312,16 +354,17 @@ def options(self, opts): opts: dict The options for the command in format {option: value} """ - conn_handler = SQLConnectionHandler() - # make sure job is editable - self._lock_job(conn_handler) - - # JSON the options dictionary - opts_json = params_dict_to_json(opts) - # Add the options to the job - sql = ("UPDATE qiita.{0} SET options = %s WHERE " - "job_id = %s").format(self._table) - conn_handler.execute(sql, (opts_json, self._id)) + with TRN: + # make sure job is editable + self._lock_job() + + # JSON the options dictionary + opts_json = params_dict_to_json(opts) + # Add the options to the job + sql = """UPDATE qiita.{0} SET options = %s + WHERE job_id = %s""".format(self._table) + TRN.add(sql, [opts_json, self._id]) + TRN.execute() @property def results(self): @@ -333,33 +376,34 @@ def results(self): Filepaths to the result files """ # Select results filepaths and filepath types from the database - conn_handler = SQLConnectionHandler() - _, basedir = get_mountpoint('job')[0] - results = conn_handler.execute_fetchall( - "SELECT fp.filepath, fpt.filepath_type FROM qiita.filepath fp " - "JOIN qiita.filepath_type fpt ON fp.filepath_type_id = " - "fpt.filepath_type_id JOIN qiita.job_results_filepath jrfp ON " - "fp.filepath_id = jrfp.filepath_id WHERE jrfp.job_id = %s", - (self._id, )) - - def add_html(basedir, check_dir, result_fps): - for res in glob(join(basedir, check_dir, "*.htm")) + \ - glob(join(basedir, check_dir, "*.html")): - result_fps.append(relpath(res, basedir)) - - # create new list, with relative paths from db base - result_fps = [] - for fp in results: - if fp[1] == "directory": - # directory, so all html files in it are results - # first, see if we have any in the main directory - add_html(basedir, fp[0], result_fps) - # now do all subdirectories - add_html(basedir, join(fp[0], "*"), result_fps) - else: - # result is exact filepath given - result_fps.append(fp[0]) - return result_fps + with TRN: + _, basedir = get_mountpoint('job')[0] + sql = """SELECT filepath, filepath_type + FROM qiita.filepath + JOIN qiita.filepath_type USING (filepath_type_id) + JOIN qiita.job_results_filepath USING (filepath_id) + WHERE job_id = %s""" + TRN.add(sql, [self._id]) + results = TRN.execute_fetchindex() + + def add_html(basedir, check_dir, result_fps): + for res in glob(join(basedir, check_dir, "*.htm")) + \ + glob(join(basedir, check_dir, "*.html")): + result_fps.append(relpath(res, basedir)) + + # create new list, with relative paths from db base + result_fps = [] + for fp in results: + if fp[1] == "directory": + # directory, so all html files in it are results + # first, see if we have any in the main directory + add_html(basedir, fp[0], result_fps) + # now do all subdirectories + add_html(basedir, join(fp[0], "*"), result_fps) + else: + # result is exact filepath given + result_fps.append(fp[0]) + return result_fps @property def error(self): @@ -370,16 +414,12 @@ def error(self): str or None error message/traceback for a job, or None if none exists """ - conn_handler = SQLConnectionHandler() - sql = ("SELECT log_id FROM qiita.{0} " - "WHERE job_id = %s".format(self._table)) - logging_id = conn_handler.execute_fetchone(sql, (self._id, ))[0] - if logging_id is None: - ret = None - else: - ret = LogEntry(logging_id) - - return ret + with TRN: + sql = "SELECT log_id FROM qiita.{0} WHERE job_id = %s".format( + self._table) + TRN.add(sql, [self._id]) + logging_id = TRN.execute_fetchlast() + return LogEntry(logging_id) if logging_id is not None else None # --- Functions --- def set_error(self, msg): @@ -390,18 +430,17 @@ def set_error(self, msg): msg : str Error message/stacktrace if available """ - conn_handler = SQLConnectionHandler() - log_entry = LogEntry.create('Runtime', msg, - info={'job': self._id}) - self._lock_job(conn_handler) - err_id = conn_handler.execute_fetchone( - "SELECT job_status_id FROM qiita.job_status WHERE " - "status = 'error'")[0] - # attach the error to the job and set to error - sql = ("UPDATE qiita.{0} SET log_id = %s, job_status_id = %s WHERE " - "job_id = %s".format(self._table)) - - conn_handler.execute(sql, (log_entry.id, err_id, self._id)) + with TRN: + log_entry = LogEntry.create('Runtime', msg, + info={'job': self._id}) + self._lock_job() + + err_id = convert_to_id('error', 'job_status', 'status') + # attach the error to the job and set to error + sql = """UPDATE qiita.{0} SET log_id = %s, job_status_id = %s + WHERE job_id = %s""".format(self._table) + TRN.add(sql, [log_entry.id, err_id, self._id]) + TRN.execute() def add_results(self, results): """Adds a list of results to the results @@ -418,19 +457,20 @@ def add_results(self, results): Curently available file types are: biom, directory, plain_text """ - # add filepaths to the job - conn_handler = SQLConnectionHandler() - self._lock_job(conn_handler) - # convert all file type text to file type ids - res_ids = [(fp, convert_to_id(fptype, "filepath_type")) - for fp, fptype in results] - file_ids = insert_filepaths(res_ids, self._id, self._table, - "filepath", conn_handler, move_files=False) - - # associate filepaths with job - sql = ("INSERT INTO qiita.{0}_results_filepath (job_id, filepath_id) " - "VALUES (%s, %s)".format(self._table)) - conn_handler.executemany(sql, [(self._id, fid) for fid in file_ids]) + with TRN: + self._lock_job() + # convert all file type text to file type ids + res_ids = [(fp, convert_to_id(fptype, "filepath_type")) + for fp, fptype in results] + file_ids = insert_filepaths(res_ids, self._id, self._table, + "filepath", move_files=False) + + # associate filepaths with job + sql = """INSERT INTO qiita.{0}_results_filepath + (job_id, filepath_id) + VALUES (%s, %s)""".format(self._table) + TRN.add(sql, [[self._id, fid] for fid in file_ids], many=True) + TRN.execute() class Command(object): @@ -457,11 +497,12 @@ def create_list(cls): ------- list of Command objects """ - conn_handler = SQLConnectionHandler() - commands = conn_handler.execute_fetchall("SELECT * FROM qiita.command") - # create the list of command objects - return [cls(c["name"], c["command"], c["input"], c["required"], - c["optional"], c["output"]) for c in commands] + with TRN: + TRN.add("SELECT * FROM qiita.command") + commands = TRN.execute_fetchindex() + # create the list of command objects + return [cls(c["name"], c["command"], c["input"], c["required"], + c["optional"], c["output"]) for c in commands] @classmethod def get_commands_by_datatype(cls, datatypes=None): @@ -482,28 +523,32 @@ def get_commands_by_datatype(cls, datatypes=None): If no datatypes are passed, the function will default to returning all datatypes available. """ - conn_handler = SQLConnectionHandler() - # get the ids of the datatypes to get commands for - if datatypes is not None: - datatype_info = [(convert_to_id(dt, "data_type"), dt) - for dt in datatypes] - else: - datatype_info = conn_handler.execute_fetchall( - "SELECT data_type_id, data_type from qiita.data_type") - - commands = defaultdict(list) - # get commands for each datatype - sql = ("SELECT C.* FROM qiita.command C JOIN qiita.command_data_type " - "CD on C.command_id = CD.command_id WHERE CD.data_type_id = %s") - for dt_id, dt in datatype_info: - comms = conn_handler.execute_fetchall(sql, (dt_id, )) - for comm in comms: - commands[dt].append(cls(comm["name"], comm["command"], - comm["input"], - comm["required"], - comm["optional"], - comm["output"])) - return commands + with TRN: + # get the ids of the datatypes to get commands for + if datatypes is not None: + datatype_info = [(convert_to_id(dt, "data_type"), dt) + for dt in datatypes] + else: + sql = "SELECT data_type_id, data_type from qiita.data_type" + TRN.add(sql) + datatype_info = TRN.execute_fetchindex() + + commands = defaultdict(list) + # get commands for each datatype + sql = """SELECT C.* + FROM qiita.command C + JOIN qiita.command_data_type USING (command_id) + WHERE data_type_id = %s""" + for dt_id, dt in datatype_info: + TRN.add(sql, [dt_id]) + comms = TRN.execute_fetchindex() + for comm in comms: + commands[dt].append(cls(comm["name"], comm["command"], + comm["input"], + comm["required"], + comm["optional"], + comm["output"])) + return commands def __eq__(self, other): if type(self) != type(other): diff --git a/qiita_db/logger.py b/qiita_db/logger.py index 7a77e4212..6c0593ad3 100644 --- a/qiita_db/logger.py +++ b/qiita_db/logger.py @@ -28,7 +28,7 @@ from json import loads, dumps from qiita_db.util import convert_to_id -from .sql_connection import SQLConnectionHandler +from .sql_connection import TRN from .base import QiitaObject @@ -63,13 +63,13 @@ def newest_records(cls, numrecords=100): list of LogEntry objects list of the log entries """ - conn_handler = SQLConnectionHandler() - sql = ("SELECT logging_id FROM qiita.{0} ORDER BY logging_id DESC " - "LIMIT %s".format(cls._table)) - ids = [x[0] - for x in conn_handler.execute_fetchall(sql, (numrecords, ))] + with TRN: + sql = """SELECT logging_id + FROM qiita.{0} + ORDER BY logging_id DESC LIMIT %s""".format(cls._table) + TRN.add(sql, [numrecords]) - return [cls(i) for i in ids] + return [cls(i) for i in TRN.execute_fetchflatten()] @classmethod def create(cls, severity, msg, info=None): @@ -97,14 +97,14 @@ def create(cls, severity, msg, info=None): info = dumps([info]) - conn_handler = SQLConnectionHandler() - sql = ("INSERT INTO qiita.{} (time, severity_id, msg, information) " - "VALUES (NOW(), %s, %s, %s) " - "RETURNING logging_id".format(cls._table)) - severity_id = convert_to_id(severity, "severity") - id_ = conn_handler.execute_fetchone(sql, (severity_id, msg, info))[0] + with TRN: + sql = """INSERT INTO qiita.{} (time, severity_id, msg, information) + VALUES (NOW(), %s, %s, %s) + RETURNING logging_id""".format(cls._table) + severity_id = convert_to_id(severity, "severity") + TRN.add(sql, [severity_id, msg, info]) - return cls(id_) + return cls(TRN.execute_fetchlast()) @property def severity(self): @@ -115,11 +115,11 @@ def severity(self): int This is a key to the SEVERITY table """ - conn_handler = SQLConnectionHandler() - sql = ("SELECT severity_id FROM qiita.{} WHERE " - "logging_id = %s".format(self._table)) - - return conn_handler.execute_fetchone(sql, (self.id,))[0] + with TRN: + sql = """SELECT severity_id FROM qiita.{} + WHERE logging_id = %s""".format(self._table) + TRN.add(sql, [self.id]) + return TRN.execute_fetchlast() @property def time(self): @@ -129,12 +129,12 @@ def time(self): ------- datetime """ - conn_handler = SQLConnectionHandler() - sql = ("SELECT time FROM qiita.{} " - "WHERE logging_id = %s".format(self._table)) - timestamp = conn_handler.execute_fetchone(sql, (self.id,))[0] + with TRN: + sql = "SELECT time FROM qiita.{} WHERE logging_id = %s".format( + self._table) + TRN.add(sql, [self.id]) - return timestamp + return TRN.execute_fetchlast() @property def info(self): @@ -153,12 +153,12 @@ def info(self): - When `info` is added, keys can be of any type, but upon retrieval, they will be of type str """ - conn_handler = SQLConnectionHandler() - sql = ("SELECT information FROM qiita.{} " - "WHERE logging_id = %s".format(self._table)) - info = conn_handler.execute_fetchone(sql, (self.id,))[0] + with TRN: + sql = """SELECT information FROM qiita.{} WHERE + logging_id = %s""".format(self._table) + TRN.add(sql, [self.id]) - return loads(info) + return loads(TRN.execute_fetchlast()) @property def msg(self): @@ -168,21 +168,22 @@ def msg(self): ------- str """ - conn_handler = SQLConnectionHandler() - sql = ("SELECT msg FROM qiita.{0} " - "WHERE logging_id = %s".format(self._table)) + with TRN: + sql = "SELECT msg FROM qiita.{0} WHERE logging_id = %s".format( + self._table) + TRN.add(sql, [self.id]) - return conn_handler.execute_fetchone(sql, (self.id,))[0] + return TRN.execute_fetchlast() def clear_info(self): """Resets the list of info dicts to be an empty list """ - conn_handler = SQLConnectionHandler() - sql = ("UPDATE qiita.{} set information = %s " - "WHERE logging_id = %s".format(self._table)) - new_info = dumps([]) + with TRN: + sql = """UPDATE qiita.{} SET information = %s + WHERE logging_id = %s""".format(self._table) + TRN.add(sql, [dumps([]), self.id]) - conn_handler.execute(sql, (new_info, self.id)) + TRN.execute() def add_info(self, info): """Adds new information to the info associated with this LogEntry @@ -197,11 +198,12 @@ def add_info(self, info): - When `info` is added, keys can be of any type, but upon retrieval, they will be of type str """ - conn_handler = SQLConnectionHandler() - current_info = self.info - current_info.append(info) - new_info = dumps(current_info) - - sql = ("UPDATE qiita.{} SET information = %s " - "WHERE logging_id = %s".format(self._table)) - conn_handler.execute(sql, (new_info, self.id)) + with TRN: + current_info = self.info + current_info.append(info) + new_info = dumps(current_info) + + sql = """UPDATE qiita.{} SET information = %s + WHERE logging_id = %s""".format(self._table) + TRN.add(sql, [new_info, self.id]) + TRN.execute() diff --git a/qiita_db/meta_util.py b/qiita_db/meta_util.py index 225b2f3c0..578367c76 100644 --- a/qiita_db/meta_util.py +++ b/qiita_db/meta_util.py @@ -27,7 +27,7 @@ from .study import Study from .data import RawData, PreprocessedData, ProcessedData from .analysis import Analysis -from .sql_connection import SQLConnectionHandler +from .sql_connection import TRN from .metadata_template import PrepTemplate, SampleTemplate @@ -45,8 +45,9 @@ def _get_data_fpids(constructor, object_id): ------- set of int """ - obj = constructor(object_id) - return {fpid for fpid, _, _ in obj.get_filepaths()} + with TRN: + obj = constructor(object_id) + return {fpid for fpid, _, _ in obj.get_filepaths()} def get_accessible_filepath_ids(user): @@ -72,93 +73,93 @@ def get_accessible_filepath_ids(user): Admins have access to all files, so all filepath ids are returned for admins """ - - if user.level == "admin": - # admins have access all files - conn_handler = SQLConnectionHandler() - fpids = conn_handler.execute_fetchall("SELECT filepath_id FROM " - "qiita.filepath") - return set(f[0] for f in fpids) - - # First, the studies - # There are private and shared studies - study_ids = user.user_studies | user.shared_studies - - filepath_ids = set() - for study_id in study_ids: - study = Study(study_id) - - # For each study, there are raw, preprocessed, and processed filepaths - raw_data_ids = study.raw_data() - preprocessed_data_ids = study.preprocessed_data() - processed_data_ids = study.processed_data() - - constructor_data_ids = ((RawData, raw_data_ids), - (PreprocessedData, preprocessed_data_ids), - (ProcessedData, processed_data_ids)) - - for constructor, data_ids in constructor_data_ids: - for data_id in data_ids: - filepath_ids.update(_get_data_fpids(constructor, data_id)) - - # adding prep and sample templates - prep_fp_ids = [] - for rdid in study.raw_data(): - for pt_id in RawData(rdid).prep_templates: - # related to https://github.com/biocore/qiita/issues/596 - if PrepTemplate.exists(pt_id): - for _id, _ in PrepTemplate(pt_id).get_filepaths(): - prep_fp_ids.append(_id) - filepath_ids.update(prep_fp_ids) - - if SampleTemplate.exists(study_id): - sample_fp_ids = [_id for _id, _ - in SampleTemplate(study_id).get_filepaths()] - filepath_ids.update(sample_fp_ids) - - # Next, the public processed data - processed_data_ids = ProcessedData.get_by_status('public') - for pd_id in processed_data_ids: - processed_data = ProcessedData(pd_id) - - # Add the filepaths of the processed data - pd_fps = (fpid for fpid, _, _ in processed_data.get_filepaths()) - filepath_ids.update(pd_fps) - - # Each processed data has a preprocessed data - ppd = PreprocessedData(processed_data.preprocessed_data) - ppd_fps = (fpid for fpid, _, _ in ppd.get_filepaths()) - filepath_ids.update(ppd_fps) - - # Each preprocessed data has a prep template - pt_id = ppd.prep_template - # related to https://github.com/biocore/qiita/issues/596 - if PrepTemplate.exists(pt_id): - pt = PrepTemplate(pt_id) - pt_fps = (fpid for fpid, _ in pt.get_filepaths()) - filepath_ids.update(pt_fps) - - # Each prep template has a raw data - rd = RawData(pt.raw_data) - rd_fps = (fpid for fpid, _, _ in rd.get_filepaths()) - filepath_ids.update(rd_fps) - - # And each processed data has a study, which has a sample template - st_id = processed_data.study - if SampleTemplate.exists(st_id): - sample_fp_ids = (_id for _id, _ - in SampleTemplate(st_id).get_filepaths()) - filepath_ids.update(sample_fp_ids) - - # Next, analyses - # Same as before, there are public, private, and shared - analysis_ids = Analysis.get_by_status('public') | user.private_analyses | \ - user.shared_analyses - - for analysis_id in analysis_ids: - analysis = Analysis(analysis_id) - - # For each analysis, there are mapping, biom, and job result filepaths - filepath_ids.update(analysis.all_associated_filepath_ids) - - return filepath_ids + with TRN: + if user.level == "admin": + # admins have access all files + TRN.add("SELECT filepath_id FROM qiita.filepath") + return set(TRN.execute_fetchflatten()) + + # First, the studies + # There are private and shared studies + study_ids = user.user_studies | user.shared_studies + + filepath_ids = set() + for study_id in study_ids: + study = Study(study_id) + + # For each study, there are raw, preprocessed, and + # processed filepaths + raw_data_ids = study.raw_data() + preprocessed_data_ids = study.preprocessed_data() + processed_data_ids = study.processed_data() + + constructor_data_ids = ((RawData, raw_data_ids), + (PreprocessedData, preprocessed_data_ids), + (ProcessedData, processed_data_ids)) + + for constructor, data_ids in constructor_data_ids: + for data_id in data_ids: + filepath_ids.update(_get_data_fpids(constructor, data_id)) + + # adding prep and sample templates + prep_fp_ids = [] + for rdid in study.raw_data(): + for pt_id in RawData(rdid).prep_templates: + # related to https://github.com/biocore/qiita/issues/596 + if PrepTemplate.exists(pt_id): + for _id, _ in PrepTemplate(pt_id).get_filepaths(): + prep_fp_ids.append(_id) + filepath_ids.update(prep_fp_ids) + + if SampleTemplate.exists(study_id): + sample_fp_ids = [_id for _id, _ + in SampleTemplate(study_id).get_filepaths()] + filepath_ids.update(sample_fp_ids) + + # Next, the public processed data + processed_data_ids = ProcessedData.get_by_status('public') + for pd_id in processed_data_ids: + processed_data = ProcessedData(pd_id) + + # Add the filepaths of the processed data + pd_fps = (fpid for fpid, _, _ in processed_data.get_filepaths()) + filepath_ids.update(pd_fps) + + # Each processed data has a preprocessed data + ppd = PreprocessedData(processed_data.preprocessed_data) + ppd_fps = (fpid for fpid, _, _ in ppd.get_filepaths()) + filepath_ids.update(ppd_fps) + + # Each preprocessed data has a prep template + pt_id = ppd.prep_template + # related to https://github.com/biocore/qiita/issues/596 + if PrepTemplate.exists(pt_id): + pt = PrepTemplate(pt_id) + pt_fps = (fpid for fpid, _ in pt.get_filepaths()) + filepath_ids.update(pt_fps) + + # Each prep template has a raw data + rd = RawData(pt.raw_data) + rd_fps = (fpid for fpid, _, _ in rd.get_filepaths()) + filepath_ids.update(rd_fps) + + # And each processed data has a study, which has a sample template + st_id = processed_data.study + if SampleTemplate.exists(st_id): + sample_fp_ids = (_id for _id, _ + in SampleTemplate(st_id).get_filepaths()) + filepath_ids.update(sample_fp_ids) + + # Next, analyses + # Same as before, there are public, private, and shared + analysis_ids = Analysis.get_by_status('public') | \ + user.private_analyses | user.shared_analyses + + for analysis_id in analysis_ids: + analysis = Analysis(analysis_id) + + # For each analysis, there are mapping, biom, and job result + # filepaths + filepath_ids.update(analysis.all_associated_filepath_ids) + + return filepath_ids diff --git a/qiita_db/metadata_template/base_metadata_template.py b/qiita_db/metadata_template/base_metadata_template.py index eceecd984..4f5435710 100644 --- a/qiita_db/metadata_template/base_metadata_template.py +++ b/qiita_db/metadata_template/base_metadata_template.py @@ -52,7 +52,7 @@ QiitaDBNotImplementedError, QiitaDBError, QiitaDBWarning, QiitaDBDuplicateHeaderError) from qiita_db.base import QiitaObject -from qiita_db.sql_connection import SQLConnectionHandler +from qiita_db.sql_connection import TRN from qiita_db.util import (exists_table, get_table_cols, get_mountpoint, insert_filepaths) from qiita_db.logger import LogEntry @@ -172,21 +172,18 @@ def exists(cls, sample_id, md_template): bool True if already exists. False otherwise. """ - cls._check_subclass() - conn_handler = SQLConnectionHandler() - return conn_handler.execute_fetchone( - "SELECT EXISTS(SELECT * FROM qiita.{0} WHERE sample_id=%s AND " - "{1}=%s)".format(cls._table, cls._id_column), - (sample_id, md_template.id))[0] - - def _get_categories(self, conn_handler): + with TRN: + cls._check_subclass() + sql = """SELECT EXISTS( + SELECT * FROM qiita.{0} + WHERE sample_id=%s AND {1}=%s + )""".format(cls._table, cls._id_column) + TRN.add(sql, [sample_id, md_template.id]) + return TRN.execute_fetchlast() + + def _get_categories(self): r"""Returns all the available metadata categories for the sample - Parameters - ---------- - conn_handler : SQLConnectionHandler - The connection handler object connected to the DB - Returns ------- set of str @@ -197,7 +194,6 @@ def _get_categories(self, conn_handler): # Remove the sample_id column as this column is used internally for # data storage and it doesn't actually belong to the metadata cols.remove('sample_id') - return set(cols) def _to_dict(self): @@ -208,16 +204,16 @@ def _to_dict(self): dict of {str: str} A dictionary of the form {category: value} """ - conn_handler = SQLConnectionHandler() - d = dict(conn_handler.execute_fetchone( - "SELECT * from qiita.{0} WHERE " - "sample_id=%s".format(self._dynamic_table), - (self._id, ))) + with TRN: + sql = "SELECT * FROM qiita.{0} WHERE sample_id=%s".format( + self._dynamic_table) + TRN.add(sql, [self._id]) + d = dict(TRN.execute_fetchindex()[0]) - # Remove the sample_id, is not part of the metadata - del d['sample_id'] + # Remove the sample_id, is not part of the metadata + del d['sample_id'] - return d + return d def __len__(self): r"""Returns the number of metadata categories @@ -227,9 +223,8 @@ def __len__(self): int The number of metadata categories """ - conn_handler = SQLConnectionHandler() # return the number of columns - return len(self._get_categories(conn_handler)) + return len(self._get_categories()) def __getitem__(self, key): r"""Returns the value of the metadata category `key` @@ -253,21 +248,21 @@ def __getitem__(self, key): -------- get """ - conn_handler = SQLConnectionHandler() - key = key.lower() - if key not in self._get_categories(conn_handler): - # The key is not available for the sample, so raise a KeyError - raise KeyError("Metadata category %s does not exists for sample %s" - " in template %d" % - (key, self._id, self._md_template.id)) - - sql = """SELECT {0} FROM qiita.{1} - WHERE sample_id=%s""".format(key, self._dynamic_table) - - return conn_handler.execute_fetchone(sql, (self._id, ))[0] - - def add_setitem_queries(self, column, value, conn_handler, queue): - """Adds the SQL queries needed to set a value to the provided queue + with TRN: + key = key.lower() + if key not in self._get_categories(): + # The key is not available for the sample, so raise a KeyError + raise KeyError( + "Metadata category %s does not exists for sample %s" + " in template %d" % (key, self._id, self._md_template.id)) + + sql = """SELECT {0} FROM qiita.{1} + WHERE sample_id=%s""".format(key, self._dynamic_table) + TRN.add(sql, [self._id]) + return TRN.execute_fetchlast() + + def setitem(self, column, value): + """Sets `value` as value for the given `column` Parameters ---------- @@ -276,26 +271,22 @@ def add_setitem_queries(self, column, value, conn_handler, queue): value : str The value to set. This is expected to be a str on the assumption that psycopg2 will cast as necessary when updating. - conn_handler : SQLConnectionHandler - The connection handler object connected to the DB - queue : str - The queue where the SQL statements will be added Raises ------ QiitaDBColumnError If the column does not exist in the table """ - # Check if the column exist in the table - if column not in self._get_categories(conn_handler): - raise QiitaDBColumnError("Column %s does not exist in %s" % - (column, self._dynamic_table)) - - sql = """UPDATE qiita.{0} - SET {1}=%s - WHERE sample_id=%s""".format(self._dynamic_table, column) + with TRN: + # Check if the column exist in the table + if column not in self._get_categories(): + raise QiitaDBColumnError("Column %s does not exist in %s" % + (column, self._dynamic_table)) - conn_handler.add_to_queue(queue, sql, (value, self._id)) + sql = """UPDATE qiita.{0} + SET {1}=%s + WHERE sample_id=%s""".format(self._dynamic_table, column) + TRN.add(sql, [value, self._id]) def __setitem__(self, column, value): r"""Sets the metadata value for the category `column` @@ -313,36 +304,33 @@ def __setitem__(self, column, value): ValueError If the value type does not match the one in the DB """ - conn_handler = SQLConnectionHandler() - queue_name = "set_item_%s" % self._id - conn_handler.create_queue(queue_name) - - self.add_setitem_queries(column, value, conn_handler, queue_name) - - try: - conn_handler.execute_queue(queue_name) - except ValueError as e: - # catching error so we can check if the error is due to different - # column type or something else - value_type = type_lookup(type(value)) - - sql = """SELECT udt_name - FROM information_schema.columns - WHERE column_name = %s - AND table_schema = 'qiita' - AND (table_name = %s OR table_name = %s)""" - column_type = conn_handler.execute_fetchone( - sql, (column, self._table, self._dynamic_table)) - - if column_type != value_type: - raise ValueError( - 'The new value being added to column: "{0}" is "{1}" ' - '(type: "{2}"). However, this column in the DB is of ' - 'type "{3}". Please change the value in your updated ' - 'template or reprocess your template.'.format( - column, value, value_type, column_type)) - - raise e + with TRN: + self.setitem(column, value) + + try: + TRN.execute() + except ValueError as e: + # catching error so we can check if the error is due to + # different column type or something else + value_type = type_lookup(type(value)) + + sql = """SELECT udt_name + FROM information_schema.columns + WHERE column_name = %s + AND table_schema = 'qiita' + AND (table_name = %s OR table_name = %s)""" + TRN.add(sql, [column, self._table, self._dynamic_table]) + column_type = TRN.execute_fetchlast() + + if column_type != value_type: + raise ValueError( + 'The new value being added to column: "{0}" is "{1}" ' + '(type: "{2}"). However, this column in the DB is of ' + 'type "{3}". Please change the value in your updated ' + 'template or reprocess your template.'.format( + column, value, value_type, column_type)) + + raise e def __delitem__(self, key): r"""Removes the sample with sample id `key` from the database @@ -366,8 +354,7 @@ def __iter__(self): -------- keys """ - conn_handler = SQLConnectionHandler() - return iter(self._get_categories(conn_handler)) + return iter(self._get_categories()) def __contains__(self, key): r"""Checks if the metadata category `key` is present @@ -382,8 +369,7 @@ def __contains__(self, key): bool True if the metadata category `key` is present, false otherwise """ - conn_handler = SQLConnectionHandler() - return key.lower() in self._get_categories(conn_handler) + return key.lower() in self._get_categories() def keys(self): r"""Iterator over the metadata categories @@ -486,14 +472,11 @@ class MetadataTemplate(QiitaObject): def _check_id(self, id_): r"""Checks that the MetadataTemplate id_ exists on the database""" - self._check_subclass() - - conn_handler = SQLConnectionHandler() - - return conn_handler.execute_fetchone( - "SELECT EXISTS(SELECT * FROM qiita.{0} WHERE " - "{1}=%s)".format(self._table, self._id_column), - (id_, ))[0] + with TRN: + sql = "SELECT EXISTS(SELECT * FROM qiita.{0} WHERE {1}=%s)".format( + self._table, self._id_column) + TRN.add(sql, [id_]) + return TRN.execute_fetchlast() @classmethod def _table_name(cls, obj_id): @@ -587,9 +570,8 @@ def _clean_validate_template(cls, md_template, study_id, restriction_dict): return md_template @classmethod - def _add_common_creation_steps_to_queue(cls, md_template, obj_id, - conn_handler, queue_name): - r"""Adds the common creation steps to the queue in conn_handler + def _common_creation_steps(cls, md_template, obj_id): + r"""Executes the common creation steps Parameters ---------- @@ -597,158 +579,159 @@ def _add_common_creation_steps_to_queue(cls, md_template, obj_id, The metadata template file contents indexed by sample ids obj_id : int The id of the object being created - conn_handler : SQLConnectionHandler - The connection handler object connected to the DB - queue_name : str - The queue where the SQL statements will be added """ - cls._check_subclass() + with TRN: + cls._check_subclass() + + # Get some useful information from the metadata template + sample_ids = md_template.index.tolist() + headers = sorted(md_template.keys().tolist()) + + # Insert values on template_sample table + values = [[obj_id, s_id] for s_id in sample_ids] + sql = """INSERT INTO qiita.{0} ({1}, sample_id) + VALUES (%s, %s)""".format(cls._table, cls._id_column) + TRN.add(sql, values, many=True) + + # Insert rows on *_columns table + datatypes = get_datatypes(md_template.ix[:, headers]) + # psycopg2 requires a list of tuples, in which each tuple is a set + # of values to use in the string formatting of the query. We have + # all the values in different lists (but in the same order) so use + # zip to create the list of tuples that psycopg2 requires. + values = [[obj_id, h, d] for h, d in zip(headers, datatypes)] + sql = """INSERT INTO qiita.{0} ({1}, column_name, column_type) + VALUES (%s, %s, %s)""".format(cls._column_table, + cls._id_column) + TRN.add(sql, values, many=True) + + # Create table with custom columns + table_name = cls._table_name(obj_id) + column_datatype = ["%s %s" % (col, dtype) + for col, dtype in zip(headers, datatypes)] + sql = """CREATE TABLE qiita.{0} ( + sample_id varchar NOT NULL, {1}, + CONSTRAINT fk_{0} FOREIGN KEY (sample_id) + REFERENCES qiita.study_sample (sample_id) + ON UPDATE CASCADE + )""".format(table_name, ', '.join(column_datatype)) + TRN.add(sql) + + # Insert values on custom table + values = as_python_types(md_template, headers) + values.insert(0, sample_ids) + values = [list(v) for v in zip(*values)] + sql = """INSERT INTO qiita.{0} (sample_id, {1}) + VALUES (%s, {2})""".format( + table_name, ", ".join(headers), + ', '.join(["%s"] * len(headers))) + TRN.add(sql, values, many=True) + + # Execute all the steps + TRN.execute() - # Get some useful information from the metadata template - sample_ids = md_template.index.tolist() - headers = sorted(md_template.keys().tolist()) - - # Insert values on template_sample table - values = [(obj_id, s_id) for s_id in sample_ids] - sql = "INSERT INTO qiita.{0} ({1}, sample_id) VALUES (%s, %s)".format( - cls._table, cls._id_column) - conn_handler.add_to_queue(queue_name, sql, values, many=True) - - # Insert rows on *_columns table - datatypes = get_datatypes(md_template.ix[:, headers]) - # psycopg2 requires a list of tuples, in which each tuple is a set - # of values to use in the string formatting of the query. We have all - # the values in different lists (but in the same order) so use zip - # to create the list of tuples that psycopg2 requires. - values = [(obj_id, h, d) for h, d in zip(headers, datatypes)] - sql = ("INSERT INTO qiita.{0} ({1}, column_name, column_type) " - "VALUES (%s, %s, %s)").format(cls._column_table, cls._id_column) - conn_handler.add_to_queue(queue_name, sql, values, many=True) - - # Create table with custom columns - table_name = cls._table_name(obj_id) - column_datatype = ["%s %s" % (col, dtype) - for col, dtype in zip(headers, datatypes)] - conn_handler.add_to_queue( - queue_name, - "CREATE TABLE qiita.{0} (" - "sample_id varchar NOT NULL, {1}, " - "CONSTRAINT fk_{0} FOREIGN KEY (sample_id) " - "REFERENCES qiita.study_sample (sample_id) " - "ON UPDATE CASCADE)".format( - table_name, ', '.join(column_datatype))) - - # Insert values on custom table - values = as_python_types(md_template, headers) - values.insert(0, sample_ids) - values = [v for v in zip(*values)] - conn_handler.add_to_queue( - queue_name, - "INSERT INTO qiita.{0} (sample_id, {1}) " - "VALUES (%s, {2})".format(table_name, ", ".join(headers), - ', '.join(["%s"] * len(headers))), - values, many=True) - - def _add_common_extend_steps_to_queue(self, md_template, conn_handler, - queue_name): - r"""Adds the common extend steps to the queue in conn_handler + def _common_extend_steps(self, md_template): + r"""executes the common extend steps Parameters ---------- md_template : DataFrame The metadata template file contents indexed by sample ids - conn_handler : SQLConnectionHandler - The connection handler object connected to the DB - queue_name : str - The queue where the SQL statements will be added Raises ------ QiitaDBError If no new samples or new columns are present in `md_template` """ - # Check if we are adding new samples - sample_ids = md_template.index.tolist() - curr_samples = set(self.keys()) - existing_samples = curr_samples.intersection(sample_ids) - new_samples = set(sample_ids).difference(existing_samples) - - # Check if we are adding new columns - headers = md_template.keys().tolist() - new_cols = set(headers).difference(self.categories()) - - if not new_cols and not new_samples: - raise QiitaDBError( - "No new samples or new columns found in the template. If you " - "want to update existing values, you should use the 'update' " - "functionality.") - - table_name = self._table_name(self._id) - if new_cols: - # If we are adding new columns, add them first (simplifies code) - # Sorting the new columns to enforce an order - new_cols = sorted(new_cols) - datatypes = get_datatypes(md_template.ix[:, new_cols]) - sql_cols = """INSERT INTO qiita.{0} ({1}, column_name, column_type) - VALUES (%s, %s, %s)""".format(self._column_table, - self._id_column) - sql_alter = """ALTER TABLE qiita.{0} ADD COLUMN {1} {2}""" - for category, dtype in zip(new_cols, datatypes): - conn_handler.add_to_queue( - queue_name, sql_cols, (self._id, category, dtype)) - conn_handler.add_to_queue( - queue_name, sql_alter.format(table_name, category, dtype)) - - if existing_samples: + with TRN: + # Check if we are adding new samples + sample_ids = md_template.index.tolist() + curr_samples = set(self.keys()) + existing_samples = curr_samples.intersection(sample_ids) + new_samples = set(sample_ids).difference(existing_samples) + + # Check if we are adding new columns + headers = md_template.keys().tolist() + new_cols = set(headers).difference(self.categories()) + + if not new_cols and not new_samples: + raise QiitaDBError( + "No new samples or new columns found in the template. " + "If you want to update existing values, you should use " + "the 'update' functionality.") + + table_name = self._table_name(self._id) + if new_cols: + # If we are adding new columns, add them first (simplifies + # code). Sorting the new columns to enforce an order + new_cols = sorted(new_cols) + datatypes = get_datatypes(md_template.ix[:, new_cols]) + sql_cols = """INSERT INTO qiita.{0} + ({1}, column_name, column_type) + VALUES (%s, %s, %s)""".format(self._column_table, + self._id_column) + sql_alter = """ALTER TABLE qiita.{0} ADD COLUMN {1} {2}""" + for category, dtype in zip(new_cols, datatypes): + TRN.add(sql_cols, [self._id, category, dtype]) + TRN.add(sql_alter.format(table_name, category, dtype)) + + if existing_samples: + warnings.warn( + "No values have been modified for existing samples " + "(%s). However, the following columns have been added " + "to them: '%s'" + % (len(existing_samples), ", ".join(new_cols)), + QiitaDBWarning) + # The values for the new columns are the only ones that get + # added to the database. None of the existing values will + # be modified (see update for that functionality) + min_md_template = \ + md_template[new_cols].loc[existing_samples] + values = as_python_types(min_md_template, new_cols) + values.append(existing_samples) + # psycopg2 requires a list of iterable, in which each + # iterable is a set of values to use in the string + # formatting of the query. We have all the values in + # different lists (but in the same order) so use zip to + # create the list of iterable that psycopg2 requires. + values = [list(v) for v in zip(*values)] + set_str = ["{0} = %s".format(col) for col in new_cols] + sql = """UPDATE qiita.{0} + SET {1} + WHERE sample_id=%s""".format(table_name, + ",".join(set_str)) + TRN.add(sql, values, many=True) + elif existing_samples: warnings.warn( - "No values have been modified for existing samples (%s). " - "However, the following columns have been added to them: " - "'%s'" % (len(existing_samples), ", ".join(new_cols)), + "%d samples already exist in the template and " + "their values won't be modified" % len(existing_samples), QiitaDBWarning) - # The values for the new columns are the only ones that get - # added to the database. None of the existing values will be - # modified (see update for that functionality) - min_md_template = md_template[new_cols].loc[existing_samples] - values = as_python_types(min_md_template, new_cols) - values.append(existing_samples) - # psycopg2 requires a list of tuples, in which each tuple is a - # set of values to use in the string formatting of the query. - # We have all the values in different lists (but in the same - # order) so use zip to create the list of tuples that psycopg2 - # requires. - values = [v for v in zip(*values)] - set_str = ["{0} = %s".format(col) for col in new_cols] - sql = """UPDATE qiita.{0} - SET {1} - WHERE sample_id=%s""".format(table_name, - ",".join(set_str)) - conn_handler.add_to_queue(queue_name, sql, values, many=True) - elif existing_samples: - warnings.warn( - "%d samples already exist in the template and " - "their values won't be modified" % len(existing_samples), - QiitaDBWarning) - - if new_samples: - new_samples = sorted(new_samples) - # At this point we only want the information from the new samples - md_template = md_template.loc[new_samples] - # Insert values on required columns - values = [(self._id, s_id) for s_id in new_samples] - sql = """INSERT INTO qiita.{0} ({1}, sample_id) - VALUES (%s, %s)""".format(self._table, self._id_column) - conn_handler.add_to_queue(queue_name, sql, values, many=True) - - # Insert values on custom table - values = as_python_types(md_template, headers) - values.insert(0, new_samples) - values = [v for v in zip(*values)] - sql = """INSERT INTO qiita.{0} (sample_id, {1}) - VALUES (%s, {2})""".format( - table_name, ", ".join(headers), - ', '.join(["%s"] * len(headers))) - conn_handler.add_to_queue(queue_name, sql, values, many=True) + if new_samples: + new_samples = sorted(new_samples) + # At this point we only want the information + # from the new samples + md_template = md_template.loc[new_samples] + + # Insert values on required columns + values = [[self._id, s_id] for s_id in new_samples] + sql = """INSERT INTO qiita.{0} ({1}, sample_id) + VALUES (%s, %s)""".format(self._table, + self._id_column) + TRN.add(sql, values, many=True) + + # Insert values on custom table + values = as_python_types(md_template, headers) + values.insert(0, new_samples) + values = [list(v) for v in zip(*values)] + sql = """INSERT INTO qiita.{0} (sample_id, {1}) + VALUES (%s, {2})""".format( + table_name, ", ".join(headers), + ', '.join(["%s"] * len(headers))) + TRN.add(sql, values, many=True) + + # Execute all the steps + TRN.execute() @classmethod def exists(cls, obj_id): @@ -765,26 +748,21 @@ def exists(cls, obj_id): True if already exists. False otherwise. """ cls._check_subclass() - return exists_table(cls._table_name(obj_id), SQLConnectionHandler()) + return exists_table(cls._table_name(obj_id)) - def _get_sample_ids(self, conn_handler): + def _get_sample_ids(self): r"""Returns all the available samples for the metadata template - Parameters - ---------- - conn_handler : SQLConnectionHandler - The connection handler object connected to the DB - Returns ------- set of str The set of all available sample ids """ - sample_ids = conn_handler.execute_fetchall( - "SELECT sample_id FROM qiita.{0} WHERE " - "{1}=%s".format(self._table, self._id_column), - (self._id, )) - return set(sample_id[0] for sample_id in sample_ids) + with TRN: + sql = "SELECT sample_id FROM qiita.{0} WHERE {1}=%s".format( + self._table, self._id_column) + TRN.add(sql, [self._id]) + return set(TRN.execute_fetchflatten()) def __len__(self): r"""Returns the number of samples in the metadata template @@ -794,8 +772,7 @@ def __len__(self): int The number of samples in the metadata template """ - conn_handler = SQLConnectionHandler() - return len(self._get_sample_ids(conn_handler)) + return len(self._get_sample_ids()) def __getitem__(self, key): r"""Returns the metadata values for sample id `key` @@ -819,11 +796,12 @@ def __getitem__(self, key): -------- get """ - if key in self: - return self._sample_cls(key, self) - else: - raise KeyError("Sample id %s does not exists in template %d" - % (key, self._id)) + with TRN: + if key in self: + return self._sample_cls(key, self) + else: + raise KeyError("Sample id %s does not exists in template %d" + % (key, self._id)) def __setitem__(self, key, value): r"""Sets the metadata values for sample id `key` @@ -859,8 +837,7 @@ def __iter__(self): -------- keys """ - conn_handler = SQLConnectionHandler() - return iter(self._get_sample_ids(conn_handler)) + return iter(self._get_sample_ids()) def __contains__(self, key): r"""Checks if the sample id `key` is present in the metadata template @@ -876,8 +853,7 @@ def __contains__(self, key): True if the sample id `key` is in the metadata template, false otherwise """ - conn_handler = SQLConnectionHandler() - return key in self._get_sample_ids(conn_handler) + return key in self._get_sample_ids() def keys(self): r"""Iterator over the sorted sample ids @@ -901,9 +877,9 @@ def values(self): Iterator Iterator over Sample obj """ - conn_handler = SQLConnectionHandler() - return iter(self._sample_cls(sample_id, self) - for sample_id in self._get_sample_ids(conn_handler)) + with TRN: + return iter(self._sample_cls(sample_id, self) + for sample_id in self._get_sample_ids()) def items(self): r"""Iterator over (sample_id, values) tuples, in sample id order @@ -913,9 +889,9 @@ def items(self): Iterator Iterator over (sample_ids, values) tuples """ - conn_handler = SQLConnectionHandler() - return iter((sample_id, self._sample_cls(sample_id, self)) - for sample_id in self._get_sample_ids(conn_handler)) + with TRN: + return iter((sample_id, self._sample_cls(sample_id, self)) + for sample_id in self._get_sample_ids()) def get(self, key): r"""Returns the metadata values for sample id `key`, or None if the @@ -990,17 +966,19 @@ def to_file(self, fp, samples=None): If supplied, only the specified samples will be written to the file """ - df = self.to_dataframe() - if samples is not None: - df = df.loc[samples] + with TRN: + df = self.to_dataframe() + if samples is not None: + df = df.loc[samples] - # Sorting the dataframe so multiple serializations of the metadata - # template are consistent. - df.sort_index(axis=0, inplace=True) - df.sort_index(axis=1, inplace=True) + # Sorting the dataframe so multiple serializations of the metadata + # template are consistent. + df.sort_index(axis=0, inplace=True) + df.sort_index(axis=1, inplace=True) - # Store the template in a file - df.to_csv(fp, index_label='sample_name', na_rep="", sep='\t') + # Store the template in a file + df.to_csv(fp, index_label='sample_name', na_rep="", sep='\t', + encoding='utf-8') def to_dataframe(self): """Returns the metadata template as a dataframe @@ -1010,69 +988,63 @@ def to_dataframe(self): pandas DataFrame The metadata in the template,indexed on sample id """ - conn_handler = SQLConnectionHandler() - cols = sorted(get_table_cols(self._table_name(self._id))) - # Get all metadata for the template - sql = "SELECT {0} FROM qiita.{1}".format(", ".join(cols), - self._table_name(self.id)) - meta = conn_handler.execute_fetchall(sql, (self._id,)) + with TRN: + cols = sorted(get_table_cols(self._table_name(self._id))) + # Get all metadata for the template + sql = "SELECT {0} FROM qiita.{1}".format(", ".join(cols), + self._table_name(self.id)) + TRN.add(sql, [self._id]) + meta = TRN.execute_fetchindex() - # Create the dataframe and clean it up a bit - df = pd.DataFrame((list(x) for x in meta), columns=cols) - df.set_index('sample_id', inplace=True, drop=True) + # Create the dataframe and clean it up a bit + df = pd.DataFrame((list(x) for x in meta), columns=cols) + df.set_index('sample_id', inplace=True, drop=True) - return df + return df def add_filepath(self, filepath, fp_id=None): r"""Populates the DB tables for storing the filepath and connects the `self` objects with this filepath""" - # Check that this function has been called from a subclass - self._check_subclass() - - # Check if the connection handler has been provided. Create a new - # one if not. - conn_handler = SQLConnectionHandler() - fp_id = self._fp_id if fp_id is None else fp_id - - try: - fpp_id = insert_filepaths([(filepath, fp_id)], None, - "templates", "filepath", conn_handler, - move_files=False)[0] - values = (self._id, fpp_id) - conn_handler.execute( - "INSERT INTO qiita.{0} ({1}, filepath_id) " - "VALUES (%s, %s)".format( - self._filepath_table, self._id_column), values) - except Exception as e: - LogEntry.create('Runtime', str(e), - info={self.__class__.__name__: self.id}) - raise e + with TRN: + fp_id = self._fp_id if fp_id is None else fp_id + + try: + fpp_id = insert_filepaths([(filepath, fp_id)], None, + "templates", "filepath", + move_files=False)[0] + sql = """INSERT INTO qiita.{0} ({1}, filepath_id) + VALUES (%s, %s)""".format(self._filepath_table, + self._id_column) + TRN.add(sql, [self._id, fpp_id]) + TRN.execute() + except Exception as e: + LogEntry.create('Runtime', str(e), + info={self.__class__.__name__: self.id}) + raise e def get_filepaths(self): r"""Retrieves the list of (filepath_id, filepath)""" - # Check that this function has been called from a subclass - self._check_subclass() - - # Check if the connection handler has been provided. Create a new - # one if not. - conn_handler = SQLConnectionHandler() - - try: - filepath_ids = conn_handler.execute_fetchall( - "SELECT filepath_id, filepath FROM qiita.filepath WHERE " - "filepath_id IN (SELECT filepath_id FROM qiita.{0} WHERE " - "{1}=%s) ORDER BY filepath_id DESC".format( - self._filepath_table, self._id_column), - (self.id, )) - except Exception as e: - LogEntry.create('Runtime', str(e), - info={self.__class__.__name__: self.id}) - raise e - - _, fb = get_mountpoint('templates')[0] - base_fp = partial(join, fb) - - return [(fpid, base_fp(fp)) for fpid, fp in filepath_ids] + with TRN: + try: + sql = """SELECT filepath_id, filepath + FROM qiita.filepath + WHERE filepath_id IN ( + SELECT filepath_id FROM qiita.{0} + WHERE {1}=%s) + ORDER BY filepath_id DESC""".format( + self._filepath_table, self._id_column) + + TRN.add(sql, [self.id]) + filepath_ids = TRN.execute_fetchindex() + except Exception as e: + LogEntry.create('Runtime', str(e), + info={self.__class__.__name__: self.id}) + raise e + + _, fb = get_mountpoint('templates')[0] + base_fp = partial(join, fb) + + return [(fpid, base_fp(fp)) for fpid, fp in filepath_ids] def categories(self): """Identifies the metadata columns present in a template @@ -1103,50 +1075,52 @@ def update(self, md_template): If md_template and db do not have the same column headers If self.can_be_updated is not True """ - conn_handler = SQLConnectionHandler() - - # Clean and validate the metadata template given - new_map = self._clean_validate_template(md_template, self.study_id, - self.columns_restrictions) - # Retrieving current metadata - current_map = self._transform_to_dict(conn_handler.execute_fetchall( - "SELECT * FROM qiita.{0}".format(self._table_name(self.id)))) - current_map = pd.DataFrame.from_dict(current_map, orient='index') - - # simple validations of sample ids and column names - samples_diff = set(new_map.index).difference(current_map.index) - if samples_diff: - raise QiitaDBError('The new template differs from what is stored ' - 'in database by these samples names: %s' - % ', '.join(samples_diff)) - columns_diff = set(new_map.columns).difference(current_map.columns) - if columns_diff: - raise QiitaDBError('The new template differs from what is stored ' - 'in database by these columns names: %s' - % ', '.join(columns_diff)) - - # here we are comparing two dataframes following: - # http://stackoverflow.com/a/17095620/4228285 - current_map.sort(axis=0, inplace=True) - current_map.sort(axis=1, inplace=True) - new_map.sort(axis=0, inplace=True) - new_map.sort(axis=1, inplace=True) - map_diff = (current_map != new_map).stack() - map_diff = map_diff[map_diff] - map_diff.index.names = ['id', 'column'] - changed_cols = map_diff.index.get_level_values('column').unique() - - if not self.can_be_updated(columns=set(changed_cols)): - raise QiitaDBError('The new template is modifying fields that ' - 'cannot be modified. Try removing the target ' - 'gene fields or deleting the processed data. ' - 'You are trying to modify: %s' - % ', '.join(changed_cols)) - - for col in changed_cols: - self.update_category(col, new_map[col].to_dict()) - - self.generate_files() + with TRN: + # Clean and validate the metadata template given + new_map = self._clean_validate_template(md_template, self.study_id, + self.columns_restrictions) + # Retrieving current metadata + sql = "SELECT * FROM qiita.{0}".format(self._table_name(self.id)) + TRN.add(sql) + current_map = self._transform_to_dict(TRN.execute_fetchindex()) + current_map = pd.DataFrame.from_dict(current_map, orient='index') + + # simple validations of sample ids and column names + samples_diff = set(new_map.index).difference(current_map.index) + if samples_diff: + raise QiitaDBError( + 'The new template differs from what is stored ' + 'in database by these samples names: %s' + % ', '.join(samples_diff)) + columns_diff = set(new_map.columns).difference(current_map.columns) + if columns_diff: + raise QiitaDBError( + 'The new template differs from what is stored ' + 'in database by these columns names: %s' + % ', '.join(columns_diff)) + + # here we are comparing two dataframes following: + # http://stackoverflow.com/a/17095620/4228285 + current_map.sort(axis=0, inplace=True) + current_map.sort(axis=1, inplace=True) + new_map.sort(axis=0, inplace=True) + new_map.sort(axis=1, inplace=True) + map_diff = (current_map != new_map).stack() + map_diff = map_diff[map_diff] + map_diff.index.names = ['id', 'column'] + changed_cols = map_diff.index.get_level_values('column').unique() + + if not self.can_be_updated(columns=set(changed_cols)): + raise QiitaDBError( + 'The new template is modifying fields that cannot be ' + 'modified. Try removing the target gene fields or ' + 'deleting the processed data. You are trying to modify: %s' + % ', '.join(changed_cols)) + + for col in changed_cols: + self.update_category(col, new_map[col].to_dict()) + + self.generate_files() def update_category(self, category, samples_and_values): """Update an existing column @@ -1169,49 +1143,48 @@ def update_category(self, category, samples_and_values): If one of the new values cannot be inserted in the DB due to different types """ - if not set(self.keys()).issuperset(samples_and_values): - missing = set(self.keys()) - set(samples_and_values) - table_name = self._table_name(self._id) - raise QiitaDBUnknownIDError(missing, table_name) - - conn_handler = SQLConnectionHandler() - queue_name = "update_category_%s_%s" % (self._id, category) - conn_handler.create_queue(queue_name) - - for k, v in viewitems(samples_and_values): - sample = self[k] - sample.add_setitem_queries(category, v, conn_handler, queue_name) - - try: - conn_handler.execute_queue(queue_name) - except ValueError as e: - # catching error so we can check if the error is due to different - # column type or something else - - value_types = set(type_lookup(type(value)) - for value in viewvalues(samples_and_values)) - - sql = """SELECT udt_name - FROM information_schema.columns - WHERE column_name = %s - AND table_schema = 'qiita' - AND (table_name = %s OR table_name = %s)""" - column_type = conn_handler.execute_fetchone( - sql, (category, self._table, self._table_name(self._id))) - - if any([column_type != vt for vt in value_types]): - value_str = ', '.join( - [str(value) for value in viewvalues(samples_and_values)]) - value_types_str = ', '.join(value_types) - - raise ValueError( - 'The new values being added to column: "%s" are "%s" ' - '(types: "%s"). However, this column in the DB is of ' - 'type "%s". Please change the values in your updated ' - 'template or reprocess your template.' - % (category, value_str, value_types_str, column_type)) - - raise e + with TRN: + if not set(self.keys()).issuperset(samples_and_values): + missing = set(self.keys()) - set(samples_and_values) + table_name = self._table_name(self._id) + raise QiitaDBUnknownIDError(missing, table_name) + + for k, v in viewitems(samples_and_values): + sample = self[k] + sample.setitem(category, v) + + try: + TRN.execute() + except ValueError as e: + # catching error so we can check if the error is due to + # different column type or something else + + value_types = set(type_lookup(type(value)) + for value in viewvalues(samples_and_values)) + + sql = """SELECT udt_name + FROM information_schema.columns + WHERE column_name = %s + AND table_schema = 'qiita' + AND (table_name = %s OR table_name = %s)""" + TRN.add(sql, + [category, self._table, self._table_name(self._id)]) + column_type = TRN.execute_fetchlast() + + if any([column_type != vt for vt in value_types]): + value_str = ', '.join( + [str(value) + for value in viewvalues(samples_and_values)]) + value_types_str = ', '.join(value_types) + + raise ValueError( + 'The new values being added to column: "%s" are "%s" ' + '(types: "%s"). However, this column in the DB is of ' + 'type "%s". Please change the values in your updated ' + 'template or reprocess your template.' + % (category, value_str, value_types_str, column_type)) + + raise e def check_restrictions(self, restrictions): """Checks if the template fulfills the restrictions diff --git a/qiita_db/metadata_template/prep_template.py b/qiita_db/metadata_template/prep_template.py index 9065a6ba6..898d75ca6 100644 --- a/qiita_db/metadata_template/prep_template.py +++ b/qiita_db/metadata_template/prep_template.py @@ -20,7 +20,7 @@ from qiita_db.exceptions import (QiitaDBColumnError, QiitaDBUnknownIDError, QiitaDBError, QiitaDBExecutionError, QiitaDBWarning) -from qiita_db.sql_connection import SQLConnectionHandler +from qiita_db.sql_connection import TRN from qiita_db.ontology import Ontology from qiita_db.util import (convert_to_id, convert_from_id, get_mountpoint, infer_status) @@ -102,77 +102,68 @@ def create(cls, md_template, study, data_type, investigation_type=None): If the investigation_type is not valid If a required column is missing in md_template """ - # If the investigation_type is supplied, make sure it is one of - # the recognized investigation types - if investigation_type is not None: - cls.validate_investigation_type(investigation_type) - - # Get a connection handler - conn_handler = SQLConnectionHandler() - queue_name = "CREATE_PREP_TEMPLATE_%d_%d" % (study.id, id(md_template)) - conn_handler.create_queue(queue_name) - - # Check if the data_type is the id or the string - if isinstance(data_type, (int, long)): - data_type_id = data_type - data_type_str = convert_from_id(data_type, "data_type") - else: - data_type_id = convert_to_id(data_type, "data_type") - data_type_str = data_type - - pt_cols = PREP_TEMPLATE_COLUMNS - if data_type_str in TARGET_GENE_DATA_TYPES: - pt_cols = deepcopy(PREP_TEMPLATE_COLUMNS) - pt_cols.update(PREP_TEMPLATE_COLUMNS_TARGET_GENE) - - md_template = cls._clean_validate_template(md_template, study.id, - pt_cols) - - # Insert the metadata template - # We need the prep_id for multiple calls below, which currently is not - # supported by the queue system. Thus, executing this outside the queue - prep_id = conn_handler.execute_fetchone( - "INSERT INTO qiita.prep_template " - "(data_type_id, investigation_type) " - "VALUES (%s, %s) RETURNING prep_template_id", - (data_type_id, investigation_type))[0] - - cls._add_common_creation_steps_to_queue(md_template, prep_id, - conn_handler, queue_name) - - # Link the prep template with the study - sql = ("INSERT INTO qiita.study_prep_template " - "(study_id, prep_template_id) VALUES (%s, %s)") - conn_handler.add_to_queue(queue_name, sql, (study.id, prep_id)) - - try: - conn_handler.execute_queue(queue_name) - except Exception: - # Clean up row from qiita.prep_template - conn_handler.execute( - "DELETE FROM qiita.prep_template where " - "{0} = %s".format(cls._id_column), (prep_id,)) - - # Check if sample IDs present here but not in sample template - sql = ("SELECT sample_id from qiita.study_sample WHERE " - "study_id = %s") - # Get list of study sample IDs, prep template study IDs, - # and their intersection - prep_samples = set(md_template.index.values) - unknown_samples = prep_samples.difference( - s[0] for s in conn_handler.execute_fetchall(sql, [study.id])) - if unknown_samples: - raise QiitaDBExecutionError( - 'Samples found in prep template but not sample template: ' - '%s' % ', '.join(unknown_samples)) - - # some other error we haven't seen before so raise it - raise - - pt = cls(prep_id) - pt.generate_files() - - return pt + with TRN: + # If the investigation_type is supplied, make sure it is one of + # the recognized investigation types + if investigation_type is not None: + cls.validate_investigation_type(investigation_type) + + # Check if the data_type is the id or the string + if isinstance(data_type, (int, long)): + data_type_id = data_type + data_type_str = convert_from_id(data_type, "data_type") + else: + data_type_id = convert_to_id(data_type, "data_type") + data_type_str = data_type + + pt_cols = PREP_TEMPLATE_COLUMNS + if data_type_str in TARGET_GENE_DATA_TYPES: + pt_cols = deepcopy(PREP_TEMPLATE_COLUMNS) + pt_cols.update(PREP_TEMPLATE_COLUMNS_TARGET_GENE) + + md_template = cls._clean_validate_template(md_template, study.id, + pt_cols) + + # Insert the metadata template + sql = """INSERT INTO qiita.prep_template + (data_type_id, investigation_type) + VALUES (%s, %s) + RETURNING prep_template_id""" + TRN.add(sql, [data_type_id, investigation_type]) + prep_id = TRN.execute_fetchlast() + + try: + cls._common_creation_steps(md_template, prep_id) + except Exception: + # Check if sample IDs present here but not in sample template + sql = """SELECT sample_id from qiita.study_sample + WHERE study_id = %s""" + # Get list of study sample IDs, prep template study IDs, + # and their intersection + TRN.add(sql, [study.id]) + prep_samples = set(md_template.index.values) + unknown_samples = prep_samples.difference( + TRN.execute_fetchflatten()) + if unknown_samples: + raise QiitaDBExecutionError( + 'Samples found in prep template but not sample ' + 'template: %s' % ', '.join(unknown_samples)) + + # some other error we haven't seen before so raise it + raise + + # Link the prep template with the study + sql = """INSERT INTO qiita.study_prep_template + (study_id, prep_template_id) + VALUES (%s, %s)""" + TRN.add(sql, [study.id, prep_id]) + + TRN.execute() + + pt = cls(prep_id) + pt.generate_files() + + return pt @classmethod def validate_investigation_type(self, investigation_type): @@ -188,12 +179,13 @@ def validate_investigation_type(self, investigation_type): QiitaDBColumnError The investigation type is not in the ENA ontology """ - ontology = Ontology(convert_to_id('ENA', 'ontology')) - terms = ontology.terms + ontology.user_defined_terms - if investigation_type not in terms: - raise QiitaDBColumnError("'%s' is Not a valid investigation_type. " - "Choose from: %s" % (investigation_type, - ', '.join(terms))) + with TRN: + ontology = Ontology(convert_to_id('ENA', 'ontology')) + terms = ontology.terms + ontology.user_defined_terms + if investigation_type not in terms: + raise QiitaDBColumnError( + "'%s' is Not a valid investigation_type. Choose from: %s" + % (investigation_type, ', '.join(terms))) @classmethod def delete(cls, id_): @@ -212,62 +204,65 @@ def delete(cls, id_): QiitaDBUnknownIDError If no prep template with id = id_ exists """ - table_name = cls._table_name(id_) - conn_handler = SQLConnectionHandler() - - if not cls.exists(id_): - raise QiitaDBUnknownIDError(id_, cls.__name__) - - preprocessed_data_exists = conn_handler.execute_fetchone( - "SELECT EXISTS(SELECT * FROM qiita.prep_template_preprocessed_data" - " WHERE prep_template_id=%s)", (id_,))[0] - - if preprocessed_data_exists: - raise QiitaDBExecutionError("Cannot remove prep template %d " - "because a preprocessed data has been" - " already generated using it." % id_) - - sql = """SELECT ( - SELECT raw_data_id - FROM qiita.prep_template - WHERE prep_template_id=%s) - IS NOT NULL""" - raw_data_attached = conn_handler.execute_fetchone(sql, (id_,))[0] - if raw_data_attached: - raise QiitaDBExecutionError( - "Cannot remove prep template %d because it has raw data " - "associated with it" % id_) - - # Delete the prep template filepaths - conn_handler.execute( - "DELETE FROM qiita.prep_template_filepath WHERE " - "prep_template_id = %s", (id_, )) - - # Drop the prep_X table - conn_handler.execute( - "DROP TABLE qiita.{0}".format(table_name)) - - # Remove the rows from prep_template_samples - conn_handler.execute( - "DELETE FROM qiita.{0} where {1} = %s".format(cls._table, - cls._id_column), - (id_,)) - - # Remove the rows from prep_columns - conn_handler.execute( - "DELETE FROM qiita.{0} where {1} = %s".format(cls._column_table, - cls._id_column), - (id_,)) - - # Remove the row from study_prep_template - conn_handler.execute( - "DELETE FROM qiita.study_prep_template " - "WHERE {0} = %s".format(cls._id_column), (id_,)) - - # Remove the row from prep_template - conn_handler.execute( - "DELETE FROM qiita.prep_template where " - "{0} = %s".format(cls._id_column), (id_,)) + with TRN: + table_name = cls._table_name(id_) + + if not cls.exists(id_): + raise QiitaDBUnknownIDError(id_, cls.__name__) + + sql = """SELECT EXISTS( + SELECT * FROM qiita.prep_template_preprocessed_data + WHERE prep_template_id=%s)""" + args = [id_] + TRN.add(sql, args) + preprocessed_data_exists = TRN.execute_fetchlast() + + if preprocessed_data_exists: + raise QiitaDBExecutionError( + "Cannot remove prep template %d because a preprocessed " + "data has been already generated using it." % id_) + + sql = """SELECT ( + SELECT raw_data_id + FROM qiita.prep_template + WHERE prep_template_id=%s) + IS NOT NULL""" + TRN.add(sql, args) + raw_data_attached = TRN.execute_fetchlast() + if raw_data_attached: + raise QiitaDBExecutionError( + "Cannot remove prep template %d because it has raw data " + "associated with it" % id_) + + # Delete the prep template filepaths + sql = """DELETE FROM qiita.prep_template_filepath + WHERE prep_template_id = %s""" + TRN.add(sql, args) + + # Drop the prep_X table + TRN.add("DROP TABLE qiita.{0}".format(table_name)) + + # Remove the rows from prep_template_samples + sql = "DELETE FROM qiita.{0} WHERE {1} = %s".format( + cls._table, cls._id_column) + TRN.add(sql, args) + + # Remove the rows from prep_columns + sql = "DELETE FROM qiita.{0} where {1} = %s".format( + cls._column_table, cls._id_column) + TRN.add(sql, args) + + # Remove the row from study_prep_template + sql = """DELETE FROM qiita.study_prep_template + WHERE {0} = %s""".format(cls._id_column) + TRN.add(sql, args) + + # Remove the row from prep_template + sql = "DELETE FROM qiita.prep_template WHERE {0} = %s".format( + cls._id_column) + TRN.add(sql, args) + + TRN.execute() def data_type(self, ret_id=False): """Returns the data_type or the data_type id @@ -282,12 +277,15 @@ def data_type(self, ret_id=False): str or int string value of data_type or data_type_id if ret_id is True """ - ret = "_id" if ret_id else "" - conn_handler = SQLConnectionHandler() - return conn_handler.execute_fetchone( - "SELECT d.data_type{0} FROM qiita.data_type d JOIN " - "qiita.prep_template p ON p.data_type_id = d.data_type_id WHERE " - "p.prep_template_id=%s".format(ret), (self.id,))[0] + with TRN: + ret = "_id" if ret_id else "" + sql = """SELECT d.data_type{0} + FROM qiita.data_type d + JOIN qiita.prep_template p + ON p.data_type_id = d.data_type_id + WHERE p.prep_template_id=%s""".format(ret) + TRN.add(sql, [self.id]) + return TRN.execute_fetchlast() @property def columns_restrictions(self): @@ -325,55 +323,61 @@ def can_be_updated(self, columns): the columns being updated are not part of PREP_TEMPLATE_COLUMNS_TARGET_GENE """ - if (not self.preprocessed_data or - self.data_type() not in TARGET_GENE_DATA_TYPES): - return True + with TRN: + if (not self.preprocessed_data or + self.data_type() not in TARGET_GENE_DATA_TYPES): + return True - tg_columns = set(chain.from_iterable( - [v.columns for v in - viewvalues(PREP_TEMPLATE_COLUMNS_TARGET_GENE)])) + tg_columns = set(chain.from_iterable( + [v.columns for v in + viewvalues(PREP_TEMPLATE_COLUMNS_TARGET_GENE)])) - if not columns & tg_columns: - return True + if not columns & tg_columns: + return True - return False + return False @property def raw_data(self): - conn_handler = SQLConnectionHandler() - result = conn_handler.execute_fetchone( - "SELECT raw_data_id FROM qiita.prep_template " - "WHERE prep_template_id=%s", (self.id,)) - if result: - return result[0] - return None + with TRN: + sql = """SELECT raw_data_id FROM qiita.prep_template + WHERE prep_template_id=%s""" + TRN.add(sql, [self.id]) + result = TRN.execute_fetchindex() + if result: + # If there is any result, it will be in the first row + # and in the first element of that row, thus [0][0] + return result[0][0] + return None @raw_data.setter def raw_data(self, raw_data): - conn_handler = SQLConnectionHandler() - sql = """SELECT ( - SELECT raw_data_id - FROM qiita.prep_template - WHERE prep_template_id=%s) - IS NOT NULL""" - exists = conn_handler.execute_fetchone(sql, (self.id,))[0] - if exists: - raise QiitaDBError( - "Prep template %d already has a raw data associated" - % self.id) - sql = """UPDATE qiita.prep_template - SET raw_data_id = %s - WHERE prep_template_id = %s""" - conn_handler.execute(sql, (raw_data.id, self.id)) + with TRN: + sql = """SELECT ( + SELECT raw_data_id + FROM qiita.prep_template + WHERE prep_template_id=%s) + IS NOT NULL""" + TRN.add(sql, [self.id]) + exists = TRN.execute_fetchlast() + if exists: + raise QiitaDBError( + "Prep template %d already has a raw data associated" + % self.id) + sql = """UPDATE qiita.prep_template + SET raw_data_id = %s + WHERE prep_template_id = %s""" + TRN.add(sql, [raw_data.id, self.id]) + TRN.execute() @property def preprocessed_data(self): - conn_handler = SQLConnectionHandler() - prep_datas = conn_handler.execute_fetchall( - "SELECT preprocessed_data_id FROM " - "qiita.prep_template_preprocessed_data WHERE prep_template_id=%s", - (self.id,)) - return [x[0] for x in prep_datas] + with TRN: + sql = """SELECT preprocessed_data_id + FROM qiita.prep_template_preprocessed_data + WHERE prep_template_id=%s""" + TRN.add(sql, [self.id]) + return TRN.execute_fetchflatten() @property def preprocessing_status(self): @@ -384,10 +388,11 @@ def preprocessing_status(self): str One of {'not_preprocessed', 'preprocessing', 'success', 'failed'} """ - conn_handler = SQLConnectionHandler() - return conn_handler.execute_fetchone( - "SELECT preprocessing_status FROM qiita.prep_template " - "WHERE {0}=%s".format(self._id_column), (self.id,))[0] + with TRN: + sql = """SELECT preprocessing_status FROM qiita.prep_template + WHERE {0}=%s""".format(self._id_column) + TRN.add(sql, [self.id]) + return TRN.execute_fetchlast() @preprocessing_status.setter def preprocessing_status(self, state): @@ -406,20 +411,19 @@ def preprocessing_status(self, state): if (state not in ('not_preprocessed', 'preprocessing', 'success') and not state.startswith('failed:')): raise ValueError('Unknown state: %s' % state) - - conn_handler = SQLConnectionHandler() - - conn_handler.execute( - "UPDATE qiita.prep_template SET preprocessing_status = %s " - "WHERE {0} = %s".format(self._id_column), - (state, self.id)) + with TRN: + sql = """UPDATE qiita.prep_template SET preprocessing_status = %s + WHERE {0} = %s""".format(self._id_column) + TRN.add(sql, [state, self.id]) + TRN.execute() @property def investigation_type(self): - conn_handler = SQLConnectionHandler() - sql = ("SELECT investigation_type FROM qiita.prep_template " - "WHERE {0} = %s".format(self._id_column)) - return conn_handler.execute_fetchone(sql, [self._id])[0] + with TRN: + sql = """SELECT investigation_type FROM qiita.prep_template + WHERE {0} = %s""".format(self._id_column) + TRN.add(sql, [self._id]) + return TRN.execute_fetchlast() @investigation_type.setter def investigation_type(self, investigation_type): @@ -435,15 +439,14 @@ def investigation_type(self, investigation_type): QiitaDBColumnError If the investigation type is not a valid ENA ontology """ - if investigation_type is not None: - self.validate_investigation_type(investigation_type) - - conn_handler = SQLConnectionHandler() + with TRN: + if investigation_type is not None: + self.validate_investigation_type(investigation_type) - conn_handler.execute( - "UPDATE qiita.prep_template SET investigation_type = %s " - "WHERE {0} = %s".format(self._id_column), - (investigation_type, self.id)) + sql = """UPDATE qiita.prep_template SET investigation_type = %s + WHERE {0} = %s""".format(self._id_column) + TRN.add(sql, [investigation_type, self.id]) + TRN.execute() @property def study_id(self): @@ -454,31 +457,28 @@ def study_id(self): int The ID of the study with which this prep template is associated """ - conn = SQLConnectionHandler() - sql = ("SELECT study_id FROM qiita.study_prep_template " - "WHERE prep_template_id=%s") - study_id = conn.execute_fetchone(sql, (self.id,)) - if study_id: - return study_id[0] - else: - raise QiitaDBError("No studies found associated with prep " - "template ID %d" % self._id) + with TRN: + sql = """SELECT study_id FROM qiita.study_prep_template + WHERE prep_template_id=%s""" + TRN.add(sql, [self.id]) + return TRN.execute_fetchlast() def generate_files(self): r"""Generates all the files that contain data from this template """ - # figuring out the filepath of the prep template - _id, fp = get_mountpoint('templates')[0] - fp = join(fp, '%d_prep_%d_%s.txt' % (self.study_id, self._id, - strftime("%Y%m%d-%H%M%S"))) - # storing the template - self.to_file(fp) + with TRN: + # figuring out the filepath of the prep template + _id, fp = get_mountpoint('templates')[0] + fp = join(fp, '%d_prep_%d_%s.txt' % (self.study_id, self._id, + strftime("%Y%m%d-%H%M%S"))) + # storing the template + self.to_file(fp) - # adding the fp to the object - self.add_filepath(fp) + # adding the fp to the object + self.add_filepath(fp) - # creating QIIME mapping file - self.create_qiime_mapping_file() + # creating QIIME mapping file + self.create_qiime_mapping_file() def create_qiime_mapping_file(self): """This creates the QIIME mapping file and links it in the db. @@ -503,91 +503,93 @@ def create_qiime_mapping_file(self): QIIME-required columns, we are going to create them and populate them with the value XXQIITAXX. """ - rename_cols = { - 'barcode': 'BarcodeSequence', - 'primer': 'LinkerPrimerSequence', - 'description': 'Description', - } - - if 'reverselinkerprimer' in self.categories(): - rename_cols['reverselinkerprimer'] = 'ReverseLinkerPrimer' - new_cols = ['BarcodeSequence', 'LinkerPrimerSequence', - 'ReverseLinkerPrimer'] - else: - new_cols = ['BarcodeSequence', 'LinkerPrimerSequence'] - - # getting the latest sample template - conn_handler = SQLConnectionHandler() - sql = """SELECT filepath_id, filepath - FROM qiita.filepath - JOIN qiita.sample_template_filepath - USING (filepath_id) - WHERE study_id=%s - ORDER BY filepath_id DESC""" - sample_template_fname = conn_handler.execute_fetchall( - sql, (self.study_id,))[0][1] - _, fp = get_mountpoint('templates')[0] - sample_template_fp = join(fp, sample_template_fname) - - # reading files via pandas - st = load_template_to_dataframe(sample_template_fp) - pt = self.to_dataframe() - - st_sample_names = set(st.index) - pt_sample_names = set(pt.index) - - if not pt_sample_names.issubset(st_sample_names): - raise ValueError( - "Prep template is not a sub set of the sample template, files" - "%s - samples: %s" - % (sample_template_fp, - ', '.join(pt_sample_names-st_sample_names))) - - mapping = pt.join(st, lsuffix="_prep") - mapping.rename(columns=rename_cols, inplace=True) - - # Pre-populate the QIIME-required columns with the value XXQIITAXX - index = mapping.index - placeholder = ['XXQIITAXX'] * len(index) - missing = [] - for val in viewvalues(rename_cols): - if val not in mapping: - missing.append(val) - mapping[val] = pd.Series(placeholder, index=index) - - if missing: - warnings.warn( - "Some columns required to generate a QIIME-compliant mapping " - "file are not present in the template. A placeholder value " - "(XXQIITAXX) has been used to populate these columns. Missing " - "columns: %s" % ', '.join(missing), - QiitaDBWarning) - - # Gets the orginal mapping columns and readjust the order to comply - # with QIIME requirements - cols = mapping.columns.values.tolist() - cols.remove('BarcodeSequence') - cols.remove('LinkerPrimerSequence') - cols.remove('Description') - new_cols.extend(cols) - new_cols.append('Description') - mapping = mapping[new_cols] - - # figuring out the filepath for the QIIME map file - _id, fp = get_mountpoint('templates')[0] - filepath = join(fp, '%d_prep_%d_qiime_%s.txt' % (self.study_id, - self.id, strftime("%Y%m%d-%H%M%S"))) - - # Save the mapping file - mapping.to_csv(filepath, index_label='#SampleID', na_rep='', - sep='\t') - - # adding the fp to the object - self.add_filepath( - filepath, - fp_id=convert_to_id("qiime_map", "filepath_type")) - - return filepath + with TRN: + rename_cols = { + 'barcode': 'BarcodeSequence', + 'primer': 'LinkerPrimerSequence', + 'description': 'Description', + } + + if 'reverselinkerprimer' in self.categories(): + rename_cols['reverselinkerprimer'] = 'ReverseLinkerPrimer' + new_cols = ['BarcodeSequence', 'LinkerPrimerSequence', + 'ReverseLinkerPrimer'] + else: + new_cols = ['BarcodeSequence', 'LinkerPrimerSequence'] + + # getting the latest sample template + sql = """SELECT filepath_id, filepath + FROM qiita.filepath + JOIN qiita.sample_template_filepath + USING (filepath_id) + WHERE study_id=%s + ORDER BY filepath_id DESC""" + TRN.add(sql, [self.study_id]) + # We know that the good filepath is the one in the first row + # because we sorted them in the SQL query + sample_template_fname = TRN.execute_fetchindex()[0][1] + _, fp = get_mountpoint('templates')[0] + sample_template_fp = join(fp, sample_template_fname) + + # reading files via pandas + st = load_template_to_dataframe(sample_template_fp) + pt = self.to_dataframe() + + st_sample_names = set(st.index) + pt_sample_names = set(pt.index) + + if not pt_sample_names.issubset(st_sample_names): + raise ValueError( + "Prep template is not a sub set of the sample template, " + "file: %s - samples: %s" + % (sample_template_fp, + ', '.join(pt_sample_names-st_sample_names))) + + mapping = pt.join(st, lsuffix="_prep") + mapping.rename(columns=rename_cols, inplace=True) + + # Pre-populate the QIIME-required columns with the value XXQIITAXX + index = mapping.index + placeholder = ['XXQIITAXX'] * len(index) + missing = [] + for val in viewvalues(rename_cols): + if val not in mapping: + missing.append(val) + mapping[val] = pd.Series(placeholder, index=index) + + if missing: + warnings.warn( + "Some columns required to generate a QIIME-compliant " + "mapping file are not present in the template. A " + "placeholder value (XXQIITAXX) has been used to populate " + "these columns. Missing columns: %s" % ', '.join(missing), + QiitaDBWarning) + + # Gets the orginal mapping columns and readjust the order to comply + # with QIIME requirements + cols = mapping.columns.values.tolist() + cols.remove('BarcodeSequence') + cols.remove('LinkerPrimerSequence') + cols.remove('Description') + new_cols.extend(cols) + new_cols.append('Description') + mapping = mapping[new_cols] + + # figuring out the filepath for the QIIME map file + _id, fp = get_mountpoint('templates')[0] + filepath = join(fp, '%d_prep_%d_qiime_%s.txt' % (self.study_id, + self.id, strftime("%Y%m%d-%H%M%S"))) + + # Save the mapping file + mapping.to_csv(filepath, index_label='#SampleID', na_rep='', + sep='\t', encoding='utf-8') + + # adding the fp to the object + self.add_filepath( + filepath, + fp_id=convert_to_id("qiime_map", "filepath_type")) + + return filepath @property def status(self): @@ -605,19 +607,19 @@ def status(self): data has been generated with this prep template; then the status is 'sandbox'. """ - conn_handler = SQLConnectionHandler() - sql = """SELECT processed_data_status - FROM qiita.processed_data_status pds - JOIN qiita.processed_data pd - USING (processed_data_status_id) - JOIN qiita.preprocessed_processed_data ppd_pd - USING (processed_data_id) - JOIN qiita.prep_template_preprocessed_data pt_ppd - USING (preprocessed_data_id) - WHERE pt_ppd.prep_template_id=%s""" - pd_statuses = conn_handler.execute_fetchall(sql, (self._id,)) - - return infer_status(pd_statuses) + with TRN: + sql = """SELECT processed_data_status + FROM qiita.processed_data_status pds + JOIN qiita.processed_data pd + USING (processed_data_status_id) + JOIN qiita.preprocessed_processed_data ppd_pd + USING (processed_data_id) + JOIN qiita.prep_template_preprocessed_data pt_ppd + USING (preprocessed_data_id) + WHERE pt_ppd.prep_template_id=%s""" + TRN.add(sql, [self._id]) + + return infer_status(TRN.execute_fetchindex()) @property def qiime_map_fp(self): @@ -628,15 +630,17 @@ def qiime_map_fp(self): str The filepath of the QIIME mapping file """ - conn_handler = SQLConnectionHandler() - - sql = """SELECT filepath_id, filepath - FROM qiita.filepath - JOIN qiita.{0} USING (filepath_id) - JOIN qiita.filepath_type USING (filepath_type_id) - WHERE {1} = %s AND filepath_type = 'qiime_map' - ORDER BY filepath_id DESC""".format(self._filepath_table, - self._id_column) - fn = conn_handler.execute_fetchall(sql, (self._id,))[0][1] - base_dir = get_mountpoint('templates')[0][1] - return join(base_dir, fn) + with TRN: + sql = """SELECT filepath_id, filepath + FROM qiita.filepath + JOIN qiita.{0} USING (filepath_id) + JOIN qiita.filepath_type USING (filepath_type_id) + WHERE {1} = %s AND filepath_type = 'qiime_map' + ORDER BY filepath_id DESC""".format(self._filepath_table, + self._id_column) + TRN.add(sql, [self._id]) + # We know that the good filepath is the one in the first row + # because we sorted them in the SQL query + fn = TRN.execute_fetchindex()[0][1] + base_dir = get_mountpoint('templates')[0][1] + return join(base_dir, fn) diff --git a/qiita_db/metadata_template/sample_template.py b/qiita_db/metadata_template/sample_template.py index 0ea29e9da..9202b0714 100644 --- a/qiita_db/metadata_template/sample_template.py +++ b/qiita_db/metadata_template/sample_template.py @@ -13,7 +13,7 @@ from qiita_core.exceptions import IncompetentQiitaDeveloperError from qiita_db.exceptions import (QiitaDBDuplicateError, QiitaDBError, QiitaDBUnknownIDError) -from qiita_db.sql_connection import SQLConnectionHandler +from qiita_db.sql_connection import TRN from qiita_db.util import get_mountpoint, convert_to_id from qiita_db.study import Study from qiita_db.data import RawData @@ -78,11 +78,11 @@ def metadata_headers(): list Alphabetical list of all metadata headers available """ - conn_handler = SQLConnectionHandler() - return [x[0] for x in - conn_handler.execute_fetchall( - "SELECT DISTINCT column_name FROM qiita.study_sample_columns " - "ORDER BY column_name")] + with TRN: + sql = """SELECT DISTINCT column_name + FROM qiita.study_sample_columns ORDER BY column_name""" + TRN.add(sql) + return TRN.execute_fetchflatten() @classmethod def create(cls, md_template, study): @@ -95,29 +95,23 @@ def create(cls, md_template, study): study : Study The study to which the sample template belongs to. """ - cls._check_subclass() + with TRN: + cls._check_subclass() - # Check that we don't have a MetadataTemplate for study - if cls.exists(study.id): - raise QiitaDBDuplicateError(cls.__name__, 'id: %d' % study.id) + # Check that we don't have a MetadataTemplate for study + if cls.exists(study.id): + raise QiitaDBDuplicateError(cls.__name__, 'id: %d' % study.id) - conn_handler = SQLConnectionHandler() - queue_name = "CREATE_SAMPLE_TEMPLATE_%d" % study.id - conn_handler.create_queue(queue_name) + # Clean and validate the metadata template given + md_template = cls._clean_validate_template(md_template, study.id, + SAMPLE_TEMPLATE_COLUMNS) - # Clean and validate the metadata template given - md_template = cls._clean_validate_template(md_template, study.id, - SAMPLE_TEMPLATE_COLUMNS) + cls._common_creation_steps(md_template, study.id) - cls._add_common_creation_steps_to_queue(md_template, study.id, - conn_handler, queue_name) + st = cls(study.id) + st.generate_files() - conn_handler.execute_queue(queue_name) - - st = cls(study.id) - st.generate_files() - - return st + return st @classmethod def delete(cls, id_): @@ -135,46 +129,40 @@ def delete(cls, id_): QiitaDBError If the study that owns this sample template has raw datas """ - cls._check_subclass() + with TRN: + cls._check_subclass() - if not cls.exists(id_): - raise QiitaDBUnknownIDError(id_, cls.__name__) + if not cls.exists(id_): + raise QiitaDBUnknownIDError(id_, cls.__name__) - raw_datas = [str(rd) for rd in Study(cls(id_).study_id).raw_data()] - if raw_datas: - raise QiitaDBError("Sample template can not be erased because " - "there are raw datas (%s) associated." % - ', '.join(raw_datas)) + # Check if there is any PrepTemplate + sql = """SELECT EXISTS(SELECT * FROM qiita.study_prep_template + WHERE study_id=%s)""" + TRN.add(sql, [id_]) + has_prep_templates = TRN.execute_fetchlast() + if has_prep_templates: + raise QiitaDBError("Sample template can not be erased because " + "there are prep templates associated.") - table_name = cls._table_name(id_) - conn_handler = SQLConnectionHandler() + table_name = cls._table_name(id_) - # Delete the sample template filepaths - queue = "delete_sample_template_%d" % id_ - conn_handler.create_queue(queue) + # Delete the sample template filepaths + sql = """DELETE FROM qiita.sample_template_filepath + WHERE study_id = %s""" + args = [id_] + TRN.add(sql, args) - conn_handler.add_to_queue( - queue, - "DELETE FROM qiita.sample_template_filepath WHERE study_id = %s", - (id_, )) + TRN.add("DROP TABLE qiita.{0}".format(table_name)) - conn_handler.add_to_queue( - queue, - "DROP TABLE qiita.{0}".format(table_name)) + sql = "DELETE FROM qiita.{0} WHERE {1} = %s".format( + cls._table, cls._id_column) + TRN.add(sql, args) - conn_handler.add_to_queue( - queue, - "DELETE FROM qiita.{0} where {1} = %s".format(cls._table, - cls._id_column), - (id_,)) + sql = "DELETE FROM qiita.{0} WHERE {1} = %s".format( + cls._column_table, cls._id_column) + TRN.add(sql, args) - conn_handler.add_to_queue( - queue, - "DELETE FROM qiita.{0} where {1} = %s".format(cls._column_table, - cls._id_column), - (id_,)) - - conn_handler.execute_queue(queue) + TRN.execute() @property def study_id(self): @@ -225,19 +213,20 @@ def can_be_updated(self, **kwargs): def generate_files(self): r"""Generates all the files that contain data from this template """ - # figuring out the filepath of the sample template - _id, fp = get_mountpoint('templates')[0] - fp = join(fp, '%d_%s.txt' % (self.id, strftime("%Y%m%d-%H%M%S"))) - # storing the sample template - self.to_file(fp) + with TRN: + # figuring out the filepath of the sample template + _id, fp = get_mountpoint('templates')[0] + fp = join(fp, '%d_%s.txt' % (self.id, strftime("%Y%m%d-%H%M%S"))) + # storing the sample template + self.to_file(fp) - # adding the fp to the object - self.add_filepath(fp) + # adding the fp to the object + self.add_filepath(fp) - # generating all new QIIME mapping files - for rd_id in Study(self.id).raw_data(): - for pt_id in RawData(rd_id).prep_templates: - PrepTemplate(pt_id).generate_files() + # generating all new QIIME mapping files + for rd_id in Study(self.id).raw_data(): + for pt_id in RawData(rd_id).prep_templates: + PrepTemplate(pt_id).generate_files() def extend(self, md_template): """Adds the given sample template to the current one @@ -247,15 +236,10 @@ def extend(self, md_template): md_template : DataFrame The metadata template file contents indexed by samples Ids """ - conn_handler = SQLConnectionHandler() - queue_name = "EXTEND_SAMPLE_TEMPLATE_%d" % self.id - conn_handler.create_queue(queue_name) - - md_template = self._clean_validate_template(md_template, self.study_id, - SAMPLE_TEMPLATE_COLUMNS) + with TRN: + md_template = self._clean_validate_template( + md_template, self.study_id, SAMPLE_TEMPLATE_COLUMNS) - self._add_common_extend_steps_to_queue(md_template, conn_handler, - queue_name) - conn_handler.execute_queue(queue_name) + self._common_extend_steps(md_template) - self.generate_files() + self.generate_files() diff --git a/qiita_db/metadata_template/test/test_base_metadata_template.py b/qiita_db/metadata_template/test/test_base_metadata_template.py index f5b320f79..fa6dd8957 100644 --- a/qiita_db/metadata_template/test/test_base_metadata_template.py +++ b/qiita_db/metadata_template/test/test_base_metadata_template.py @@ -50,12 +50,11 @@ def test_table_name(self): with self.assertRaises(IncompetentQiitaDeveloperError): MetadataTemplate._table_name(self.study) - def test_add_common_creation_steps_to_queue(self): - """_add_common_creation_steps_to_queue raises an error from base class + def test_common_creation_steps(self): + """common_creation_steps raises an error from base class """ with self.assertRaises(IncompetentQiitaDeveloperError): - MetadataTemplate._add_common_creation_steps_to_queue( - None, 1, None, "") + MetadataTemplate._common_creation_steps(None, 1) def test_clean_validate_template(self): """_clean_validate_template raises an error from base class""" diff --git a/qiita_db/metadata_template/test/test_prep_template.py b/qiita_db/metadata_template/test/test_prep_template.py index 77c2363ab..9d082ceca 100644 --- a/qiita_db/metadata_template/test/test_prep_template.py +++ b/qiita_db/metadata_template/test/test_prep_template.py @@ -27,7 +27,6 @@ QiitaDBColumnError, QiitaDBWarning, QiitaDBError) -from qiita_db.sql_connection import SQLConnectionHandler from qiita_db.study import Study from qiita_db.data import RawData, ProcessedData from qiita_db.util import exists_table, get_mountpoint, get_count @@ -55,45 +54,6 @@ def setUp(self): class TestPrepSampleReadOnly(BaseTestPrepSample): - def test_add_setitem_queries_error(self): - conn_handler = SQLConnectionHandler() - queue = "test_queue" - conn_handler.create_queue(queue) - - with self.assertRaises(QiitaDBColumnError): - self.tester.add_setitem_queries( - 'COL_DOES_NOT_EXIST', 'Foo', conn_handler, queue) - - def test_add_setitem_queries_required(self): - conn_handler = SQLConnectionHandler() - queue = "test_queue" - conn_handler.create_queue(queue) - - self.tester.add_setitem_queries( - 'center_name', 'FOO', conn_handler, queue) - - obs = conn_handler.queues[queue] - sql = """UPDATE qiita.prep_1 - SET center_name=%s - WHERE sample_id=%s""" - exp = [(sql, ('FOO', '1.SKB8.640193'))] - self.assertEqual(obs, exp) - - def test_add_setitem_queries_dynamic(self): - conn_handler = SQLConnectionHandler() - queue = "test_queue" - conn_handler.create_queue(queue) - - self.tester.add_setitem_queries( - 'barcode', 'AAAAAAAAAAAA', conn_handler, queue) - - obs = conn_handler.queues[queue] - sql = """UPDATE qiita.prep_1 - SET barcode=%s - WHERE sample_id=%s""" - exp = [(sql, ('AAAAAAAAAAAA', '1.SKB8.640193'))] - self.assertEqual(obs, exp) - def test_init_unknown_error(self): """Init errors if the PrepSample id is not found in the template""" with self.assertRaises(QiitaDBUnknownIDError): @@ -139,8 +99,7 @@ def test_exists_false(self): def test_get_categories(self): """Correctly returns the set of category headers""" - conn_handler = SQLConnectionHandler() - obs = self.tester._get_categories(conn_handler) + obs = self.tester._get_categories() self.assertEqual(obs, self.exp_categories) def test_len(self): @@ -419,8 +378,7 @@ def test_exists_false(self): def test_get_sample_ids(self): """get_sample_ids returns the correct set of sample ids""" - conn_handler = SQLConnectionHandler() - obs = self.tester._get_sample_ids(conn_handler) + obs = self.tester._get_sample_ids() self.assertEqual(obs, self.exp_sample_ids) def test_len(self): @@ -585,99 +543,6 @@ def test_to_dataframe(self): u'samp_size', u'sequencing_meth', u'illumina_technology', u'sample_center', u'pcr_primers', u'study_center'}) - def test_add_common_creation_steps_to_queue(self): - """add_common_creation_steps_to_queue adds the correct sql statements - """ - metadata_dict = { - '2.SKB8.640193': {'center_name': 'ANL', - 'center_project_name': 'Test Project', - 'emp_status': 'EMP', - 'str_column': 'Value for sample 1', - 'linkerprimersequence': 'GTGCCAGCMGCCGCGGTAA', - 'barcodesequence': 'GTCCGCAAGTTA', - 'run_prefix': "s_G1_L001_sequences", - 'platform': 'ILLUMINA', - 'library_construction_protocol': 'AAAA', - 'experiment_design_description': 'BBBB'}, - '2.SKD8.640184': {'center_name': 'ANL', - 'center_project_name': 'Test Project', - 'emp_status': 'EMP', - 'str_column': 'Value for sample 2', - 'linkerprimersequence': 'GTGCCAGCMGCCGCGGTAA', - 'barcodesequence': 'CGTAGAGCTCTC', - 'run_prefix': "s_G1_L001_sequences", - 'platform': 'ILLUMINA', - 'library_construction_protocol': 'AAAA', - 'experiment_design_description': 'BBBB'}, - } - metadata = pd.DataFrame.from_dict(metadata_dict, orient='index') - - conn_handler = SQLConnectionHandler() - queue_name = "TEST_QUEUE" - conn_handler.create_queue(queue_name) - PrepTemplate._add_common_creation_steps_to_queue( - metadata, 2, conn_handler, queue_name) - - sql_insert_common = ( - 'INSERT INTO qiita.prep_template_sample ' - '(prep_template_id, sample_id) VALUES (%s, %s)') - sql_insert_common_params_1 = (2, '2.SKB8.640193') - sql_insert_common_params_2 = (2, '2.SKD8.640184') - - sql_insert_prep_columns = ( - 'INSERT INTO qiita.prep_columns ' - '(prep_template_id, column_name, column_type) ' - 'VALUES (%s, %s, %s)') - - sql_create_table = ( - 'CREATE TABLE qiita.prep_2 ' - '(sample_id varchar NOT NULL, barcodesequence varchar, ' - 'center_name varchar, center_project_name varchar, ' - 'emp_status varchar, experiment_design_description varchar, ' - 'library_construction_protocol varchar, ' - 'linkerprimersequence varchar, platform varchar, ' - 'run_prefix varchar, str_column varchar, ' - 'CONSTRAINT fk_prep_2 FOREIGN KEY (sample_id) REFERENCES ' - 'qiita.study_sample (sample_id) ON UPDATE CASCADE)') - - sql_insert_dynamic = ( - 'INSERT INTO qiita.prep_2 ' - '(sample_id, barcodesequence, center_name, center_project_name, ' - 'emp_status, experiment_design_description, ' - 'library_construction_protocol, linkerprimersequence, platform, ' - 'run_prefix, str_column) ' - 'VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)') - - sql_insert_dynamic_params_1 = ( - '2.SKB8.640193', 'GTCCGCAAGTTA', 'ANL', 'Test Project', 'EMP', - 'BBBB', 'AAAA', 'GTGCCAGCMGCCGCGGTAA', 'ILLUMINA', - 's_G1_L001_sequences', 'Value for sample 1') - sql_insert_dynamic_params_2 = ( - '2.SKD8.640184', 'CGTAGAGCTCTC', 'ANL', 'Test Project', 'EMP', - 'BBBB', 'AAAA', 'GTGCCAGCMGCCGCGGTAA', 'ILLUMINA', - 's_G1_L001_sequences', 'Value for sample 2') - - exp = [ - (sql_insert_common, sql_insert_common_params_1), - (sql_insert_common, sql_insert_common_params_2), - (sql_insert_prep_columns, (2, 'barcodesequence', 'varchar')), - (sql_insert_prep_columns, (2, 'center_name', 'varchar')), - (sql_insert_prep_columns, (2, 'center_project_name', 'varchar')), - (sql_insert_prep_columns, (2, 'emp_status', 'varchar')), - (sql_insert_prep_columns, - (2, 'experiment_design_description', 'varchar')), - (sql_insert_prep_columns, - (2, 'library_construction_protocol', 'varchar')), - (sql_insert_prep_columns, (2, 'linkerprimersequence', 'varchar')), - (sql_insert_prep_columns, (2, 'platform', 'varchar')), - (sql_insert_prep_columns, (2, 'run_prefix', 'varchar')), - (sql_insert_prep_columns, (2, 'str_column', 'varchar')), - (sql_create_table, None), - (sql_insert_dynamic, sql_insert_dynamic_params_1), - (sql_insert_dynamic, sql_insert_dynamic_params_2)] - - self.assertEqual(conn_handler.queues[queue_name], exp) - def test_clean_validate_template_error_bad_chars(self): """Raises an error if there are invalid characters in the sample names """ @@ -890,7 +755,7 @@ def test_create_error_cleanup(self): WHERE prep_template_id=%s)""" self.assertFalse(self.conn_handler.execute_fetchone(sql, (exp_id,))[0]) - self.assertFalse(exists_table("prep_%d" % exp_id, self.conn_handler)) + self.assertFalse(exists_table("prep_%d" % exp_id)) def _common_creation_checks(self, new_id, pt, fp_count): # The returned object has the correct id @@ -939,7 +804,7 @@ def _common_creation_checks(self, new_id, pt, fp_count): self.assertItemsEqual(obs, exp) # The new table exists - self.assertTrue(exists_table("prep_%s" % new_id, self.conn_handler)) + self.assertTrue(exists_table("prep_%s" % new_id)) # The new table hosts the correct values obs = [dict(o) for o in self.conn_handler.execute_fetchall( @@ -1093,7 +958,7 @@ def test_create_warning(self): self.assertItemsEqual(obs, exp) # The new table exists - self.assertTrue(exists_table("prep_%s" % new_id, self.conn_handler)) + self.assertTrue(exists_table("prep_%s" % new_id)) # The new table hosts the correct values obs = [dict(o) for o in self.conn_handler.execute_fetchall( diff --git a/qiita_db/metadata_template/test/test_sample_template.py b/qiita_db/metadata_template/test/test_sample_template.py index f7adea4cb..00bdcb1f4 100644 --- a/qiita_db/metadata_template/test/test_sample_template.py +++ b/qiita_db/metadata_template/test/test_sample_template.py @@ -24,7 +24,6 @@ QiitaDBDuplicateHeaderError, QiitaDBColumnError, QiitaDBError, QiitaDBWarning) -from qiita_db.sql_connection import SQLConnectionHandler from qiita_db.study import Study, StudyPerson from qiita_db.user import User from qiita_db.util import exists_table, get_count @@ -53,45 +52,6 @@ def setUp(self): class TestSampleReadOnly(BaseTestSample): - def test_add_setitem_queries_error(self): - conn_handler = SQLConnectionHandler() - queue = "test_queue" - conn_handler.create_queue(queue) - - with self.assertRaises(QiitaDBColumnError): - self.tester.add_setitem_queries( - 'COL_DOES_NOT_EXIST', 0.30, conn_handler, queue) - - def test_add_setitem_queries_required(self): - conn_handler = SQLConnectionHandler() - queue = "test_queue" - conn_handler.create_queue(queue) - - self.tester.add_setitem_queries( - 'physical_specimen_remaining', True, conn_handler, queue) - - obs = conn_handler.queues[queue] - sql = """UPDATE qiita.sample_1 - SET physical_specimen_remaining=%s - WHERE sample_id=%s""" - exp = [(sql, (True, '1.SKB8.640193'))] - self.assertEqual(obs, exp) - - def test_add_setitem_queries_dynamic(self): - conn_handler = SQLConnectionHandler() - queue = "test_queue" - conn_handler.create_queue(queue) - - self.tester.add_setitem_queries( - 'tot_nitro', '1234.5', conn_handler, queue) - - obs = conn_handler.queues[queue] - sql = """UPDATE qiita.sample_1 - SET tot_nitro=%s - WHERE sample_id=%s""" - exp = [(sql, ('1234.5', '1.SKB8.640193'))] - self.assertEqual(obs, exp) - def test_init_unknown_error(self): """Init raises an error if the sample id is not found in the template """ @@ -138,8 +98,7 @@ def test_exists_false(self): def test_get_categories(self): """Correctly returns the set of category headers""" - conn_handler = SQLConnectionHandler() - obs = self.tester._get_categories(conn_handler) + obs = self.tester._get_categories() self.assertEqual(obs, self.exp_categories) def test_len(self): @@ -556,8 +515,7 @@ def test_exists_true(self): def test_get_sample_ids(self): """get_sample_ids returns the correct set of sample ids""" - conn_handler = SQLConnectionHandler() - obs = self.tester._get_sample_ids(conn_handler) + obs = self.tester._get_sample_ids() self.assertEqual(obs, self.exp_sample_ids) def test_len(self): @@ -692,118 +650,6 @@ def test_get_none(self): """get returns none if the sample id is not present""" self.assertTrue(self.tester.get('Not_a_Sample') is None) - def test_add_common_creation_steps_to_queue(self): - """add_common_creation_steps_to_queue adds the correct sql statements - """ - metadata_dict = { - '2.Sample1': {'physical_specimen_location': 'location1', - 'physical_specimen_remaining': True, - 'dna_extracted': True, - 'sample_type': 'type1', - 'collection_timestamp': - datetime(2014, 5, 29, 12, 24, 51), - 'host_subject_id': 'NotIdentified', - 'description': 'Test Sample 1', - 'str_column': 'Value for sample 1', - 'int_column': 1, - 'latitude': 42.42, - 'longitude': 41.41}, - '2.Sample2': {'physical_specimen_location': 'location1', - 'physical_specimen_remaining': True, - 'dna_extracted': True, - 'sample_type': 'type1', - 'int_column': 2, - 'collection_timestamp': - datetime(2014, 5, 29, 12, 24, 51), - 'host_subject_id': 'NotIdentified', - 'description': 'Test Sample 2', - 'str_column': 'Value for sample 2', - 'latitude': 4.2, - 'longitude': 1.1}, - '2.Sample3': {'physical_specimen_location': 'location1', - 'physical_specimen_remaining': True, - 'dna_extracted': True, - 'sample_type': 'type1', - 'collection_timestamp': - datetime(2014, 5, 29, 12, 24, 51), - 'host_subject_id': 'NotIdentified', - 'description': 'Test Sample 3', - 'str_column': 'Value for sample 3', - 'int_column': 3, - 'latitude': 4.8, - 'longitude': 4.41}, - } - metadata = pd.DataFrame.from_dict(metadata_dict, orient='index') - - conn_handler = SQLConnectionHandler() - queue_name = "TEST_QUEUE" - conn_handler.create_queue(queue_name) - SampleTemplate._add_common_creation_steps_to_queue( - metadata, 2, conn_handler, queue_name) - - sql_insert_required = ( - 'INSERT INTO qiita.study_sample (study_id, sample_id) ' - 'VALUES (%s, %s)') - - sql_insert_sample_cols = ( - 'INSERT INTO qiita.study_sample_columns ' - '(study_id, column_name, column_type) ' - 'VALUES (%s, %s, %s)') - - sql_crate_table = ( - 'CREATE TABLE qiita.sample_2 (sample_id varchar NOT NULL, ' - 'collection_timestamp timestamp, description varchar, ' - 'dna_extracted bool, host_subject_id varchar, int_column integer, ' - 'latitude float8, longitude float8, ' - 'physical_specimen_location varchar, ' - 'physical_specimen_remaining bool, sample_type varchar, ' - 'str_column varchar, ' - 'CONSTRAINT fk_sample_2 FOREIGN KEY (sample_id) REFERENCES ' - 'qiita.study_sample (sample_id) ON UPDATE CASCADE)') - - sql_insert_dynamic = ( - 'INSERT INTO qiita.sample_2 ' - '(sample_id, collection_timestamp, description, dna_extracted, ' - 'host_subject_id, int_column, latitude, longitude, ' - 'physical_specimen_location, physical_specimen_remaining, ' - 'sample_type, str_column) ' - 'VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)') - sql_insert_dynamic_params_1 = ( - '2.Sample1', datetime(2014, 5, 29, 12, 24, 51), 'Test Sample 1', - True, 'NotIdentified', 1, 42.42, 41.41, 'location1', True, 'type1', - 'Value for sample 1') - sql_insert_dynamic_params_2 = ( - '2.Sample2', datetime(2014, 5, 29, 12, 24, 51), 'Test Sample 2', - True, 'NotIdentified', 2, 4.2, 1.1, 'location1', True, 'type1', - 'Value for sample 2') - sql_insert_dynamic_params_3 = ( - '2.Sample3', datetime(2014, 5, 29, 12, 24, 51), 'Test Sample 3', - True, 'NotIdentified', 3, 4.8, 4.41, 'location1', True, 'type1', - 'Value for sample 3') - - exp = [ - (sql_insert_required, (2, '2.Sample1')), - (sql_insert_required, (2, '2.Sample2')), - (sql_insert_required, (2, '2.Sample3')), - (sql_insert_sample_cols, (2, 'collection_timestamp', 'timestamp')), - (sql_insert_sample_cols, (2, 'description', 'varchar')), - (sql_insert_sample_cols, (2, 'dna_extracted', 'bool')), - (sql_insert_sample_cols, (2, 'host_subject_id', 'varchar')), - (sql_insert_sample_cols, (2, 'int_column', 'integer')), - (sql_insert_sample_cols, (2, 'latitude', 'float8')), - (sql_insert_sample_cols, (2, 'longitude', 'float8')), - (sql_insert_sample_cols, - (2, 'physical_specimen_location', 'varchar')), - (sql_insert_sample_cols, - (2, 'physical_specimen_remaining', 'bool')), - (sql_insert_sample_cols, (2, 'sample_type', 'varchar')), - (sql_insert_sample_cols, (2, 'str_column', 'varchar')), - (sql_crate_table, None), - (sql_insert_dynamic, sql_insert_dynamic_params_1), - (sql_insert_dynamic, sql_insert_dynamic_params_2), - (sql_insert_dynamic, sql_insert_dynamic_params_3)] - self.assertEqual(conn_handler.queues[queue_name], exp) - def test_clean_validate_template_error_bad_chars(self): """Raises an error if there are invalid characters in the sample names """ @@ -987,7 +833,7 @@ def test_create_error_cleanup(self): self.conn_handler.execute_fetchone(sql, (self.new_study.id,))[0]) self.assertFalse( - exists_table("sample_%d" % self.new_study.id, self.conn_handler)) + exists_table("sample_%d" % self.new_study.id)) def test_create(self): """Creates a new SampleTemplate""" @@ -1028,7 +874,7 @@ def test_create(self): self.assertEqual(obs, exp) # The new table exists - self.assertTrue(exists_table("sample_%s" % new_id, self.conn_handler)) + self.assertTrue(exists_table("sample_%s" % new_id)) # The new table hosts the correct values sql = "SELECT * FROM qiita.sample_{0}".format(new_id) @@ -1117,7 +963,7 @@ def test_create_int_prefix(self): self.assertEqual(obs, exp) # The new table exists - self.assertTrue(exists_table("sample_%s" % new_id, self.conn_handler)) + self.assertTrue(exists_table("sample_%s" % new_id)) # The new table hosts the correct values sql = "SELECT * FROM qiita.sample_{0}".format(new_id) @@ -1206,7 +1052,7 @@ def test_create_str_prefixes(self): self.assertEqual(obs, exp) # The new table exists - self.assertTrue(exists_table("sample_%s" % new_id, self.conn_handler)) + self.assertTrue(exists_table("sample_%s" % new_id)) # The new table hosts the correct values sql = "SELECT * FROM qiita.sample_{0}".format(new_id) @@ -1296,7 +1142,7 @@ def test_create_already_prefixed_samples(self): self.assertEqual(obs, exp) # The new table exists - self.assertTrue(exists_table("sample_%s" % new_id, self.conn_handler)) + self.assertTrue(exists_table("sample_%s" % new_id)) # The new table hosts the correct values sql = "SELECT * FROM qiita.sample_{0}".format(new_id) diff --git a/qiita_db/ontology.py b/qiita_db/ontology.py index 5b269d3ec..6091a8d80 100644 --- a/qiita_db/ontology.py +++ b/qiita_db/ontology.py @@ -28,7 +28,7 @@ from .base import QiitaObject from .util import convert_from_id -from .sql_connection import SQLConnectionHandler +from .sql_connection import TRN class Ontology(QiitaObject): @@ -42,30 +42,33 @@ class Ontology(QiitaObject): _table = 'ontology' def __contains__(self, value): - conn_handler = SQLConnectionHandler() - sql = """SELECT EXISTS (SELECT * FROM qiita.term t JOIN qiita.{0} o - on t.ontology_id = o.ontology_id WHERE o.ontology_id = %s and - term = %s)""".format(self._table) - - return conn_handler.execute_fetchone(sql, (self._id, value))[0] + with TRN: + sql = """SELECT EXISTS ( + SELECT * + FROM qiita.term t + JOIN qiita.{0} o ON t.ontology_id = o.ontology_id + WHERE o.ontology_id = %s + AND term = %s)""".format(self._table) + TRN.add(sql, [self._id, value]) + return TRN.execute_fetchlast() @property def terms(self): - conn_handler = SQLConnectionHandler() - sql = """SELECT term FROM qiita.term WHERE ontology_id = %s AND - user_defined = false""" - - return [row[0] for row in - conn_handler.execute_fetchall(sql, [self.id])] + with TRN: + sql = """SELECT term + FROM qiita.term + WHERE ontology_id = %s AND user_defined = false""" + TRN.add(sql, [self.id]) + return TRN.execute_fetchflatten() @property def user_defined_terms(self): - conn_handler = SQLConnectionHandler() - sql = """SELECT term FROM qiita.term WHERE ontology_id = %s AND - user_defined = true""" - - return [row[0] for row in - conn_handler.execute_fetchall(sql, [self.id])] + with TRN: + sql = """SELECT term + FROM qiita.term + WHERE ontology_id = %s AND user_defined = true""" + TRN.add(sql, [self.id]) + return TRN.execute_fetchflatten() @property def shortname(self): @@ -79,18 +82,16 @@ def add_user_defined_term(self, term): term : str New user defined term to add into a given ontology """ + with TRN: + # we don't need to add an existing term + terms = self.user_defined_terms + self.terms - # we don't need to add an existing term - terms = self.user_defined_terms + self.terms - - if term not in terms: - conn_handler = SQLConnectionHandler() - sql = """INSERT INTO qiita.term - (ontology_id, term, user_defined) - VALUES - (%s, %s, true);""" - - conn_handler.execute(sql, [self.id, term]) + if term not in terms: + sql = """INSERT INTO qiita.term + (ontology_id, term, user_defined) + VALUES (%s, %s, true);""" + TRN.add(sql, [self.id, term]) + TRN.execute() def term_type(self, term): """Get the type of a given ontology term @@ -108,16 +109,16 @@ def term_type(self, term): user-defined and 'not_ontology' if the term is not part of the ontology. """ - conn_handler = SQLConnectionHandler() - sql = """SELECT user_defined FROM - qiita.term - WHERE term = %s AND ontology_id = %s - """ - result = conn_handler.execute_fetchone(sql, [term, self.id]) - - if result is None: - return 'not_ontology' - elif result[0]: - return 'user_defined' - elif not result[0]: - return 'ontology' + with TRN: + sql = """SELECT user_defined FROM + qiita.term + WHERE term = %s AND ontology_id = %s""" + TRN.add(sql, [term, self.id]) + result = TRN.execute_fetchindex() + + if not result: + return 'not_ontology' + elif result[0][0]: + return 'user_defined' + elif not result[0][0]: + return 'ontology' diff --git a/qiita_db/parameters.py b/qiita_db/parameters.py index eba0b8341..f30fe95ca 100644 --- a/qiita_db/parameters.py +++ b/qiita_db/parameters.py @@ -8,7 +8,7 @@ from __future__ import division from .base import QiitaObject -from .sql_connection import SQLConnectionHandler +from .sql_connection import TRN from .util import get_table_cols_w_type, get_table_cols from .exceptions import QiitaDBDuplicateError @@ -33,38 +33,40 @@ def _check_columns(cls, **kwargs): @classmethod def exists(cls, **kwargs): r"""Check if the parameter set already exists on the DB""" - cls._check_columns(**kwargs) + with TRN: + cls._check_columns(**kwargs) - conn_handler = SQLConnectionHandler() + cols = ["{} = %s".format(col) for col in kwargs] + sql = "SELECT EXISTS(SELECT * FROM qiita.{0} WHERE {1})".format( + cls._table, ' AND '.join(cols)) - cols = ["{} = %s".format(col) for col in kwargs] - - return conn_handler.execute_fetchone( - "SELECT EXISTS(SELECT * FROM qiita.{0} WHERE {1})".format( - cls._table, ' AND '.join(cols)), - kwargs.values())[0] + TRN.add(sql, kwargs.values()) + return TRN.execute_fetchlast() @classmethod def create(cls, param_set_name, **kwargs): r"""Adds a new parameter set to the DB""" - cls._check_columns(**kwargs) + with TRN: + cls._check_columns(**kwargs) - conn_handler = SQLConnectionHandler() + vals = kwargs.values() + vals.insert(0, param_set_name) - vals = kwargs.values() - vals.insert(0, param_set_name) + if cls.exists(**kwargs): + raise QiitaDBDuplicateError( + cls.__name__, "Values: %s" % kwargs) - if cls.exists(**kwargs): - raise QiitaDBDuplicateError(cls.__name__, "Values: %s" % kwargs) + sql = """INSERT INTO qiita.{0} (param_set_name, {1}) + VALUES (%s, {2}) + RETURNING {3}""".format( + cls._table, + ', '.join(kwargs), + ', '.join(['%s'] * len(kwargs)), + cls._column_id) - id_ = conn_handler.execute_fetchone( - "INSERT INTO qiita.{0} (param_set_name, {1}) VALUES (%s, {2}) " - "RETURNING {3}".format( - cls._table, ', '.join(kwargs), - ', '.join(['%s'] * len(kwargs)), cls._column_id), - vals)[0] + TRN.add(sql, vals) - return cls(id_) + return cls(TRN.execute_fetchlast()) @classmethod def iter(cls): @@ -75,11 +77,12 @@ def iter(cls): generator Yields a parameter instance """ - conn_handler = SQLConnectionHandler() - sql = "SELECT {0} FROM qiita.{1}".format(cls._column_id, cls._table) - - for result in conn_handler.execute_fetchall(sql): - yield cls(result[0]) + with TRN: + sql = "SELECT {0} FROM qiita.{1}".format(cls._column_id, + cls._table) + TRN.add(sql) + for result in TRN.execute_fetchflatten(): + yield cls(result) @property def name(self): @@ -90,11 +93,11 @@ def name(self): str The name of the parameter set """ - conn_handler = SQLConnectionHandler() - return conn_handler.execute_fetchone( - "SELECT param_set_name FROM qiita.{0} WHERE {1} = %s".format( - self._table, self._column_id), - (self.id,))[0] + with TRN: + sql = "SELECT param_set_name FROM qiita.{0} WHERE {1} = %s".format( + self._table, self._column_id) + TRN.add(sql, [self.id]) + return TRN.execute_fetchlast() @property def values(self): @@ -105,16 +108,17 @@ def values(self): dict Dictionary with the parameter values keyed by parameter name """ - conn_handler = SQLConnectionHandler() - result = dict(conn_handler.execute_fetchone( - "SELECT * FROM qiita.{0} WHERE {1} = %s".format( - self._table, self._column_id), - (self.id,))) - # Remove the parameter id and the parameter name as those are used - # internally, and they are not passed to the processing step - del result[self._column_id] - del result['param_set_name'] - return result + with TRN: + sql = "SELECT * FROM qiita.{0} WHERE {1} = %s".format( + self._table, self._column_id) + TRN.add(sql, [self.id]) + # There should be only one row + result = dict(TRN.execute_fetchindex()[0]) + # Remove the parameter id and the parameter name as those are used + # internally, and they are not passed to the processing step + del result[self._column_id] + del result['param_set_name'] + return result def _check_id(self, id_): r"""Check that the provided ID actually exists in the database @@ -129,21 +133,20 @@ def _check_id(self, id_): This function overwrites the base function, as sql layout doesn't follow the same conventions done in the other classes. """ - self._check_subclass() - - conn_handler = SQLConnectionHandler() - - return conn_handler.execute_fetchone( - "SELECT EXISTS(SELECT * FROM qiita.{0} WHERE {1} = %s)".format( - self._table, self._column_id), - (id_, ))[0] + with TRN: + sql = """SELECT EXISTS( + SELECT * FROM qiita.{0} + WHERE {1} = %s)""".format(self._table, self._column_id) + TRN.add(sql, [id_]) + return TRN.execute_fetchlast() def _get_values_as_dict(self): r"""""" - conn_handler = SQLConnectionHandler() - return dict(conn_handler.execute_fetchone( - "SELECT * FROM qiita.{0} WHERE {1}=%s".format( - self._table, self._column_id), (self.id,))) + with TRN: + sql = "SELECT * FROM qiita.{0} WHERE {1}=%s".format( + self._table, self._column_id) + TRN.add(sql, [self.id]) + return dict(TRN.execute_fetchindex()[0]) def to_str(self): r"""Generates a string with the parameter values @@ -153,22 +156,23 @@ def to_str(self): str The string with all the parameters """ - table_cols = get_table_cols_w_type(self._table) - table_cols.remove([self._column_id, 'bigint']) + with TRN: + table_cols = get_table_cols_w_type(self._table) + table_cols.remove([self._column_id, 'bigint']) - values = self._get_values_as_dict() + values = self._get_values_as_dict() - result = [] - for p_name, p_type in sorted(table_cols): - if p_name in self._ignore_cols: - continue - if p_type == 'boolean': - if values[p_name]: - result.append("--%s" % p_name) - else: - result.append("--%s %s" % (p_name, values[p_name])) + result = [] + for p_name, p_type in sorted(table_cols): + if p_name in self._ignore_cols: + continue + if p_type == 'boolean': + if values[p_name]: + result.append("--%s" % p_name) + else: + result.append("--%s %s" % (p_name, values[p_name])) - return " ".join(result) + return " ".join(result) class PreprocessedIlluminaParams(BaseParameters): @@ -197,11 +201,11 @@ class ProcessedSortmernaParams(BaseParameters): @property def reference(self): """"Returns the reference id used on this parameter set""" - conn_handler = SQLConnectionHandler() - - return conn_handler.execute_fetchone( - "SELECT reference_id FROM qiita.{0} WHERE {1}=%s".format( - self._table, self._column_id), (self.id,))[0] + with TRN: + sql = "SELECT reference_id FROM qiita.{0} WHERE {1}=%s".format( + self._table, self._column_id) + TRN.add(sql, [self.id]) + return TRN.execute_fetchlast() def to_file(self, f): r"""Writes the parameters to a file in QIIME parameters file format diff --git a/qiita_db/portal.py b/qiita_db/portal.py index fb819c451..9a65cb470 100644 --- a/qiita_db/portal.py +++ b/qiita_db/portal.py @@ -7,7 +7,7 @@ # ----------------------------------------------------------------------------- import warnings -from .sql_connection import SQLConnectionHandler +from .sql_connection import TRN from .util import convert_to_id from .base import QiitaObject from .exceptions import (QiitaDBError, QiitaDBDuplicateError, QiitaDBWarning, @@ -33,9 +33,10 @@ class Portal(QiitaObject): _table = 'portal_type' def __init__(self, portal): - self.portal = portal - portal_id = convert_to_id(portal, 'portal_type', 'portal') - super(Portal, self).__init__(portal_id) + with TRN: + self.portal = portal + portal_id = convert_to_id(portal, 'portal_type', 'portal') + super(Portal, self).__init__(portal_id) @staticmethod def list_portals(): @@ -51,10 +52,13 @@ def list_portals(): This does not return the QIITA portal in the list, as it is a required portal that can not be edited. """ - sql = """SELECT portal FROM qiita.portal_type WHERE portal != 'QIITA' - ORDER BY portal""" - conn_handler = SQLConnectionHandler() - return [x[0] for x in conn_handler.execute_fetchall(sql)] + with TRN: + sql = """SELECT portal + FROM qiita.portal_type + WHERE portal != 'QIITA' + ORDER BY portal""" + TRN.add(sql) + return TRN.execute_fetchflatten() @classmethod def create(cls, portal, desc): @@ -72,39 +76,42 @@ def create(cls, portal, desc): QiitaDBDuplicateError Portal name already exists """ - if cls.exists(portal): - raise QiitaDBDuplicateError("Portal", portal) - - # Add portal and default analyses for all users - sql = """DO $do$ - DECLARE - pid bigint; - eml varchar; - aid bigint; - BEGIN - INSERT INTO qiita.portal_type (portal, portal_description) - VALUES (%s, %s) RETURNING portal_type_id INTO pid; - - FOR eml IN - SELECT email FROM qiita.qiita_user - LOOP - INSERT INTO qiita.analysis - (email, name, description, dflt, analysis_status_id) - VALUES (eml, eml || '-dflt', 'dflt', true, 1) RETURNING - analysis_id INTO aid; - - INSERT INTO qiita.analysis_workflow (analysis_id, step) - VALUES (aid, 2); - - INSERT INTO qiita.analysis_portal - (analysis_id, portal_type_id) - VALUES (aid, pid); - END LOOP; - END $do$;""" - conn_handler = SQLConnectionHandler() - conn_handler.execute(sql, [portal, desc]) - - return cls(portal) + with TRN: + if cls.exists(portal): + raise QiitaDBDuplicateError("Portal", portal) + + # Add portal and default analyses for all users + sql = """DO $do$ + DECLARE + pid bigint; + eml varchar; + aid bigint; + BEGIN + INSERT INTO qiita.portal_type (portal, portal_description) + VALUES (%s, %s) + RETURNING portal_type_id INTO pid; + + FOR eml IN + SELECT email FROM qiita.qiita_user + LOOP + INSERT INTO qiita.analysis + (email, name, description, dflt, + analysis_status_id) + VALUES (eml, eml || '-dflt', 'dflt', true, 1) + RETURNING analysis_id INTO aid; + + INSERT INTO qiita.analysis_workflow (analysis_id, step) + VALUES (aid, 2); + + INSERT INTO qiita.analysis_portal + (analysis_id, portal_type_id) + VALUES (aid, pid); + END LOOP; + END $do$;""" + TRN.add(sql, [portal, desc]) + TRN.execute() + + return cls(portal) @staticmethod def delete(portal): @@ -120,53 +127,58 @@ def delete(portal): QiitaDBError Portal has analyses or studies attached to it """ - conn_handler = SQLConnectionHandler() - # Check if attached to any studies - portal_id = convert_to_id(portal, 'portal_type', 'portal') - sql = """SELECT study_id from qiita.study_portal - WHERE portal_type_id = %s""" - studies = conn_handler.execute_fetchall(sql, [portal_id]) - if studies: - raise QiitaDBError( - " Cannot delete portal '%s', studies still attached: %s" % - (portal, ', '.join([str(s[0]) for s in studies]))) - - # Check if attached to any analyses - sql = """SELECT analysis_id from qiita.analysis_portal - JOIN qiita.analysis USING (analysis_id) - WHERE portal_type_id = %s and dflt = FALSE""" - analyses = conn_handler.execute_fetchall(sql, [portal_id]) - if analyses: - raise QiitaDBError( - " Cannot delete portal '%s', analyses still attached: %s" % - (portal, ', '.join([str(a[0]) for a in analyses]))) - - # Remove portal and default analyses for all users - sql = """DO $do$ - DECLARE - aid bigint; - BEGIN - FOR aid IN - SELECT analysis_id FROM qiita.analysis_portal - JOIN qiita.analysis USING (analysis_id) - WHERE portal_type_id = %s and dflt = True - LOOP - DELETE FROM qiita.analysis_portal - WHERE analysis_id = aid; - - DELETE FROM qiita.analysis_workflow - WHERE analysis_id = aid; - - DELETE FROM qiita.analysis_sample - WHERE analysis_id = aid; - - DELETE FROM qiita.analysis - WHERE analysis_id = aid; - END LOOP; - DELETE FROM qiita.portal_type WHERE portal_type_id = %s; - END $do$;""" - conn_handler = SQLConnectionHandler() - conn_handler.execute(sql, [portal_id] * 2) + with TRN: + # Check if attached to any studies + portal_id = convert_to_id(portal, 'portal_type', 'portal') + sql = """SELECT study_id + FROM qiita.study_portal + WHERE portal_type_id = %s""" + TRN.add(sql, [portal_id]) + studies = TRN.execute_fetchflatten() + if studies: + raise QiitaDBError( + " Cannot delete portal '%s', studies still attached: %s" % + (portal, ', '.join(map(str, studies)))) + + # Check if attached to any analyses + sql = """SELECT analysis_id + FROM qiita.analysis_portal + JOIN qiita.analysis USING (analysis_id) + WHERE portal_type_id = %s AND dflt = FALSE""" + TRN.add(sql, [portal_id]) + analyses = TRN.execute_fetchflatten() + if analyses: + raise QiitaDBError( + " Cannot delete portal '%s', analyses still attached: %s" % + (portal, ', '.join(map(str, analyses)))) + + # Remove portal and default analyses for all users + sql = """DO $do$ + DECLARE + aid bigint; + BEGIN + FOR aid IN + SELECT analysis_id + FROM qiita.analysis_portal + JOIN qiita.analysis USING (analysis_id) + WHERE portal_type_id = %s AND dflt = True + LOOP + DELETE FROM qiita.analysis_portal + WHERE analysis_id = aid; + + DELETE FROM qiita.analysis_workflow + WHERE analysis_id = aid; + + DELETE FROM qiita.analysis_sample + WHERE analysis_id = aid; + + DELETE FROM qiita.analysis + WHERE analysis_id = aid; + END LOOP; + DELETE FROM qiita.portal_type WHERE portal_type_id = %s; + END $do$;""" + TRN.add(sql, [portal_id] * 2) + TRN.execute() @staticmethod def exists(portal): @@ -197,22 +209,22 @@ def get_studies(self): set of int All study ids in the database that match the given portal """ - conn_handler = SQLConnectionHandler() - sql = """SELECT study_id FROM qiita.study_portal - WHERE portal_type_id = %s""" - return {x[0] for x in - conn_handler.execute_fetchall(sql, [self._id])} + with TRN: + sql = """SELECT study_id FROM qiita.study_portal + WHERE portal_type_id = %s""" + TRN.add(sql, [self._id]) + return set(TRN.execute_fetchflatten()) def _check_studies(self, studies): - conn_handler = SQLConnectionHandler() - # Check if any study IDs given do not exist. - sql = "SELECT study_id from qiita.study WHERE study_id IN %s" - existing = [x[0] for x in conn_handler.execute_fetchall( - sql, [tuple(studies)])] - if len(existing) != len(studies): - bad = map(str, set(studies).difference(existing)) - raise QiitaDBError("The following studies do not exist: %s" % - ", ".join(bad)) + with TRN: + # Check if any study IDs given do not exist. + sql = "SELECT study_id FROM qiita.study WHERE study_id IN %s" + TRN.add(sql, [tuple(studies)]) + existing = TRN.execute_fetchflatten() + if len(existing) != len(studies): + bad = map(str, set(studies).difference(existing)) + raise QiitaDBError("The following studies do not exist: %s" % + ", ".join(bad)) def add_studies(self, studies): """Adds studies to given portal @@ -229,27 +241,30 @@ def add_studies(self, studies): QiitaDBWarning Some studies already exist in the given portal """ - self._check_studies(studies) - - conn_handler = SQLConnectionHandler() - # Clean list of studies down to ones not associated with portal already - sql = """SELECT study_id from qiita.study_portal - WHERE portal_type_id = %s AND study_id IN %s""" - duplicates = [x[0] for x in conn_handler.execute_fetchall( - sql, [self._id, tuple(studies)])] - - if len(duplicates) > 0: - warnings.warn("The following studies are already part of %s: %s" % - (self.portal, ', '.join(map(str, duplicates))), - QiitaDBWarning) - - # Add cleaned list to the portal - clean_studies = set(studies).difference(duplicates) - sql = """INSERT INTO qiita.study_portal (study_id, portal_type_id) - VALUES (%s, %s)""" - if len(clean_studies) != 0: - conn_handler.executemany( - sql, [(s, self._id) for s in clean_studies]) + with TRN: + self._check_studies(studies) + + # Clean list of studies down to ones not associated + # with portal already + sql = """SELECT study_id + FROM qiita.study_portal + WHERE portal_type_id = %s AND study_id IN %s""" + TRN.add(sql, [self._id, tuple(studies)]) + duplicates = TRN.execute_fetchflatten() + + if len(duplicates) > 0: + warnings.warn( + "The following studies are already part of %s: %s" + % (self.portal, ', '.join(map(str, duplicates))), + QiitaDBWarning) + + # Add cleaned list to the portal + clean_studies = set(studies).difference(duplicates) + sql = """INSERT INTO qiita.study_portal (study_id, portal_type_id) + VALUES (%s, %s)""" + if len(clean_studies) != 0: + TRN.add(sql, [[s, self._id] for s in clean_studies], many=True) + TRN.execute() def remove_studies(self, studies): """Removes studies from given portal @@ -271,36 +286,41 @@ def remove_studies(self, studies): """ if self.portal == "QIITA": raise ValueError('Can not remove from main QIITA portal!') - self._check_studies(studies) - - conn_handler = SQLConnectionHandler() - # Make sure study not used in analysis in portal - sql = """SELECT DISTINCT study_id FROM qiita.study_processed_data - JOIN qiita.analysis_sample USING (processed_data_id) - JOIN qiita.analysis_portal USING (analysis_id) - WHERE portal_type_id = %s AND study_id IN %s""" - analysed = [x[0] for x in conn_handler.execute_fetchall( - sql, [self.id, tuple(studies)])] - if analysed: - raise QiitaDBError("The following studies are used in an analysis " - "on portal %s and can't be removed: %s" % - (self.portal, ", ".join(map(str, analysed)))) - - # Clean list of studies down to ones associated with portal already - sql = """SELECT study_id from qiita.study_portal - WHERE portal_type_id = %s AND study_id IN %s""" - clean_studies = [x[0] for x in conn_handler.execute_fetchall( - sql, [self._id, tuple(studies)])] - - if len(clean_studies) != len(studies): - rem = map(str, set(studies).difference(clean_studies)) - warnings.warn("The following studies are not part of %s: %s" % - (self.portal, ', '.join(rem)), QiitaDBWarning) - - sql = """DELETE FROM qiita.study_portal - WHERE study_id IN %s AND portal_type_id = %s""" - if len(clean_studies) != 0: - conn_handler.execute(sql, [tuple(studies), self._id]) + + with TRN: + self._check_studies(studies) + + # Make sure study not used in analysis in portal + sql = """SELECT DISTINCT study_id + FROM qiita.study_processed_data + JOIN qiita.analysis_sample USING (processed_data_id) + JOIN qiita.analysis_portal USING (analysis_id) + WHERE portal_type_id = %s AND study_id IN %s""" + TRN.add(sql, [self.id, tuple(studies)]) + analysed = TRN.execute_fetchflatten() + if analysed: + raise QiitaDBError( + "The following studies are used in an analysis on portal " + "%s and can't be removed: %s" + % (self.portal, ", ".join(map(str, analysed)))) + + # Clean list of studies down to ones associated with portal already + sql = """SELECT study_id + FROM qiita.study_portal + WHERE portal_type_id = %s AND study_id IN %s""" + TRN.add(sql, [self._id, tuple(studies)]) + clean_studies = TRN.execute_fetchflatten() + + if len(clean_studies) != len(studies): + rem = map(str, set(studies).difference(clean_studies)) + warnings.warn("The following studies are not part of %s: %s" % + (self.portal, ', '.join(rem)), QiitaDBWarning) + + sql = """DELETE FROM qiita.study_portal + WHERE study_id IN %s AND portal_type_id = %s""" + if len(clean_studies) != 0: + TRN.add(sql, [tuple(studies), self._id]) + TRN.execute() def get_analyses(self): """Returns analysis id for all Analyses belonging to a portal @@ -310,33 +330,37 @@ def get_analyses(self): set of int All analysis ids in the database that match the given portal """ - conn_handler = SQLConnectionHandler() - sql = """SELECT analysis_id FROM qiita.analysis_portal - WHERE portal_type_id = %s""" - return {x[0] for x in - conn_handler.execute_fetchall(sql, [self._id])} + with TRN: + sql = """SELECT analysis_id + FROM qiita.analysis_portal + WHERE portal_type_id = %s""" + TRN.add(sql, [self._id]) + return set(TRN.execute_fetchflatten()) def _check_analyses(self, analyses): - conn_handler = SQLConnectionHandler() - # Check if any analysis IDs given do not exist. - sql = "SELECT analysis_id from qiita.analysis WHERE analysis_id IN %s" - existing = [x[0] for x in conn_handler.execute_fetchall( - sql, [tuple(analyses)])] - if len(existing) != len(analyses): - bad = map(str, set(analyses).difference(existing)) - raise QiitaDBError("The following analyses do not exist: %s" % - ", ".join(bad)) - - # Check if any analyses given are default - sql = ("SELECT analysis_id from qiita.analysis WHERE analysis_id IN %s" - " AND dflt = True") - default = [x[0] for x in conn_handler.execute_fetchall( - sql, [tuple(analyses)])] - if len(default) > 0: - bad = map(str, set(analyses).difference(default)) - raise QiitaDBError( - "The following analyses are default and can't be deleted or " - "assigned to another portal: %s" % ", ".join(bad)) + with TRN: + # Check if any analysis IDs given do not exist. + sql = """SELECT analysis_id + FROM qiita.analysis + WHERE analysis_id IN %s""" + TRN.add(sql, [tuple(analyses)]) + existing = TRN.execute_fetchflatten() + if len(existing) != len(analyses): + bad = map(str, set(analyses).difference(existing)) + raise QiitaDBError("The following analyses do not exist: %s" + % ", ".join(bad)) + + # Check if any analyses given are default + sql = """SELECT analysis_id + FROM qiita.analysis + WHERE analysis_id IN %s AND dflt = True""" + TRN.add(sql, [tuple(analyses)]) + default = TRN.execute_fetchflatten() + if len(default) > 0: + bad = map(str, set(analyses).difference(default)) + raise QiitaDBError( + "The following analyses are default and can't be deleted " + "or assigned to another portal: %s" % ", ".join(bad)) def add_analyses(self, analyses): """Adds analyses to given portal @@ -355,45 +379,52 @@ def add_analyses(self, analyses): QiitaDBWarning Some analyses already exist in the given portal """ - self._check_analyses(analyses) - - conn_handler = SQLConnectionHandler() - if self.portal != "QIITA": - # Make sure new portal has access to all studies in analysis - sql = """SELECT DISTINCT analysis_id from qiita.analysis_sample - JOIN qiita.study_processed_data USING (processed_data_id) - WHERE study_id NOT IN ( - SELECT study_id from qiita.study_portal - WHERE portal_type_id = %s) - AND analysis_id IN %s ORDER BY analysis_id""" - missing_info = [x[0] for x in conn_handler.execute_fetchall( - sql, [self._id, tuple(analyses)])] - if missing_info: - raise QiitaDBError("Portal %s is mising studies used in the " - "following analyses: %s" % - (self.portal, - ", ".join(map(str, missing_info)))) - - # Clean list of analyses to ones not already associated with portal - sql = """SELECT analysis_id from qiita.analysis_portal - JOIN qiita.analysis USING (analysis_id) - WHERE portal_type_id = %s AND analysis_id IN %s - AND dflt != TRUE""" - duplicates = [x[0] for x in conn_handler.execute_fetchall( - sql, [self._id, tuple(analyses)])] - - if len(duplicates) > 0: - warnings.warn("The following analyses are already part of %s: %s" % - (self.portal, ', '.join(map(str, duplicates))), - QiitaDBWarning) - - sql = """INSERT INTO qiita.analysis_portal - (analysis_id, portal_type_id) - VALUES (%s, %s)""" - clean_analyses = set(analyses).difference(duplicates) - if len(clean_analyses) != 0: - conn_handler.executemany( - sql, [(a, self._id) for a in clean_analyses]) + with TRN: + self._check_analyses(analyses) + + if self.portal != "QIITA": + # Make sure new portal has access to all studies in analysis + sql = """SELECT DISTINCT analysis_id + FROM qiita.analysis_sample + JOIN qiita.study_processed_data + USING (processed_data_id) + WHERE study_id NOT IN ( + SELECT study_id + FROM qiita.study_portal + WHERE portal_type_id = %s) + AND analysis_id IN %s + ORDER BY analysis_id""" + TRN.add(sql, [self._id, tuple(analyses)]) + missing_info = TRN.execute_fetchflatten() + if missing_info: + raise QiitaDBError( + "Portal %s is mising studies used in the following " + "analyses: %s" + % (self.portal, ", ".join(map(str, missing_info)))) + + # Clean list of analyses to ones not already associated with portal + sql = """SELECT analysis_id + FROM qiita.analysis_portal + JOIN qiita.analysis USING (analysis_id) + WHERE portal_type_id = %s AND analysis_id IN %s + AND dflt != TRUE""" + TRN.add(sql, [self._id, tuple(analyses)]) + duplicates = TRN.execute_fetchflatten() + + if len(duplicates) > 0: + warnings.warn( + "The following analyses are already part of %s: %s" + % (self.portal, ', '.join(map(str, duplicates))), + QiitaDBWarning) + + sql = """INSERT INTO qiita.analysis_portal + (analysis_id, portal_type_id) + VALUES (%s, %s)""" + clean_analyses = set(analyses).difference(duplicates) + if len(clean_analyses) != 0: + TRN.add(sql, [[a, self._id] for a in clean_analyses], + many=True) + TRN.execute() def remove_analyses(self, analyses): """Removes analyses from given portal @@ -410,25 +441,27 @@ def remove_analyses(self, analyses): QiitaDBWarning Some analyses already do not exist in the given portal """ - self._check_analyses(analyses) - if self.portal == "QIITA": - raise ValueError('Can not remove from main QIITA portal!') - - conn_handler = SQLConnectionHandler() - # Clean list of analyses to ones already associated with portal - sql = """SELECT analysis_id from qiita.analysis_portal - JOIN qiita.analysis USING (analysis_id) - WHERE portal_type_id = %s AND analysis_id IN %s - AND dflt != TRUE""" - clean_analyses = [x[0] for x in conn_handler.execute_fetchall( - sql, [self._id, tuple(analyses)])] - - if len(clean_analyses) != len(analyses): - rem = map(str, set(analyses).difference(clean_analyses)) - warnings.warn("The following analyses are not part of %s: %s" % - (self.portal, ', '.join(rem)), QiitaDBWarning) - - sql = """DELETE FROM qiita.analysis_portal - WHERE analysis_id IN %s AND portal_type_id = %s""" - if len(clean_analyses) != 0: - conn_handler.execute(sql, [tuple(clean_analyses), self._id]) + with TRN: + self._check_analyses(analyses) + if self.portal == "QIITA": + raise ValueError('Can not remove from main QIITA portal!') + + # Clean list of analyses to ones already associated with portal + sql = """SELECT analysis_id + FROM qiita.analysis_portal + JOIN qiita.analysis USING (analysis_id) + WHERE portal_type_id = %s AND analysis_id IN %s + AND dflt != TRUE""" + TRN.add(sql, [self._id, tuple(analyses)]) + clean_analyses = TRN.execute_fetchflatten() + + if len(clean_analyses) != len(analyses): + rem = map(str, set(analyses).difference(clean_analyses)) + warnings.warn("The following analyses are not part of %s: %s" + % (self.portal, ', '.join(rem)), QiitaDBWarning) + + sql = """DELETE FROM qiita.analysis_portal + WHERE analysis_id IN %s AND portal_type_id = %s""" + if len(clean_analyses) != 0: + TRN.add(sql, [tuple(clean_analyses), self._id]) + TRN.execute() diff --git a/qiita_db/reference.py b/qiita_db/reference.py index 5e96bcda0..5eb3f6feb 100644 --- a/qiita_db/reference.py +++ b/qiita_db/reference.py @@ -12,7 +12,7 @@ from .exceptions import QiitaDBDuplicateError from .util import (insert_filepaths, convert_to_id, get_mountpoint) -from .sql_connection import SQLConnectionHandler +from .sql_connection import TRN class Reference(QiitaObject): @@ -62,41 +62,42 @@ def create(cls, name, version, seqs_fp, tax_fp=None, tree_fp=None): If the reference database with name `name` and version `version` already exists on the system """ - if cls.exists(name, version): - raise QiitaDBDuplicateError("Reference", - "Name: %s, Version: %s" - % (name, version)) - - conn_handler = SQLConnectionHandler() - - seq_id = insert_filepaths([(seqs_fp, convert_to_id("reference_seqs", - "filepath_type"))], - "%s_%s" % (name, version), "reference", - "filepath", conn_handler)[0] - - # Check if the database has taxonomy file - tax_id = None - if tax_fp: - fps = [(tax_fp, convert_to_id("reference_tax", "filepath_type"))] - tax_id = insert_filepaths(fps, "%s_%s" % (name, version), - "reference", "filepath", conn_handler)[0] - - # Check if the database has tree file - tree_id = None - if tree_fp: - fps = [(tree_fp, convert_to_id("reference_tree", "filepath_type"))] - tree_id = insert_filepaths(fps, "%s_%s" % (name, version), - "reference", "filepath", - conn_handler)[0] - - # Insert the actual object to the db - ref_id = conn_handler.execute_fetchone( - "INSERT INTO qiita.{0} (reference_name, reference_version, " - "sequence_filepath, taxonomy_filepath, tree_filepath) VALUES " - "(%s, %s, %s, %s, %s) RETURNING reference_id".format(cls._table), - (name, version, seq_id, tax_id, tree_id))[0] - - return cls(ref_id) + with TRN: + if cls.exists(name, version): + raise QiitaDBDuplicateError("Reference", + "Name: %s, Version: %s" + % (name, version)) + + seq_id = insert_filepaths( + [(seqs_fp, convert_to_id("reference_seqs", "filepath_type"))], + "%s_%s" % (name, version), "reference", "filepath")[0] + + # Check if the database has taxonomy file + tax_id = None + if tax_fp: + fps = [(tax_fp, + convert_to_id("reference_tax", "filepath_type"))] + tax_id = insert_filepaths(fps, "%s_%s" % (name, version), + "reference", "filepath")[0] + + # Check if the database has tree file + tree_id = None + if tree_fp: + fps = [(tree_fp, + convert_to_id("reference_tree", "filepath_type"))] + tree_id = insert_filepaths(fps, "%s_%s" % (name, version), + "reference", "filepath")[0] + + # Insert the actual object to the db + sql = """INSERT INTO qiita.{0} + (reference_name, reference_version, sequence_filepath, + taxonomy_filepath, tree_filepath) + VALUES (%s, %s, %s, %s, %s) + RETURNING reference_id""".format(cls._table) + TRN.add(sql, [name, version, seq_id, tax_id, tree_id]) + id_ = TRN.execute_fetchlast() + + return cls(id_) @classmethod def exists(cls, name, version): @@ -114,54 +115,62 @@ def exists(cls, name, version): QiitaDBNotImplementedError If the method is not overwritten by a subclass """ - conn_handler = SQLConnectionHandler() - return conn_handler.execute_fetchone( - "SELECT EXISTS(SELECT * FROM qiita.{0} WHERE " - "reference_name=%s AND reference_version=%s)".format(cls._table), - (name, version))[0] + with TRN: + sql = """SELECT EXISTS( + SELECT * FROM qiita.{0} + WHERE reference_name=%s + AND reference_version=%s)""".format(cls._table) + TRN.add(sql, [name, version]) + return TRN.execute_fetchlast() @property def name(self): - conn_handler = SQLConnectionHandler() - return conn_handler.execute_fetchone( - "SELECT reference_name FROM qiita.{0} WHERE " - "reference_id = %s".format(self._table), (self._id,))[0] - _, basefp = get_mountpoint('reference')[0] + with TRN: + sql = """SELECT reference_name FROM qiita.{0} + WHERE reference_id = %s""".format(self._table) + TRN.add(sql, [self._id]) + return TRN.execute_fetchlast() @property def version(self): - conn_handler = SQLConnectionHandler() - return conn_handler.execute_fetchone( - "SELECT reference_version FROM qiita.{0} WHERE " - "reference_id = %s".format(self._table), (self._id,))[0] - _, basefp = get_mountpoint('reference')[0] + with TRN: + sql = """SELECT reference_version FROM qiita.{0} + WHERE reference_id = %s""".format(self._table) + TRN.add(sql, [self._id]) + return TRN.execute_fetchlast() @property def sequence_fp(self): - conn_handler = SQLConnectionHandler() - rel_path = conn_handler.execute_fetchone( - "SELECT f.filepath FROM qiita.filepath f JOIN qiita.{0} r ON " - "r.sequence_filepath=f.filepath_id WHERE " - "r.reference_id=%s".format(self._table), (self._id,))[0] - _, basefp = get_mountpoint('reference')[0] - return join(basefp, rel_path) + with TRN: + sql = """SELECT f.filepath + FROM qiita.filepath f + JOIN qiita.{0} r ON r.sequence_filepath=f.filepath_id + WHERE r.reference_id=%s""".format(self._table) + TRN.add(sql, [self._id]) + rel_path = TRN.execute_fetchlast() + _, basefp = get_mountpoint('reference')[0] + return join(basefp, rel_path) @property def taxonomy_fp(self): - conn_handler = SQLConnectionHandler() - rel_path = conn_handler.execute_fetchone( - "SELECT f.filepath FROM qiita.filepath f JOIN qiita.{0} r ON " - "r.taxonomy_filepath=f.filepath_id WHERE " - "r.reference_id=%s".format(self._table), (self._id,))[0] - _, basefp = get_mountpoint('reference')[0] - return join(basefp, rel_path) + with TRN: + sql = """SELECT f.filepath + FROM qiita.filepath f + JOIN qiita.{0} r ON r.taxonomy_filepath=f.filepath_id + WHERE r.reference_id=%s""".format(self._table) + TRN.add(sql, [self._id]) + rel_path = TRN.execute_fetchlast() + _, basefp = get_mountpoint('reference')[0] + return join(basefp, rel_path) @property def tree_fp(self): - conn_handler = SQLConnectionHandler() - rel_path = conn_handler.execute_fetchone( - "SELECT f.filepath FROM qiita.filepath f JOIN qiita.{0} r ON " - "r.tree_filepath=f.filepath_id WHERE " - "r.reference_id=%s".format(self._table), (self._id,))[0] - _, basefp = get_mountpoint('reference')[0] - return join(basefp, rel_path) + with TRN: + sql = """SELECT f.filepath + FROM qiita.filepath f + JOIN qiita.{0} r ON r.tree_filepath=f.filepath_id + WHERE r.reference_id=%s""".format(self._table) + TRN.add(sql, [self._id]) + rel_path = TRN.execute_fetchlast() + _, basefp = get_mountpoint('reference')[0] + return join(basefp, rel_path) diff --git a/qiita_db/search.py b/qiita_db/search.py index 6cc87ec35..6e7488cb0 100644 --- a/qiita_db/search.py +++ b/qiita_db/search.py @@ -71,8 +71,8 @@ from future.utils import viewitems from qiita_db.util import scrub_data, convert_type, get_table_cols +from qiita_db.sql_connection import TRN from qiita_core.qiita_settings import qiita_config -from qiita_db.sql_connection import SQLConnectionHandler from qiita_db.study import Study from qiita_db.data import ProcessedData from qiita_db.exceptions import QiitaDBIncompatibleDatatypeError @@ -193,27 +193,31 @@ def __call__(self, searchstr, user): Metadata column names and string searches are case-sensitive """ - study_sql, sample_sql, meta_headers = \ - self._parse_study_search_string(searchstr, True) - conn_handler = SQLConnectionHandler() - # get all studies containing the metadata headers requested - study_ids = {x[0] for x in conn_handler.execute_fetchall(study_sql)} - # strip to only studies user has access to - if user.level not in {'admin', 'dev', 'superuser'}: - study_ids = study_ids.intersection(Study.get_by_status('public') | - user.user_studies | - user.shared_studies) - - results = {} - # run search on each study to get out the matching samples - for sid in study_ids: - study_res = conn_handler.execute_fetchall(sample_sql.format(sid)) - if study_res: - # only add study to results if actually has samples in results - results[sid] = study_res - self.results = results - self.meta_headers = meta_headers - return results, meta_headers + with TRN: + study_sql, sample_sql, meta_headers = \ + self._parse_study_search_string(searchstr, True) + + # get all studies containing the metadata headers requested + TRN.add(study_sql) + study_ids = set(TRN.execute_fetchflatten()) + # strip to only studies user has access to + if user.level not in {'admin', 'dev', 'superuser'}: + study_ids = study_ids.intersection( + Study.get_by_status('public') | user.user_studies | + user.shared_studies) + + results = {} + # run search on each study to get out the matching samples + for sid in study_ids: + TRN.add(sample_sql.format(sid)) + study_res = TRN.execute_fetchindex() + if study_res: + # only add study to results if actually has samples + # in results + results[sid] = study_res + self.results = results + self.meta_headers = meta_headers + return results, meta_headers def _parse_study_search_string(self, searchstr, only_with_processed_data=False): @@ -376,32 +380,34 @@ def filter_by_processed_data(self, datatypes=None): sample_id, column headers are the metadata categories searched over """ - if datatypes is not None: - # convert to set for easy lookups - datatypes = set(datatypes) - study_proc_ids = {} - proc_data_samples = {} - samples_meta = {} - headers = {c: val for c, val in enumerate(self.meta_headers)} - for study_id, study_meta in viewitems(self.results): - # add metadata to dataframe and dict - # use from_dict because pandas doesn't like cursor objects - samples_meta[study_id] = pd.DataFrame.from_dict( - {s[0]: s[1:] for s in study_meta}, orient='index') - samples_meta[study_id].rename(columns=headers, inplace=True) - # set up study-based data needed - study = Study(study_id) - study_sample_ids = {s[0] for s in study_meta} - study_proc_ids[study_id] = defaultdict(list) - for proc_data_id in study.processed_data(): - proc_data = ProcessedData(proc_data_id) - datatype = proc_data.data_type() - # skip processed data if it doesn't fit the given datatypes - if datatypes is not None and datatype not in datatypes: - continue - filter_samps = proc_data.samples.intersection(study_sample_ids) - if filter_samps: - proc_data_samples[proc_data_id] = sorted(filter_samps) - study_proc_ids[study_id][datatype].append(proc_data_id) - - return study_proc_ids, proc_data_samples, samples_meta + with TRN: + if datatypes is not None: + # convert to set for easy lookups + datatypes = set(datatypes) + study_proc_ids = {} + proc_data_samples = {} + samples_meta = {} + headers = {c: val for c, val in enumerate(self.meta_headers)} + for study_id, study_meta in viewitems(self.results): + # add metadata to dataframe and dict + # use from_dict because pandas doesn't like cursor objects + samples_meta[study_id] = pd.DataFrame.from_dict( + {s[0]: s[1:] for s in study_meta}, orient='index') + samples_meta[study_id].rename(columns=headers, inplace=True) + # set up study-based data needed + study = Study(study_id) + study_sample_ids = {s[0] for s in study_meta} + study_proc_ids[study_id] = defaultdict(list) + for proc_data_id in study.processed_data(): + proc_data = ProcessedData(proc_data_id) + datatype = proc_data.data_type() + # skip processed data if it doesn't fit the given datatypes + if datatypes is not None and datatype not in datatypes: + continue + filter_samps = proc_data.samples.intersection( + study_sample_ids) + if filter_samps: + proc_data_samples[proc_data_id] = sorted(filter_samps) + study_proc_ids[study_id][datatype].append(proc_data_id) + + return study_proc_ids, proc_data_samples, samples_meta diff --git a/qiita_db/sql_connection.py b/qiita_db/sql_connection.py index 6efdc90f6..d3b9ef8dc 100644 --- a/qiita_db/sql_connection.py +++ b/qiita_db/sql_connection.py @@ -7,6 +7,10 @@ This modules provides wrappers for the psycopg2 module to allow easy use of transaction blocks and SQL execution/data retrieval. +This module provides the variable TRN, which is the transaction available +to use in the system. The singleton pattern is applied and this works as long +as the system remains single-threaded. + Classes ------- @@ -14,6 +18,7 @@ :toctree: generated/ SQLConnectionHandler + Transaction Examples -------- @@ -79,8 +84,7 @@ from __future__ import division from contextlib import contextmanager from itertools import chain -from functools import partial -from tempfile import mktemp +from functools import partial, wraps from datetime import date, time, datetime import re @@ -88,7 +92,8 @@ OperationalError) from psycopg2.extras import DictCursor from psycopg2.extensions import ( - ISOLATION_LEVEL_AUTOCOMMIT, ISOLATION_LEVEL_READ_COMMITTED) + ISOLATION_LEVEL_AUTOCOMMIT, ISOLATION_LEVEL_READ_COMMITTED, + TRANSACTION_STATUS_IDLE) from qiita_core.qiita_settings import qiita_config @@ -457,176 +462,516 @@ def execute_fetchall(self, sql, sql_args=None): return result - def _check_queue_exists(self, queue_name): - """Checks if queue `queue_name` exists in the handler - Parameters - ---------- - queue_name : str - The name of the queue +def _checker(func): + """Decorator to check that methods are executed inside the context""" + @wraps(func) + def wrapper(self, *args, **kwargs): + if self._contexts_entered == 0: + raise RuntimeError( + "Operation not permitted. Transaction methods can only be " + "invoked within the context manager.") + return func(self, *args, **kwargs) + return wrapper + + +class Transaction(object): + """A context manager that encapsulates a DB transaction + + A transaction is defined by a series of consecutive queries that need to + be applied to the database as a single block. + + Raises + ------ + RuntimeError + If the transaction methods are invoked outside a context. + + Notes + ----- + When the execution leaves the context manager, any remaining queries in + the transaction will be executed and committed. + """ + + _regex = re.compile("^{(\d+):(\d+):(\d+)}$") + + def __init__(self): + self._queries = [] + self._results = [] + self._contexts_entered = 0 + self._connection = None + self._post_commit_funcs = [] + self._post_rollback_funcs = [] + + def _open_connection(self): + # If the connection already exists and is not closed, don't do anything + if self._connection is not None and self._connection.closed == 0: + return + + try: + self._connection = connect(user=qiita_config.user, + password=qiita_config.password, + database=qiita_config.database, + host=qiita_config.host, + port=qiita_config.port) + except OperationalError as e: + # catch threee known common exceptions and raise runtime errors + try: + etype = e.message.split(':')[1].split()[0] + except IndexError: + # we recieved a really unanticipated error without a colon + etype = '' + if etype == 'database': + etext = ('This is likely because the database `%s` has not ' + 'been created or has been dropped.' % + qiita_config.database) + elif etype == 'role': + etext = ('This is likely because the user string `%s` ' + 'supplied in your configuration file `%s` is ' + 'incorrect or not an authorized postgres user.' % + (qiita_config.user, qiita_config.conf_fp)) + elif etype == 'Connection': + etext = ('This is likely because postgres isn\'t ' + 'running. Check that postgres is correctly ' + 'installed and is running.') + else: + # we recieved a really unanticipated error with a colon + etext = '' + ebase = ('An OperationalError with the following message occured' + '\n\n\t%s\n%s For more information, review `INSTALL.md`' + ' in the Qiita installation base directory.') + raise RuntimeError(ebase % (e.message, etext)) + + def close(self): + if self._connection is not None: + self._connection.close() + + @contextmanager + def _get_cursor(self): + """Returns a postgres cursor Returns ------- - bool - True if queue `queue_name` exist in the handler. False otherwise. + psycopg2.cursor + The psycopg2 cursor + + Raises + ------ + RuntimeError + if the cursor cannot be created """ - return queue_name in self.queues + self._open_connection() - def create_queue(self, queue_name): - """Add a new queue to the connection + try: + with self._connection.cursor(cursor_factory=DictCursor) as cur: + yield cur + except PostgresError as e: + raise RuntimeError("Cannot get postgres cursor: %s" % e) - Parameters - ---------- - queue_name : str - Name of the new queue + def __enter__(self): + self._open_connection() + self._contexts_entered += 1 + return self + + def _clean_up(self, exc_type): + if exc_type is not None: + # An exception occurred during the execution of the transaction + # Make sure that we leave the DB w/o any modification + self.rollback() + elif self._queries: + # There are still queries to be executed, execute them + # It is safe to use the execute method here, as internally is + # wrapped in a try/except and rollbacks in case of failure + self.execute() + self.commit() + elif self._connection.get_transaction_status() != \ + TRANSACTION_STATUS_IDLE: + # There are no queries to be executed, however, the transaction + # is still not committed. Commit it so the changes are not lost + self.commit() + + def __exit__(self, exc_type, exc_value, traceback): + # We only need to perform some action if this is the last context + # that we are entering + if self._contexts_entered == 1: + # We need to wrap the entire function in a try/finally because + # at the end we need to decrement _contexts_entered + try: + self._clean_up(exc_type) + finally: + self._contexts_entered -= 1 + else: + self._contexts_entered -= 1 + + def _raise_execution_error(self, sql, sql_args, error): + """Rollbacks the current transaction and raises a useful error + The error message contains the name of the transaction, the failed + query, the arguments of the failed query and the error generated. Raises ------ - KeyError - Queue name already exists + ValueError """ - if self._check_queue_exists(queue_name): - raise KeyError("Queue %s already exists" % queue_name) + self.rollback() + raise ValueError( + "Error running SQL query:\n" + "Query: %s\nArguments: %s\nError: %s\n" + % (sql, str(sql_args), str(error))) - self.queues[queue_name] = [] + def _replace_placeholders(self, sql, sql_args): + """Replaces the placeholder in `sql_args` with the actual value - def list_queues(self): - """Returns list of all queue names currently in handler + Parameters + ---------- + sql : str + The SQL query + sql_args : list + The arguments of the SQL query Returns ------- - list of str - names of queues in handler - """ - return self.queues.keys() + tuple of (str, list of objects) + The input SQL query (unmodified) and the SQL arguments with the + placeholder (if any) substituted with the actual value of the + previous query - def add_to_queue(self, queue, sql, sql_args=None, many=False): - """Add an sql command to the end of a queue + Raises + ------ + ValueError + If a placeholder does not match any previous result + If a placeholder points to a query that do not produce any result + """ + for pos, arg in enumerate(sql_args): + # Check if we have a placeholder + if isinstance(arg, str): + placeholder = self._regex.search(arg) + if placeholder: + # We do have a placeholder, get the indexes + # Query index + q_idx = int(placeholder.group(1)) + # Row index + r_idx = int(placeholder.group(2)) + # Value index + v_idx = int(placeholder.group(3)) + try: + sql_args[pos] = self._results[q_idx][r_idx][v_idx] + except IndexError: + # A previous query that was expected to retrieve + # some data from the DB did not return as many + # values as expected + self._raise_execution_error( + sql, sql_args, + "The placeholder {%d:%d:%d} does not match to " + "any previous result" + % (q_idx, r_idx, v_idx)) + except TypeError: + # The query that the placeholder is pointing to + # is not expected to retrieve any value + # (e.g. an INSERT w/o RETURNING clause) + self._raise_execution_error( + sql, sql_args, + "The placeholder {%d:%d:%d} is referring to " + "a SQL query that does not retrieve data" + % (q_idx, r_idx, v_idx)) + + # If sql_args is an empty list, psycopg2 doesn't work correctly if we + # are passing '%' characters to the SQL query in a LIKE statement + sql_args = sql_args if sql_args else None + return sql, sql_args + + @_checker + def add(self, sql, sql_args=None, many=False): + """Add an sql query to the transaction + + If the current query needs a result of a previous query in the + transaction, a placeholder of the form '{#:#:#}' can be used. The first + number is the index of the previous SQL query in the transaction, the + second number is the row from that query result and the third number is + the index of the value within the query result row. + The placeholder will be replaced by the actual value at execution time. Parameters ---------- - queue : str - name of queue adding to sql : str - sql command to run - sql_args : list, tuple or dict, optional - the arguments to fill sql command with + The sql query + sql_args : list of objects, optional + The arguments to the sql query many : bool, optional - Whether or not this should be treated as an executemany command. - Default False + Whether or not we should add the query multiple times to the + transaction Raises ------ - KeyError - queue does not exist - """ - if not self._check_queue_exists(queue): - raise KeyError("Queue '%s' does not exist" % queue) + TypeError + If `sql_args` is provided and is not a list + RuntimeError + If invoked outside a context + Notes + ----- + If `many` is true, `sql_args` should be a list of lists, in which each + list of the list contains the parameters for one SQL query of the many. + Each element on the list is all the parameters for a single one of the + many queries added. The amount of SQL queries added to the list is + len(sql_args). + """ if not many: sql_args = [sql_args] for args in sql_args: - self._check_sql_args(args) - self.queues[queue].append((sql, args)) - - def _rollback_raise_error(self, queue, sql, sql_args, e): - self._connection.rollback() - # wipe out queue since it has an error in it - del self.queues[queue] - raise ValueError( - "Error running SQL query in queue %s: %s\nARGS: %s\nError: %s" - % (queue, sql, str(sql_args), e)) - - def execute_queue(self, queue): - """Executes all sql in a queue in a single transaction block - - Parameters - ---------- - queue : str - Name of queue to execute - - Notes - ----- - Does not support executemany command. Instead, enter the multiple - SQL commands as multiple entries in the queue. - - Raises - ------ - KetError - If queue does not exist - IndexError - If a sql argument placeholder does not correspond to the result of - any previously-executed query. + if args: + if not isinstance(args, list): + raise TypeError("sql_args should be a list. Found %s" + % type(args)) + else: + args = [] + self._queries.append((sql, args)) + + def _execute(self): + """Internal function that actually executes the transaction + The `execute` function exposed in the API wraps this one to make sure + that we catch any exception that happens in here and we rollback the + transaction """ - if not self._check_queue_exists(queue): - raise KeyError("Queue '%s' does not exist" % queue) + with self._get_cursor() as cur: + for sql, sql_args in self._queries: + sql, sql_args = self._replace_placeholders(sql, sql_args) - with self.get_postgres_cursor() as cur: - results = [] - clear_res = False - for sql, sql_args in self.queues[queue]: - if sql_args is not None: - # The user can provide a tuple, make sure that it is a - # list, so we can assign the item - sql_args = list(sql_args) - for pos, arg in enumerate(sql_args): - # check if previous results needed and replace - if isinstance(arg, str): - result = self._regex.search(arg) - if result: - result_pos = int(result.group(1)) - try: - sql_args[pos] = results[result_pos] - except IndexError: - self._rollback_raise_error( - queue, sql, sql_args, - "The index provided as a placeholder " - "%d does not correspond to any " - "previous result" % result_pos) - clear_res = True - # wipe out results if needed and reset clear_res - if clear_res: - results = [] - clear_res = False - # Fire off the SQL command + # Execute the current SQL command try: cur.execute(sql, sql_args) except Exception as e: - self._rollback_raise_error(queue, sql, sql_args, e) + # We catch any exception as we want to make sure that we + # rollback every time that something went wrong + self._raise_execution_error(sql, sql_args, e) - # fetch results if available and append to results list try: res = cur.fetchall() except ProgrammingError as e: # At this execution point, we don't know if the sql query - # that we executed was a INSERT or a SELECT. If it was a - # SELECT and there is nothing to fetch, it will return an - # empty list. However, if it was a INSERT it will raise a - # ProgrammingError, so we catch that one and pass. - pass + # that we executed should retrieve values from the database + # If the query was not supposed to retrieve any value + # (e.g. an INSERT without a RETURNING clause), it will + # raise a ProgrammingError. Otherwise it will just return + # an empty list + res = None except PostgresError as e: - self._rollback_raise_error(queue, sql, sql_args, e) - else: - # append all results linearly - results.extend(flatten(res)) - self._connection.commit() - # wipe out queue since finished - del self.queues[queue] - return results - - def get_temp_queue(self): - """Get a queue name that did not exist when this function was called + # Some other error happened during the execution of the + # query, so we need to rollback + self._raise_execution_error(sql, sql_args, e) + + # Store the results of the current query + self._results.append(res) + + # wipe out the already executed queries + self._queries = [] + + return self._results + + @_checker + def execute(self): + """Executes the transaction + + Returns + ------- + list of DictCursor + The results of all the SQL queries in the transaction + + Raises + ------ + RuntimeError + If invoked outside a context + + Notes + ----- + If any exception occurs during the execution transaction, a rollback + is executed and no changes are reflected in the database. + When calling execute, the transaction will never be committed, it will + be automatically committed when leaving the context + + See Also + -------- + execute_fetchlast + execute_fetchindex + execute_fetchflatten + """ + try: + return self._execute() + except Exception: + self.rollback() + raise + + @_checker + def execute_fetchlast(self): + """Executes the transaction and returns the last result + + This is a convenient function that is equivalent to + `self.execute()[-1][0][0]` Returns ------- - str - The name of the queue + object + The first value of the last SQL query executed + + See Also + -------- + execute + execute_fetchindex + execute_fetchflatten """ - temp_queue_name = mktemp() - while temp_queue_name in self.queues: - temp_queue_name = mktemp() + return self.execute()[-1][0][0] + + @_checker + def execute_fetchindex(self, idx=-1): + """Executes the transaction and returns the results of the `idx` query + + This is a convenient function that is equivalent to + `self.execute()[idx] - self.create_queue(temp_queue_name) + Parameters + ---------- + idx : int, optional + The index of the query to return the result. It defaults to -1, the + last query. + + Returns + ------- + DictCursor + The results of the `idx` query in the transaction + + See Also + -------- + execute + execute_fetchlast + execute_fetchflatten + """ + return self.execute()[idx] + + @_checker + def execute_fetchflatten(self, idx=-1): + """Executes the transaction and returns the flattened results of the + `idx` query + + This is a convenient function that is equivalen to + `chain.from_iterable(self.execute()[idx])` + + Parameters + ---------- + idx : int, optional + The index of the query to return the result. It defaults to -1, the + last query. + + Returns + ------- + list of objects + The flattened results of the `idx` query + + See Also + -------- + execute + execute_fetchlast + execute_fetchindex + """ + return list(chain.from_iterable(self.execute()[idx])) + + def _funcs_executor(self, funcs, func_str): + error_msg = [] + for f, args, kwargs in funcs: + try: + f(*args, **kwargs) + except Exception as e: + error_msg.append(str(e)) + # The functions in these two lines are mutually exclusive. When one of + # them is executed, we can restore both of them. + self._post_commit_funcs = [] + self._post_rollback_funcs = [] + if error_msg: + raise RuntimeError( + "An error occurred during the post %s commands:\n%s" + % (func_str, "\n".join(error_msg))) + + @_checker + def commit(self): + """Commits the transaction and reset the queries + + Raises + ------ + RuntimeError + If invoked outside a context + """ + # Reset the queries, the results and the index + self._queries = [] + self._results = [] + try: + self._connection.commit() + except Exception: + self._connection.close() + raise + # Execute the post commit functions + self._funcs_executor(self._post_commit_funcs, "commit") + + @_checker + def rollback(self): + """Rollbacks the transaction and reset the queries + + Raises + ------ + RuntimeError + If invoked outside a context + """ + # Reset the queries, the results and the index + self._queries = [] + self._results = [] + try: + self._connection.rollback() + except Exception: + self._connection.close() + raise + # Execute the post rollback functions + self._funcs_executor(self._post_rollback_funcs, "rollback") + + @property + def index(self): + return len(self._queries) + len(self._results) + + @_checker + def add_post_commit_func(self, func, *args, **kwargs): + """Adds a post commit function + + The function added will be executed after the next commit in the + transaction, unless a rollback is executed. This is useful, for + example, to perform some filesystem clean up once the transaction is + committed. + + Parameters + ---------- + func : function + The function to add for the post commit functions + args : tuple + The arguments of the function + kwargs : dict + The keyword arguments of the function + """ + self._post_commit_funcs.append((func, args, kwargs)) + + @_checker + def add_post_rollback_func(self, func, *args, **kwargs): + """Adds a post rollback function + + The function added will be executed after the next rollback in the + transaction, unless a commit is executed. This is useful, for example, + to restore the filesystem in case a rollback occurs, avoiding leaving + the database and the filesystem in an out of sync state. + + Parameters + ---------- + func : function + The function to add for the post rollback functions + args : tuple + The arguments of the function + kwargs : dict + The keyword arguments of the function + """ + self._post_rollback_funcs.append((func, args, kwargs)) - return temp_queue_name +# Singleton pattern, create the transaction for the entire system +TRN = Transaction() diff --git a/qiita_db/study.py b/qiita_db/study.py index fb03c4723..f0c03ac90 100644 --- a/qiita_db/study.py +++ b/qiita_db/study.py @@ -102,10 +102,11 @@ from qiita_core.exceptions import IncompetentQiitaDeveloperError from qiita_core.qiita_settings import qiita_config from .base import QiitaObject -from .exceptions import (QiitaDBStatusError, QiitaDBColumnError, QiitaDBError) +from .exceptions import (QiitaDBStatusError, QiitaDBColumnError, QiitaDBError, + QiitaDBDuplicateError) from .util import (check_required_columns, check_table_cols, convert_to_id, get_environmental_packages, get_table_cols, infer_status) -from .sql_connection import SQLConnectionHandler +from .sql_connection import TRN from .util import exists_table @@ -151,7 +152,7 @@ class Study(QiitaObject): get_table_cols('study'), get_table_cols('study_status'), get_table_cols('timeseries_type'), get_table_cols('study_pmid'))) - def _lock_non_sandbox(self, conn_handler): + def _lock_non_sandbox(self): """Raises QiitaDBStatusError if study is non-sandboxed""" if self.status != 'sandbox': raise QiitaDBStatusError("Illegal operation on non-sandbox study!") @@ -159,18 +160,17 @@ def _lock_non_sandbox(self, conn_handler): @property def status(self): r"""The status is inferred by the status of its processed data""" - conn_handler = SQLConnectionHandler() - # Get the status of all its processed data - sql = """SELECT processed_data_status - FROM qiita.processed_data_status pds - JOIN qiita.processed_data pd - USING (processed_data_status_id) - JOIN qiita.study_processed_data spd - USING (processed_data_id) - WHERE spd.study_id = %s""" - pd_statuses = conn_handler.execute_fetchall(sql, (self._id,)) - - return infer_status(pd_statuses) + with TRN: + # Get the status of all its processed data + sql = """SELECT processed_data_status + FROM qiita.processed_data_status pds + JOIN qiita.processed_data pd + USING (processed_data_status_id) + JOIN qiita.study_processed_data spd + USING (processed_data_id) + WHERE spd.study_id = %s""" + TRN.add(sql, [self._id]) + return infer_status(TRN.execute_fetchindex()) @classmethod def get_by_status(cls, status): @@ -186,30 +186,32 @@ def get_by_status(cls, status): set of int All study ids in the database that match the given status """ - conn_handler = SQLConnectionHandler() - sql = """SELECT study_id FROM qiita.study_processed_data - JOIN qiita.processed_data using (processed_data_id) - JOIN qiita.processed_data_status - USING (processed_data_status_id) - JOIN qiita.study_portal USING (study_id) - JOIN qiita.portal_type USING (portal_type_id) - WHERE processed_data_status=%s AND portal = %s""" - studies = {x[0] for x in - conn_handler.execute_fetchall( - sql, (status, qiita_config.portal))} - # If status is sandbox, all the studies that are not present in the - # study_processed_data are also sandbox - if status == 'sandbox': - sql = """SELECT study_id FROM qiita.study - JOIN qiita.study_portal USING (study_id) - JOIN qiita.portal_type USING (portal_type_id) - WHERE portal = %s AND study_id NOT IN ( - SELECT study_id FROM qiita.study_processed_data)""" - extra_studies = {x[0] for x in conn_handler.execute_fetchall( - sql, [qiita_config.portal])} - studies = studies.union(extra_studies) - - return studies + with TRN: + sql = """SELECT study_id + FROM qiita.study_processed_data + JOIN qiita.processed_data USING (processed_data_id) + JOIN qiita.processed_data_status + USING (processed_data_status_id) + JOIN qiita.study_portal USING (study_id) + JOIN qiita.portal_type USING (portal_type_id) + WHERE processed_data_status=%s AND portal = %s""" + TRN.add(sql, [status, qiita_config.portal]) + studies = set(TRN.execute_fetchflatten()) + # If status is sandbox, all the studies that are not present in the + # study_processed_data are also sandbox + if status == 'sandbox': + sql = """SELECT study_id + FROM qiita.study + JOIN qiita.study_portal USING (study_id) + JOIN qiita.portal_type USING (portal_type_id) + WHERE portal = %s + AND study_id NOT IN ( + SELECT study_id + FROM qiita.study_processed_data)""" + TRN.add(sql, [qiita_config.portal]) + studies = studies.union(TRN.execute_fetchflatten()) + + return studies @classmethod def get_info(cls, study_ids=None, info_cols=None): @@ -236,26 +238,30 @@ def get_info(cls, study_ids=None, info_cols=None): search_cols = ",".join(sorted(cls._info_cols.intersection(info_cols))) - sql = """SELECT {0} FROM ( - qiita.study - JOIN qiita.timeseries_type USING (timeseries_type_id) - LEFT JOIN (SELECT study_id, array_agg(pmid ORDER BY pmid) as - pmid FROM qiita.study_pmid GROUP BY study_id) sp USING (study_id) - JOIN qiita.study_portal USING (study_id) - JOIN qiita.portal_type USING (portal_type_id) - ) WHERE portal = '{1}'""".format(search_cols, qiita_config.portal) - - if study_ids is not None: - sql = "{0} AND study_id in ({1})".format( - sql, ','.join(str(s) for s in study_ids)) - - conn_handler = SQLConnectionHandler() - res = conn_handler.execute_fetchall(sql) - if study_ids is not None and (res is None or - len(res) != len(study_ids)): + with TRN: + sql = """SELECT {0} + FROM ( + qiita.study + JOIN qiita.timeseries_type USING (timeseries_type_id) + LEFT JOIN ( + SELECT study_id, array_agg(pmid ORDER BY pmid) AS + pmid + FROM qiita.study_pmid + GROUP BY study_id) sp USING (study_id) + JOIN qiita.study_portal USING (study_id) + JOIN qiita.portal_type USING (portal_type_id)) + WHERE portal = %s""".format(search_cols) + + args = [qiita_config.portal] + if study_ids is not None: + sql = "{0} AND study_id IN %s".format(sql) + args.append(tuple(study_ids)) + + TRN.add(sql, args) + res = TRN.execute_fetchindex() + if study_ids is not None and len(res) != len(study_ids): raise QiitaDBError('Non-portal-accessible studies asked for!') - - return res if res is not None else [] + return res @classmethod def exists(cls, study_title): @@ -270,11 +276,13 @@ def exists(cls, study_title): ------- bool """ - conn_handler = SQLConnectionHandler() - sql = ("SELECT exists(select study_id from qiita.{} WHERE " - "study_title = %s)").format(cls._table) - - return conn_handler.execute_fetchone(sql, [study_title])[0] + with TRN: + sql = """SELECT EXISTS( + SELECT study_id + FROM qiita.{} + WHERE study_title = %s)""".format(cls._table) + TRN.add(sql, [study_title]) + return TRN.execute_fetchlast() @classmethod def create(cls, owner, title, efo, info, investigation=None): @@ -302,6 +310,8 @@ def create(cls, owner, title, efo, info, investigation=None): IncompetentQiitaDeveloperError email, study_id, study_status_id, or study_title passed as a key empty efo list passed + QiitaDBDuplicateError + If a study with the given title already exists Notes ----- @@ -317,59 +327,72 @@ def create(cls, owner, title, efo, info, investigation=None): if not efo: raise IncompetentQiitaDeveloperError("Need EFO information!") - # add default values to info - insertdict = deepcopy(info) - insertdict['email'] = owner.id - insertdict['study_title'] = title - if "reprocess" not in insertdict: - insertdict['reprocess'] = False - - # No nuns allowed - insertdict = {k: v for k, v in viewitems(insertdict) if v is not None} - - conn_handler = SQLConnectionHandler() - # make sure dictionary only has keys for available columns in db - check_table_cols(conn_handler, insertdict, cls._table) - # make sure reqired columns in dictionary - check_required_columns(conn_handler, insertdict, cls._table) - - # Insert study into database - sql = ("INSERT INTO qiita.{0} ({1}) VALUES ({2}) RETURNING " - "study_id".format(cls._table, ','.join(insertdict), - ','.join(['%s'] * len(insertdict)))) - # make sure data in same order as sql column names, and ids are used - data = [] - for col in insertdict: - if isinstance(insertdict[col], QiitaObject): - data.append(insertdict[col].id) - else: - data.append(insertdict[col]) - - study_id = conn_handler.execute_fetchone(sql, data)[0] - - # insert efo information into database - sql = ("INSERT INTO qiita.{0}_experimental_factor (study_id, " - "efo_id) VALUES (%s, %s)".format(cls._table)) - conn_handler.executemany(sql, [(study_id, e) for e in efo]) - - # Add to both QIITA and given portal (if not QIITA) - portal_id = convert_to_id(qiita_config.portal, 'portal_type', 'portal') - sql = """INSERT INTO qiita.study_portal - (study_id, portal_type_id) - VALUES (%s, %s)""" - args = [[study_id, portal_id]] - if qiita_config.portal != 'QIITA': - qp_id = convert_to_id('QIITA', 'portal_type', 'portal') - args.append([study_id, qp_id]) - conn_handler.executemany(sql, args) - - # add study to investigation if necessary - if investigation: - sql = ("INSERT INTO qiita.investigation_study (investigation_id, " - "study_id) VALUES (%s, %s)") - conn_handler.execute(sql, (investigation.id, study_id)) - - return cls(study_id) + with TRN: + if cls.exists(title): + raise QiitaDBDuplicateError("Study", "title: %s" % title) + + # add default values to info + insertdict = deepcopy(info) + insertdict['email'] = owner.id + insertdict['study_title'] = title + if "reprocess" not in insertdict: + insertdict['reprocess'] = False + + # No nuns allowed + insertdict = {k: v for k, v in viewitems(insertdict) + if v is not None} + + # make sure dictionary only has keys for available columns in db + check_table_cols(insertdict, cls._table) + # make sure reqired columns in dictionary + check_required_columns(insertdict, cls._table) + + # Insert study into database + sql = """INSERT INTO qiita.{0} ({1}) + VALUES ({2}) RETURNING study_id""".format( + cls._table, ','.join(insertdict), + ','.join(['%s'] * len(insertdict))) + + # make sure data in same order as sql column names, + # and ids are used + data = [] + for col in insertdict: + if isinstance(insertdict[col], QiitaObject): + data.append(insertdict[col].id) + else: + data.append(insertdict[col]) + + TRN.add(sql, data) + study_id = TRN.execute_fetchlast() + + # insert efo information into database + sql = """INSERT INTO qiita.{0}_experimental_factor + (study_id, efo_id) + VALUES (%s, %s)""".format(cls._table) + TRN.add(sql, [[study_id, e] for e in efo], many=True) + + # Add to both QIITA and given portal (if not QIITA) + portal_id = convert_to_id(qiita_config.portal, 'portal_type', + 'portal') + sql = """INSERT INTO qiita.study_portal (study_id, portal_type_id) + VALUES (%s, %s)""" + args = [[study_id, portal_id]] + if qiita_config.portal != 'QIITA': + qp_id = convert_to_id('QIITA', 'portal_type', 'portal') + args.append([study_id, qp_id]) + TRN.add(sql, args, many=True) + TRN.execute() + + # add study to investigation if necessary + if investigation: + sql = """INSERT INTO qiita.investigation_study + (investigation_id, study_id) + VALUES (%s, %s)""" + TRN.add(sql, [investigation.id, study_id]) + + TRN.execute() + + return cls(study_id) @classmethod def delete(cls, id_): @@ -385,57 +408,43 @@ def delete(cls, id_): QiitaDBError If the sample_(id_) table exists means a sample template exists """ - cls._check_subclass() - - # checking that the id_ exists - cls(id_) + with TRN: + # checking that the id_ exists + cls(id_) - conn_handler = SQLConnectionHandler() - if exists_table('sample_%d' % id_, conn_handler): - raise QiitaDBError('Study "%s" cannot be erased because it has a ' - 'sample template' % cls(id_).title) + if exists_table('sample_%d' % id_): + raise QiitaDBError( + 'Study "%s" cannot be erased because it has a ' + 'sample template' % cls(id_).title) - queue = "delete_study_%d" % id_ - conn_handler.create_queue(queue) + sql = "DELETE FROM qiita.study_sample_columns WHERE study_id = %s" + args = [id_] + TRN.add(sql, args) - conn_handler.add_to_queue( - queue, - "DELETE FROM qiita.study_sample_columns WHERE study_id = %s", - (id_, )) + sql = "DELETE FROM qiita.study_portal WHERE study_id = %s" + TRN.add(sql, args) - conn_handler.add_to_queue( - queue, - "DELETE FROM qiita.study_portal WHERE study_id = %s", - (id_, )) + sql = """DELETE FROM qiita.study_experimental_factor + WHERE study_id = %s""" + TRN.add(sql, args) - conn_handler.add_to_queue( - queue, - "DELETE FROM qiita.study_experimental_factor WHERE study_id = %s", - (id_, )) + sql = "DELETE FROM qiita.study_pmid WHERE study_id = %s" + TRN.add(sql, args) - conn_handler.add_to_queue( - queue, - "DELETE FROM qiita.study_pmid WHERE study_id = %s", (id_, )) + sql = """DELETE FROM qiita.study_environmental_package + WHERE study_id = %s""" + TRN.add(sql, args) - conn_handler.add_to_queue( - queue, - "DELETE FROM qiita.study_environmental_package WHERE study_id = " - "%s", (id_, )) + sql = "DELETE FROM qiita.study_users WHERE study_id = %s" + TRN.add(sql, args) - conn_handler.add_to_queue( - queue, - "DELETE FROM qiita.study_users WHERE study_id = %s", (id_, )) + sql = "DELETE FROM qiita.investigation_study WHERE study_id = %s" + TRN.add(sql, args) - conn_handler.add_to_queue( - queue, - "DELETE FROM qiita.investigation_study WHERE study_id = " - "%s", (id_, )) + sql = "DELETE FROM qiita.study WHERE study_id = %s" + TRN.add(sql, args) - conn_handler.add_to_queue( - queue, - "DELETE FROM qiita.study WHERE study_id = %s", (id_, )) - - conn_handler.execute_queue(queue) + TRN.execute() # --- Attributes --- @@ -448,10 +457,11 @@ def title(self): str Title of study """ - conn_handler = SQLConnectionHandler() - sql = ("SELECT study_title FROM qiita.{0} WHERE " - "study_id = %s".format(self._table)) - return conn_handler.execute_fetchone(sql, (self._id, ))[0] + with TRN: + sql = """SELECT study_title FROM qiita.{0} + WHERE study_id = %s""".format(self._table) + TRN.add(sql, [self._id]) + return TRN.execute_fetchlast() @title.setter def title(self, title): @@ -462,10 +472,11 @@ def title(self, title): title : str The new study title """ - conn_handler = SQLConnectionHandler() - sql = ("UPDATE qiita.{0} SET study_title = %s WHERE " - "study_id = %s".format(self._table)) - return conn_handler.execute(sql, (title, self._id)) + with TRN: + sql = """UPDATE qiita.{0} SET study_title = %s + WHERE study_id = %s""".format(self._table) + TRN.add(sql, [title, self._id]) + return TRN.execute() @property def info(self): @@ -476,16 +487,18 @@ def info(self): dict info of study keyed to column names """ - conn_handler = SQLConnectionHandler() - sql = "SELECT * FROM qiita.{0} WHERE study_id = %s".format(self._table) - info = dict(conn_handler.execute_fetchone(sql, (self._id, ))) - # remove non-info items from info - for item in self._non_info: - info.pop(item) - # This is an optional column, but should not be considered part of the - # info - info.pop('study_id') - return info + with TRN: + sql = "SELECT * FROM qiita.{0} WHERE study_id = %s".format( + self._table) + TRN.add(sql, [self._id]) + info = dict(TRN.execute_fetchindex()[0]) + # remove non-info items from info + for item in self._non_info: + info.pop(item) + # This is an optional column, but should not be considered part + # of the info + info.pop('study_id') + return info @info.setter def info(self, info): @@ -513,36 +526,37 @@ def info(self, info): raise QiitaDBColumnError("non info keys passed: %s" % self._non_info.intersection(info)) - conn_handler = SQLConnectionHandler() - - if 'timeseries_type_id' in info: - # We only lock if the timeseries type changes - self._lock_non_sandbox(conn_handler) - - # make sure dictionary only has keys for available columns in db - check_table_cols(conn_handler, info, self._table) - - sql_vals = [] - data = [] - # build query with data values in correct order for SQL statement - for key, val in viewitems(info): - sql_vals.append("{0} = %s".format(key)) - if isinstance(val, QiitaObject): - data.append(val.id) - else: - data.append(val) - data.append(self._id) - - sql = ("UPDATE qiita.{0} SET {1} WHERE " - "study_id = %s".format(self._table, ','.join(sql_vals))) - conn_handler.execute(sql, data) + with TRN: + if 'timeseries_type_id' in info: + # We only lock if the timeseries type changes + self._lock_non_sandbox() + + # make sure dictionary only has keys for available columns in db + check_table_cols(info, self._table) + + sql_vals = [] + data = [] + # build query with data values in correct order for SQL statement + for key, val in viewitems(info): + sql_vals.append("{0} = %s".format(key)) + if isinstance(val, QiitaObject): + data.append(val.id) + else: + data.append(val) + data.append(self._id) + + sql = "UPDATE qiita.{0} SET {1} WHERE study_id = %s".format( + self._table, ','.join(sql_vals)) + TRN.add(sql, data) + TRN.execute() @property def efo(self): - conn_handler = SQLConnectionHandler() - sql = ("SELECT efo_id FROM qiita.{0}_experimental_factor WHERE " - "study_id = %s".format(self._table)) - return [x[0] for x in conn_handler.execute_fetchall(sql, (self._id, ))] + with TRN: + sql = """SELECT efo_id FROM qiita.{0}_experimental_factor + WHERE study_id = %s""".format(self._table) + TRN.add(sql, [self._id]) + return TRN.execute_fetchflatten() @efo.setter def efo(self, efo_vals): @@ -560,16 +574,18 @@ def efo(self, efo_vals): """ if not efo_vals: raise IncompetentQiitaDeveloperError("Need EFO information!") - conn_handler = SQLConnectionHandler() - self._lock_non_sandbox(conn_handler) - # wipe out any EFOs currently attached to study - sql = ("DELETE FROM qiita.{0}_experimental_factor WHERE " - "study_id = %s".format(self._table)) - conn_handler.execute(sql, (self._id, )) - # insert new EFO information into database - sql = ("INSERT INTO qiita.{0}_experimental_factor (study_id, " - "efo_id) VALUES (%s, %s)".format(self._table)) - conn_handler.executemany(sql, [(self._id, efo) for efo in efo_vals]) + with TRN: + self._lock_non_sandbox() + # wipe out any EFOs currently attached to study + sql = """DELETE FROM qiita.{0}_experimental_factor + WHERE study_id = %s""".format(self._table) + TRN.add(sql, [self._id]) + # insert new EFO information into database + sql = """INSERT INTO qiita.{0}_experimental_factor + (study_id, efo_id) + VALUES (%s, %s)""".format(self._table) + TRN.add(sql, [[self._id, efo] for efo in efo_vals], many=True) + TRN.execute() @property def shared_with(self): @@ -580,10 +596,11 @@ def shared_with(self): list of User ids Users the study is shared with """ - conn_handler = SQLConnectionHandler() - sql = ("SELECT email FROM qiita.{0}_users WHERE " - "study_id = %s".format(self._table)) - return [x[0] for x in conn_handler.execute_fetchall(sql, (self._id,))] + with TRN: + sql = """SELECT email FROM qiita.{0}_users + WHERE study_id = %s""".format(self._table) + TRN.add(sql, [self._id]) + return TRN.execute_fetchflatten() @property def pmids(self): @@ -594,10 +611,11 @@ def pmids(self): list of str list of all the PMIDs """ - conn_handler = SQLConnectionHandler() - sql = ("SELECT pmid FROM qiita.{0}_pmid WHERE " - "study_id = %s".format(self._table)) - return [x[0] for x in conn_handler.execute_fetchall(sql, (self._id, ))] + with TRN: + sql = "SELECT pmid FROM qiita.{0}_pmid WHERE study_id = %s".format( + self._table) + TRN.add(sql, [self._id]) + return TRN.execute_fetchflatten() @pmids.setter def pmids(self, values): @@ -617,25 +635,18 @@ def pmids(self, values): if not isinstance(values, list): raise TypeError('pmids should be a list') - # Get the connection to the database - conn_handler = SQLConnectionHandler() - - # Create a queue for the operations that we need to do - queue = "%d_pmid_setter" % self._id - conn_handler.create_queue(queue) - - # Delete the previous pmids associated with the study - sql = "DELETE FROM qiita.study_pmid WHERE study_id=%s" - sql_args = (self._id,) - conn_handler.add_to_queue(queue, sql, sql_args) + with TRN: + # Delete the previous pmids associated with the study + sql = "DELETE FROM qiita.study_pmid WHERE study_id=%s" + TRN.add(sql, [self._id]) - # Set the new ones - sql = "INSERT INTO qiita.study_pmid (study_id, pmid) VALUES (%s, %s)" - sql_args = [(self._id, val) for val in values] - conn_handler.add_to_queue(queue, sql, sql_args, many=True) + # Set the new ones + sql = """INSERT INTO qiita.study_pmid (study_id, pmid) + VALUES (%s, %s)""" + sql_args = [[self._id, val] for val in values] + TRN.add(sql, sql_args, many=True) - # Execute the queue - conn_handler.execute_queue(queue) + TRN.execute() @property def investigation(self): @@ -645,11 +656,14 @@ def investigation(self): ------- Investigation id """ - conn_handler = SQLConnectionHandler() - sql = ("SELECT investigation_id FROM qiita.investigation_study WHERE " - "study_id = %s") - inv = conn_handler.execute_fetchone(sql, (self._id, )) - return inv[0] if inv is not None else inv + with TRN: + sql = """SELECT investigation_id FROM qiita.investigation_study + WHERE study_id = %s""" + TRN.add(sql, [self._id]) + inv = TRN.execute_fetchindex() + # If this study belongs to an investigation it will be in + # the first value of the first row [0][0] + return inv[0][0] if inv else None @property def sample_template(self): @@ -669,13 +683,14 @@ def data_types(self): ------- list of str """ - conn_handler = SQLConnectionHandler() - sql = """SELECT DISTINCT data_type - FROM qiita.study_prep_template - JOIN qiita.prep_template USING (prep_template_id) - JOIN qiita.data_type USING (data_type_id) - WHERE study_id = %s""" - return [x[0] for x in conn_handler.execute_fetchall(sql, (self._id,))] + with TRN: + sql = """SELECT DISTINCT data_type + FROM qiita.study_prep_template + JOIN qiita.prep_template USING (prep_template_id) + JOIN qiita.data_type USING (data_type_id) + WHERE study_id = %s""" + TRN.add(sql, [self._id]) + return TRN.execute_fetchflatten() @property def owner(self): @@ -686,11 +701,11 @@ def owner(self): str The email (id) of the user that owns this study """ - conn_handler = SQLConnectionHandler() - sql = """select email from qiita.{} where study_id = %s""".format( - self._table) - - return conn_handler.execute_fetchone(sql, [self._id])[0] + with TRN: + sql = """SELECT email FROM qiita.{} WHERE study_id = %s""".format( + self._table) + TRN.add(sql, [self._id]) + return TRN.execute_fetchlast() @property def environmental_packages(self): @@ -701,12 +716,12 @@ def environmental_packages(self): list of str The environmental package names associated with the study """ - conn_handler = SQLConnectionHandler() - env_pkgs = conn_handler.execute_fetchall( - "SELECT environmental_package_name FROM " - "qiita.study_environmental_package WHERE study_id = %s", - (self._id,)) - return [pkg[0] for pkg in env_pkgs] + with TRN: + sql = """SELECT environmental_package_name + FROM qiita.study_environmental_package + WHERE study_id = %s""" + TRN.add(sql, [self._id]) + return TRN.execute_fetchflatten() @environmental_packages.setter def environmental_packages(self, values): @@ -724,43 +739,38 @@ def environmental_packages(self, values): ValueError If any environmental packages listed on values is not recognized """ - # Get the connection to the database - conn_handler = SQLConnectionHandler() - - # The environmental packages can be changed only if the study is - # sandboxed - self._lock_non_sandbox(conn_handler) - - # Check that a list is actually passed - if not isinstance(values, list): - raise TypeError('Environmental packages should be a list') - - # Get all the environmental packages - env_pkgs = [pkg[0] for pkg in get_environmental_packages()] - - # Check that all the passed values are valid environmental packages - missing = set(values).difference(env_pkgs) - if missing: - raise ValueError('Environmetal package(s) not recognized: %s' - % ', '.join(missing)) - - # Create a queue for the operations that we need to do - queue = "%d_env_pkgs_setter" % self._id - conn_handler.create_queue(queue) - - # Delete the previous environmental packages associated with the study - sql = "DELETE FROM qiita.study_environmental_package WHERE study_id=%s" - sql_args = (self._id,) - conn_handler.add_to_queue(queue, sql, sql_args) - - # Set the new ones - sql = ("INSERT INTO qiita.study_environmental_package " - "(study_id, environmental_package_name) VALUES (%s, %s)") - sql_args = [(self._id, val) for val in values] - conn_handler.add_to_queue(queue, sql, sql_args, many=True) - - # Execute the queue - conn_handler.execute_queue(queue) + with TRN: + # The environmental packages can be changed only if the study is + # sandboxed + self._lock_non_sandbox() + + # Check that a list is actually passed + if not isinstance(values, list): + raise TypeError('Environmental packages should be a list') + + # Get all the environmental packages + env_pkgs = [pkg[0] for pkg in get_environmental_packages()] + + # Check that all the passed values are valid environmental packages + missing = set(values).difference(env_pkgs) + if missing: + raise ValueError('Environmetal package(s) not recognized: %s' + % ', '.join(missing)) + + # Delete the previous environmental packages associated with + # the study + sql = """DELETE FROM qiita.study_environmental_package + WHERE study_id=%s""" + TRN.add(sql, [self._id]) + + # Set the new ones + sql = """INSERT INTO qiita.study_environmental_package + (study_id, environmental_package_name) + VALUES (%s, %s)""" + sql_args = [[self._id, val] for val in values] + TRN.add(sql, sql_args, many=True) + + TRN.execute() @property def _portals(self): @@ -771,11 +781,13 @@ def _portals(self): list of str Portal names study is associated with """ - sql = """SELECT portal from qiita.portal_type - JOIN qiita.study_portal USING (portal_type_id) - WHERE study_id = %s""" - conn_handler = SQLConnectionHandler() - return [x[0] for x in conn_handler.execute_fetchall(sql, [self._id])] + with TRN: + sql = """SELECT portal + FROM qiita.portal_type + JOIN qiita.study_portal USING (portal_type_id) + WHERE study_id = %s""" + TRN.add(sql, [self._id]) + return TRN.execute_fetchflatten() # --- methods --- def raw_data(self, data_type=None): @@ -790,18 +802,20 @@ def raw_data(self, data_type=None): ------- list of RawData ids """ - spec_data = "" - if data_type: - spec_data = " AND data_type_id = %d" % convert_to_id(data_type, - "data_type") - conn_handler = SQLConnectionHandler() - sql = """SELECT raw_data_id - FROM qiita.study_prep_template - JOIN qiita.prep_template USING (prep_template_id) - JOIN qiita.raw_data USING (raw_data_id) - WHERE study_id = %s{0}""".format(spec_data) - - return [x[0] for x in conn_handler.execute_fetchall(sql, (self._id,))] + with TRN: + spec_data = "" + args = [self._id] + if data_type: + spec_data = " AND data_type_id = %d" + args.append(convert_to_id(data_type, "data_type")) + + sql = """SELECT raw_data_id + FROM qiita.study_prep_template + JOIN qiita.prep_template USING (prep_template_id) + JOIN qiita.raw_data USING (raw_data_id) + WHERE study_id = %s{0}""".format(spec_data) + TRN.add(sql, args) + return TRN.execute_fetchflatten() def prep_templates(self, data_type=None): """Return list of prep template ids @@ -816,17 +830,19 @@ def prep_templates(self, data_type=None): ------- list of PrepTemplate ids """ - spec_data = "" - if data_type: - spec_data = " AND data_type_id = %s" % convert_to_id(data_type, - "data_type") - - conn_handler = SQLConnectionHandler() - sql = """SELECT prep_template_id - FROM qiita.study_prep_template - JOIN qiita.prep_template USING (prep_template_id) - WHERE study_id = %s{0}""".format(spec_data) - return [x[0] for x in conn_handler.execute_fetchall(sql, (self._id,))] + with TRN: + spec_data = "" + args = [self._id] + if data_type: + spec_data = " AND data_type_id = %s" + args.append(convert_to_id(data_type, "data_type")) + + sql = """SELECT prep_template_id + FROM qiita.study_prep_template + JOIN qiita.prep_template USING (prep_template_id) + WHERE study_id = %s{0}""".format(spec_data) + TRN.add(sql, args) + return TRN.execute_fetchflatten() def preprocessed_data(self, data_type=None): """ Returns list of data ids for preprocessed data info @@ -840,14 +856,18 @@ def preprocessed_data(self, data_type=None): ------- list of PreprocessedData ids """ - spec_data = "" - if data_type: - spec_data = " AND data_type_id = %d" % convert_to_id(data_type, - "data_type") - conn_handler = SQLConnectionHandler() - sql = ("SELECT preprocessed_data_id FROM qiita.study_preprocessed_data" - " WHERE study_id = %s{0}".format(spec_data)) - return [x[0] for x in conn_handler.execute_fetchall(sql, (self._id,))] + with TRN: + spec_data = "" + args = [self._id] + if data_type: + spec_data = " AND data_type_id = %d" + args.append(convert_to_id(data_type, "data_type")) + + sql = """SELECT preprocessed_data_id + FROM qiita.study_preprocessed_data + WHERE study_id = %s{0}""".format(spec_data) + TRN.add(sql, args) + return TRN.execute_fetchflatten() def processed_data(self, data_type=None): """ Returns list of data ids for processed data info @@ -861,16 +881,20 @@ def processed_data(self, data_type=None): ------- list of ProcessedData ids """ - spec_data = "" - if data_type: - spec_data = " AND p.data_type_id = %d" % convert_to_id(data_type, - "data_type") - conn_handler = SQLConnectionHandler() - sql = ("SELECT p.processed_data_id FROM qiita.processed_data p JOIN " - "qiita.study_processed_data sp ON p.processed_data_id = " - "sp.processed_data_id WHERE " - "sp.study_id = %s{0}".format(spec_data)) - return [x[0] for x in conn_handler.execute_fetchall(sql, (self._id,))] + with TRN: + spec_data = "" + args = [self._id] + if data_type: + spec_data = " AND p.data_type_id = %d" + args.append(convert_to_id(data_type, "data_type")) + + sql = """SELECT p.processed_data_id + FROM qiita.processed_data p + JOIN qiita.study_processed_data sp + ON p.processed_data_id = sp.processed_data_id + WHERE sp.study_id = %s{0}""".format(spec_data) + TRN.add(sql, args) + return TRN.execute_fetchflatten() def add_pmid(self, pmid): """Adds PMID to study @@ -880,10 +904,11 @@ def add_pmid(self, pmid): pmid : str pmid to associate with study """ - conn_handler = SQLConnectionHandler() - sql = ("INSERT INTO qiita.{0}_pmid (study_id, pmid) " - "VALUES (%s, %s)".format(self._table)) - conn_handler.execute(sql, (self._id, pmid)) + with TRN: + sql = """INSERT INTO qiita.{0}_pmid (study_id, pmid) + VALUES (%s, %s)""".format(self._table) + TRN.add(sql, [self._id, pmid]) + TRN.execute() def has_access(self, user, no_public=False): """Returns whether the given user has access to the study @@ -901,15 +926,16 @@ def has_access(self, user, no_public=False): bool Whether user has access to study or not """ - # if admin or superuser, just return true - if user.level in {'superuser', 'admin'}: - return True + with TRN: + # if admin or superuser, just return true + if user.level in {'superuser', 'admin'}: + return True - if no_public: - return self._id in user.user_studies | user.shared_studies - else: - return self._id in user.user_studies | user.shared_studies \ - | self.get_by_status('public') + if no_public: + return self._id in user.user_studies | user.shared_studies + else: + return self._id in user.user_studies | user.shared_studies \ + | self.get_by_status('public') def share(self, user): """Share the study with another user @@ -919,19 +945,18 @@ def share(self, user): user: User object The user to share the study with """ - conn_handler = SQLConnectionHandler() - - # Make sure the study is not already shared with the given user - if user.id in self.shared_with: - return - # Do not allow the study to be shared with the owner - if user.id == self.owner: - return - - sql = ("INSERT INTO qiita.study_users (study_id, email) VALUES " - "(%s, %s)") - - conn_handler.execute(sql, (self._id, user.id)) + with TRN: + # Make sure the study is not already shared with the given user + if user.id in self.shared_with: + return + # Do not allow the study to be shared with the owner + if user.id == self.owner: + return + + sql = """INSERT INTO qiita.study_users (study_id, email) + VALUES (%s, %s)""" + TRN.add(sql, [self._id, user.id]) + TRN.execute() def unshare(self, user): """Unshare the study with another user @@ -941,12 +966,11 @@ def unshare(self, user): user: User object The user to unshare the study with """ - conn_handler = SQLConnectionHandler() - - sql = ("DELETE FROM qiita.study_users WHERE study_id = %s AND " - "email = %s") - - conn_handler.execute(sql, (self._id, user.id)) + with TRN: + sql = """DELETE FROM qiita.study_users + WHERE study_id = %s AND email = %s""" + TRN.add(sql, [self._id, user.id]) + TRN.execute() class StudyPerson(QiitaObject): @@ -977,13 +1001,13 @@ def iter(cls): Yields a `StudyPerson` object for each person in the database, in order of ascending study_person_id """ - conn = SQLConnectionHandler() - sql = "select study_person_id from qiita.{} order by study_person_id" - results = conn.execute_fetchall(sql.format(cls._table)) + with TRN: + sql = """SELECT study_person_id FROM qiita.{} + ORDER BY study_person_id""".format(cls._table) + TRN.add(sql) - for result in results: - ID = result[0] - yield StudyPerson(ID) + for id_ in TRN.execute_fetchflatten(): + yield StudyPerson(id_) @classmethod def exists(cls, name, affiliation): @@ -1001,10 +1025,13 @@ def exists(cls, name, affiliation): bool True if person exists else false """ - conn_handler = SQLConnectionHandler() - sql = ("SELECT exists(SELECT * FROM qiita.{0} WHERE " - "name = %s AND affiliation = %s)".format(cls._table)) - return conn_handler.execute_fetchone(sql, (name, affiliation))[0] + with TRN: + sql = """SELECT EXISTS( + SELECT * FROM qiita.{0} + WHERE name = %s + AND affiliation = %s)""".format(cls._table) + TRN.add(sql, [name, affiliation]) + return TRN.execute_fetchlast() @classmethod def create(cls, name, email, affiliation, address=None, phone=None): @@ -1028,23 +1055,22 @@ def create(cls, name, email, affiliation, address=None, phone=None): New StudyPerson object """ - if cls.exists(name, affiliation): - sql = ("SELECT study_person_id from qiita.{0} WHERE name = %s and" - " affiliation = %s".format(cls._table)) - conn_handler = SQLConnectionHandler() - spid = conn_handler.execute_fetchone(sql, (name, affiliation)) - - # Doesn't exist so insert new person - else: - sql = ("INSERT INTO qiita.{0} (name, email, affiliation, address, " - "phone) VALUES" - " (%s, %s, %s, %s, %s) RETURNING " - "study_person_id".format(cls._table)) - conn_handler = SQLConnectionHandler() - spid = conn_handler.execute_fetchone(sql, (name, email, - affiliation, address, - phone)) - return cls(spid[0]) + with TRN: + if cls.exists(name, affiliation): + sql = """SELECT study_person_id + FROM qiita.{0} + WHERE name = %s + AND affiliation = %s""".format(cls._table) + args = [name, affiliation] + else: + sql = """INSERT INTO qiita.{0} (name, email, affiliation, + address, phone) + VALUES (%s, %s, %s, %s, %s) + RETURNING study_person_id""".format(cls._table) + args = [name, email, affiliation, address, phone] + + TRN.add(sql, args) + return cls(TRN.execute_fetchlast()) # Properties @property @@ -1056,10 +1082,11 @@ def name(self): str Name of person """ - conn_handler = SQLConnectionHandler() - sql = ("SELECT name FROM qiita.{0} WHERE " - "study_person_id = %s".format(self._table)) - return conn_handler.execute_fetchone(sql, (self._id, ))[0] + with TRN: + sql = """SELECT name FROM qiita.{0} + WHERE study_person_id = %s""".format(self._table) + TRN.add(sql, [self._id]) + return TRN.execute_fetchlast() @property def email(self): @@ -1070,10 +1097,11 @@ def email(self): str Email of person """ - conn_handler = SQLConnectionHandler() - sql = ("SELECT email FROM qiita.{0} WHERE " - "study_person_id = %s".format(self._table)) - return conn_handler.execute_fetchone(sql, (self._id, ))[0] + with TRN: + sql = """SELECT email FROM qiita.{0} + WHERE study_person_id = %s""".format(self._table) + TRN.add(sql, [self._id]) + return TRN.execute_fetchlast() @property def affiliation(self): @@ -1084,10 +1112,11 @@ def affiliation(self): str Affiliation of person """ - conn_handler = SQLConnectionHandler() - sql = ("SELECT affiliation FROM qiita.{0} WHERE " - "study_person_id = %s".format(self._table)) - return conn_handler.execute_fetchone(sql, [self._id])[0] + with TRN: + sql = """SELECT affiliation FROM qiita.{0} + WHERE study_person_id = %s""".format(self._table) + TRN.add(sql, [self._id]) + return TRN.execute_fetchlast() @property def address(self): @@ -1098,10 +1127,11 @@ def address(self): str or None address or None if no address in database """ - conn_handler = SQLConnectionHandler() - sql = ("SELECT address FROM qiita.{0} WHERE study_person_id =" - " %s".format(self._table)) - return conn_handler.execute_fetchone(sql, (self._id, ))[0] + with TRN: + sql = """SELECT address FROM qiita.{0} + WHERE study_person_id = %s""".format(self._table) + TRN.add(sql, [self._id]) + return TRN.execute_fetchlast() @address.setter def address(self, value): @@ -1112,10 +1142,11 @@ def address(self, value): value : str New address for person """ - conn_handler = SQLConnectionHandler() - sql = ("UPDATE qiita.{0} SET address = %s WHERE " - "study_person_id = %s".format(self._table)) - conn_handler.execute(sql, (value, self._id)) + with TRN: + sql = """UPDATE qiita.{0} SET address = %s + WHERE study_person_id = %s""".format(self._table) + TRN.add(sql, [value, self._id]) + TRN.execute() @property def phone(self): @@ -1126,10 +1157,11 @@ def phone(self): str or None phone or None if no address in database """ - conn_handler = SQLConnectionHandler() - sql = ("SELECT phone FROM qiita.{0} WHERE " - "study_person_id = %s".format(self._table)) - return conn_handler.execute_fetchone(sql, (self._id, ))[0] + with TRN: + sql = """SELECT phone FROM qiita.{0} + WHERE study_person_id = %s""".format(self._table) + TRN.add(sql, [self._id]) + return TRN.execute_fetchlast() @phone.setter def phone(self, value): @@ -1140,7 +1172,8 @@ def phone(self, value): value : str New phone number for person """ - conn_handler = SQLConnectionHandler() - sql = ("UPDATE qiita.{0} SET phone = %s WHERE " - "study_person_id = %s".format(self._table)) - conn_handler.execute(sql, (value, self._id)) + with TRN: + sql = """UPDATE qiita.{0} SET phone = %s + WHERE study_person_id = %s""".format(self._table) + TRN.add(sql, [value, self._id]) + TRN.execute() diff --git a/qiita_db/support_files/patches/python_patches/14.py b/qiita_db/support_files/patches/python_patches/14.py index a974a11c8..bdeb70a39 100644 --- a/qiita_db/support_files/patches/python_patches/14.py +++ b/qiita_db/support_files/patches/python_patches/14.py @@ -6,69 +6,57 @@ from os.path import basename -from skbio.util import flatten - -from qiita_db.sql_connection import SQLConnectionHandler +from qiita_db.sql_connection import TRN from qiita_db.metadata_template import PrepTemplate -conn_handler = SQLConnectionHandler() - -sql = "SELECT prep_template_id FROM qiita.prep_template" -all_ids = conn_handler.execute_fetchall(sql) - -q_name = 'unlink-bad-mapping-files' -conn_handler.create_queue(q_name) - -# remove all the bad mapping files -for prep_template_id in all_ids: - - prep_template_id = prep_template_id[0] - pt = PrepTemplate(prep_template_id) - fps = pt.get_filepaths() - - # get the QIIME mapping file, note that the way to figure out what is and - # what's not a qiime mapping file is to check for the existance of the - # word qiime in the basename of the file path, hacky but that's the way - # it is being done in qiita_pet/uimodules/raw_data_tab.py - mapping_files = [f for f in fps if '_qiime_' in basename(f[1])] - - table = 'prep_template_filepath' - column = 'prep_template_id' - - # unlink all the qiime mapping files for this prep template object - for mf in mapping_files: - - # (1) get the ids that we are going to delete. - # because of the FK restriction, we cannot just delete the ids - ids = conn_handler.execute_fetchall( - 'SELECT filepath_id FROM qiita.{0} WHERE ' - '{1}=%s and filepath_id=%s'.format(table, column), (pt.id, mf[0])) - ids = flatten(ids) - - # (2) delete the entries from the prep_template_filepath table - conn_handler.add_to_queue( - q_name, "DELETE FROM qiita.{0} " - "WHERE {1}=%s and filepath_id=%s;".format(table, column), - (pt.id, mf[0])) - - # (3) delete the entries from the filepath table - conn_handler.add_to_queue( - q_name, - "DELETE FROM qiita.filepath WHERE " - "filepath_id IN ({0});".format(', '.join(map(str, ids)))) - -try: - conn_handler.execute_queue(q_name) -except Exception as e: - raise - -# create correct versions of the mapping files -for prep_template_id in all_ids: - - prep_template_id = prep_template_id[0] - pt = PrepTemplate(prep_template_id) - - # we can guarantee that all the filepaths will be prep templates so - # we can just generate the qiime mapping files - for _, fpt in pt.get_filepaths(): - pt.create_qiime_mapping_file(fpt) +with TRN: + sql = "SELECT prep_template_id FROM qiita.prep_template" + TRN.add(sql) + all_ids = TRN.execute_fetchflatten() + + # remove all the bad mapping files + for prep_template_id in all_ids: + pt = PrepTemplate(prep_template_id) + fps = pt.get_filepaths() + + # get the QIIME mapping file, note that the way to figure out what is + # and what's not a qiime mapping file is to check for the existance of + # the word qiime in the basename of the file path, hacky but that's + # the way it is being done in qiita_pet/uimodules/raw_data_tab.py + mapping_files = [f for f in fps if '_qiime_' in basename(f[1])] + + table = 'prep_template_filepath' + column = 'prep_template_id' + + # unlink all the qiime mapping files for this prep template object + for mf in mapping_files: + + # (1) get the ids that we are going to delete. + # because of the FK restriction, we cannot just delete the ids + sql = """SELECT filepath_id + FROM qiita.{0} + WHERE {1}=%s AND filepath_id=%s""".format(table, column) + TRN.add(sql, [pt.id, mf[0]]) + ids = TRN.execute_fetchflatten() + + # (2) delete the entries from the prep_template_filepath table + sql = """DELETE FROM qiita.{0} + WHERE {1}=%s and filepath_id=%s""".format(table, column) + TRN.add(sql, [pt.id, mf[0]]) + + # (3) delete the entries from the filepath table + sql = "DELETE FROM qiita.filepath WHERE filepath_id IN %s" + TRN.add(sql, [tuple(ids)]) + + TRN.execute() + + # create correct versions of the mapping files + for prep_template_id in all_ids: + + prep_template_id = prep_template_id[0] + pt = PrepTemplate(prep_template_id) + + # we can guarantee that all the filepaths will be prep templates so + # we can just generate the qiime mapping files + for _, fpt in pt.get_filepaths(): + pt.create_qiime_mapping_file(fpt) diff --git a/qiita_db/support_files/patches/python_patches/15.py b/qiita_db/support_files/patches/python_patches/15.py index 25a620948..2be84220f 100644 --- a/qiita_db/support_files/patches/python_patches/15.py +++ b/qiita_db/support_files/patches/python_patches/15.py @@ -4,23 +4,26 @@ from os.path import basename, dirname from qiita_db.util import get_mountpoint -from qiita_db.sql_connection import SQLConnectionHandler +from qiita_db.sql_connection import TRN -conn_handler = SQLConnectionHandler() +with TRN: + sql = """SELECT f.* + FROM qiita.filepath f + JOIN qiita.analysis_filepath afp + ON f.filepath_id = afp.filepath_id""" + TRN.add(sql) + filepaths = TRN.execute_fetchindex() -filepaths = conn_handler.execute_fetchall( - 'SELECT f.* from qiita.filepath f JOIN qiita.analysis_filepath afp ON ' - 'f.filepath_id = afp.filepath_id') + # retrieve relative filepaths as dictionary for matching + mountpoints = {m[1].rstrip('/\\'): m[0] for m in get_mountpoint( + 'analysis', retrieve_all=True)} -# retrieve relative filepaths as dictionary for matching -mountpoints = {m[1].rstrip('/\\'): m[0] for m in get_mountpoint( - 'analysis', retrieve_all=True)} + sql = """UPDATE qiita.filepath SET filepath = %s, data_directory_id = %s + WHERE filepath_id = %s""" + for filepath in filepaths: + filename = basename(filepath['filepath']) + # find the ID of the analysis filepath used + mp_id = mountpoints[dirname(filepath['filepath']).rstrip('/\\')] + TRN.add(sql, [filename, mp_id, filepath['filepath_id']]) -for filepath in filepaths: - filename = basename(filepath['filepath']) - # find the ID of the analysis filepath used - mp_id = mountpoints[dirname(filepath['filepath']).rstrip('/\\')] - conn_handler.execute( - 'UPDATE qiita.filepath SET filepath = %s, data_directory_id = %s WHERE' - ' filepath_id = %s', - [filename, mp_id, filepath['filepath_id']]) + TRN.execute() diff --git a/qiita_db/support_files/patches/python_patches/23.py b/qiita_db/support_files/patches/python_patches/23.py index cd0292893..c7039a692 100644 --- a/qiita_db/support_files/patches/python_patches/23.py +++ b/qiita_db/support_files/patches/python_patches/23.py @@ -1,20 +1,19 @@ # Mar 27, 2015 # Need to re-generate the files, given that some headers have changed -from qiita_db.sql_connection import SQLConnectionHandler +from qiita_db.sql_connection import TRN from qiita_db.metadata_template import SampleTemplate, PrepTemplate -conn_handler = SQLConnectionHandler() +with TRN: + # Get all the sample templates + TRN.add("SELECT DISTINCT study_id from qiita.study_sample") + study_ids = TRN.execute_fetchflatten() -# Get all the sample templates -sql = """SELECT DISTINCT study_id from qiita.study_sample""" -study_ids = {s[0] for s in conn_handler.execute_fetchall(sql)} + for s_id in study_ids: + SampleTemplate(s_id).generate_files() -for s_id in study_ids: - SampleTemplate(s_id).generate_files() - -# Get all the prep templates -sql = """SELECT prep_template_id from qiita.prep_template""" -prep_ids = {p[0] for p in conn_handler.execute_fetchall(sql)} -for prep_id in prep_ids: - PrepTemplate(prep_id).generate_files() + # Get all the prep templates + TRN.add("SELECT DISTINCT prep_template_id from qiita.prep_template") + prep_ids = TRN.execute_fetchflatten() + for prep_id in prep_ids: + PrepTemplate(prep_id).generate_files() diff --git a/qiita_db/support_files/patches/python_patches/25.py b/qiita_db/support_files/patches/python_patches/25.py index c54c3557a..ba7fbb817 100644 --- a/qiita_db/support_files/patches/python_patches/25.py +++ b/qiita_db/support_files/patches/python_patches/25.py @@ -4,121 +4,101 @@ # make the RawData to be effectively just a container for the raw files, # which is how it was acting previously. -from qiita_db.sql_connection import SQLConnectionHandler +from qiita_db.sql_connection import TRN from qiita_db.data import RawData from qiita_db.util import move_filepaths_to_upload_folder -conn_handler = SQLConnectionHandler() -queue = "PATCH_25" -conn_handler.create_queue(queue) - -# the system may contain raw data with no prep template associated to it. -# Retrieve all those raw data ids -sql = """SELECT raw_data_id - FROM qiita.raw_data - WHERE raw_data_id NOT IN ( - SELECT DISTINCT raw_data_id FROM qiita.prep_template);""" -rd_ids = [x[0] for x in conn_handler.execute_fetchall(sql)] - -# We will delete those RawData. However, if they have files attached, we should -# move them to the uploads folder of the study -sql_detach = """DELETE FROM qiita.study_raw_data - WHERE raw_data_id = %s AND study_id = %s""" -sql_unlink = "DELETE FROM qiita.raw_filepath WHERE raw_data_id = %s" -sql_delete = "DELETE FROM qiita.raw_data WHERE raw_data_id = %s" -sql_studies = """SELECT study_id FROM qiita.study_raw_data - WHERE raw_data_id = %s""" -move_files = [] -for rd_id in rd_ids: - rd = RawData(rd_id) - filepaths = rd.get_filepaths() - studies = [s[0] for s in conn_handler.execute_fetchall(sql_studies, - (rd_id,))] - if filepaths: - # we need to move the files to a study. We chose the one with lower - # study id. Currently there is no case in the live database in which a - # RawData with no prep templates is attached to more than one study, - # but I think it is better to normalize this just in case - move_files.append((min(studies), filepaths)) - - # To delete the RawData we first need to unlink all the files - conn_handler.add_to_queue(queue, sql_unlink, (rd_id,)) - - # Then, remove the raw data from all the studies - for st_id in studies: - conn_handler.add_to_queue(queue, sql_detach, (rd_id, st_id)) - - conn_handler.add_to_queue(queue, sql_delete, (rd_id,)) - -# We can now perform all changes in the DB. Although these changes can be -# done in an SQL patch, they are done here because we need to execute the -# previous clean up in the database before we can actually execute the SQL -# patch. -sql = """CREATE TABLE qiita.study_prep_template ( - study_id bigint NOT NULL, - prep_template_id bigint NOT NULL, - CONSTRAINT idx_study_prep_template - PRIMARY KEY ( study_id, prep_template_id ) - ); - -CREATE INDEX idx_study_prep_template_0 - ON qiita.study_prep_template ( study_id ); - -CREATE INDEX idx_study_prep_template_1 - ON qiita.study_prep_template ( prep_template_id ); - -COMMENT ON TABLE qiita.study_prep_template IS - 'links study to its prep templates'; - -ALTER TABLE qiita.study_prep_template - ADD CONSTRAINT fk_study_prep_template_study - FOREIGN KEY ( study_id ) REFERENCES qiita.study( study_id ); - -ALTER TABLE qiita.study_prep_template - ADD CONSTRAINT fk_study_prep_template_pt - FOREIGN KEY ( prep_template_id ) - REFERENCES qiita.prep_template( prep_template_id ); - --- Connect the existing prep templates in the system with their studies -DO $do$ -DECLARE - vals RECORD; -BEGIN -FOR vals IN - SELECT prep_template_id, study_id - FROM qiita.prep_template - JOIN qiita.study_raw_data USING (raw_data_id) -LOOP - INSERT INTO qiita.study_prep_template (study_id, prep_template_id) - VALUES (vals.study_id, vals.prep_template_id); -END LOOP; -END $do$; - ---- Drop the study_raw__data table as it's not longer used -DROP TABLE qiita.study_raw_data; - --- The raw_data_id column now can be nullable -ALTER TABLE qiita.prep_template - ALTER COLUMN raw_data_id DROP NOT NULL; -""" -conn_handler.add_to_queue(queue, sql) -conn_handler.execute_queue(queue) - -# After the changes in the database have been performed, move the files -# to the uploads folder -errors = [] -for st_id, fps in move_files: - try: - move_filepaths_to_upload_folder(st_id, fps) - except Exception, e: - # An error here is unlikely. However, it's possible and there is no - # clean way that we can unroll all the previous changes in the DB. - errors.append((st_id, fps, str(e))) - -# Show the user any error that could have been generated during the files -# movement -if errors: - print ("The following errors where generated when trying to move files " - "to the upload folder") - for st_id, fps, e in errors: - print "Study: %d, Filepaths: %s, Error: %s" % (st_id, fps, e) +with TRN: + # the system may contain raw data with no prep template associated to it. + # Retrieve all those raw data ids + sql = """SELECT raw_data_id + FROM qiita.raw_data + WHERE raw_data_id NOT IN ( + SELECT DISTINCT raw_data_id FROM qiita.prep_template);""" + TRN.add(sql) + rd_ids = TRN.execute_fetchflatten() + + # We will delete those RawData. However, if they have files attached, we + # should move them to the uploads folder of the study + sql_detach = """DELETE FROM qiita.study_raw_data + WHERE raw_data_id = %s AND study_id = %s""" + sql_unlink = "DELETE FROM qiita.raw_filepath WHERE raw_data_id = %s" + sql_delete = "DELETE FROM qiita.raw_data WHERE raw_data_id = %s" + sql_studies = """SELECT study_id FROM qiita.study_raw_data + WHERE raw_data_id = %s""" + move_files = [] + for rd_id in rd_ids: + rd = RawData(rd_id) + filepaths = rd.get_filepaths() + TRN.add(sql_studies, [rd_id]) + studies = TRN.execute_fetchflatten() + if filepaths: + # we need to move the files to a study. We chose the one with lower + # study id. Currently there is no case in the live database in + # which a RawData with no prep templates is attached to more than + # one study, but I think it is better to normalize this just + # in case + move_filepaths_to_upload_folder(min(studies), filepaths) + + # To delete the RawData we first need to unlink all the files + TRN.add(sql_unlink, [rd_id]) + + # Then, remove the raw data from all the studies + for st_id in studies: + TRN.add(sql_detach, [rd_id, st_id]) + + TRN.add(sql_delete, [rd_id]) + + # We can now perform all changes in the DB. Although these changes can be + # done in an SQL patch, they are done here because we need to execute the + # previous clean up in the database before we can actually execute the SQL + # patch. + sql = """CREATE TABLE qiita.study_prep_template ( + study_id bigint NOT NULL, + prep_template_id bigint NOT NULL, + CONSTRAINT idx_study_prep_template + PRIMARY KEY ( study_id, prep_template_id ) + ); + + CREATE INDEX idx_study_prep_template_0 + ON qiita.study_prep_template ( study_id ); + + CREATE INDEX idx_study_prep_template_1 + ON qiita.study_prep_template ( prep_template_id ); + + COMMENT ON TABLE qiita.study_prep_template IS + 'links study to its prep templates'; + + ALTER TABLE qiita.study_prep_template + ADD CONSTRAINT fk_study_prep_template_study + FOREIGN KEY ( study_id ) REFERENCES qiita.study( study_id ); + + ALTER TABLE qiita.study_prep_template + ADD CONSTRAINT fk_study_prep_template_pt + FOREIGN KEY ( prep_template_id ) + REFERENCES qiita.prep_template( prep_template_id ); + + -- Connect the existing prep templates in the system with their studies + DO $do$ + DECLARE + vals RECORD; + BEGIN + FOR vals IN + SELECT prep_template_id, study_id + FROM qiita.prep_template + JOIN qiita.study_raw_data USING (raw_data_id) + LOOP + INSERT INTO qiita.study_prep_template (study_id, prep_template_id) + VALUES (vals.study_id, vals.prep_template_id); + END LOOP; + END $do$; + + --- Drop the study_raw__data table as it's not longer used + DROP TABLE qiita.study_raw_data; + + -- The raw_data_id column now can be nullable + ALTER TABLE qiita.prep_template + ALTER COLUMN raw_data_id DROP NOT NULL; + """ + TRN.add(sql) + TRN.execute() diff --git a/qiita_db/support_files/patches/python_patches/6.py b/qiita_db/support_files/patches/python_patches/6.py index b8d39b1d7..e7497c9db 100644 --- a/qiita_db/support_files/patches/python_patches/6.py +++ b/qiita_db/support_files/patches/python_patches/6.py @@ -6,29 +6,28 @@ from time import strftime from qiita_db.util import get_mountpoint -from qiita_db.sql_connection import SQLConnectionHandler +from qiita_db.sql_connection import TRN from qiita_db.metadata_template import SampleTemplate, PrepTemplate -conn_handler = SQLConnectionHandler() +with TRN: + _id, fp_base = get_mountpoint('templates')[0] -_id, fp_base = get_mountpoint('templates')[0] + TRN.add("SELECT study_id FROM qiita.study") + for study_id in TRN.execute_fetchflatten(): + if SampleTemplate.exists(study_id): + st = SampleTemplate(study_id) + fp = join(fp_base, + '%d_%s.txt' % (study_id, strftime("%Y%m%d-%H%M%S"))) + st.to_file(fp) + st.add_filepath(fp) -for study_id in conn_handler.execute_fetchall( - "SELECT study_id FROM qiita.study"): - study_id = study_id[0] - if SampleTemplate.exists(study_id): - st = SampleTemplate(study_id) - fp = join(fp_base, '%d_%s.txt' % (study_id, strftime("%Y%m%d-%H%M%S"))) - st.to_file(fp) - st.add_filepath(fp) + TRN.add("SELECT prep_template_id FROM qiita.prep_template") + for prep_template_id in TRN.execute_fetchflatten(): + pt = PrepTemplate(prep_template_id) + study_id = pt.study_id -for prep_template_id in conn_handler.execute_fetchall( - "SELECT prep_template_id FROM qiita.prep_template"): - prep_template_id = prep_template_id[0] - pt = PrepTemplate(prep_template_id) - study_id = pt.study_id - - fp = join(fp_base, '%d_prep_%d_%s.txt' % (pt.study_id, prep_template_id, - strftime("%Y%m%d-%H%M%S"))) - pt.to_file(fp) - pt.add_filepath(fp) + fp = join(fp_base, + '%d_prep_%d_%s.txt' % (pt.study_id, prep_template_id, + strftime("%Y%m%d-%H%M%S"))) + pt.to_file(fp) + pt.add_filepath(fp) diff --git a/qiita_db/support_files/patches/python_patches/7.py b/qiita_db/support_files/patches/python_patches/7.py index cdaf9a4c1..7b82afed9 100644 --- a/qiita_db/support_files/patches/python_patches/7.py +++ b/qiita_db/support_files/patches/python_patches/7.py @@ -3,18 +3,16 @@ # prep templates from qiita_db.util import get_mountpoint -from qiita_db.sql_connection import SQLConnectionHandler +from qiita_db.sql_connection import TRN from qiita_db.metadata_template import PrepTemplate -conn_handler = SQLConnectionHandler() +with TRN: + _id, fp_base = get_mountpoint('templates')[0] -_id, fp_base = get_mountpoint('templates')[0] + TRN.add("SELECT prep_template_id FROM qiita.prep_template") + for prep_template_id in TRN.execute_fetchflatten(): + pt = PrepTemplate(prep_template_id) + study_id = pt.study_id -for prep_template_id in conn_handler.execute_fetchall( - "SELECT prep_template_id FROM qiita.prep_template"): - prep_template_id = prep_template_id[0] - pt = PrepTemplate(prep_template_id) - study_id = pt.study_id - - for _, fpt in pt.get_filepaths(): - pt.create_qiime_mapping_file(fpt) + for _, fpt in pt.get_filepaths(): + pt.create_qiime_mapping_file(fpt) diff --git a/qiita_db/test/test_analysis.py b/qiita_db/test/test_analysis.py index 953ecef04..575174506 100644 --- a/qiita_db/test/test_analysis.py +++ b/qiita_db/test/test_analysis.py @@ -61,11 +61,11 @@ def test_lock_check(self): "A New Analysis") new.status = status with self.assertRaises(QiitaDBStatusError): - new._lock_check(self.conn_handler) + new._lock_check() def test_lock_check_ok(self): self.analysis.status = "in_construction" - self.analysis._lock_check(self.conn_handler) + self.analysis._lock_check() def test_status_setter_checks(self): self.analysis.status = "public" diff --git a/qiita_db/test/test_commands.py b/qiita_db/test/test_commands.py index c6ef0c9d5..4d161ff07 100644 --- a/qiita_db/test/test_commands.py +++ b/qiita_db/test/test_commands.py @@ -736,11 +736,13 @@ def test_update_preprocessed_data_from_cmd_ppd(self): PY_PATCH = """ from qiita_db.study import Study +from qiita_db.sql_connection import TRN study = Study(1) -conn = SQLConnectionHandler() -conn.executemany( - "INSERT INTO qiita.patchtest10 (testing) VALUES (%s)", - [[study.id], [study.id*100]]) + +with TRN: + sql = "INSERT INTO qiita.patchtest10 (testing) VALUES (%s)" + TRN.add(sql, [[study.id], [study.id*100]], many=True) + TRN.execute() """ PARAMETERS = """max_bad_run_length\t3 diff --git a/qiita_db/test/test_sql_connection.py b/qiita_db/test/test_sql_connection.py index eedb0f08c..47e049b56 100644 --- a/qiita_db/test/test_sql_connection.py +++ b/qiita_db/test/test_sql_connection.py @@ -1,12 +1,16 @@ from unittest import TestCase, main +from os import remove, close +from os.path import exists +from tempfile import mkstemp from psycopg2._psycopg import connection from psycopg2.extras import DictCursor from psycopg2 import connect from psycopg2.extensions import (ISOLATION_LEVEL_AUTOCOMMIT, - ISOLATION_LEVEL_READ_COMMITTED) + ISOLATION_LEVEL_READ_COMMITTED, + TRANSACTION_STATUS_IDLE) -from qiita_db.sql_connection import SQLConnectionHandler +from qiita_db.sql_connection import SQLConnectionHandler, Transaction, TRN from qiita_core.util import qiita_test_checker from qiita_core.qiita_settings import qiita_config @@ -18,7 +22,7 @@ @qiita_test_checker() -class TestConnHandler(TestCase): +class TestBase(TestCase): def setUp(self): # Add the test table to the database, so we can use it in the tests with connect(user=qiita_config.user, password=qiita_config.password, @@ -26,6 +30,12 @@ def setUp(self): database=qiita_config.database) as con: with con.cursor() as cur: cur.execute(DB_TEST_TABLE) + self._files_to_remove = [] + + def tearDown(self): + for fp in self._files_to_remove: + if exists(fp): + remove(fp) def _populate_test_table(self): """Aux function that populates the test table""" @@ -53,6 +63,8 @@ def _assert_sql_equal(self, exp): self.assertEqual(obs, exp) + +class TestConnHandler(TestBase): def test_init(self): obs = SQLConnectionHandler() self.assertEqual(obs.admin, 'no_admin') @@ -193,173 +205,552 @@ def test_execute_fetchall_with_sql_args(self): obs = self.conn_handler.execute_fetchall(sql, (True,)) self.assertEqual(obs, [['test1', True, 1], ['test2', True, 2]]) - def test_check_queue_exists(self): - self.assertFalse(self.conn_handler._check_queue_exists('foo')) - self.conn_handler.create_queue('foo') - self.assertTrue(self.conn_handler._check_queue_exists('foo')) - - def test_create_queue(self): - self.assertEqual(self.conn_handler.queues, {}) - self.conn_handler.create_queue("toy_queue") - self.assertEqual(self.conn_handler.queues, {'toy_queue': []}) - - def test_create_queue_error(self): - self.conn_handler.create_queue("test_queue") - with self.assertRaises(KeyError): - self.conn_handler.create_queue("test_queue") - - def test_list_queues(self): - self.assertEqual(self.conn_handler.list_queues(), []) - self.conn_handler.create_queue("test_queue") - self.assertEqual(self.conn_handler.list_queues(), ["test_queue"]) - - def test_add_to_queue(self): - self.conn_handler.create_queue("test_queue") - - sql1 = "INSERT INTO qiita.test_table (bool_column) VALUES (%s)" - sql_args1 = (True,) - self.conn_handler.add_to_queue("test_queue", sql1, sql_args1) - self.assertEqual(self.conn_handler.queues, - {"test_queue": [(sql1, sql_args1)]}) - sql2 = "INSERT INTO qiita.test_table (int_column) VALUES (1)" - self.conn_handler.add_to_queue("test_queue", sql2) - self.assertEqual(self.conn_handler.queues, - {"test_queue": [(sql1, sql_args1), (sql2, None)]}) +class TestTransaction(TestBase): + def test_init(self): + obs = Transaction() + self.assertEqual(obs._queries, []) + self.assertEqual(obs._results, []) + self.assertEqual(obs.index, 0) + self.assertEqual(obs._connection, None) + self.assertEqual(obs._contexts_entered, 0) + with obs: + pass + self.assertTrue(isinstance(obs._connection, connection)) - def test_add_to_queue_many(self): - self.conn_handler.create_queue("test_queue") + def test_replace_placeholders(self): + with TRN: + TRN._results = [ + [["res1", 1]], [["res2a", 2], ["res2b", 3]], None, None, + [["res5", 5]]] + sql = "SELECT 42" + obs_sql, obs_args = TRN._replace_placeholders( + sql, ["{0:0:0}"]) + self.assertEqual(obs_sql, sql) + self.assertEqual(obs_args, ["res1"]) + + obs_sql, obs_args = TRN._replace_placeholders( + sql, ["{1:0:0}"]) + self.assertEqual(obs_sql, sql) + self.assertEqual(obs_args, ["res2a"]) + + obs_sql, obs_args = TRN._replace_placeholders( + sql, ["{1:1:1}"]) + self.assertEqual(obs_sql, sql) + self.assertEqual(obs_args, [3]) + + obs_sql, obs_args = TRN._replace_placeholders( + sql, ["{4:0:0}"]) + self.assertEqual(obs_sql, sql) + self.assertEqual(obs_args, ["res5"]) + + obs_sql, obs_args = TRN._replace_placeholders( + sql, ["foo", "{0:0:1}", "bar", "{1:0:1}"]) + self.assertEqual(obs_sql, sql) + self.assertEqual(obs_args, ["foo", 1, "bar", 2]) + + def test_replace_placeholders_index_error(self): + with TRN: + TRN._results = [ + [["res1", 1]], [["res2a", 2], ["res2b", 2]]] + + error_regex = ('The placeholder {0:0:3} does not match to any ' + 'previous result') + with self.assertRaisesRegexp(ValueError, error_regex): + TRN._replace_placeholders("SELECT 42", ["{0:0:3}"]) + + error_regex = ('The placeholder {0:2:0} does not match to any ' + 'previous result') + with self.assertRaisesRegexp(ValueError, error_regex): + TRN._replace_placeholders("SELECT 42", ["{0:2:0}"]) + + error_regex = ('The placeholder {2:0:0} does not match to any ' + 'previous result') + with self.assertRaisesRegexp(ValueError, error_regex): + TRN._replace_placeholders("SELECT 42", ["{2:0:0}"]) + + def test_replace_placeholders_type_error(self): + with TRN: + TRN._results = [None] + + error_regex = ("The placeholder {0:0:0} is referring to a SQL " + "query that does not retrieve data") + with self.assertRaisesRegexp(ValueError, error_regex): + TRN._replace_placeholders("SELECT 42", ["{0:0:0}"]) + + def test_add(self): + with TRN: + self.assertEqual(TRN._queries, []) + + sql1 = "INSERT INTO qiita.test_table (bool_column) VALUES (%s)" + args1 = [True] + TRN.add(sql1, args1) + sql2 = "INSERT INTO qiita.test_table (int_column) VALUES (1)" + TRN.add(sql2) + + exp = [(sql1, args1), (sql2, [])] + self.assertEqual(TRN._queries, exp) + + # Remove queries so __exit__ doesn't try to execute it + TRN._queries = [] + + def test_add_many(self): + with TRN: + self.assertEqual(TRN._queries, []) + + sql = "INSERT INTO qiita.test_table (int_column) VALUES (%s)" + args = [[1], [2], [3]] + TRN.add(sql, args, many=True) + + exp = [(sql, [1]), (sql, [2]), (sql, [3])] + self.assertEqual(TRN._queries, exp) + + def test_add_error(self): + with TRN: + + with self.assertRaises(TypeError): + TRN.add("SELECT 42", (1,)) + + with self.assertRaises(TypeError): + TRN.add("SELECT 42", {'foo': 'bar'}) + + with self.assertRaises(TypeError): + TRN.add("SELECT 42", [(1,), (1,)], many=True) + + def test_execute(self): + with TRN: + sql = """INSERT INTO qiita.test_table (str_column, int_column) + VALUES (%s, %s)""" + TRN.add(sql, ["test_insert", 2]) + sql = """UPDATE qiita.test_table + SET int_column = %s, bool_column = %s + WHERE str_column = %s""" + TRN.add(sql, [20, False, "test_insert"]) + obs = TRN.execute() + self.assertEqual(obs, [None, None]) + self._assert_sql_equal([]) - sql = "INSERT INTO qiita.test_table (int_column) VALUES (%s)" - sql_args = [(1,), (2,), (3,)] - self.conn_handler.add_to_queue("test_queue", sql, sql_args, many=True) - self.assertEqual(self.conn_handler.queues, - {"test_queue": [(sql, (1,)), (sql, (2,)), - (sql, (3,))]}) - - def test_add_to_queue_error(self): - with self.assertRaises(KeyError): - self.conn_handler.add_to_queue("foo", "SELECT 42") - - def test_execute_queue(self): - self.conn_handler.create_queue("test_queue") - sql = """INSERT INTO qiita.test_table (str_column, int_column) - VALUES (%s, %s)""" - self.conn_handler.add_to_queue("test_queue", sql, ['test_insert', '2']) - sql = """UPDATE qiita.test_table - SET int_column = 20, bool_column = FALSE - WHERE str_column = %s""" - self.conn_handler.add_to_queue("test_queue", sql, ['test_insert']) - obs = self.conn_handler.execute_queue("test_queue") - self.assertEqual(obs, []) self._assert_sql_equal([("test_insert", False, 20)]) - def test_execute_queue_many(self): - sql = """INSERT INTO qiita.test_table (str_column, int_column) - VALUES (%s, %s)""" - sql_args = [('insert1', 1), ('insert2', 2), ('insert3', 3)] - - self.conn_handler.create_queue("test_queue") - self.conn_handler.add_to_queue("test_queue", sql, sql_args, many=True) - sql = """UPDATE qiita.test_table - SET int_column = 20, bool_column = FALSE - WHERE str_column = %s""" - self.conn_handler.add_to_queue("test_queue", sql, ['insert2']) - obs = self.conn_handler.execute_queue('test_queue') - self.assertEqual(obs, []) - - self._assert_sql_equal([('insert1', True, 1), ('insert3', True, 3), + def test_execute_many(self): + with TRN: + sql = """INSERT INTO qiita.test_table (str_column, int_column) + VALUES (%s, %s)""" + args = [['insert1', 1], ['insert2', 2], ['insert3', 3]] + TRN.add(sql, args, many=True) + sql = """UPDATE qiita.test_table + SET int_column = %s, bool_column = %s + WHERE str_column = %s""" + TRN.add(sql, [20, False, 'insert2']) + obs = TRN.execute() + self.assertEqual(obs, [None, None, None, None]) + + self._assert_sql_equal([]) + + self._assert_sql_equal([('insert1', True, 1), + ('insert3', True, 3), ('insert2', False, 20)]) - def test_execute_queue_last_return(self): - self.conn_handler.create_queue("test_queue") - sql = """INSERT INTO qiita.test_table (str_column, int_column) - VALUES (%s, %s)""" - self.conn_handler.add_to_queue("test_queue", sql, ['test_insert', '2']) - sql = """UPDATE qiita.test_table SET bool_column = FALSE - WHERE str_column = %s RETURNING int_column""" - self.conn_handler.add_to_queue("test_queue", sql, ['test_insert']) - obs = self.conn_handler.execute_queue("test_queue") - self.assertEqual(obs, [2]) - - def test_execute_queue_placeholders(self): - self.conn_handler.create_queue("test_queue") - sql = """INSERT INTO qiita.test_table (int_column) VALUES (%s) - RETURNING str_column""" - self.conn_handler.add_to_queue("test_queue", sql, (2,)) - sql = """UPDATE qiita.test_table SET bool_column = FALSE - WHERE str_column = %s""" - self.conn_handler.add_to_queue("test_queue", sql, ('{0}',)) - obs = self.conn_handler.execute_queue("test_queue") - self.assertEqual(obs, []) - self._assert_sql_equal([('foo', False, 2)]) - - def test_execute_queue_placeholders_regex(self): - self.conn_handler.create_queue("test_queue") - sql = """INSERT INTO qiita.test_table (int_column) - VALUES (%s) RETURNING str_column""" - self.conn_handler.add_to_queue("test_queue", sql, (1,)) - sql = """UPDATE qiita.test_table SET str_column = %s - WHERE str_column = %s""" - self.conn_handler.add_to_queue("test_queue", sql, ("", "{0}")) - obs = self.conn_handler.execute_queue("test_queue") - self.assertEqual(obs, []) - self._assert_sql_equal([('', True, 1)]) - - def test_execute_queue_fail(self): - self.conn_handler.create_queue("test_queue") - sql = """INSERT INTO qiita.test_table (int_column) VALUES (%s)""" - self.conn_handler.add_to_queue("test_queue", sql, (2,)) - sql = """UPDATE qiita.test_table SET bool_column = False - WHERE str_column = %s""" - self.conn_handler.add_to_queue("test_queue", sql, ('{0}',)) + def test_execute_return(self): + with TRN: + sql = """INSERT INTO qiita.test_table (str_column, int_column) + VALUES (%s, %s) RETURNING str_column, int_column""" + TRN.add(sql, ['test_insert', 2]) + sql = """UPDATE qiita.test_table SET bool_column = %s + WHERE str_column = %s RETURNING int_column""" + TRN.add(sql, [False, 'test_insert']) + obs = TRN.execute() + self.assertEqual(obs, [[['test_insert', 2]], [[2]]]) + + def test_execute_return_many(self): + with TRN: + sql = """INSERT INTO qiita.test_table (str_column, int_column) + VALUES (%s, %s) RETURNING str_column, int_column""" + args = [['insert1', 1], ['insert2', 2], ['insert3', 3]] + TRN.add(sql, args, many=True) + sql = """UPDATE qiita.test_table SET bool_column = %s + WHERE str_column = %s""" + TRN.add(sql, [False, 'insert2']) + sql = "SELECT * FROM qiita.test_table" + TRN.add(sql) + obs = TRN.execute() + exp = [[['insert1', 1]], # First query of the many query + [['insert2', 2]], # Second query of the many query + [['insert3', 3]], # Third query of the many query + None, # Update query + [['insert1', True, 1], # First result select + ['insert3', True, 3], # Second result select + ['insert2', False, 2]]] # Third result select + self.assertEqual(obs, exp) + + def test_execute_placeholders(self): + with TRN: + sql = """INSERT INTO qiita.test_table (int_column) VALUES (%s) + RETURNING str_column""" + TRN.add(sql, [2]) + sql = """UPDATE qiita.test_table SET str_column = %s + WHERE str_column = %s""" + TRN.add(sql, ["", "{0:0:0}"]) + obs = TRN.execute() + self.assertEqual(obs, [[['foo']], None]) + self._assert_sql_equal([]) + + self._assert_sql_equal([('', True, 2)]) + + def test_execute_error_bad_placeholder(self): + with TRN: + sql = "INSERT INTO qiita.test_table (int_column) VALUES (%s)" + TRN.add(sql, [2]) + sql = """UPDATE qiita.test_table SET bool_column = %s + WHERE str_column = %s""" + TRN.add(sql, [False, "{0:0:0}"]) + + with self.assertRaises(ValueError): + TRN.execute() + + # make sure rollback correctly + self._assert_sql_equal([]) + + def test_execute_error_no_result_placeholder(self): + with TRN: + sql = "INSERT INTO qiita.test_table (int_column) VALUES (%s)" + TRN.add(sql, [[1], [2], [3]], many=True) + sql = """SELECT str_column FROM qiita.test_table + WHERE int_column = %s""" + TRN.add(sql, [4]) + sql = """UPDATE qiita.test_table SET bool_column = %s + WHERE str_column = %s""" + TRN.add(sql, [False, "{3:0:0}"]) + + with self.assertRaises(ValueError): + TRN.execute() + + # make sure rollback correctly + self._assert_sql_equal([]) + + def test_execute_huge_transaction(self): + with TRN: + # Add a lot of inserts to the transaction + sql = "INSERT INTO qiita.test_table (int_column) VALUES (%s)" + for i in range(1000): + TRN.add(sql, [i]) + # Add some updates to the transaction + sql = """UPDATE qiita.test_table SET bool_column = %s + WHERE int_column = %s""" + for i in range(500): + TRN.add(sql, [False, i]) + # Make the transaction fail with the last insert + sql = """INSERT INTO qiita.table_to_make (the_trans_to_fail) + VALUES (1)""" + TRN.add(sql) + + with self.assertRaises(ValueError): + TRN.execute() + + # make sure rollback correctly + self._assert_sql_equal([]) + + def test_execute_commit_false(self): + with TRN: + sql = """INSERT INTO qiita.test_table (str_column, int_column) + VALUES (%s, %s) RETURNING str_column, int_column""" + args = [['insert1', 1], ['insert2', 2], ['insert3', 3]] + TRN.add(sql, args, many=True) + + obs = TRN.execute() + exp = [[['insert1', 1]], [['insert2', 2]], [['insert3', 3]]] + self.assertEqual(obs, exp) + + self._assert_sql_equal([]) + + TRN.commit() + + self._assert_sql_equal([('insert1', True, 1), ('insert2', True, 2), + ('insert3', True, 3)]) + + def test_execute_commit_false_rollback(self): + with TRN: + sql = """INSERT INTO qiita.test_table (str_column, int_column) + VALUES (%s, %s) RETURNING str_column, int_column""" + args = [['insert1', 1], ['insert2', 2], ['insert3', 3]] + TRN.add(sql, args, many=True) + + obs = TRN.execute() + exp = [[['insert1', 1]], [['insert2', 2]], [['insert3', 3]]] + self.assertEqual(obs, exp) + + self._assert_sql_equal([]) + + TRN.rollback() + + self._assert_sql_equal([]) + + def test_execute_commit_false_wipe_queries(self): + with TRN: + sql = """INSERT INTO qiita.test_table (str_column, int_column) + VALUES (%s, %s) RETURNING str_column, int_column""" + args = [['insert1', 1], ['insert2', 2], ['insert3', 3]] + TRN.add(sql, args, many=True) + + obs = TRN.execute() + exp = [[['insert1', 1]], [['insert2', 2]], [['insert3', 3]]] + self.assertEqual(obs, exp) + + self._assert_sql_equal([]) + + sql = """UPDATE qiita.test_table SET bool_column = %s + WHERE str_column = %s""" + args = [False, 'insert2'] + TRN.add(sql, args) + self.assertEqual(TRN._queries, [(sql, args)]) + + TRN.execute() + self._assert_sql_equal([]) - with self.assertRaises(ValueError): - self.conn_handler.execute_queue("test_queue") - - # make sure rollback correctly - self._assert_sql_equal([]) - - def test_execute_queue_error(self): - self.conn_handler.create_queue("test_queue") - sql = """INSERT INTO qiita.test_table (str_column, int_column) - VALUES (%s, %s)""" - self.conn_handler.add_to_queue("test_queue", sql, ['test_insert', '2']) - sql = """UPDATE qiita.test_table - SET int_column = 20, bool_column = FALSE - WHERE str_column = %s""" - self.conn_handler.add_to_queue("test_queue", sql, ['test_insert']) - with self.assertRaises(KeyError): - self.conn_handler.execute_queue("oops!") - - def test_huge_queue(self): - self.conn_handler.create_queue("test_queue") - # Add a lof of inserts to the queue - sql = "INSERT INTO qiita.test_table (int_column) VALUES (%s)" - for x in range(1000): - self.conn_handler.add_to_queue("test_queue", sql, (x,)) - - # Make the queue fail with the last insert - sql = "INSERT INTO qiita.table_to_make (the_queue_to_fail) VALUES (1)" - self.conn_handler.add_to_queue("test_queue", sql) - - with self.assertRaises(ValueError): - self.conn_handler.execute_queue("test_queue") - - # make sure rollback correctly + self._assert_sql_equal([('insert1', True, 1), ('insert3', True, 3), + ('insert2', False, 2)]) + + def test_execute_fetchlast(self): + with TRN: + sql = """INSERT INTO qiita.test_table (str_column, int_column) + VALUES (%s, %s) RETURNING str_column, int_column""" + args = [['insert1', 1], ['insert2', 2], ['insert3', 3]] + TRN.add(sql, args, many=True) + + sql = """SELECT EXISTS( + SELECT * FROM qiita.test_table WHERE int_column=%s)""" + TRN.add(sql, [2]) + self.assertTrue(TRN.execute_fetchlast()) + + def test_execute_fetchindex(self): + with TRN: + sql = """INSERT INTO qiita.test_table (str_column, int_column) + VALUES (%s, %s) RETURNING str_column, int_column""" + args = [['insert1', 1], ['insert2', 2], ['insert3', 3]] + TRN.add(sql, args, many=True) + self.assertEqual(TRN.execute_fetchindex(), [['insert3', 3]]) + + sql = """INSERT INTO qiita.test_table (str_column, int_column) + VALUES (%s, %s) RETURNING str_column, int_column""" + args = [['insert4', 4], ['insert5', 5], ['insert6', 6]] + TRN.add(sql, args, many=True) + self.assertEqual(TRN.execute_fetchindex(3), [['insert4', 4]]) + + def test_execute_fetchflatten(self): + with TRN: + sql = """INSERT INTO qiita.test_table (str_column, int_column) + VALUES (%s, %s)""" + args = [['insert1', 1], ['insert2', 2], ['insert3', 3]] + TRN.add(sql, args, many=True) + + sql = "SELECT str_column, int_column FROM qiita.test_table" + TRN.add(sql) + + sql = "SELECT int_column FROM qiita.test_table" + TRN.add(sql) + obs = TRN.execute_fetchflatten() + self.assertEqual(obs, [1, 2, 3]) + + sql = "SELECT 42" + TRN.add(sql) + obs = TRN.execute_fetchflatten(idx=3) + self.assertEqual(obs, ['insert1', 1, 'insert2', 2, 'insert3', 3]) + + def test_context_manager_rollback(self): + try: + with TRN: + sql = """INSERT INTO qiita.test_table (str_column, int_column) + VALUES (%s, %s) RETURNING str_column, int_column""" + args = [['insert1', 1], ['insert2', 2], ['insert3', 3]] + TRN.add(sql, args, many=True) + + TRN.execute() + raise ValueError("Force exiting the context manager") + except ValueError: + pass self._assert_sql_equal([]) + self.assertEqual( + TRN._connection.get_transaction_status(), + TRANSACTION_STATUS_IDLE) + + def test_context_manager_execute(self): + with TRN: + sql = """INSERT INTO qiita.test_table (str_column, int_column) + VALUES (%s, %s) RETURNING str_column, int_column""" + args = [['insert1', 1], ['insert2', 2], ['insert3', 3]] + TRN.add(sql, args, many=True) + self._assert_sql_equal([]) + + self._assert_sql_equal([('insert1', True, 1), ('insert2', True, 2), + ('insert3', True, 3)]) + self.assertEqual( + TRN._connection.get_transaction_status(), + TRANSACTION_STATUS_IDLE) + + def test_context_manager_no_commit(self): + with TRN: + sql = """INSERT INTO qiita.test_table (str_column, int_column) + VALUES (%s, %s) RETURNING str_column, int_column""" + args = [['insert1', 1], ['insert2', 2], ['insert3', 3]] + TRN.add(sql, args, many=True) + + TRN.execute() + self._assert_sql_equal([]) + + self._assert_sql_equal([('insert1', True, 1), ('insert2', True, 2), + ('insert3', True, 3)]) + self.assertEqual( + TRN._connection.get_transaction_status(), + TRANSACTION_STATUS_IDLE) + + def test_context_manager_multiple(self): + self.assertEqual(TRN._contexts_entered, 0) + + with TRN: + self.assertEqual(TRN._contexts_entered, 1) + + TRN.add("SELECT 42") + with TRN: + self.assertEqual(TRN._contexts_entered, 2) + sql = """INSERT INTO qiita.test_table (str_column, int_column) + VALUES (%s, %s) RETURNING str_column, int_column""" + args = [['insert1', 1], ['insert2', 2], ['insert3', 3]] + TRN.add(sql, args, many=True) + + # We exited the second context, nothing should have been executed + self.assertEqual(TRN._contexts_entered, 1) + self.assertEqual( + TRN._connection.get_transaction_status(), + TRANSACTION_STATUS_IDLE) + self._assert_sql_equal([]) + + # We have exited the first context, everything should have been + # executed and committed + self.assertEqual(TRN._contexts_entered, 0) + self._assert_sql_equal([('insert1', True, 1), ('insert2', True, 2), + ('insert3', True, 3)]) + self.assertEqual( + TRN._connection.get_transaction_status(), + TRANSACTION_STATUS_IDLE) + + def test_context_manager_multiple_2(self): + self.assertEqual(TRN._contexts_entered, 0) + + def tester(): + self.assertEqual(TRN._contexts_entered, 1) + with TRN: + self.assertEqual(TRN._contexts_entered, 2) + sql = """SELECT EXISTS( + SELECT * FROM qiita.test_table WHERE int_column=%s)""" + TRN.add(sql, [2]) + self.assertTrue(TRN.execute_fetchlast()) + self.assertEqual(TRN._contexts_entered, 1) + + with TRN: + self.assertEqual(TRN._contexts_entered, 1) + sql = """INSERT INTO qiita.test_table (str_column, int_column) + VALUES (%s, %s) RETURNING str_column, int_column""" + args = [['insert1', 1], ['insert2', 2], ['insert3', 3]] + TRN.add(sql, args, many=True) + tester() + self.assertEqual(TRN._contexts_entered, 1) + self._assert_sql_equal([]) + + self.assertEqual(TRN._contexts_entered, 0) + self._assert_sql_equal([('insert1', True, 1), ('insert2', True, 2), + ('insert3', True, 3)]) + self.assertEqual( + TRN._connection.get_transaction_status(), + TRANSACTION_STATUS_IDLE) + + def test_post_commit_funcs(self): + fd, fp = mkstemp() + close(fd) + self._files_to_remove.append(fp) + + def func(fp): + with open(fp, 'w') as f: + f.write('\n') + + with TRN: + TRN.add("SELECT 42") + TRN.add_post_commit_func(func, fp) + + self.assertTrue(exists(fp)) + + def test_post_commit_funcs_error(self): + def func(): + raise ValueError() + + with self.assertRaises(RuntimeError): + with TRN: + TRN.add("SELECT 42") + TRN.add_post_commit_func(func) + + def test_post_rollback_funcs(self): + fd, fp = mkstemp() + close(fd) + self._files_to_remove.append(fp) + + def func(fp): + with open(fp, 'w') as f: + f.write('\n') + + with TRN: + TRN.add("SELECT 42") + TRN.add_post_rollback_func(func, fp) + TRN.rollback() + + self.assertTrue(exists(fp)) + + def test_post_rollback_funcs_error(self): + def func(): + raise ValueError() + + with self.assertRaises(RuntimeError): + with TRN: + TRN.add("SELECT 42") + TRN.add_post_rollback_func(func) + TRN.rollback() + + def test_context_manager_checker(self): + with self.assertRaises(RuntimeError): + TRN.add("SELECT 42") + + with self.assertRaises(RuntimeError): + TRN.execute() + + with self.assertRaises(RuntimeError): + TRN.commit() + + with self.assertRaises(RuntimeError): + TRN.rollback() + + with TRN: + TRN.add("SELECT 42") + + with self.assertRaises(RuntimeError): + TRN.execute() + + def test_index(self): + with TRN: + self.assertEqual(TRN.index, 0) + + TRN.add("SELECT 42") + self.assertEqual(TRN.index, 1) + + sql = "INSERT INTO qiita.test_table (int_column) VALUES (%s)" + args = [[1], [2], [3]] + TRN.add(sql, args, many=True) + self.assertEqual(TRN.index, 4) + + TRN.execute() + self.assertEqual(TRN.index, 4) + + TRN.add(sql, args, many=True) + self.assertEqual(TRN.index, 7) - def test_get_temp_queue(self): - my_queue = self.conn_handler.get_temp_queue() - self.assertTrue(my_queue in self.conn_handler.list_queues()) - - self.conn_handler.add_to_queue(my_queue, - "SELECT * from qiita.qiita_user") - self.conn_handler.add_to_queue(my_queue, - "SELECT * from qiita.user_level") - self.conn_handler.execute_queue(my_queue) - - self.assertTrue(my_queue not in self.conn_handler.list_queues()) + self.assertEqual(TRN.index, 0) if __name__ == "__main__": main() diff --git a/qiita_db/test/test_study.py b/qiita_db/test/test_study.py index 4f9f5a9ab..6cf9c448a 100644 --- a/qiita_db/test/test_study.py +++ b/qiita_db/test/test_study.py @@ -13,7 +13,7 @@ from qiita_db.util import convert_to_id from qiita_db.exceptions import ( QiitaDBColumnError, QiitaDBStatusError, QiitaDBError, - QiitaDBUnknownIDError) + QiitaDBUnknownIDError, QiitaDBDuplicateError) # ----------------------------------------------------------------------------- # Copyright (c) 2014--, The Qiita Development Team. @@ -46,12 +46,17 @@ def test_iter(self): ('empDude', 'emp_dude@foo.bar', 'broad', None, '444-222-3333'), ('PIDude', 'PI_dude@foo.bar', 'Wash U', '123 PI street', None)] for i, person in enumerate(StudyPerson.iter()): - self.assertTrue(person.id == i+1) - self.assertTrue(person.name == expected[i][0]) - self.assertTrue(person.email == expected[i][1]) - self.assertTrue(person.affiliation == expected[i][2]) - self.assertTrue(person.address == expected[i][3]) - self.assertTrue(person.phone == expected[i][4]) + self.assertEqual(person.id, i+1) + self.assertEqual(person.name, expected[i][0]) + self.assertEqual(person.email, expected[i][1]) + self.assertEqual(person.affiliation, expected[i][2]) + self.assertEqual(person.address, expected[i][3]) + self.assertEqual(person.phone, expected[i][4]) + + def test_exists(self): + self.assertTrue(StudyPerson.exists('LabDude', 'knight lab')) + self.assertFalse(StudyPerson.exists('AnotherDude', 'knight lab')) + self.assertFalse(StudyPerson.exists('LabDude', 'Another lab')) def test_create_studyperson_already_exists(self): obs = StudyPerson.create('LabDude', 'lab_dude@foo.bar', 'knight lab') @@ -327,6 +332,13 @@ def test_exists(self): 'Cannabis Soils')) self.assertFalse(Study.exists('Not Cannabis Soils')) + def test_create_duplicate(self): + with self.assertRaises(QiitaDBDuplicateError): + Study.create( + User('test@foo.bar'), + 'Identification of the Microbiomes for Cannabis Soils', + [1], self.info) + def test_create_study_min_data(self): """Insert a study into the database""" before = datetime.now() diff --git a/qiita_db/test/test_util.py b/qiita_db/test/test_util.py index 06e3b1054..2e4a5a49e 100644 --- a/qiita_db/test/test_util.py +++ b/qiita_db/test/test_util.py @@ -37,7 +37,8 @@ filepath_id_to_rel_path, filepath_ids_to_rel_paths, move_filepaths_to_upload_folder, move_upload_files_to_trash, - check_access_to_analysis_result, infer_status) + check_access_to_analysis_result, infer_status, + get_preprocessed_params_tables) @qiita_test_checker() @@ -64,13 +65,12 @@ def test_params_dict_to_json(self): def test_check_required_columns(self): # Doesn't do anything if correct info passed, only errors if wrong info - check_required_columns(self.conn_handler, self.required, self.table) + check_required_columns(self.required, self.table) def test_check_required_columns_fail(self): self.required.remove('study_title') with self.assertRaises(QiitaDBColumnError): - check_required_columns(self.conn_handler, self.required, - self.table) + check_required_columns(self.required, self.table) def test_get_lat_longs(self): exp = [ @@ -107,13 +107,12 @@ def test_get_lat_longs(self): def test_check_table_cols(self): # Doesn't do anything if correct info passed, only errors if wrong info - check_table_cols(self.conn_handler, self.required, self.table) + check_table_cols(self.required, self.table) def test_check_table_cols_fail(self): self.required.append('BADTHINGNOINHERE') with self.assertRaises(QiitaDBColumnError): - check_table_cols(self.conn_handler, self.required, - self.table) + check_table_cols(self.required, self.table) def test_get_table_cols(self): obs = get_table_cols("qiita_user") @@ -140,43 +139,36 @@ def test_get_table_cols_w_type(self): def test_exists_table(self): """Correctly checks if a table exists""" # True cases - self.assertTrue(exists_table("filepath", self.conn_handler)) - self.assertTrue(exists_table("qiita_user", self.conn_handler)) - self.assertTrue(exists_table("analysis", self.conn_handler)) - self.assertTrue(exists_table("prep_1", self.conn_handler)) - self.assertTrue(exists_table("sample_1", self.conn_handler)) + self.assertTrue(exists_table("filepath")) + self.assertTrue(exists_table("qiita_user")) + self.assertTrue(exists_table("analysis")) + self.assertTrue(exists_table("prep_1")) + self.assertTrue(exists_table("sample_1")) # False cases - self.assertFalse(exists_table("sample_2", self.conn_handler)) - self.assertFalse(exists_table("prep_2", self.conn_handler)) - self.assertFalse(exists_table("foo_table", self.conn_handler)) - self.assertFalse(exists_table("bar_table", self.conn_handler)) + self.assertFalse(exists_table("sample_2")) + self.assertFalse(exists_table("prep_2")) + self.assertFalse(exists_table("foo_table")) + self.assertFalse(exists_table("bar_table")) def test_exists_dynamic_table(self): """Correctly checks if a dynamic table exists""" # True cases self.assertTrue(exists_dynamic_table( "preprocessed_sequence_illumina_params", "preprocessed_", - "_params", self.conn_handler)) - self.assertTrue(exists_dynamic_table("prep_1", "prep_", "", - self.conn_handler)) - self.assertTrue(exists_dynamic_table("filepath", "", "", - self.conn_handler)) + "_params")) + self.assertTrue(exists_dynamic_table("prep_1", "prep_", "")) + self.assertTrue(exists_dynamic_table("filepath", "", "")) # False cases self.assertFalse(exists_dynamic_table( - "preprocessed_foo_params", "preprocessed_", "_params", - self.conn_handler)) + "preprocessed_foo_params", "preprocessed_", "_params")) self.assertFalse(exists_dynamic_table( - "preprocessed__params", "preprocessed_", "_params", - self.conn_handler)) + "preprocessed__params", "preprocessed_", "_params")) self.assertFalse(exists_dynamic_table( - "foo_params", "preprocessed_", "_params", - self.conn_handler)) + "foo_params", "preprocessed_", "_params")) self.assertFalse(exists_dynamic_table( - "preprocessed_foo", "preprocessed_", "_params", - self.conn_handler)) + "preprocessed_foo", "preprocessed_", "_params")) self.assertFalse(exists_dynamic_table( - "foo", "preprocessed_", "_params", - self.conn_handler)) + "foo", "preprocessed_", "_params")) def test_convert_to_id(self): """Tests that ids are returned correctly""" @@ -250,6 +242,13 @@ def test_check_count(self): self.assertTrue(check_count('qiita.study_person', 3)) self.assertFalse(check_count('qiita.study_person', 2)) + def test_get_preprocessed_params_tables(self): + obs = get_preprocessed_params_tables() + exp = ['preprocessed_sequence_454_params', + 'preprocessed_sequence_illumina_params', + 'preprocessed_spectra_params'] + self.assertEqual(obs, exp) + def test_get_processed_params_tables(self): obs = get_processed_params_tables() self.assertEqual(obs, ['processed_params_sortmerna', @@ -264,8 +263,7 @@ def test_insert_filepaths(self): exp_new_id = 1 + self.conn_handler.execute_fetchone( "SELECT count(1) FROM qiita.filepath")[0] - obs = insert_filepaths([(fp, 1)], 1, "raw_data", "filepath", - self.conn_handler) + obs = insert_filepaths([(fp, 1)], 1, "raw_data", "filepath") self.assertEqual(obs, [exp_new_id]) # Check that the files have been copied correctly @@ -291,7 +289,7 @@ def test_insert_filepaths_string(self): exp_new_id = 1 + self.conn_handler.execute_fetchone( "SELECT count(1) FROM qiita.filepath")[0] obs = insert_filepaths([(fp, "raw_forward_seqs")], 1, "raw_data", - "filepath", self.conn_handler) + "filepath") self.assertEqual(obs, [exp_new_id]) # Check that the files have been copied correctly @@ -307,54 +305,6 @@ def test_insert_filepaths_string(self): exp = [[exp_new_id, exp_fp, 1, '852952723', 1, 5]] self.assertEqual(obs, exp) - def test_insert_filepaths_queue(self): - fd, fp = mkstemp() - close(fd) - with open(fp, "w") as f: - f.write("\n") - self.files_to_remove.append(fp) - - # create and populate queue - self.conn_handler.create_queue("toy_queue") - self.conn_handler.add_to_queue( - "toy_queue", "INSERT INTO qiita.qiita_user (email, name, password," - "phone) VALUES (%s, %s, %s, %s)", - ['insert@foo.bar', 'Toy', 'pass', '111-111-1111']) - - exp_new_id = 1 + self.conn_handler.execute_fetchone( - "SELECT count(1) FROM qiita.filepath")[0] - insert_filepaths([(fp, "raw_forward_seqs")], 1, "raw_data", - "filepath", self.conn_handler, queue='toy_queue') - - self.conn_handler.add_to_queue( - "toy_queue", "INSERT INTO qiita.raw_filepath (raw_data_id, " - "filepath_id) VALUES (1, %s)", ['{0}']) - self.conn_handler.execute_queue("toy_queue") - - # check that the user was added to the DB - obs = self.conn_handler.execute_fetchall( - "SELECT * from qiita.qiita_user WHERE email = %s", - ['insert@foo.bar']) - exp = [['insert@foo.bar', 5, 'pass', 'Toy', None, None, '111-111-1111', - None, None, None]] - self.assertEqual(obs, exp) - - # Check that the filepaths have been added to the DB - obs = self.conn_handler.execute_fetchall( - "SELECT * FROM qiita.filepath WHERE filepath_id=%d" % exp_new_id) - exp_fp = "1_%s" % basename(fp) - exp = [[exp_new_id, exp_fp, 1, '852952723', 1, 5]] - self.assertEqual(obs, exp) - - # check that raw_filpath data was added to the DB - obs = self.conn_handler.execute_fetchall( - """SELECT * - FROM qiita.raw_filepath - WHERE filepath_id=%d""" % exp_new_id) - exp_fp = "1_%s" % basename(fp) - exp = [[1, exp_new_id]] - self.assertEqual(obs, exp) - def _common_purge_filpeaths_test(self): # Get all the filepaths so we can test if they've been removed or not sql_fp = "SELECT filepath, data_directory_id FROM qiita.filepath" diff --git a/qiita_db/user.py b/qiita_db/user.py index 6e9b59830..da67d7846 100644 --- a/qiita_db/user.py +++ b/qiita_db/user.py @@ -34,7 +34,7 @@ IncompetentQiitaDeveloperError) from qiita_core.qiita_settings import qiita_config from .base import QiitaObject -from .sql_connection import SQLConnectionHandler +from .sql_connection import TRN from .util import (create_rand_string, check_table_cols, hash_password, convert_to_id) from .exceptions import (QiitaDBColumnError, QiitaDBDuplicateError, @@ -81,13 +81,11 @@ def _check_id(self, id_): This function overwrites the base function, as sql layout doesn't follow the same conventions done in the other classes. """ - self._check_subclass() - - conn_handler = SQLConnectionHandler() - - return conn_handler.execute_fetchone( - "SELECT EXISTS(SELECT * FROM qiita.qiita_user WHERE " - "email = %s)", (id_, ))[0] + with TRN: + sql = """SELECT EXISTS( + SELECT * FROM qiita.qiita_user WHERE email = %s)""" + TRN.add(sql, [id_]) + return TRN.execute_fetchlast() @classmethod def iter(cls): @@ -99,11 +97,12 @@ def iter(cls): Yields a user ID (email) for each user in the database, in order of ascending ID """ - conn_handler = SQLConnectionHandler() - sql = """select email from qiita.{}""".format(cls._table) - - for result in conn_handler.execute_fetchall(sql): - yield result[0] + with TRN: + sql = """select email from qiita.{}""".format(cls._table) + TRN.add(sql) + # Using [-1] to get the results of the last SQL query + for result in TRN.execute_fetchindex(): + yield result[0] @classmethod def login(cls, email, password): @@ -129,31 +128,33 @@ def login(cls, email, password): IncorrectPasswordError Password passed is not correct for user """ - # see if user exists - if not cls.exists(email): - raise IncorrectEmailError("Email not valid: %s" % email) - - if not validate_password(password): - raise IncorrectPasswordError("Password not valid!") - - # pull password out of database - conn_handler = SQLConnectionHandler() - sql = ("SELECT password, user_level_id FROM qiita.{0} WHERE " - "email = %s".format(cls._table)) - info = conn_handler.execute_fetchone(sql, (email, )) - - # verify user email verification - # MAGIC NUMBER 5 = unverified email - if int(info[1]) == 5: - return False - - # verify password - dbpass = info[0] - hashed = hash_password(password, dbpass) - if hashed == dbpass: - return cls(email) - else: - raise IncorrectPasswordError("Password not valid!") + with TRN: + # see if user exists + if not cls.exists(email): + raise IncorrectEmailError("Email not valid: %s" % email) + + if not validate_password(password): + raise IncorrectPasswordError("Password not valid!") + + # pull password out of database + sql = ("SELECT password, user_level_id FROM qiita.{0} WHERE " + "email = %s".format(cls._table)) + TRN.add(sql, [email]) + # Using [0] because there is only one row + info = TRN.execute_fetchindex()[0] + + # verify user email verification + # MAGIC NUMBER 5 = unverified email + if int(info[1]) == 5: + return False + + # verify password + dbpass = info[0] + hashed = hash_password(password, dbpass) + if hashed == dbpass: + return cls(email) + else: + raise IncorrectPasswordError("Password not valid!") @classmethod def exists(cls, email): @@ -164,13 +165,15 @@ def exists(cls, email): email : str the email of the user """ - if not validate_email(email): - raise IncorrectEmailError("Email string not valid: %s" % email) - conn_handler = SQLConnectionHandler() + with TRN: + if not validate_email(email): + raise IncorrectEmailError("Email string not valid: %s" % email) - return conn_handler.execute_fetchone( - "SELECT EXISTS(SELECT * FROM qiita.{0} WHERE " - "email = %s)".format(cls._table), (email, ))[0] + sql = """SELECT EXISTS( + SELECT * FROM qiita.{0} + WHERE email = %s)""".format(cls._table) + TRN.add(sql, [email]) + return TRN.execute_fetchlast() @classmethod def create(cls, email, password, info=None): @@ -194,48 +197,44 @@ def create(cls, email, password, info=None): QiitaDBDuplicateError User already exists """ - # validate email and password for new user - if not validate_email(email): - raise IncorrectEmailError("Bad email given: %s" % email) - if not validate_password(password): - raise IncorrectPasswordError("Bad password given!") - - # make sure user does not already exist - if cls.exists(email): - raise QiitaDBDuplicateError("User", "email: %s" % email) - - # make sure non-info columns aren't passed in info dict - if info: - if cls._non_info.intersection(info): - raise QiitaDBColumnError("non info keys passed: %s" % - cls._non_info.intersection(info)) - else: - info = {} - - # create email verification code and hashed password to insert - # add values to info - info["email"] = email - info["password"] = hash_password(password) - info["user_verify_code"] = create_rand_string(20, punct=False) - - # make sure keys in info correspond to columns in table - conn_handler = SQLConnectionHandler() - check_table_cols(conn_handler, info, cls._table) - - # build info to insert making sure columns and data are in same order - # for sql insertion - columns = info.keys() - values = [info[col] for col in columns] - queue = "add_user_%s" % email - conn_handler.create_queue(queue) - # crete user - sql = "INSERT INTO qiita.{0} ({1}) VALUES ({2})".format( - cls._table, ','.join(columns), ','.join(['%s'] * len(values))) - conn_handler.add_to_queue(queue, sql, values) - - conn_handler.execute_queue(queue) - - return cls(email) + with TRN: + # validate email and password for new user + if not validate_email(email): + raise IncorrectEmailError("Bad email given: %s" % email) + if not validate_password(password): + raise IncorrectPasswordError("Bad password given!") + + # make sure user does not already exist + if cls.exists(email): + raise QiitaDBDuplicateError("User", "email: %s" % email) + + # make sure non-info columns aren't passed in info dict + if info: + if cls._non_info.intersection(info): + raise QiitaDBColumnError("non info keys passed: %s" % + cls._non_info.intersection(info)) + else: + info = {} + + # create email verification code and hashed password to insert + # add values to info + info["email"] = email + info["password"] = hash_password(password) + info["user_verify_code"] = create_rand_string(20, punct=False) + + # make sure keys in info correspond to columns in table + check_table_cols(info, cls._table) + + # build info to insert making sure columns and data are in + # same order for sql insertion + columns = info.keys() + values = [info[col] for col in columns] + # crete user + sql = "INSERT INTO qiita.{0} ({1}) VALUES ({2})".format( + cls._table, ','.join(columns), ','.join(['%s'] * len(values))) + TRN.add(sql, values) + + return cls(email) @classmethod def verify_code(cls, email, code, code_type): @@ -261,55 +260,59 @@ def verify_code(cls, email, code, code_type): QiitaDBError User has no code of the given type """ - if code_type == 'create': - column = 'user_verify_code' - elif code_type == 'reset': - column = 'pass_reset_code' - else: - raise IncompetentQiitaDeveloperError( - "code_type must be 'create' or 'reset' Uknown type %s" % - code_type) - sql = ("SELECT {1} from qiita.{0} where email" - " = %s".format(cls._table, column)) - conn_handler = SQLConnectionHandler() - db_code = conn_handler.execute_fetchone(sql, (email,)) - - # If the query didn't return anything, then there's no way the code - # can match - if db_code is None: - return False - - db_code = db_code[0] - if not db_code: - raise QiitaDBError("No %s code for user %s" % - (column, email)) - - correct_code = db_code == code - - if correct_code and code_type == "create": - # verify the user - level = convert_to_id('user', 'user_level', 'name') - sql = """UPDATE qiita.{} SET user_level_id = %s - WHERE email = %s""".format(cls._table) - conn_handler.execute(sql, (level, email)) - # create user default sample holders once verified - # create one per portal - analysis_sql = """INSERT INTO qiita.analysis - (email, name, description, dflt, analysis_status_id) - VALUES (%s, %s, %s, %s, 1)""" - sql = "SELECT portal_type_id from qiita.portal_type" - analysis_args = [ - (email, '%s-dflt-%d' % (email, portal[0]), 'dflt', True) - for portal in conn_handler.execute_fetchall(sql)] - conn_handler.executemany(analysis_sql, analysis_args) - - if correct_code: - # wipe out code so we know it was used - sql = """UPDATE qiita.{0} SET {1} = '' - WHERE email = %s""".format(cls._table, column) - conn_handler.execute(sql, [email]) - - return correct_code + with TRN: + if code_type == 'create': + column = 'user_verify_code' + elif code_type == 'reset': + column = 'pass_reset_code' + else: + raise IncompetentQiitaDeveloperError( + "code_type must be 'create' or 'reset' Uknown type %s" + % code_type) + sql = "SELECT {0} FROM qiita.{1} WHERE email = %s".format( + column, cls._table) + TRN.add(sql, [email]) + db_code = TRN.execute_fetchindex() + + if not db_code: + return False + + db_code = db_code[0][0] + if db_code is None: + raise QiitaDBError("No %s code for user %s" % + (column, email)) + + correct_code = db_code == code + + if correct_code: + sql = """UPDATE qiita.{0} SET {1} = NULL + WHERE email = %s""".format(cls._table, column) + TRN.add(sql, [email]) + + if code_type == "create": + # verify the user + level = convert_to_id('user', 'user_level', 'name') + sql = """UPDATE qiita.{} SET user_level_id = %s + WHERE email = %s""".format(cls._table) + TRN.add(sql, [level, email]) + + # create user default sample holders once verified + # create one per portal + sql = "SELECT portal_type_id from qiita.portal_type" + TRN.add(sql) + + an_sql = """INSERT INTO qiita.analysis + (email, name, description, dflt, + analysis_status_id) + VALUES (%s, %s, %s, %s, 1)""" + an_args = [ + [email, '%s-dflt-%d' % (email, portal), 'dflt', True] + for portal in TRN.execute_fetchflatten()] + TRN.add(an_sql, an_args, many=True) + + TRN.execute() + + return correct_code # ---properties--- @property @@ -320,23 +323,29 @@ def email(self): @property def level(self): """The level of privileges of the user""" - conn_handler = SQLConnectionHandler() - sql = ("SELECT ul.name from qiita.user_level ul JOIN qiita.{0} u ON " - "ul.user_level_id = u.user_level_id WHERE " - "u.email = %s".format(self._table)) - return conn_handler.execute_fetchone(sql, (self._id, ))[0] + with TRN: + sql = """SELECT ul.name + FROM qiita.user_level ul + JOIN qiita.{0} u + ON ul.user_level_id = u.user_level_id + WHERE u.email = %s""".format(self._table) + TRN.add(sql, [self._id]) + return TRN.execute_fetchlast() @property def info(self): """Dict with any other information attached to the user""" - conn_handler = SQLConnectionHandler() - sql = "SELECT * from qiita.{0} WHERE email = %s".format(self._table) - # Need direct typecast from psycopg2 dict to standard dict - info = dict(conn_handler.execute_fetchone(sql, (self._id, ))) - # Remove non-info columns - for col in self._non_info: - info.pop(col) - return info + with TRN: + sql = "SELECT * from qiita.{0} WHERE email = %s".format( + self._table) + # Need direct typecast from psycopg2 dict to standard dict + TRN.add(sql, [self._id]) + # [0] retrieves the first row (the only one present) + info = dict(TRN.execute_fetchindex()[0]) + # Remove non-info columns + for col in self._non_info: + info.pop(col) + return info @info.setter def info(self, info): @@ -346,84 +355,84 @@ def info(self, info): ---------- info : dict """ - # make sure non-info columns aren't passed in info dict - if self._non_info.intersection(info): - raise QiitaDBColumnError("non info keys passed!") - - # make sure keys in info correspond to columns in table - conn_handler = SQLConnectionHandler() - check_table_cols(conn_handler, info, self._table) - - # build sql command and data to update - sql_insert = [] - data = [] - # items used for py3 compatability - for key, val in info.items(): - sql_insert.append("{0} = %s".format(key)) - data.append(val) - data.append(self._id) - - sql = ("UPDATE qiita.{0} SET {1} WHERE " - "email = %s".format(self._table, ','.join(sql_insert))) - conn_handler.execute(sql, data) + with TRN: + # make sure non-info columns aren't passed in info dict + if self._non_info.intersection(info): + raise QiitaDBColumnError("non info keys passed!") + + # make sure keys in info correspond to columns in table + check_table_cols(info, self._table) + + # build sql command and data to update + sql_insert = [] + data = [] + # items used for py3 compatability + for key, val in info.items(): + sql_insert.append("{0} = %s".format(key)) + data.append(val) + data.append(self._id) + + sql = ("UPDATE qiita.{0} SET {1} WHERE " + "email = %s".format(self._table, ','.join(sql_insert))) + TRN.add(sql, data) + TRN.execute() @property def default_analysis(self): - sql = """SELECT analysis_id FROM qiita.analysis - JOIN qiita.analysis_portal USING (analysis_id) - JOIN qiita.portal_type USING (portal_type_id) - WHERE email = %s AND dflt = true AND portal = %s""" - conn_handler = SQLConnectionHandler() - return conn_handler.execute_fetchone( - sql, [self._id, qiita_config.portal])[0] + with TRN: + sql = """SELECT analysis_id + FROM qiita.analysis + JOIN qiita.analysis_portal USING (analysis_id) + JOIN qiita.portal_type USING (portal_type_id) + WHERE email = %s AND dflt = true AND portal = %s""" + TRN.add(sql, [self._id, qiita_config.portal]) + return TRN.execute_fetchlast() @property def user_studies(self): """Returns a list of study ids owned by the user""" - sql = """SELECT study_id FROM qiita.study - JOIN qiita.study_portal USING (study_id) - JOIN qiita.portal_type USING (portal_type_id) - WHERE email = %s AND portal = %s""" - conn_handler = SQLConnectionHandler() - study_ids = conn_handler.execute_fetchall( - sql, (self._id, qiita_config.portal)) - return {s[0] for s in study_ids} + with TRN: + sql = """SELECT study_id + FROM qiita.study + JOIN qiita.study_portal USING (study_id) + JOIN qiita.portal_type USING (portal_type_id) + WHERE email = %s AND portal = %s""" + TRN.add(sql, [self._id, qiita_config.portal]) + return set(TRN.execute_fetchflatten()) @property def shared_studies(self): """Returns a list of study ids shared with the user""" - sql = """SELECT study_id FROM qiita.study_users - JOIN qiita.study_portal USING (study_id) - JOIN qiita.portal_type USING (portal_type_id) - WHERE email = %s and portal = %s""" - conn_handler = SQLConnectionHandler() - study_ids = conn_handler.execute_fetchall( - sql, (self._id, qiita_config.portal)) - return {s[0] for s in study_ids} + with TRN: + sql = """SELECT study_id + FROM qiita.study_users + JOIN qiita.study_portal USING (study_id) + JOIN qiita.portal_type USING (portal_type_id) + WHERE email = %s and portal = %s""" + TRN.add(sql, [self._id, qiita_config.portal]) + return set(TRN.execute_fetchflatten()) @property def private_analyses(self): """Returns a list of private analysis ids owned by the user""" - sql = """SELECT analysis_id FROM qiita.analysis - JOIN qiita.analysis_portal USING (analysis_id) - JOIN qiita.portal_type USING (portal_type_id) - WHERE email = %s AND dflt = false AND portal = %s""" - conn_handler = SQLConnectionHandler() - analysis_ids = conn_handler.execute_fetchall( - sql, (self._id, qiita_config.portal)) - return {a[0] for a in analysis_ids} + with TRN: + sql = """SELECT analysis_id FROM qiita.analysis + JOIN qiita.analysis_portal USING (analysis_id) + JOIN qiita.portal_type USING (portal_type_id) + WHERE email = %s AND dflt = false AND portal = %s""" + TRN.add(sql, [self._id, qiita_config.portal]) + return set(TRN.execute_fetchflatten()) @property def shared_analyses(self): """Returns a list of analysis ids shared with the user""" - sql = """SELECT analysis_id FROM qiita.analysis_users - JOIN qiita.analysis_portal USING (analysis_id) - JOIN qiita.portal_type USING (portal_type_id) - WHERE email = %s AND portal = %s""" - conn_handler = SQLConnectionHandler() - analysis_ids = conn_handler.execute_fetchall( - sql, (self._id, qiita_config.portal)) - return {a[0] for a in analysis_ids} + with TRN: + sql = """SELECT analysis_id FROM qiita.analysis_users + JOIN qiita.analysis_portal USING (analysis_id) + JOIN qiita.portal_type USING (portal_type_id) + WHERE email = %s AND portal = %s""" + TRN.add(sql, [self._id, qiita_config.portal]) + return set(TRN.execute_fetchflatten()) # ------- methods --------- def change_password(self, oldpass, newpass): @@ -441,23 +450,25 @@ def change_password(self, oldpass, newpass): bool password changed or not """ - conn_handler = SQLConnectionHandler() - dbpass = conn_handler.execute_fetchone( - "SELECT password FROM qiita.{0} WHERE email = %s".format( - self._table), (self._id, ))[0] - if dbpass == hash_password(oldpass, dbpass): - self._change_pass(newpass) - return True - return False + with TRN: + sql = "SELECT password FROM qiita.{0} WHERE email = %s".format( + self._table) + TRN.add(sql, [self._id]) + dbpass = TRN.execute_fetchlast() + if dbpass == hash_password(oldpass, dbpass): + self._change_pass(newpass) + return True + return False def generate_reset_code(self): """Generates a password reset code for user""" - reset_code = create_rand_string(20, punct=False) - sql = ("UPDATE qiita.{0} SET pass_reset_code = %s, " - "pass_reset_timestamp = NOW() WHERE email = %s".format( - self._table)) - conn_handler = SQLConnectionHandler() - conn_handler.execute(sql, (reset_code, self._id)) + with TRN: + reset_code = create_rand_string(20, punct=False) + sql = """UPDATE qiita.{0} + SET pass_reset_code = %s, pass_reset_timestamp = NOW() + WHERE email = %s""".format(self._table) + TRN.add(sql, [reset_code, self._id]) + TRN.execute() def change_forgot_password(self, code, newpass): """Changes the password if the code is valid @@ -474,19 +485,22 @@ def change_forgot_password(self, code, newpass): bool password changed or not """ - if self.verify_code(self._id, code, "reset"): - self._change_pass(newpass) - return True - return False + with TRN: + if self.verify_code(self._id, code, "reset"): + self._change_pass(newpass) + return True + return False def _change_pass(self, newpass): - if not validate_password(newpass): - raise IncorrectPasswordError("Bad password given!") - - sql = ("UPDATE qiita.{0} SET password=%s, pass_reset_code=NULL WHERE " - "email = %s".format(self._table)) - conn_handler = SQLConnectionHandler() - conn_handler.execute(sql, (hash_password(newpass), self._id)) + with TRN: + if not validate_password(newpass): + raise IncorrectPasswordError("Bad password given!") + + sql = """UPDATE qiita.{0} + SET password=%s, pass_reset_code = NULL + WHERE email = %s""".format(self._table) + TRN.add(sql, [hash_password(newpass), self._id]) + TRN.execute() def validate_email(email): diff --git a/qiita_db/util.py b/qiita_db/util.py index 01f204e86..2fd825d61 100644 --- a/qiita_db/util.py +++ b/qiita_db/util.py @@ -51,10 +51,11 @@ from shutil import move, rmtree from json import dumps from datetime import datetime +from itertools import chain from qiita_core.exceptions import IncompetentQiitaDeveloperError from .exceptions import QiitaDBColumnError, QiitaDBError, QiitaDBLookupError -from .sql_connection import SQLConnectionHandler +from .sql_connection import TRN def params_dict_to_json(options): @@ -143,16 +144,17 @@ def get_filetypes(key='type'): If `key` is "type", dict is of the form {type: filetype_id} If `key` is "filetype_id", dict is of the form {filetype_id: type} """ - con = SQLConnectionHandler() - if key == 'type': - cols = 'type, filetype_id' - elif key == 'filetype_id': - cols = 'filetype_id, type' - else: - raise QiitaDBColumnError("Unknown key. Pass either 'type' or " - "'filetype_id'.") - sql = 'select {} from qiita.filetype'.format(cols) - return dict(con.execute_fetchall(sql)) + with TRN: + if key == 'type': + cols = 'type, filetype_id' + elif key == 'filetype_id': + cols = 'filetype_id, type' + else: + raise QiitaDBColumnError("Unknown key. Pass either 'type' or " + "'filetype_id'.") + sql = 'SELECT {} FROM qiita.filetype'.format(cols) + TRN.add(sql) + return dict(TRN.execute_fetchindex()) def get_filepath_types(key='filepath_type'): @@ -172,16 +174,17 @@ def get_filepath_types(key='filepath_type'): - If `key` is "filepath_type_id", dict is of the form {filepath_type_id: filepath_type} """ - con = SQLConnectionHandler() - if key == 'filepath_type': - cols = 'filepath_type, filepath_type_id' - elif key == 'filepath_type_id': - cols = 'filepath_type_id, filepath_type' - else: - raise QiitaDBColumnError("Unknown key. Pass either 'filepath_type' or " - "'filepath_type_id'.") - sql = 'select {} from qiita.filepath_type'.format(cols) - return dict(con.execute_fetchall(sql)) + with TRN: + if key == 'filepath_type': + cols = 'filepath_type, filepath_type_id' + elif key == 'filepath_type_id': + cols = 'filepath_type_id, filepath_type' + else: + raise QiitaDBColumnError("Unknown key. Pass either 'filepath_type'" + " or 'filepath_type_id'.") + sql = 'SELECT {} FROM qiita.filepath_type'.format(cols) + TRN.add(sql) + return dict(TRN.execute_fetchindex()) def get_data_types(key='data_type'): @@ -200,16 +203,17 @@ def get_data_types(key='data_type'): - If `key` is "data_type_id", dict is of the form {data_type_id: data_type} """ - con = SQLConnectionHandler() - if key == 'data_type': - cols = 'data_type, data_type_id' - elif key == 'data_type_id': - cols = 'data_type_id, data_type' - else: - raise QiitaDBColumnError("Unknown key. Pass either 'data_type_id' or " - "'data_type'.") - sql = 'select {} from qiita.data_type'.format(cols) - return dict(con.execute_fetchall(sql)) + with TRN: + if key == 'data_type': + cols = 'data_type, data_type_id' + elif key == 'data_type_id': + cols = 'data_type_id, data_type' + else: + raise QiitaDBColumnError("Unknown key. Pass either 'data_type_id' " + "or 'data_type'.") + sql = 'SELECT {} FROM qiita.data_type'.format(cols) + TRN.add(sql) + return dict(TRN.execute_fetchindex()) def create_rand_string(length, punct=True): @@ -262,13 +266,11 @@ def hash_password(password, hashedpw=None): return output -def check_required_columns(conn_handler, keys, table): +def check_required_columns(keys, table): """Makes sure all required columns in database table are in keys Parameters ---------- - conn_handler: SQLConnectionHandler object - Previously opened connection to the database keys: iterable Holds the keys in the dictionary table: str @@ -281,27 +283,27 @@ def check_required_columns(conn_handler, keys, table): RuntimeError Unable to get columns from database """ - sql = ("SELECT is_nullable, column_name, column_default " - "FROM information_schema.columns " - "WHERE table_name = %s") - cols = conn_handler.execute_fetchall(sql, (table, )) - # Test needed because a user with certain permissions can query without - # error but be unable to get the column names - if len(cols) == 0: - raise RuntimeError("Unable to fetch column names for table %s" % table) - required = set(x[1] for x in cols if x[0] == 'NO' and x[2] is None) - if len(required.difference(keys)) > 0: - raise QiitaDBColumnError("Required keys missing: %s" % - required.difference(keys)) - - -def check_table_cols(conn_handler, keys, table): + with TRN: + sql = """SELECT is_nullable, column_name, column_default + FROM information_schema.columns WHERE table_name = %s""" + TRN.add(sql, [table]) + cols = TRN.execute_fetchindex() + # Test needed because a user with certain permissions can query without + # error but be unable to get the column names + if len(cols) == 0: + raise RuntimeError("Unable to fetch column names for table %s" + % table) + required = set(x[1] for x in cols if x[0] == 'NO' and x[2] is None) + if len(required.difference(keys)) > 0: + raise QiitaDBColumnError("Required keys missing: %s" % + required.difference(keys)) + + +def check_table_cols(keys, table): """Makes sure all keys correspond to column headers in a table Parameters ---------- - conn_handler: SQLConnectionHandler object - Previously opened connection to the database keys: iterable Holds the keys in the dictionary table: str @@ -314,16 +316,19 @@ def check_table_cols(conn_handler, keys, table): RuntimeError Unable to get columns from database """ - sql = ("SELECT column_name FROM information_schema.columns WHERE " - "table_name = %s") - cols = [x[0] for x in conn_handler.execute_fetchall(sql, (table, ))] - # Test needed because a user with certain permissions can query without - # error but be unable to get the column names - if len(cols) == 0: - raise RuntimeError("Unable to fetch column names for table %s" % table) - if len(set(keys).difference(cols)) > 0: - raise QiitaDBColumnError("Non-database keys found: %s" % - set(keys).difference(cols)) + with TRN: + sql = """SELECT column_name FROM information_schema.columns + WHERE table_name = %s""" + TRN.add(sql, [table]) + cols = TRN.execute_fetchflatten() + # Test needed because a user with certain permissions can query without + # error but be unable to get the column names + if len(cols) == 0: + raise RuntimeError("Unable to fetch column names for table %s" + % table) + if len(set(keys).difference(cols)) > 0: + raise QiitaDBColumnError("Non-database keys found: %s" % + set(keys).difference(cols)) def get_table_cols(table): @@ -339,11 +344,11 @@ def get_table_cols(table): list of str The column headers of `table` """ - conn_handler = SQLConnectionHandler() - headers = conn_handler.execute_fetchall( - "SELECT column_name FROM information_schema.columns WHERE " - "table_name=%s AND table_schema='qiita'", (table, )) - return [h[0] for h in headers] + with TRN: + sql = """SELECT column_name FROM information_schema.columns + WHERE table_name=%s AND table_schema='qiita'""" + TRN.add(sql, [table]) + return TRN.execute_fetchflatten() def get_table_cols_w_type(table): @@ -359,31 +364,37 @@ def get_table_cols_w_type(table): list of tuples of (str, str) The column headers and data type of `table` """ - conn_handler = SQLConnectionHandler() - return conn_handler.execute_fetchall( - "SELECT column_name, data_type FROM information_schema.columns WHERE " - "table_name=%s", (table,)) + with TRN: + sql = """SELECT column_name, data_type FROM information_schema.columns + WHERE table_name=%s""" + TRN.add(sql, [table]) + return TRN.execute_fetchindex() -def exists_table(table, conn_handler): - r"""Checks if `table` exists on the database connected through - `conn_handler` +def exists_table(table): + r"""Checks if `table` exists on the database Parameters ---------- table : str The table name to check if exists - conn_handler : SQLConnectionHandler - The connection handler object connected to the DB + + Returns + ------- + bool + Whether `table` exists on the database or not """ - return conn_handler.execute_fetchone( - "SELECT exists(SELECT * FROM information_schema.tables WHERE " - "table_name=%s)", (table,))[0] + with TRN: + sql = """SELECT exists( + SELECT * FROM information_schema.tables + WHERE table_name=%s)""" + TRN.add(sql, [table]) + return TRN.execute_fetchlast() -def exists_dynamic_table(table, prefix, suffix, conn_handler): - r"""Checks if the dynamic `table` exists on the database connected through - `conn_handler`, and its name starts with prefix and ends with suffix +def exists_dynamic_table(table, prefix, suffix): + r"""Checks if the dynamic `table` exists on the database, and its name + starts with prefix and ends with suffix Parameters ---------- @@ -393,11 +404,15 @@ def exists_dynamic_table(table, prefix, suffix, conn_handler): The table name prefix suffix : str The table name suffix - conn_handler : SQLConnectionHandler - The connection handler object connected to the DB + + Returns + ------- + bool + Whether `table` exists on the database or not and its name + starts with prefix and ends with suffix """ return (table.startswith(prefix) and table.endswith(suffix) and - exists_table(table, conn_handler)) + exists_table(table)) def get_db_files_base_dir(): @@ -408,10 +423,9 @@ def get_db_files_base_dir(): str The path to the base directory of all db files """ - conn_handler = SQLConnectionHandler() - - return conn_handler.execute_fetchone( - "SELECT base_data_dir FROM settings")[0] + with TRN: + TRN.add("SELECT base_data_dir FROM settings") + return TRN.execute_fetchlast() def get_work_base_dir(): @@ -422,10 +436,9 @@ def get_work_base_dir(): str The path to the base directory of all db files """ - conn_handler = SQLConnectionHandler() - - return conn_handler.execute_fetchone( - "SELECT base_work_dir FROM settings")[0] + with TRN: + TRN.add("SELECT base_work_dir FROM settings") + return TRN.execute_fetchlast() def compute_checksum(path): @@ -549,20 +562,19 @@ def get_mountpoint(mount_type, retrieve_all=False): list List of tuple, where: [(id_mountpoint, filepath_of_mountpoint)] """ - conn_handler = SQLConnectionHandler() - - if retrieve_all: - result = conn_handler.execute_fetchall( - "SELECT data_directory_id, mountpoint, subdirectory FROM " - "qiita.data_directory WHERE data_type='%s' ORDER BY active DESC" - % mount_type) - else: - result = [conn_handler.execute_fetchone( - "SELECT data_directory_id, mountpoint, subdirectory FROM " - "qiita.data_directory WHERE data_type='%s' and active=true" - % mount_type)] - basedir = get_db_files_base_dir() - return [(d, join(basedir, m, s)) for d, m, s in result] + with TRN: + if retrieve_all: + sql = """SELECT data_directory_id, mountpoint, subdirectory + FROM qiita.data_directory + WHERE data_type=%s ORDER BY active DESC""" + else: + sql = """SELECT data_directory_id, mountpoint, subdirectory + FROM qiita.data_directory + WHERE data_type=%s AND active=true""" + TRN.add(sql, [mount_type]) + result = TRN.execute_fetchindex() + basedir = get_db_files_base_dir() + return [(d, join(basedir, m, s)) for d, m, s in result] def get_mountpoint_path_by_id(mount_id): @@ -578,18 +590,20 @@ def get_mountpoint_path_by_id(mount_id): str The mountpoint path """ - conn_handler = SQLConnectionHandler() - mountpoint, subdirectory = conn_handler.execute_fetchone( - """SELECT mountpoint, subdirectory FROM qiita.data_directory - WHERE data_directory_id=%s""", (mount_id,)) - return join(get_db_files_base_dir(), mountpoint, subdirectory) - - -def insert_filepaths(filepaths, obj_id, table, filepath_table, conn_handler, - move_files=True, queue=None): - r"""Inserts `filepaths` in the DB connected with `conn_handler`. Since - the files live outside the database, the directory in which the files - lives is controlled by the database, so it copies the filepaths from + with TRN: + sql = """SELECT mountpoint, subdirectory FROM qiita.data_directory + WHERE data_directory_id=%s""" + TRN.add(sql, [mount_id]) + mountpoint, subdirectory = TRN.execute_fetchindex()[0] + return join(get_db_files_base_dir(), mountpoint, subdirectory) + + +def insert_filepaths(filepaths, obj_id, table, filepath_table, + move_files=True): + r"""Inserts `filepaths` in the database. + + Since the files live outside the database, the directory in which the files + lives is controlled by the database, so it moves the filepaths from its original location to the controlled directory. Parameters @@ -603,107 +617,107 @@ def insert_filepaths(filepaths, obj_id, table, filepath_table, conn_handler, Table that holds the file data. filepath_table : str Table that holds the filepath information - conn_handler : SQLConnectionHandler - The connection handler object connected to the DB move_files : bool, optional Whether or not to copy from the given filepaths to the db filepaths default: True - queue : str, optional - The queue to add this transaction to. Default return list of ids Returns ------- - list or None - List of the filepath_id in the database for each added filepath if - queue not specified, or no return value if queue specified + list of int + List of the filepath_id in the database for each added filepath """ - new_filepaths = filepaths - - dd_id, mp = get_mountpoint(table)[0] - base_fp = join(get_db_files_base_dir(), mp) - - if move_files: - # Generate the new fileapths. Format: DataId_OriginalName - # Keeping the original name is useful for checking if the RawData - # alrady exists on the DB - db_path = partial(join, base_fp) - new_filepaths = [ - (db_path("%s_%s" % (obj_id, basename(path))), id) - for path, id in filepaths] - # Move the original files to the controlled DB directory - for old_fp, new_fp in zip(filepaths, new_filepaths): - move(old_fp[0], new_fp[0]) - - def str_to_id(x): - return (x if isinstance(x, (int, long)) - else convert_to_id(x, "filepath_type")) - paths_w_checksum = [(relpath(path, base_fp), str_to_id(id), - compute_checksum(path)) - for path, id in new_filepaths] - # Create the list of SQL values to add - values = ["('%s', %s, '%s', %s, %s)" % (scrub_data(path), pid, - checksum, 1, dd_id) for path, pid, checksum in - paths_w_checksum] - # Insert all the filepaths at once and get the filepath_id back - sql = ("INSERT INTO qiita.{0} (filepath, filepath_type_id, checksum, " - "checksum_algorithm_id, data_directory_id) VALUES {1} RETURNING" - " filepath_id".format(filepath_table, ', '.join(values))) - if queue is not None: - # Drop the sql into the given queue - conn_handler.add_to_queue(queue, sql, None) - else: - ids = conn_handler.execute_fetchall(sql) - - # we will receive a list of lists with a single element on it - # (the id), transform it to a list of ids - return [id[0] for id in ids] + with TRN: + new_filepaths = filepaths + + dd_id, mp = get_mountpoint(table)[0] + base_fp = join(get_db_files_base_dir(), mp) + + if move_files: + # Generate the new fileapths. Format: DataId_OriginalName + # Keeping the original name is useful for checking if the RawData + # alrady exists on the DB + db_path = partial(join, base_fp) + new_filepaths = [ + (db_path("%s_%s" % (obj_id, basename(path))), id_) + for path, id_ in filepaths] + # Move the original files to the controlled DB directory + for old_fp, new_fp in zip(filepaths, new_filepaths): + move(old_fp[0], new_fp[0]) + # In case the transaction executes a rollback, we need to + # make sure the files have not been moved + TRN.add_post_rollback_func(move, new_fp, old_fp) + + def str_to_id(x): + return (x if isinstance(x, (int, long)) + else convert_to_id(x, "filepath_type")) + paths_w_checksum = [(relpath(path, base_fp), str_to_id(id_), + compute_checksum(path)) + for path, id_ in new_filepaths] + # Create the list of SQL values to add + values = [[path, pid, checksum, 1, dd_id] + for path, pid, checksum in paths_w_checksum] + # Insert all the filepaths at once and get the filepath_id back + sql = """INSERT INTO qiita.{0} + (filepath, filepath_type_id, checksum, + checksum_algorithm_id, data_directory_id) + VALUES (%s, %s, %s, %s, %s) + RETURNING filepath_id""".format(filepath_table) + idx = TRN.index + TRN.add(sql, values, many=True) + # Since we added the query with many=True, we've added len(values) + # queries to the transaction, so the ids are in the last idx queries + return list(chain.from_iterable( + chain.from_iterable(TRN.execute()[idx:]))) def purge_filepaths(): r"""Goes over the filepath table and remove all the filepaths that are not - used in any place""" - conn_handler = SQLConnectionHandler() - - # Get all the (table, column) pairs that reference to the filepath table - # Code adapted from http://stackoverflow.com/q/5347050/3746629 - table_cols_pairs = conn_handler.execute_fetchall( - """SELECT R.TABLE_NAME, R.column_name - FROM INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE u - INNER JOIN INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS FK - ON U.CONSTRAINT_CATALOG = FK.UNIQUE_CONSTRAINT_CATALOG - AND U.CONSTRAINT_SCHEMA = FK.UNIQUE_CONSTRAINT_SCHEMA - AND U.CONSTRAINT_NAME = FK.UNIQUE_CONSTRAINT_NAME - INNER JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE R - ON R.CONSTRAINT_CATALOG = FK.CONSTRAINT_CATALOG - AND R.CONSTRAINT_SCHEMA = FK.CONSTRAINT_SCHEMA - AND R.CONSTRAINT_NAME = FK.CONSTRAINT_NAME - WHERE U.COLUMN_NAME = 'filepath_id' - AND U.TABLE_SCHEMA = 'qiita' - AND U.TABLE_NAME = 'filepath'""") - - union_str = " UNION ".join( - ["SELECT %s FROM qiita.%s WHERE %s IS NOT NULL" % (col, table, col) - for table, col in table_cols_pairs]) - # Get all the filepaths from the filepath table that are not - # referenced from any place in the database - fps = conn_handler.execute_fetchall( - """SELECT filepath_id, filepath, filepath_type, data_directory_id - FROM qiita.filepath FP JOIN qiita.filepath_type FPT - ON FP.filepath_type_id = FPT.filepath_type_id - WHERE filepath_id NOT IN (%s)""" % union_str) - - # We can now go over and remove all the filepaths - for fp_id, fp, fp_type, dd_id in fps: - conn_handler.execute("DELETE FROM qiita.filepath WHERE filepath_id=%s", - (fp_id,)) - - # Remove the data - fp = join(get_mountpoint_path_by_id(dd_id), fp) - if exists(fp): - if fp_type is 'directory': - rmtree(fp) - else: - remove(fp) + used in any place + """ + with TRN: + # Get all the (table, column) pairs that reference to the filepath + # table. Adapted from http://stackoverflow.com/q/5347050/3746629 + sql = """SELECT R.TABLE_NAME, R.column_name + FROM INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE u + INNER JOIN INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS FK + ON U.CONSTRAINT_CATALOG = FK.UNIQUE_CONSTRAINT_CATALOG + AND U.CONSTRAINT_SCHEMA = FK.UNIQUE_CONSTRAINT_SCHEMA + AND U.CONSTRAINT_NAME = FK.UNIQUE_CONSTRAINT_NAME + INNER JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE R + ON R.CONSTRAINT_CATALOG = FK.CONSTRAINT_CATALOG + AND R.CONSTRAINT_SCHEMA = FK.CONSTRAINT_SCHEMA + AND R.CONSTRAINT_NAME = FK.CONSTRAINT_NAME + WHERE U.COLUMN_NAME = 'filepath_id' + AND U.TABLE_SCHEMA = 'qiita' + AND U.TABLE_NAME = 'filepath'""" + TRN.add(sql) + + union_str = " UNION ".join( + ["SELECT %s FROM qiita.%s WHERE %s IS NOT NULL" % (col, table, col) + for table, col in TRN.execute_fetchindex()]) + # Get all the filepaths from the filepath table that are not + # referenced from any place in the database + sql = """SELECT filepath_id, filepath, filepath_type, data_directory_id + FROM qiita.filepath FP JOIN qiita.filepath_type FPT + ON FP.filepath_type_id = FPT.filepath_type_id + WHERE filepath_id NOT IN (%s)""" % union_str + TRN.add(sql) + + # We can now go over and remove all the filepaths + sql = "DELETE FROM qiita.filepath WHERE filepath_id=%s" + for fp_id, fp, fp_type, dd_id in TRN.execute_fetchindex(): + TRN.add(sql, [fp_id]) + + # Remove the data + fp = join(get_mountpoint_path_by_id(dd_id), fp) + if exists(fp): + if fp_type is 'directory': + func = rmtree + else: + func = remove + TRN.add_post_commit_func(func, fp) + + TRN.execute() def move_filepaths_to_upload_folder(study_id, filepaths): @@ -717,19 +731,23 @@ def move_filepaths_to_upload_folder(study_id, filepaths): filepaths : list List of filepaths to move to the upload folder """ - conn_handler = SQLConnectionHandler() - uploads_fp = join(get_mountpoint("uploads")[0][1], str(study_id)) + with TRN: + uploads_fp = join(get_mountpoint("uploads")[0][1], str(study_id)) + path_builder = partial(join, uploads_fp) + + # We can now go over and remove all the filepaths + sql = """DELETE FROM qiita.filepath WHERE filepath_id=%s""" + for fp_id, fp, _ in filepaths: + TRN.add(sql, [fp_id]) - # We can now go over and remove all the filepaths - for fp_id, fp, _ in filepaths: - conn_handler.execute("DELETE FROM qiita.filepath WHERE filepath_id=%s", - (fp_id,)) + # removing id from the raw data filename + filename = basename(fp).split('_', 1)[1] + destination = path_builder(filename) - # removing id from the raw data filename - filename = basename(fp).split('_', 1)[1] - destination = join(uploads_fp, filename) + TRN.add_post_rollback_func(move, destination, fp) + move(fp, destination) - move(fp, destination) + TRN.execute() def get_filepath_id(table, fp): @@ -742,24 +760,31 @@ def get_filepath_id(table, fp): fp : str The filepath + Returns + ------- + int + The filepath id forthe given filepath + Raises ------ QiitaDBError If fp is not stored in the DB. """ - conn_handler = SQLConnectionHandler() - _, mp = get_mountpoint(table)[0] - base_fp = join(get_db_files_base_dir(), mp) + with TRN: + _, mp = get_mountpoint(table)[0] + base_fp = join(get_db_files_base_dir(), mp) - fp_id = conn_handler.execute_fetchone( - "SELECT filepath_id FROM qiita.filepath WHERE filepath=%s", - (relpath(fp, base_fp),)) + sql = "SELECT filepath_id FROM qiita.filepath WHERE filepath=%s" + TRN.add(sql, [relpath(fp, base_fp)]) + fp_id = TRN.execute_fetchindex() - # check if the query has actually returned something - if not fp_id: - raise QiitaDBError("Filepath not stored in the database") + # check if the query has actually returned something + if not fp_id: + raise QiitaDBError("Filepath not stored in the database") - return fp_id[0] + # If there was a result it was a single row and and single value, + # hence access to [0][0] + return fp_id[0][0] def filepath_id_to_rel_path(filepath_id): @@ -768,16 +793,16 @@ def filepath_id_to_rel_path(filepath_id): Returns ------- str + The relative path for the given filepath id """ - conn = SQLConnectionHandler() - - sql = """SELECT dd.mountpoint, dd.subdirectory, fp.filepath - FROM qiita.filepath fp JOIN qiita.data_directory dd - ON fp.data_directory_id = dd.data_directory_id - WHERE fp.filepath_id = %s""" - - result = join(*conn.execute_fetchone(sql, [filepath_id])) - return result + with TRN: + sql = """SELECT mountpoint, subdirectory, filepath + FROM qiita.filepath + JOIN qiita.data_directory USING (data_directory_id) + WHERE filepath_id = %s""" + TRN.add(sql, [filepath_id]) + # It should be only one row + return join(*TRN.execute_fetchindex()[0]) def filepath_ids_to_rel_paths(filepath_ids): @@ -792,22 +817,17 @@ def filepath_ids_to_rel_paths(filepath_ids): dict where keys are ints and values are str {filepath_id: relative_path} """ - conn = SQLConnectionHandler() - - sql = """SELECT fp.filepath_id, dd.mountpoint, dd.subdirectory, fp.filepath - FROM qiita.filepath fp JOIN qiita.data_directory dd - ON fp.data_directory_id = dd.data_directory_id - WHERE fp.filepath_id in ({})""".format( - ', '.join([str(fpid) for fpid in filepath_ids])) - - if filepath_ids: - result = {row[0]: join(*row[1:]) - for row in conn.execute_fetchall(sql)} - - return result - else: + if not filepath_ids: return {} + with TRN: + sql = """SELECT filepath_id, mountpoint, subdirectory, filepath + FROM qiita.filepath + JOIN qiita.data_directory USING (data_directory_id) + WHERE filepath_id IN %s""" + TRN.add(sql, [tuple(filepath_ids)]) + return {row[0]: join(*row[1:]) for row in TRN.execute_fetchindex()} + def convert_to_id(value, table, text_col=None): """Converts a string value to its corresponding table identifier @@ -832,12 +852,17 @@ def convert_to_id(value, table, text_col=None): The passed string has no associated id """ text_col = table if text_col is None else text_col - conn_handler = SQLConnectionHandler() - sql = "SELECT {0}_id FROM qiita.{0} WHERE {1} = %s".format(table, text_col) - _id = conn_handler.execute_fetchone(sql, (value, )) - if _id is None: - raise QiitaDBLookupError("%s not valid for table %s" % (value, table)) - return _id[0] + with TRN: + sql = "SELECT {0}_id FROM qiita.{0} WHERE {1} = %s".format( + table, text_col) + TRN.add(sql, [value]) + _id = TRN.execute_fetchindex() + if not _id: + raise QiitaDBLookupError("%s not valid for table %s" + % (value, table)) + # If there was a result it was a single row and and single value, + # hence access to [0][0] + return _id[0][0] def convert_from_id(value, table): @@ -860,13 +885,16 @@ def convert_from_id(value, table): QiitaDBLookupError The passed id has no associated string """ - conn_handler = SQLConnectionHandler() - string = conn_handler.execute_fetchone( - "SELECT {0} FROM qiita.{0} WHERE {0}_id = %s".format(table), - (value, )) - if string is None: - raise QiitaDBLookupError("%s not valid for table %s" % (value, table)) - return string[0] + with TRN: + sql = "SELECT {0} FROM qiita.{0} WHERE {0}_id = %s".format(table) + TRN.add(sql, [value]) + string = TRN.execute_fetchindex() + if not string: + raise QiitaDBLookupError("%s not valid for table %s" + % (value, table)) + # If there was a result it was a single row and and single value, + # hence access to [0][0] + return string[0][0] def get_count(table): @@ -881,9 +909,10 @@ def get_count(table): ------- int """ - conn = SQLConnectionHandler() - sql = "SELECT count(1) FROM %s" % table - return conn.execute_fetchone(sql)[0] + with TRN: + sql = "SELECT count(1) FROM %s" % table + TRN.add(sql) + return TRN.execute_fetchlast() def check_count(table, exp_count): @@ -911,10 +940,16 @@ def get_preprocessed_params_tables(): ------- list or str """ - sql = ("SELECT * FROM information_schema.tables WHERE table_schema = " - "'qiita' AND SUBSTR(table_name, 1, 13) = 'preprocessed_'") - conn = SQLConnectionHandler() - return [x[2] for x in conn.execute_fetchall(sql)] + with TRN: + sql = """SELECT table_name FROM information_schema.tables + WHERE table_schema = 'qiita' + AND SUBSTR(table_name, 1, 13) = 'preprocessed_' + AND table_name NOT IN ('preprocessed_data', + 'preprocessed_filepath', + 'preprocessed_processed_data') + ORDER BY table_name""" + TRN.add(sql) + return TRN.execute_fetchflatten() def get_processed_params_tables(): @@ -924,11 +959,13 @@ def get_processed_params_tables(): ------- list of str """ - sql = ("SELECT * FROM information_schema.tables WHERE table_schema = " - "'qiita' AND SUBSTR(table_name, 1, 17) = 'processed_params_'") - - conn = SQLConnectionHandler() - return sorted([x[2] for x in conn.execute_fetchall(sql)]) + with TRN: + sql = """SELECT table_name FROM information_schema.tables + WHERE table_schema = 'qiita' + AND SUBSTR(table_name, 1, 17) = 'processed_params_' + ORDER BY table_name""" + TRN.add(sql) + return TRN.execute_fetchflatten() def get_lat_longs(): @@ -939,20 +976,20 @@ def get_lat_longs(): list of [float, float] The latitude and longitude for each sample in the database """ - conn = SQLConnectionHandler() - sql = """SELECT DISTINCT table_name - FROM information_schema.columns - WHERE SUBSTR(table_name, 1, 7) = 'sample_' - AND table_schema = 'qiita' - AND column_name IN ('latitude', 'longitude');""" - tables_gen = (t[0] for t in conn.execute_fetchall(sql)) + with TRN: + sql = """SELECT DISTINCT table_name + FROM information_schema.columns + WHERE SUBSTR(table_name, 1, 7) = 'sample_' + AND table_schema = 'qiita' + AND column_name IN ('latitude', 'longitude');""" + TRN.add(sql) - sql = "SELECT latitude, longitude FROM qiita.{0}" - result = [] - for table in tables_gen: - result.extend(conn.execute_fetchall(sql.format(table))) + sql = "SELECT latitude, longitude FROM qiita.{0}" + idx = TRN.index + for table in TRN.execute_fetchflatten(): + TRN.add(sql.format(table)) - return result + return list(chain.from_iterable(TRN.execute()[idx:])) def get_environmental_packages(): @@ -965,9 +1002,9 @@ def get_environmental_packages(): environmental package name and the second string is the table where the metadata for the environmental package is stored """ - conn_handler = SQLConnectionHandler() - return conn_handler.execute_fetchall( - "SELECT * FROM qiita.environmental_package") + with TRN: + TRN.add("SELECT * FROM qiita.environmental_package") + return TRN.execute_fetchindex() def get_timeseries_types(): @@ -979,31 +1016,10 @@ def get_timeseries_types(): The available timeseries types. Each timeseries type is defined by the tuple (timeseries_id, timeseries_type, intervention_type) """ - conn_handler = SQLConnectionHandler() - return conn_handler.execute_fetchall( - "SELECT * FROM qiita.timeseries_type ORDER BY timeseries_type_id") - - -def find_repeated(values): - """Find repeated elements in the inputed list - - Parameters - ---------- - values : list - List of elements to find duplicates in - - Returns - ------- - set - Repeated elements in ``values`` - """ - seen, repeated = set(), set() - for value in values: - if value in seen: - repeated.add(value) - else: - seen.add(value) - return repeated + with TRN: + sql = "SELECT * FROM qiita.timeseries_type ORDER BY timeseries_type_id" + TRN.add(sql) + return TRN.execute_fetchindex() def check_access_to_analysis_result(user_id, requested_path): @@ -1023,31 +1039,30 @@ def check_access_to_analysis_result(user_id, requested_path): list of int The filepath IDs associated with the requested path """ - conn = SQLConnectionHandler() - - # Get all filepath ids associated with analyses that the user has - # access to where the filepath is the base_requested_fp from above. - # There should typically be only one matching filepath ID, but for safety - # we allow for the possibility of multiple. - sql = """select fp.filepath_id - from qiita.analysis_job aj join ( - select analysis_id from qiita.analysis A - join qiita.analysis_status stat - on A.analysis_status_id = stat.analysis_status_id - where stat.analysis_status_id = 6 - UNION - select analysis_id from qiita.analysis_users - where email = %s - UNION - select analysis_id from qiita.analysis where email = %s - ) ids on aj.analysis_id = ids.analysis_id - join qiita.job_results_filepath jrfp on - aj.job_id = jrfp.job_id - join qiita.filepath fp on jrfp.filepath_id = fp.filepath_id - where fp.filepath = %s""" - - return [row[0] for row in conn.execute_fetchall( - sql, [user_id, user_id, requested_path])] + with TRN: + # Get all filepath ids associated with analyses that the user has + # access to where the filepath is the base_requested_fp from above. + # There should typically be only one matching filepath ID, but for + # safety we allow for the possibility of multiple. + sql = """SELECT fp.filepath_id + FROM qiita.analysis_job aj JOIN ( + SELECT analysis_id FROM qiita.analysis A + JOIN qiita.analysis_status stat + ON A.analysis_status_id = stat.analysis_status_id + WHERE stat.analysis_status_id = 6 + UNION + SELECT analysis_id FROM qiita.analysis_users + WHERE email = %s + UNION + SELECT analysis_id FROM qiita.analysis WHERE email = %s + ) ids ON aj.analysis_id = ids.analysis_id + JOIN qiita.job_results_filepath jrfp ON + aj.job_id = jrfp.job_id + JOIN qiita.filepath fp ON jrfp.filepath_id = fp.filepath_id + WHERE fp.filepath = %s""" + TRN.add(sql, [user_id, user_id, requested_path]) + + return TRN.execute_fetchflatten() def infer_status(statuses): diff --git a/qiita_pet/handlers/analysis_handlers.py b/qiita_pet/handlers/analysis_handlers.py index b08fba82f..f769c927f 100644 --- a/qiita_pet/handlers/analysis_handlers.py +++ b/qiita_pet/handlers/analysis_handlers.py @@ -32,6 +32,7 @@ from qiita_db.exceptions import QiitaDBUnknownIDError from qiita_db.study import Study from qiita_db.logger import LogEntry +from qiita_core.util import execute_as_transaction SELECT_SAMPLES = 2 SELECT_COMMANDS = 3 @@ -59,6 +60,7 @@ def check_analysis_access(user, analysis): class SelectCommandsHandler(BaseHandler): """Select commands to be executed""" @authenticated + @execute_as_transaction def get(self): analysis_id = int(self.get_argument('aid')) analysis = Analysis(analysis_id) @@ -71,6 +73,7 @@ def get(self): commands=commands, data_types=data_types, aid=analysis.id) @authenticated + @execute_as_transaction def post(self): name = self.get_argument('name') desc = self.get_argument('description') @@ -86,6 +89,7 @@ def post(self): class AnalysisWaitHandler(BaseHandler): @authenticated + @execute_as_transaction def get(self, analysis_id): analysis_id = int(analysis_id) try: @@ -100,6 +104,7 @@ def get(self, analysis_id): group_id=group_id, aname=analysis.name) @authenticated + @execute_as_transaction def post(self, analysis_id): analysis_id = int(analysis_id) rarefaction_depth = self.get_argument('rarefaction-depth') @@ -132,6 +137,7 @@ def post(self, analysis_id): class AnalysisResultsHandler(BaseHandler): @authenticated + @execute_as_transaction def get(self, analysis_id): analysis_id = int(analysis_id.split("/")[0]) analysis = Analysis(analysis_id) @@ -159,6 +165,7 @@ def get(self, analysis_id): basefolder=get_db_files_base_dir()) @authenticated + @execute_as_transaction def post(self, analysis_id): analysis_id = int(analysis_id.split("/")[0]) analysis_id_sent = int(self.get_argument('analysis_id')) @@ -192,6 +199,7 @@ def post(self, analysis_id): class ShowAnalysesHandler(BaseHandler): """Shows the user's analyses""" @authenticated + @execute_as_transaction def get(self): message = self.get_argument('message', '') level = self.get_argument('level', '') @@ -205,6 +213,7 @@ def get(self): class ResultsHandler(StaticFileHandler, BaseHandler): + @execute_as_transaction def validate_absolute_path(self, root, absolute_path): """Overrides StaticFileHandler's method to include authentication """ @@ -243,6 +252,7 @@ def validate_absolute_path(self, root, absolute_path): class SelectedSamplesHandler(BaseHandler): @authenticated + @execute_as_transaction def get(self): # Format sel_data to get study IDs for the processed data sel_data = defaultdict(dict) @@ -260,6 +270,7 @@ def get(self): class AnalysisSummaryAJAX(BaseHandler): @authenticated + @execute_as_transaction def get(self): info = Analysis(self.current_user.default_analysis).summary_data() self.write(dumps(info)) diff --git a/qiita_pet/handlers/auth_handlers.py b/qiita_pet/handlers/auth_handlers.py index 8b868480e..e7afd6b5d 100644 --- a/qiita_pet/handlers/auth_handlers.py +++ b/qiita_pet/handlers/auth_handlers.py @@ -7,7 +7,7 @@ from qiita_pet.handlers.base_handlers import BaseHandler from qiita_core.qiita_settings import qiita_config -from qiita_core.util import send_email +from qiita_core.util import send_email, execute_as_transaction from qiita_core.exceptions import (IncorrectPasswordError, IncorrectEmailError, UnverifiedEmailError) from qiita_db.user import User @@ -25,6 +25,7 @@ def get(self): error_message = "" self.render("create_user.html", error=error_message) + @execute_as_transaction def post(self): username = self.get_argument("email", "").strip().lower() password = self.get_argument("newpass", "") @@ -75,6 +76,7 @@ class AuthLoginHandler(BaseHandler): def get(self): self.redirect("/") + @execute_as_transaction def post(self): if r_client.get('maintenance') is not None: raise HTTPError(503, "Site is down for maintenance") diff --git a/qiita_pet/handlers/compute.py b/qiita_pet/handlers/compute.py index 5d241b8a0..392447ef4 100644 --- a/qiita_pet/handlers/compute.py +++ b/qiita_pet/handlers/compute.py @@ -9,7 +9,7 @@ from qiita_ware.context import submit from qiita_ware.dispatchable import (add_files_to_raw_data, unlink_all_files, create_raw_data) - +from qiita_core.util import execute_as_transaction from qiita_db.study import Study from qiita_db.exceptions import QiitaDBUnknownIDError from qiita_db.util import get_mountpoint @@ -32,6 +32,7 @@ def get(self, job_id): class CreateRawData(BaseHandler): @authenticated + @execute_as_transaction def post(self): pt_id = self.get_argument('prep_template_id') raw_data_filetype = self.get_argument('filetype') @@ -79,6 +80,7 @@ def _split(x): class AddFilesToRawData(BaseHandler): @authenticated + @execute_as_transaction def post(self): # vars to add files to raw data @@ -131,6 +133,7 @@ def _split(x): class UnlinkAllFiles(BaseHandler): @authenticated + @execute_as_transaction def post(self): # vars to remove all files from a raw data study_id = self.get_argument('study_id', None) diff --git a/qiita_pet/handlers/download.py b/qiita_pet/handlers/download.py index ac0472881..1a113e928 100644 --- a/qiita_pet/handlers/download.py +++ b/qiita_pet/handlers/download.py @@ -6,10 +6,12 @@ from qiita_pet.exceptions import QiitaPetAuthorizationError from qiita_db.util import filepath_id_to_rel_path from qiita_db.meta_util import get_accessible_filepath_ids +from qiita_core.util import execute_as_transaction class DownloadHandler(BaseHandler): @authenticated + @execute_as_transaction def get(self, filepath_id): filepath_id = int(filepath_id) # Check access to file diff --git a/qiita_pet/handlers/logger_handlers.py b/qiita_pet/handlers/logger_handlers.py index b5f700a1c..2c6fddc27 100644 --- a/qiita_pet/handlers/logger_handlers.py +++ b/qiita_pet/handlers/logger_handlers.py @@ -4,6 +4,7 @@ from .base_handlers import BaseHandler from qiita_db.logger import LogEntry +from qiita_core.util import execute_as_transaction from tornado.web import HTTPError @@ -14,12 +15,14 @@ def check_access(self): "to view error page" % self.current_user) @authenticated + @execute_as_transaction def get(self): self.check_access() logentries = LogEntry.newest_records() self.render("error_log.html", logentries=logentries) @authenticated + @execute_as_transaction def post(self): self.check_access() numentries = int(self.get_argument("numrecords")) diff --git a/qiita_pet/handlers/portal.py b/qiita_pet/handlers/portal.py index 69056bcd2..4f3df8237 100644 --- a/qiita_pet/handlers/portal.py +++ b/qiita_pet/handlers/portal.py @@ -11,6 +11,7 @@ from tornado.web import authenticated, HTTPError +from qiita_core.util import execute_as_transaction from qiita_db.study import Study from qiita_db.portal import Portal from qiita_db.exceptions import QiitaDBError @@ -25,6 +26,7 @@ def check_admin(self): raise HTTPError(403, "%s does not have access to portal editing!" % self.current_user.id) + @execute_as_transaction def get_info(self, portal="QIITA"): # Add the portals and, optionally, checkbox to the information studies = Portal(portal).get_studies() @@ -43,6 +45,7 @@ def get_info(self, portal="QIITA"): class StudyPortalHandler(PortalEditBase): @authenticated + @execute_as_transaction def get(self): self.check_admin() info = self.get_info() @@ -53,6 +56,7 @@ def get(self): portals=portals, submit_url="/admin/portals/studies/") @authenticated + @execute_as_transaction def post(self): self.check_admin() portal = self.get_argument('portal') @@ -78,6 +82,7 @@ def post(self): class StudyPortalAJAXHandler(PortalEditBase): @authenticated + @execute_as_transaction def get(self): self.check_admin() portal = self.get_argument('view-portal') diff --git a/qiita_pet/handlers/preprocessing_handlers.py b/qiita_pet/handlers/preprocessing_handlers.py index 328c441ec..0cf8f262b 100644 --- a/qiita_pet/handlers/preprocessing_handlers.py +++ b/qiita_pet/handlers/preprocessing_handlers.py @@ -7,10 +7,12 @@ Preprocessed454Params) from qiita_db.metadata_template import PrepTemplate from qiita_ware.context import submit +from qiita_core.util import execute_as_transaction class PreprocessHandler(BaseHandler): @authenticated + @execute_as_transaction def post(self): study_id = int(self.get_argument('study_id')) prep_template_id = int(self.get_argument('prep_template_id')) diff --git a/qiita_pet/handlers/stats.py b/qiita_pet/handlers/stats.py index 384089b58..d52e88a2e 100644 --- a/qiita_pet/handlers/stats.py +++ b/qiita_pet/handlers/stats.py @@ -5,6 +5,7 @@ from moi import r_client from tornado.gen import coroutine, Task +from qiita_core.util import execute_as_transaction from qiita_db.util import get_count from qiita_db.study import Study from qiita_db.util import get_lat_longs @@ -12,6 +13,7 @@ class StatsHandler(BaseHandler): + @execute_as_transaction def _get_stats(self, callback): # check if the key exists in redis lats = r_client.lrange('stats:sample_lats', 0, -1) @@ -53,6 +55,7 @@ def _get_stats(self, callback): callback([num_studies, num_samples, num_users, lat_longs]) @coroutine + @execute_as_transaction def get(self): num_studies, num_samples, num_users, lat_longs = \ yield Task(self._get_stats) diff --git a/qiita_pet/handlers/study_handlers/description_handlers.py b/qiita_pet/handlers/study_handlers/description_handlers.py index ed300152a..e1df527c2 100644 --- a/qiita_pet/handlers/study_handlers/description_handlers.py +++ b/qiita_pet/handlers/study_handlers/description_handlers.py @@ -17,6 +17,7 @@ from tornado.gen import coroutine, Task from pandas.parser import CParserError +from qiita_core.util import execute_as_transaction from qiita_core.qiita_settings import qiita_config from qiita_db.study import Study from qiita_db.data import RawData, PreprocessedData, ProcessedData @@ -82,7 +83,7 @@ def _to_int(value): class StudyDescriptionHandler(BaseHandler): - + @execute_as_transaction def _get_study_and_check_access(self, study_id): """Checks if the current user has access to the study @@ -119,6 +120,7 @@ def _get_study_and_check_access(self, study_id): return study, user, full_access + @execute_as_transaction def _process_investigation_type(self, inv_type, user_def_type, new_type): """Return the investigation_type and add it to the ontology if needed @@ -147,6 +149,7 @@ def _process_investigation_type(self, inv_type, user_def_type, new_type): inv_type = user_def_type return inv_type + @execute_as_transaction def process_sample_template(self, study, user, callback): """Process a sample template from the POST method @@ -212,6 +215,7 @@ def process_sample_template(self, study, user, callback): callback((msg, msg_level, None, None, None)) + @execute_as_transaction def update_sample_template(self, study, user, callback): """Update a sample template from the POST method @@ -268,6 +272,7 @@ def update_sample_template(self, study, user, callback): msg_level = "danger" callback((msg, msg_level, None, None, None)) + @execute_as_transaction def add_to_sample_template(self, study, user, callback): """Process a sample template from the POST method @@ -325,6 +330,7 @@ def add_to_sample_template(self, study, user, callback): callback((msg, msg_level, None, None, None)) + @execute_as_transaction def add_raw_data(self, study, user, callback): """Adds an existing raw data to the study @@ -356,6 +362,7 @@ def add_raw_data(self, study, user, callback): callback((msg, msg_level, 'prep_template_tab', pt_id, None)) + @execute_as_transaction def add_prep_template(self, study, user, callback): """Adds a prep template to the system @@ -424,6 +431,7 @@ def add_prep_template(self, study, user, callback): callback((msg, msg_level, 'prep_template_tab', pt_id, None)) + @execute_as_transaction def update_prep_template(self, study, user, callback): """Update a prep template from the POST method @@ -483,6 +491,7 @@ def update_prep_template(self, study, user, callback): callback((msg, msg_level, 'prep_template_tab', pt_id, None)) + @execute_as_transaction def make_public(self, study, user, callback): """Makes the current study public @@ -503,6 +512,7 @@ def make_public(self, study, user, callback): msg_level = "success" callback((msg, msg_level, "processed_data_tab", pd_id, None)) + @execute_as_transaction def approve_study(self, study, user, callback): """Approves the current study if and only if the current user is admin @@ -528,6 +538,7 @@ def approve_study(self, study, user, callback): msg_level = "danger" callback((msg, msg_level, "processed_data_tab", pd_id, None)) + @execute_as_transaction def request_approval(self, study, user, callback): """Changes the status of the current study to "awaiting_approval" @@ -548,6 +559,7 @@ def request_approval(self, study, user, callback): msg_level = "success" callback((msg, msg_level, "processed_data_tab", pd_id, None)) + @execute_as_transaction def make_sandbox(self, study, user, callback): """Reverts the current study to the 'sandbox' status @@ -568,6 +580,7 @@ def make_sandbox(self, study, user, callback): msg_level = "success" callback((msg, msg_level, "processed_data_tab", pd_id, None)) + @execute_as_transaction def update_investigation_type(self, study, user, callback): """Updates the investigation type of a prep template @@ -636,6 +649,7 @@ def unspecified_action(self, study, user, callback): msg_level = 'danger' callback((msg, msg_level, 'study_information_tab', None, None)) + @execute_as_transaction def remove_add_study_template(self, raw_data, study_id, fp_rsp, data_type, is_mapping_file): """Replace prep templates, raw data, and sample template with a new one @@ -670,10 +684,12 @@ def remove_add_prep_template(self, fp_rpt, study, data_type_id, remove(fp_rpt) return pt_id + @execute_as_transaction def _extend_sample_template(self, st_id, fp_rpt): SampleTemplate(st_id).extend(load_template_to_dataframe(fp_rpt)) @coroutine + @execute_as_transaction def display_template(self, study, user, msg, msg_level, full_access, top_tab=None, sub_tab=None, prep_tab=None): """Simple function to avoid duplication of code""" @@ -716,6 +732,7 @@ def display_template(self, study, user, msg, msg_level, full_access, sub_tab=sub_tab, prep_tab=prep_tab) + @execute_as_transaction def delete_study(self, study, user, callback): """Delete study @@ -884,6 +901,7 @@ def delete_processed_data(self, study, user, callback): callback((msg, msg_level, 'processed_data_tab', pd_id, None)) @authenticated + @execute_as_transaction def get(self, study_id): study, user, full_access = self._get_study_and_check_access(study_id) @@ -897,6 +915,7 @@ def get(self, study_id): @authenticated @coroutine + @execute_as_transaction def post(self, study_id): study, user, full_access = self._get_study_and_check_access(study_id) @@ -934,6 +953,7 @@ def post(self, study_id): class PreprocessingSummaryHandler(BaseHandler): + @execute_as_transaction def _get_template_variables(self, preprocessed_data_id, callback): """Generates all the variables needed to render the template @@ -986,6 +1006,7 @@ def _get_template_variables(self, preprocessed_data_id, callback): @authenticated @coroutine + @execute_as_transaction def get(self, preprocessed_data_id): ppd_id = _to_int(preprocessed_data_id) diff --git a/qiita_pet/handlers/study_handlers/ebi_handlers.py b/qiita_pet/handlers/study_handlers/ebi_handlers.py index 3e7a6f795..95112a09f 100644 --- a/qiita_pet/handlers/study_handlers/ebi_handlers.py +++ b/qiita_pet/handlers/study_handlers/ebi_handlers.py @@ -19,9 +19,11 @@ from qiita_db.study import Study from qiita_db.exceptions import QiitaDBUnknownIDError from qiita_pet.handlers.base_handlers import BaseHandler +from qiita_core.util import execute_as_transaction class EBISubmitHandler(BaseHandler): + @execute_as_transaction def display_template(self, preprocessed_data_id, msg, msg_level): """Simple function to avoid duplication of code""" preprocessed_data_id = int(preprocessed_data_id) @@ -93,6 +95,7 @@ def get(self, preprocessed_data_id): self.display_template(preprocessed_data_id, "", "") @authenticated + @execute_as_transaction def post(self, preprocessed_data_id): user = self.current_user # make sure user is admin and can therefore actually submit to EBI diff --git a/qiita_pet/handlers/study_handlers/edit_handlers.py b/qiita_pet/handlers/study_handlers/edit_handlers.py index bee813f90..0d7eb84b7 100644 --- a/qiita_pet/handlers/study_handlers/edit_handlers.py +++ b/qiita_pet/handlers/study_handlers/edit_handlers.py @@ -16,6 +16,7 @@ from qiita_db.exceptions import QiitaDBUnknownIDError from qiita_pet.handlers.base_handlers import BaseHandler from qiita_pet.handlers.util import check_access +from qiita_core.util import execute_as_transaction class StudyEditorForm(Form): @@ -58,6 +59,7 @@ class StudyEditorForm(Form): lab_person = SelectField('Lab Person', coerce=lambda x: x) + @execute_as_transaction def __init__(self, study=None, **kwargs): super(StudyEditorForm, self).__init__(**kwargs) @@ -111,6 +113,7 @@ class StudyEditorExtendedForm(StudyEditorForm): [validators.Required()]) timeseries = SelectField('Event-Based Data', coerce=lambda x: x) + @execute_as_transaction def __init__(self, study=None, **kwargs): super(StudyEditorExtendedForm, self).__init__(study=study, **kwargs) @@ -138,6 +141,7 @@ def __init__(self, study=None, **kwargs): class StudyEditHandler(BaseHandler): + @execute_as_transaction def _check_study_exists_and_user_access(self, study_id): try: study = Study(int(study_id)) @@ -175,6 +179,7 @@ def _get_study_person_id(self, index, new_people_info): return index @authenticated + @execute_as_transaction def get(self, study_id=None): study = None form_factory = StudyEditorExtendedForm @@ -192,6 +197,7 @@ def get(self, study_id=None): creation_form=creation_form, study=study) @authenticated + @execute_as_transaction def post(self, study=None): the_study = None form_factory = StudyEditorExtendedForm diff --git a/qiita_pet/handlers/study_handlers/listing_handlers.py b/qiita_pet/handlers/study_handlers/listing_handlers.py index a31f6c369..94aeeb35e 100644 --- a/qiita_pet/handlers/study_handlers/listing_handlers.py +++ b/qiita_pet/handlers/study_handlers/listing_handlers.py @@ -23,11 +23,12 @@ from qiita_db.util import get_table_cols from qiita_db.data import ProcessedData from qiita_core.exceptions import IncompetentQiitaDeveloperError - +from qiita_core.util import execute_as_transaction from qiita_pet.handlers.base_handlers import BaseHandler from qiita_pet.handlers.util import study_person_linkifier, pubmed_linkifier +@execute_as_transaction def _get_shared_links_for_study(study): shared = [] for person in study.shared_with: @@ -44,6 +45,7 @@ def _get_shared_links_for_study(study): return ", ".join(shared) +@execute_as_transaction def _build_single_study_info(study, info, study_proc, proc_samples): """Clean up and add to the study info for HTML purposes @@ -91,6 +93,7 @@ def _build_single_study_info(study, info, study_proc, proc_samples): return info +@execute_as_transaction def _build_single_proc_data_info(proc_data_id, data_type, samples): """Build the proc data info list for the child row in datatable @@ -119,6 +122,7 @@ def _build_single_proc_data_info(proc_data_id, data_type, samples): return proc_info +@execute_as_transaction def _build_study_info(user, study_proc=None, proc_samples=None): """Builds list of dicts for studies table, with all HTML formatted @@ -190,6 +194,7 @@ def _build_study_info(user, study_proc=None, proc_samples=None): return infolist +@execute_as_transaction def _check_owner(user, study): """make sure user is the owner of the study requested""" if not user.id == study.owner: @@ -200,6 +205,7 @@ def _check_owner(user, study): class ListStudiesHandler(BaseHandler): @authenticated @coroutine + @execute_as_transaction def get(self, message="", msg_level=None): all_emails_except_current = yield Task(self._get_all_emails) all_emails_except_current.remove(self.current_user.id) @@ -217,6 +223,7 @@ def _get_all_emails(self, callback): class StudyApprovalList(BaseHandler): @authenticated + @execute_as_transaction def get(self): user = self.current_user if user.level != 'admin': @@ -233,21 +240,25 @@ def get(self): class ShareStudyAJAX(BaseHandler): + @execute_as_transaction def _get_shared_for_study(self, study, callback): shared_links = _get_shared_links_for_study(study) users = study.shared_with callback((users, shared_links)) + @execute_as_transaction def _share(self, study, user, callback): user = User(user) callback(study.share(user)) + @execute_as_transaction def _unshare(self, study, user, callback): user = User(user) callback(study.unshare(user)) @authenticated @coroutine + @execute_as_transaction def get(self): study_id = int(self.get_argument('study_id')) study = Study(study_id) @@ -268,6 +279,7 @@ def get(self): class SearchStudiesAJAX(BaseHandler): @authenticated + @execute_as_transaction def get(self, ignore): user = self.get_argument('user') query = self.get_argument('query') diff --git a/qiita_pet/handlers/study_handlers/metadata_summary_handlers.py b/qiita_pet/handlers/study_handlers/metadata_summary_handlers.py index 7d05dd056..c9353832f 100644 --- a/qiita_pet/handlers/study_handlers/metadata_summary_handlers.py +++ b/qiita_pet/handlers/study_handlers/metadata_summary_handlers.py @@ -14,6 +14,7 @@ from qiita_db.metadata_template import SampleTemplate, PrepTemplate from qiita_db.exceptions import QiitaDBUnknownIDError from qiita_pet.handlers.base_handlers import BaseHandler +from qiita_core.util import execute_as_transaction class MetadataSummaryHandler(BaseHandler): @@ -48,6 +49,7 @@ def _get_template(self, constructor, template_id): return template @authenticated + @execute_as_transaction def get(self, arguments): study_id = int(self.get_argument('study_id')) diff --git a/qiita_pet/handlers/study_handlers/vamps_handlers.py b/qiita_pet/handlers/study_handlers/vamps_handlers.py index 297e107b7..9028661b5 100644 --- a/qiita_pet/handlers/study_handlers/vamps_handlers.py +++ b/qiita_pet/handlers/study_handlers/vamps_handlers.py @@ -17,9 +17,11 @@ from qiita_db.study import Study from qiita_db.exceptions import QiitaDBUnknownIDError from qiita_pet.handlers.base_handlers import BaseHandler +from qiita_core.util import execute_as_transaction class VAMPSHandler(BaseHandler): + @execute_as_transaction def display_template(self, preprocessed_data_id, msg, msg_level): """Simple function to avoid duplication of code""" preprocessed_data_id = int(preprocessed_data_id) @@ -68,6 +70,7 @@ def get(self, preprocessed_data_id): self.display_template(preprocessed_data_id, "", "") @authenticated + @execute_as_transaction def post(self, preprocessed_data_id): # make sure user is admin and can therefore actually submit to VAMPS if self.current_user.level != 'admin': diff --git a/qiita_pet/handlers/upload.py b/qiita_pet/handlers/upload.py index e7049fc48..b12bfc237 100644 --- a/qiita_pet/handlers/upload.py +++ b/qiita_pet/handlers/upload.py @@ -9,7 +9,7 @@ from .base_handlers import BaseHandler from qiita_core.qiita_settings import qiita_config - +from qiita_core.util import execute_as_transaction from qiita_db.util import (get_files_from_uploads_folders, get_mountpoint, move_upload_files_to_trash) from qiita_db.study import Study @@ -18,6 +18,7 @@ class StudyUploadFileHandler(BaseHandler): @authenticated + @execute_as_transaction def display_template(self, study_id, msg): """Simple function to avoid duplication of code""" study_id = int(study_id) @@ -34,6 +35,7 @@ def display_template(self, study_id, msg): files=get_files_from_uploads_folders(str(study_id))) @authenticated + @execute_as_transaction def get(self, study_id): try: study = Study(int(study_id)) @@ -44,6 +46,7 @@ def get(self, study_id): self.display_template(study_id, "") @authenticated + @execute_as_transaction def post(self, study_id): try: study = Study(int(study_id)) @@ -84,6 +87,7 @@ def validate_file_extension(self, filename): (self.current_user, str(filename))) @authenticated + @execute_as_transaction def post(self): resumable_identifier = self.get_argument('resumableIdentifier') resumable_filename = self.get_argument('resumableFilename') @@ -126,6 +130,7 @@ def post(self): self.set_status(200) @authenticated + @execute_as_transaction def get(self): """ this is the first point of entry into the upload service diff --git a/qiita_pet/handlers/user_handlers.py b/qiita_pet/handlers/user_handlers.py index ca3c5c7ce..428b0dff0 100644 --- a/qiita_pet/handlers/user_handlers.py +++ b/qiita_pet/handlers/user_handlers.py @@ -6,7 +6,7 @@ from qiita_db.user import User from qiita_db.logger import LogEntry from qiita_db.exceptions import QiitaDBUnknownIDError -from qiita_core.util import send_email +from qiita_core.util import send_email, execute_as_transaction from qiita_core.qiita_settings import qiita_config @@ -26,6 +26,7 @@ def get(self): self.render("user_profile.html", profile=profile, msg="", passmsg="") @authenticated + @execute_as_transaction def post(self): passmsg = "" msg = "" @@ -75,6 +76,7 @@ class ForgotPasswordHandler(BaseHandler): def get(self): self.render("lost_pass.html", user=None, message="", level="") + @execute_as_transaction def post(self): message = "" level = "" @@ -117,6 +119,7 @@ def get(self, code): self.render("change_lost_pass.html", user=None, message="", level="", code=code) + @execute_as_transaction def post(self, code): message = "" level = "" diff --git a/qiita_pet/handlers/util.py b/qiita_pet/handlers/util.py index b8b99e499..764a05c32 100644 --- a/qiita_pet/handlers/util.py +++ b/qiita_pet/handlers/util.py @@ -11,8 +11,10 @@ from tornado.web import HTTPError from qiita_pet.util import linkify +from qiita_core.util import execute_as_transaction +@execute_as_transaction def check_access(user, study, no_public=False, raise_error=False): """make sure user has access to the study requested""" if not study.has_access(user, no_public): diff --git a/qiita_pet/handlers/websocket_handlers.py b/qiita_pet/handlers/websocket_handlers.py index ebeab09c4..380199bee 100644 --- a/qiita_pet/handlers/websocket_handlers.py +++ b/qiita_pet/handlers/websocket_handlers.py @@ -12,6 +12,7 @@ from moi import r_client from qiita_pet.handlers.base_handlers import BaseHandler from qiita_db.analysis import Analysis +from qiita_core.util import execute_as_transaction class MessageHandler(WebSocketHandler): @@ -79,6 +80,7 @@ def on_close(self): class SelectedSocketHandler(WebSocketHandler, BaseHandler): """Websocket for removing samples on default analysis display page""" @authenticated + @execute_as_transaction def on_message(self, msg): # When the websocket receives a message from the javascript client, # parse into JSON @@ -100,6 +102,7 @@ def on_message(self, msg): class SelectSamplesHandler(WebSocketHandler, BaseHandler): """Websocket for selecting and deselecting samples on list studies page""" @authenticated + @execute_as_transaction def on_message(self, msg): """Selects samples on a message from the user diff --git a/qiita_pet/test/test_auth_handlers.py b/qiita_pet/test/test_auth_handlers.py index 9b7e1d5ea..1a72fea9c 100644 --- a/qiita_pet/test/test_auth_handlers.py +++ b/qiita_pet/test/test_auth_handlers.py @@ -21,6 +21,8 @@ def test_post(self): class TestAuthVerifyHandler(TestHandlerBase): + database = True + def test_get(self): response = self.get('/auth/verify/SOMETHINGHERE?email=test%40foo.bar') self.assertEqual(response.code, 500) diff --git a/qiita_pet/test/test_study_handlers.py b/qiita_pet/test/test_study_handlers.py index ae678ee71..120f4a245 100644 --- a/qiita_pet/test/test_study_handlers.py +++ b/qiita_pet/test/test_study_handlers.py @@ -453,7 +453,7 @@ def test_delete_sample_template(self): # checking that the action was sent self.assertIn("Sample template can not be erased because there are " - "raw datas", response.body) + "prep templates", response.body) def test_delete_raw_data(self): response = self.post('/study/description/1', @@ -464,7 +464,7 @@ def test_delete_raw_data(self): # checking that the action was sent self.assertIn("Couldn't remove raw data 1: Raw data (1) can't be " - "remove because it has linked files", response.body) + "removed because it has linked files", response.body) def test_delete_prep_template(self): response = self.post('/study/description/1', diff --git a/qiita_pet/test/tornado_test_base.py b/qiita_pet/test/tornado_test_base.py index 8c99a5c76..f071b2510 100644 --- a/qiita_pet/test/tornado_test_base.py +++ b/qiita_pet/test/tornado_test_base.py @@ -5,17 +5,14 @@ from urllib.parse import urlencode from tornado.testing import AsyncHTTPTestCase -from qiita_core.qiita_settings import qiita_config from qiita_pet.webserver import Application from qiita_pet.handlers.base_handlers import BaseHandler -from qiita_db.sql_connection import SQLConnectionHandler -from qiita_db.environment_manager import drop_and_rebuild_tst_database +from qiita_db.environment_manager import clean_test_environment from qiita_db.user import User class TestHandlerBase(AsyncHTTPTestCase): database = False - conn_handler = SQLConnectionHandler() app = Application() def get_app(self): @@ -25,18 +22,7 @@ def get_app(self): def setUp(self): if self.database: - # First, we check that we are not in a production environment - # It is possible that we are connecting to a production database - test_db = self.conn_handler.execute_fetchone( - "SELECT test FROM settings")[0] - # Or the loaded config file belongs to a production environment - if not qiita_config.test_environment or not test_db: - raise RuntimeError("Working in a production environment. Not " - "executing the tests to keep the production" - " database safe.") - - # Drop the schema and rebuild the test database - drop_and_rebuild_tst_database(self.conn_handler) + clean_test_environment() super(TestHandlerBase, self).setUp() diff --git a/qiita_pet/uimodules/prep_template_tab.py b/qiita_pet/uimodules/prep_template_tab.py index 8ba3a6f7c..11d2660f3 100644 --- a/qiita_pet/uimodules/prep_template_tab.py +++ b/qiita_pet/uimodules/prep_template_tab.py @@ -24,6 +24,7 @@ from qiita_pet.util import STATUS_STYLER from qiita_pet.handlers.util import download_link_or_path from .base_uimodule import BaseUIModule +from qiita_core.util import execute_as_transaction filepath_types = [k.split('_', 1)[1].replace('_', ' ') @@ -36,6 +37,7 @@ per_sample_FASTQ=['forward seqs', 'reverse seqs']) +@execute_as_transaction def _get_accessible_raw_data(user): """Retrieves a tuple of raw_data_id and one study title for that raw_data @@ -50,6 +52,7 @@ def _get_accessible_raw_data(user): return d +@execute_as_transaction def _template_generator(study, full_access): """Generates tuples of prep template information @@ -75,6 +78,7 @@ def _template_generator(study, full_access): class PrepTemplateTab(BaseUIModule): + @execute_as_transaction def render(self, study, full_access): files = [f for _, f in get_files_from_uploads_folders(str(study.id)) if f.endswith(('txt', 'tsv'))] @@ -105,6 +109,7 @@ def render(self, study, full_access): class PrepTemplateInfoTab(BaseUIModule): + @execute_as_transaction def render(self, study, prep_template, full_access, ena_terms, user_defined_terms): user = self.current_user @@ -247,6 +252,7 @@ def render(self, study, prep_template, full_access, ena_terms, class RawDataInfoDiv(BaseUIModule): + @execute_as_transaction def render(self, raw_data_id, prep_template, study, files): rd = RawData(raw_data_id) raw_data_files = [(basename(fp), fp_type[4:]) @@ -282,6 +288,7 @@ def render(self, raw_data_id, prep_template, study, files): class EditInvestigationType(BaseUIModule): + @execute_as_transaction def render(self, ena_terms, user_defined_terms, prep_id, inv_type, ppd_id): return self.render_string( "study_description_templates/edit_investigation_type.html", diff --git a/qiita_pet/uimodules/preprocessed_data_tab.py b/qiita_pet/uimodules/preprocessed_data_tab.py index 7973524c8..d80803038 100644 --- a/qiita_pet/uimodules/preprocessed_data_tab.py +++ b/qiita_pet/uimodules/preprocessed_data_tab.py @@ -6,6 +6,7 @@ # The full license is in the file LICENSE, distributed with this software. # ----------------------------------------------------------------------------- +from qiita_core.util import execute_as_transaction from qiita_db.data import PreprocessedData from qiita_db.metadata_template import PrepTemplate from qiita_db.ontology import Ontology @@ -16,6 +17,7 @@ class PreprocessedDataTab(BaseUIModule): + @execute_as_transaction def render(self, study, full_access): ppd_gen = (PreprocessedData(ppd_id) for ppd_id in study.preprocessed_data()) @@ -29,6 +31,7 @@ def render(self, study, full_access): class PreprocessedDataInfoTab(BaseUIModule): + @execute_as_transaction def render(self, study_id, preprocessed_data): user = self.current_user ppd_id = preprocessed_data.id diff --git a/qiita_pet/uimodules/processed_data_tab.py b/qiita_pet/uimodules/processed_data_tab.py index e545d9f9b..4e6e467be 100644 --- a/qiita_pet/uimodules/processed_data_tab.py +++ b/qiita_pet/uimodules/processed_data_tab.py @@ -6,6 +6,7 @@ # The full license is in the file LICENSE, distributed with this software. # ----------------------------------------------------------------------------- +from qiita_core.util import execute_as_transaction from qiita_core.qiita_settings import qiita_config from qiita_db.data import ProcessedData from qiita_pet.util import STATUS_STYLER @@ -13,6 +14,7 @@ class ProcessedDataTab(BaseUIModule): + @execute_as_transaction def render(self, study, full_access, allow_approval, approval_deny_msg): pd_gen = (ProcessedData(pd_id) for pd_id in sorted(study.processed_data())) @@ -28,6 +30,7 @@ def render(self, study, full_access, allow_approval, approval_deny_msg): class ProcessedDataInfoTab(BaseUIModule): + @execute_as_transaction def render(self, study_id, processed_data, allow_approval, approval_deny_msg): user = self.current_user diff --git a/qiita_pet/uimodules/study_information_tab.py b/qiita_pet/uimodules/study_information_tab.py index b2e1c154e..7a21a46b2 100644 --- a/qiita_pet/uimodules/study_information_tab.py +++ b/qiita_pet/uimodules/study_information_tab.py @@ -11,6 +11,7 @@ from future.utils import viewitems +from qiita_core.util import execute_as_transaction from qiita_db.util import get_files_from_uploads_folders, get_data_types from qiita_db.study import StudyPerson from qiita_db.metadata_template import SampleTemplate @@ -26,6 +27,7 @@ class StudyInformationTab(BaseUIModule): + @execute_as_transaction def render(self, study): study_info = study.info id = study.id diff --git a/qiita_pet/util.py b/qiita_pet/util.py index 361993094..83e0e3631 100644 --- a/qiita_pet/util.py +++ b/qiita_pet/util.py @@ -23,6 +23,7 @@ # ----------------------------------------------------------------------------- from future.utils import viewitems +from qiita_core.util import execute_as_transaction from qiita_db.reference import Reference @@ -67,6 +68,7 @@ def clean_str(item): return str(item).replace(" ", "_").replace(":", "") +@execute_as_transaction def generate_param_str(param): """Generate an html string with the parameter values