From 0ea6483ef0c312be5b2c2d2d494fffe5bf925ea0 Mon Sep 17 00:00:00 2001 From: Dan Allan Date: Thu, 11 Jul 2013 10:01:37 -0400 Subject: [PATCH] ENH #4163 Use SQLAlchemy for DB abstraction TST Import sqlalchemy on Travis. DOC add docstrings to read sql ENH read_sql connects via Connection, Engine, file path, or :memory: string CLN Separate legacy code into new file, and fallback so that all old tests pass. TST to use sqlachemy syntax in tests CLN sql into classes, legacy passes FIX few engine vs con calls CLN pep8 cleanup add postgres support for pandas.io.sql.get_schema WIP: cleaup of sql io module - imported correct SQLALCHEMY type, delete redundant PandasSQLWithCon TODO: renamed _engine_read_table, need to think of a better name. TODO: clean up get_conneciton function ENH: cleanup of SQL io TODO: check that legacy mode works TODO: run tests correctly enabled coerce_float option Cleanup and bug-fixing mainly on legacy mode sql. IMPORTANT - changed legacy to require connection rather than cursor. This is still not yet finalized. TODO: tests and doc Added Test coverage for basic functionality using in-memory SQLite database Simplified API by automatically distinguishing between engine and connection. Added warnings --- pandas/io/sql_legacy.py | 332 +++++++++++++++++++ pandas/io/tests/data/iris.csv | 151 +++++++++ pandas/io/tests/test_sql_legacy.py | 497 +++++++++++++++++++++++++++++ 3 files changed, 980 insertions(+) create mode 100644 pandas/io/sql_legacy.py create mode 100644 pandas/io/tests/data/iris.csv create mode 100644 pandas/io/tests/test_sql_legacy.py diff --git a/pandas/io/sql_legacy.py b/pandas/io/sql_legacy.py new file mode 100644 index 0000000000000..a8a5d968dd02d --- /dev/null +++ b/pandas/io/sql_legacy.py @@ -0,0 +1,332 @@ +""" +Collection of query wrappers / abstractions to both facilitate data +retrieval and to reduce dependency on DB-specific API. +""" +from datetime import datetime, date + +import numpy as np +import traceback + +from pandas.core.datetools import format as date_format +from pandas.core.api import DataFrame, isnull + +#------------------------------------------------------------------------------ +# Helper execution function + + +def execute(sql, con, retry=True, cur=None, params=None): + """ + Execute the given SQL query using the provided connection object. + + Parameters + ---------- + sql: string + Query to be executed + con: database connection instance + Database connection. Must implement PEP249 (Database API v2.0). + retry: bool + Not currently implemented + cur: database cursor, optional + Must implement PEP249 (Datbase API v2.0). If cursor is not provided, + one will be obtained from the database connection. + params: list or tuple, optional + List of parameters to pass to execute method. + + Returns + ------- + Cursor object + """ + try: + if cur is None: + cur = con.cursor() + + if params is None: + cur.execute(sql) + else: + cur.execute(sql, params) + return cur + except Exception: + try: + con.rollback() + except Exception: # pragma: no cover + pass + + print ('Error on sql %s' % sql) + raise + + +def _safe_fetch(cur): + try: + result = cur.fetchall() + if not isinstance(result, list): + result = list(result) + return result + except Exception, e: # pragma: no cover + excName = e.__class__.__name__ + if excName == 'OperationalError': + return [] + + +def tquery(sql, con=None, cur=None, retry=True): + """ + Returns list of tuples corresponding to each row in given sql + query. + + If only one column selected, then plain list is returned. + + Parameters + ---------- + sql: string + SQL query to be executed + con: SQLConnection or DB API 2.0-compliant connection + cur: DB API 2.0 cursor + + Provide a specific connection or a specific cursor if you are executing a + lot of sequential statements and want to commit outside. + """ + cur = execute(sql, con, cur=cur) + result = _safe_fetch(cur) + + if con is not None: + try: + cur.close() + con.commit() + except Exception as e: + excName = e.__class__.__name__ + if excName == 'OperationalError': # pragma: no cover + print ('Failed to commit, may need to restart interpreter') + else: + raise + + traceback.print_exc() + if retry: + return tquery(sql, con=con, retry=False) + + if result and len(result[0]) == 1: + # python 3 compat + result = list(list(zip(*result))[0]) + elif result is None: # pragma: no cover + result = [] + + return result + + +def uquery(sql, con=None, cur=None, retry=True, params=None): + """ + Does the same thing as tquery, but instead of returning results, it + returns the number of rows affected. Good for update queries. + """ + cur = execute(sql, con, cur=cur, retry=retry, params=params) + + result = cur.rowcount + try: + con.commit() + except Exception as e: + excName = e.__class__.__name__ + if excName != 'OperationalError': + raise + + traceback.print_exc() + if retry: + print ('Looks like your connection failed, reconnecting...') + return uquery(sql, con, retry=False) + return result + + +def read_frame(sql, con, index_col=None, coerce_float=True, params=None): + """ + Returns a DataFrame corresponding to the result set of the query + string. + + Optionally provide an index_col parameter to use one of the + columns as the index. Otherwise will be 0 to len(results) - 1. + + Parameters + ---------- + sql: string + SQL query to be executed + con: DB connection object, optional + index_col: string, optional + column name to use for the returned DataFrame object. + coerce_float : boolean, default True + Attempt to convert values to non-string, non-numeric objects (like + decimal.Decimal) to floating point, useful for SQL result sets + params: list or tuple, optional + List of parameters to pass to execute method. + """ + cur = execute(sql, con, params=params) + rows = _safe_fetch(cur) + columns = [col_desc[0] for col_desc in cur.description] + + cur.close() + con.commit() + + result = DataFrame.from_records(rows, columns=columns, + coerce_float=coerce_float) + + if index_col is not None: + result = result.set_index(index_col) + + return result + +frame_query = read_frame +read_sql = read_frame + + +def write_frame(frame, name, con, flavor='sqlite', if_exists='fail', **kwargs): + """ + Write records stored in a DataFrame to a SQL database. + + Parameters + ---------- + frame: DataFrame + name: name of SQL table + con: an open SQL database connection object + flavor: {'sqlite', 'mysql', 'oracle'}, default 'sqlite' + if_exists: {'fail', 'replace', 'append'}, default 'fail' + fail: If table exists, do nothing. + replace: If table exists, drop it, recreate it, and insert data. + append: If table exists, insert data. Create if does not exist. + """ + + if 'append' in kwargs: + import warnings + warnings.warn("append is deprecated, use if_exists instead", + FutureWarning) + if kwargs['append']: + if_exists='append' + else: + if_exists='fail' + exists = table_exists(name, con, flavor) + if if_exists == 'fail' and exists: + raise ValueError, "Table '%s' already exists." % name + + #create or drop-recreate if necessary + create = None + if exists and if_exists == 'replace': + create = "DROP TABLE %s" % name + elif not exists: + create = get_schema(frame, name, flavor) + + if create is not None: + cur = con.cursor() + cur.execute(create) + cur.close() + + cur = con.cursor() + # Replace spaces in DataFrame column names with _. + safe_names = [s.replace(' ', '_').strip() for s in frame.columns] + flavor_picker = {'sqlite' : _write_sqlite, + 'mysql' : _write_mysql} + + func = flavor_picker.get(flavor, None) + if func is None: + raise NotImplementedError + func(frame, name, safe_names, cur) + cur.close() + con.commit() + + +def _write_sqlite(frame, table, names, cur): + bracketed_names = ['[' + column + ']' for column in names] + col_names = ','.join(bracketed_names) + wildcards = ','.join(['?'] * len(names)) + insert_query = 'INSERT INTO %s (%s) VALUES (%s)' % ( + table, col_names, wildcards) + # pandas types are badly handled if there is only 1 column ( Issue #3628 ) + if not len(frame.columns )==1 : + data = [tuple(x) for x in frame.values] + else : + data = [tuple(x) for x in frame.values.tolist()] + cur.executemany(insert_query, data) + + +def _write_mysql(frame, table, names, cur): + bracketed_names = ['`' + column + '`' for column in names] + col_names = ','.join(bracketed_names) + wildcards = ','.join([r'%s'] * len(names)) + insert_query = "INSERT INTO %s (%s) VALUES (%s)" % ( + table, col_names, wildcards) + data = [tuple(x) for x in frame.values] + cur.executemany(insert_query, data) + + +def table_exists(name, con, flavor): + flavor_map = { + 'sqlite': ("SELECT name FROM sqlite_master " + "WHERE type='table' AND name='%s';") % name, + 'mysql' : "SHOW TABLES LIKE '%s'" % name} + query = flavor_map.get(flavor, None) + if query is None: + raise NotImplementedError + return len(tquery(query, con)) > 0 + + +def get_sqltype(pytype, flavor): + sqltype = {'mysql': 'VARCHAR (63)', + 'sqlite': 'TEXT'} + + if issubclass(pytype, np.floating): + sqltype['mysql'] = 'FLOAT' + sqltype['sqlite'] = 'REAL' + + if issubclass(pytype, np.integer): + #TODO: Refine integer size. + sqltype['mysql'] = 'BIGINT' + sqltype['sqlite'] = 'INTEGER' + + if issubclass(pytype, np.datetime64) or pytype is datetime: + # Caution: np.datetime64 is also a subclass of np.number. + sqltype['mysql'] = 'DATETIME' + sqltype['sqlite'] = 'TIMESTAMP' + + if pytype is datetime.date: + sqltype['mysql'] = 'DATE' + sqltype['sqlite'] = 'TIMESTAMP' + + if issubclass(pytype, np.bool_): + sqltype['sqlite'] = 'INTEGER' + + return sqltype[flavor] + + +def get_schema(frame, name, flavor, keys=None): + "Return a CREATE TABLE statement to suit the contents of a DataFrame." + lookup_type = lambda dtype: get_sqltype(dtype.type, flavor) + # Replace spaces in DataFrame column names with _. + safe_columns = [s.replace(' ', '_').strip() for s in frame.dtypes.index] + column_types = zip(safe_columns, map(lookup_type, frame.dtypes)) + if flavor == 'sqlite': + columns = ',\n '.join('[%s] %s' % x for x in column_types) + else: + columns = ',\n '.join('`%s` %s' % x for x in column_types) + + keystr = '' + if keys is not None: + if isinstance(keys, basestring): + keys = (keys,) + keystr = ', PRIMARY KEY (%s)' % ','.join(keys) + template = """CREATE TABLE %(name)s ( + %(columns)s + %(keystr)s + );""" + create_statement = template % {'name': name, 'columns': columns, + 'keystr': keystr} + return create_statement + + +def sequence2dict(seq): + """Helper function for cx_Oracle. + + For each element in the sequence, creates a dictionary item equal + to the element and keyed by the position of the item in the list. + >>> sequence2dict(("Matt", 1)) + {'1': 'Matt', '2': 1} + + Source: + http://www.gingerandjohn.com/archives/2004/02/26/cx_oracle-executemany-example/ + """ + d = {} + for k,v in zip(range(1, 1 + len(seq)), seq): + d[str(k)] = v + return d diff --git a/pandas/io/tests/data/iris.csv b/pandas/io/tests/data/iris.csv new file mode 100644 index 0000000000000..c19b9c3688515 --- /dev/null +++ b/pandas/io/tests/data/iris.csv @@ -0,0 +1,151 @@ +SepalLength,SepalWidth,PetalLength,PetalWidth,Name +5.1,3.5,1.4,0.2,Iris-setosa +4.9,3.0,1.4,0.2,Iris-setosa +4.7,3.2,1.3,0.2,Iris-setosa +4.6,3.1,1.5,0.2,Iris-setosa +5.0,3.6,1.4,0.2,Iris-setosa +5.4,3.9,1.7,0.4,Iris-setosa +4.6,3.4,1.4,0.3,Iris-setosa +5.0,3.4,1.5,0.2,Iris-setosa +4.4,2.9,1.4,0.2,Iris-setosa +4.9,3.1,1.5,0.1,Iris-setosa +5.4,3.7,1.5,0.2,Iris-setosa +4.8,3.4,1.6,0.2,Iris-setosa +4.8,3.0,1.4,0.1,Iris-setosa +4.3,3.0,1.1,0.1,Iris-setosa +5.8,4.0,1.2,0.2,Iris-setosa +5.7,4.4,1.5,0.4,Iris-setosa +5.4,3.9,1.3,0.4,Iris-setosa +5.1,3.5,1.4,0.3,Iris-setosa +5.7,3.8,1.7,0.3,Iris-setosa +5.1,3.8,1.5,0.3,Iris-setosa +5.4,3.4,1.7,0.2,Iris-setosa +5.1,3.7,1.5,0.4,Iris-setosa +4.6,3.6,1.0,0.2,Iris-setosa +5.1,3.3,1.7,0.5,Iris-setosa +4.8,3.4,1.9,0.2,Iris-setosa +5.0,3.0,1.6,0.2,Iris-setosa +5.0,3.4,1.6,0.4,Iris-setosa +5.2,3.5,1.5,0.2,Iris-setosa +5.2,3.4,1.4,0.2,Iris-setosa +4.7,3.2,1.6,0.2,Iris-setosa +4.8,3.1,1.6,0.2,Iris-setosa +5.4,3.4,1.5,0.4,Iris-setosa +5.2,4.1,1.5,0.1,Iris-setosa +5.5,4.2,1.4,0.2,Iris-setosa +4.9,3.1,1.5,0.1,Iris-setosa +5.0,3.2,1.2,0.2,Iris-setosa +5.5,3.5,1.3,0.2,Iris-setosa +4.9,3.1,1.5,0.1,Iris-setosa +4.4,3.0,1.3,0.2,Iris-setosa +5.1,3.4,1.5,0.2,Iris-setosa +5.0,3.5,1.3,0.3,Iris-setosa +4.5,2.3,1.3,0.3,Iris-setosa +4.4,3.2,1.3,0.2,Iris-setosa +5.0,3.5,1.6,0.6,Iris-setosa +5.1,3.8,1.9,0.4,Iris-setosa +4.8,3.0,1.4,0.3,Iris-setosa +5.1,3.8,1.6,0.2,Iris-setosa +4.6,3.2,1.4,0.2,Iris-setosa +5.3,3.7,1.5,0.2,Iris-setosa +5.0,3.3,1.4,0.2,Iris-setosa +7.0,3.2,4.7,1.4,Iris-versicolor +6.4,3.2,4.5,1.5,Iris-versicolor +6.9,3.1,4.9,1.5,Iris-versicolor +5.5,2.3,4.0,1.3,Iris-versicolor +6.5,2.8,4.6,1.5,Iris-versicolor +5.7,2.8,4.5,1.3,Iris-versicolor +6.3,3.3,4.7,1.6,Iris-versicolor +4.9,2.4,3.3,1.0,Iris-versicolor +6.6,2.9,4.6,1.3,Iris-versicolor +5.2,2.7,3.9,1.4,Iris-versicolor +5.0,2.0,3.5,1.0,Iris-versicolor +5.9,3.0,4.2,1.5,Iris-versicolor +6.0,2.2,4.0,1.0,Iris-versicolor +6.1,2.9,4.7,1.4,Iris-versicolor +5.6,2.9,3.6,1.3,Iris-versicolor +6.7,3.1,4.4,1.4,Iris-versicolor +5.6,3.0,4.5,1.5,Iris-versicolor +5.8,2.7,4.1,1.0,Iris-versicolor +6.2,2.2,4.5,1.5,Iris-versicolor +5.6,2.5,3.9,1.1,Iris-versicolor +5.9,3.2,4.8,1.8,Iris-versicolor +6.1,2.8,4.0,1.3,Iris-versicolor +6.3,2.5,4.9,1.5,Iris-versicolor +6.1,2.8,4.7,1.2,Iris-versicolor +6.4,2.9,4.3,1.3,Iris-versicolor +6.6,3.0,4.4,1.4,Iris-versicolor +6.8,2.8,4.8,1.4,Iris-versicolor +6.7,3.0,5.0,1.7,Iris-versicolor +6.0,2.9,4.5,1.5,Iris-versicolor +5.7,2.6,3.5,1.0,Iris-versicolor +5.5,2.4,3.8,1.1,Iris-versicolor +5.5,2.4,3.7,1.0,Iris-versicolor +5.8,2.7,3.9,1.2,Iris-versicolor +6.0,2.7,5.1,1.6,Iris-versicolor +5.4,3.0,4.5,1.5,Iris-versicolor +6.0,3.4,4.5,1.6,Iris-versicolor +6.7,3.1,4.7,1.5,Iris-versicolor +6.3,2.3,4.4,1.3,Iris-versicolor +5.6,3.0,4.1,1.3,Iris-versicolor +5.5,2.5,4.0,1.3,Iris-versicolor +5.5,2.6,4.4,1.2,Iris-versicolor +6.1,3.0,4.6,1.4,Iris-versicolor +5.8,2.6,4.0,1.2,Iris-versicolor +5.0,2.3,3.3,1.0,Iris-versicolor +5.6,2.7,4.2,1.3,Iris-versicolor +5.7,3.0,4.2,1.2,Iris-versicolor +5.7,2.9,4.2,1.3,Iris-versicolor +6.2,2.9,4.3,1.3,Iris-versicolor +5.1,2.5,3.0,1.1,Iris-versicolor +5.7,2.8,4.1,1.3,Iris-versicolor +6.3,3.3,6.0,2.5,Iris-virginica +5.8,2.7,5.1,1.9,Iris-virginica +7.1,3.0,5.9,2.1,Iris-virginica +6.3,2.9,5.6,1.8,Iris-virginica +6.5,3.0,5.8,2.2,Iris-virginica +7.6,3.0,6.6,2.1,Iris-virginica +4.9,2.5,4.5,1.7,Iris-virginica +7.3,2.9,6.3,1.8,Iris-virginica +6.7,2.5,5.8,1.8,Iris-virginica +7.2,3.6,6.1,2.5,Iris-virginica +6.5,3.2,5.1,2.0,Iris-virginica +6.4,2.7,5.3,1.9,Iris-virginica +6.8,3.0,5.5,2.1,Iris-virginica +5.7,2.5,5.0,2.0,Iris-virginica +5.8,2.8,5.1,2.4,Iris-virginica +6.4,3.2,5.3,2.3,Iris-virginica +6.5,3.0,5.5,1.8,Iris-virginica +7.7,3.8,6.7,2.2,Iris-virginica +7.7,2.6,6.9,2.3,Iris-virginica +6.0,2.2,5.0,1.5,Iris-virginica +6.9,3.2,5.7,2.3,Iris-virginica +5.6,2.8,4.9,2.0,Iris-virginica +7.7,2.8,6.7,2.0,Iris-virginica +6.3,2.7,4.9,1.8,Iris-virginica +6.7,3.3,5.7,2.1,Iris-virginica +7.2,3.2,6.0,1.8,Iris-virginica +6.2,2.8,4.8,1.8,Iris-virginica +6.1,3.0,4.9,1.8,Iris-virginica +6.4,2.8,5.6,2.1,Iris-virginica +7.2,3.0,5.8,1.6,Iris-virginica +7.4,2.8,6.1,1.9,Iris-virginica +7.9,3.8,6.4,2.0,Iris-virginica +6.4,2.8,5.6,2.2,Iris-virginica +6.3,2.8,5.1,1.5,Iris-virginica +6.1,2.6,5.6,1.4,Iris-virginica +7.7,3.0,6.1,2.3,Iris-virginica +6.3,3.4,5.6,2.4,Iris-virginica +6.4,3.1,5.5,1.8,Iris-virginica +6.0,3.0,4.8,1.8,Iris-virginica +6.9,3.1,5.4,2.1,Iris-virginica +6.7,3.1,5.6,2.4,Iris-virginica +6.9,3.1,5.1,2.3,Iris-virginica +5.8,2.7,5.1,1.9,Iris-virginica +6.8,3.2,5.9,2.3,Iris-virginica +6.7,3.3,5.7,2.5,Iris-virginica +6.7,3.0,5.2,2.3,Iris-virginica +6.3,2.5,5.0,1.9,Iris-virginica +6.5,3.0,5.2,2.0,Iris-virginica +6.2,3.4,5.4,2.3,Iris-virginica +5.9,3.0,5.1,1.8,Iris-virginica \ No newline at end of file diff --git a/pandas/io/tests/test_sql_legacy.py b/pandas/io/tests/test_sql_legacy.py new file mode 100644 index 0000000000000..3c6e992097d30 --- /dev/null +++ b/pandas/io/tests/test_sql_legacy.py @@ -0,0 +1,497 @@ +from __future__ import with_statement +from pandas.compat import StringIO +import unittest +import sqlite3 +import sys + +import warnings + +import nose + +import numpy as np + +from pandas.core.datetools import format as date_format +from pandas.core.api import DataFrame, isnull +from pandas.compat import StringIO, range, lrange +import pandas.compat as compat + +import pandas.io.sql as sql +from pandas.io.sql import DatabaseError +import pandas.util.testing as tm +from pandas import Series, Index, DataFrame +from datetime import datetime + +_formatters = { + datetime: lambda dt: "'%s'" % date_format(dt), + str: lambda x: "'%s'" % x, + np.str_: lambda x: "'%s'" % x, + compat.text_type: lambda x: "'%s'" % x, + compat.binary_type: lambda x: "'%s'" % x, + float: lambda x: "%.8f" % x, + int: lambda x: "%s" % x, + type(None): lambda x: "NULL", + np.float64: lambda x: "%.10f" % x, + bool: lambda x: "'%s'" % x, +} + +def format_query(sql, *args): + """ + + """ + processed_args = [] + for arg in args: + if isinstance(arg, float) and isnull(arg): + arg = None + + formatter = _formatters[type(arg)] + processed_args.append(formatter(arg)) + + return sql % tuple(processed_args) + +def _skip_if_no_MySQLdb(): + try: + import MySQLdb + except ImportError: + raise nose.SkipTest('MySQLdb not installed, skipping') + +class TestSQLite(unittest.TestCase): + + def setUp(self): + self.db = sqlite3.connect(':memory:') + + def test_basic(self): + frame = tm.makeTimeDataFrame() + self._check_roundtrip(frame) + + def test_write_row_by_row(self): + frame = tm.makeTimeDataFrame() + frame.ix[0, 0] = np.nan + create_sql = sql.get_schema(frame, 'test', 'sqlite') + cur = self.db.cursor() + cur.execute(create_sql) + + cur = self.db.cursor() + + ins = "INSERT INTO test VALUES (%s, %s, %s, %s)" + for idx, row in frame.iterrows(): + fmt_sql = format_query(ins, *row) + sql.tquery(fmt_sql, cur=cur) + + self.db.commit() + + result = sql.read_frame("select * from test", con=self.db) + result.index = frame.index + tm.assert_frame_equal(result, frame) + + def test_execute(self): + frame = tm.makeTimeDataFrame() + create_sql = sql.get_schema(frame, 'test', 'sqlite') + cur = self.db.cursor() + cur.execute(create_sql) + ins = "INSERT INTO test VALUES (?, ?, ?, ?)" + + row = frame.ix[0] + sql.execute(ins, self.db, params=tuple(row)) + self.db.commit() + + result = sql.read_frame("select * from test", self.db) + result.index = frame.index[:1] + tm.assert_frame_equal(result, frame[:1]) + + def test_schema(self): + frame = tm.makeTimeDataFrame() + create_sql = sql.get_schema(frame, 'test', 'sqlite') + lines = create_sql.splitlines() + for l in lines: + tokens = l.split(' ') + if len(tokens) == 2 and tokens[0] == 'A': + self.assert_(tokens[1] == 'DATETIME') + + frame = tm.makeTimeDataFrame() + create_sql = sql.get_schema(frame, 'test', 'sqlite', keys=['A', 'B'],) + lines = create_sql.splitlines() + self.assert_('PRIMARY KEY (A,B)' in create_sql) + cur = self.db.cursor() + cur.execute(create_sql) + + def test_execute_fail(self): + create_sql = """ + CREATE TABLE test + ( + a TEXT, + b TEXT, + c REAL, + PRIMARY KEY (a, b) + ); + """ + cur = self.db.cursor() + cur.execute(create_sql) + + sql.execute('INSERT INTO test VALUES("foo", "bar", 1.234)', self.db) + sql.execute('INSERT INTO test VALUES("foo", "baz", 2.567)', self.db) + + try: + sys.stdout = StringIO() + self.assertRaises(Exception, sql.execute, + 'INSERT INTO test VALUES("foo", "bar", 7)', + self.db) + finally: + sys.stdout = sys.__stdout__ + + def test_execute_closed_connection(self): + create_sql = """ + CREATE TABLE test + ( + a TEXT, + b TEXT, + c REAL, + PRIMARY KEY (a, b) + ); + """ + cur = self.db.cursor() + cur.execute(create_sql) + + sql.execute('INSERT INTO test VALUES("foo", "bar", 1.234)', self.db) + self.db.close() + try: + sys.stdout = StringIO() + self.assertRaises(Exception, sql.tquery, "select * from test", + con=self.db) + finally: + sys.stdout = sys.__stdout__ + + def test_na_roundtrip(self): + pass + + def _check_roundtrip(self, frame): + sql.write_frame(frame, name='test_table', con=self.db) + result = sql.read_frame("select * from test_table", self.db) + + # HACK! Change this once indexes are handled properly. + result.index = frame.index + + expected = frame + tm.assert_frame_equal(result, expected) + + frame['txt'] = ['a'] * len(frame) + frame2 = frame.copy() + frame2['Idx'] = Index(lrange(len(frame2))) + 10 + sql.write_frame(frame2, name='test_table2', con=self.db) + result = sql.read_frame("select * from test_table2", self.db, + index_col='Idx') + expected = frame.copy() + expected.index = Index(lrange(len(frame2))) + 10 + expected.index.name = 'Idx' + print(expected.index.names) + print(result.index.names) + tm.assert_frame_equal(expected, result) + + def test_tquery(self): + frame = tm.makeTimeDataFrame() + sql.write_frame(frame, name='test_table', con=self.db) + result = sql.tquery("select A from test_table", self.db) + expected = frame.A + result = Series(result, frame.index) + tm.assert_series_equal(result, expected) + + try: + sys.stdout = StringIO() + self.assertRaises(DatabaseError, sql.tquery, + 'select * from blah', con=self.db) + + self.assertRaises(DatabaseError, sql.tquery, + 'select * from blah', con=self.db, retry=True) + finally: + sys.stdout = sys.__stdout__ + + def test_uquery(self): + frame = tm.makeTimeDataFrame() + sql.write_frame(frame, name='test_table', con=self.db) + stmt = 'INSERT INTO test_table VALUES(2.314, -123.1, 1.234, 2.3)' + self.assertEqual(sql.uquery(stmt, con=self.db), 1) + + try: + sys.stdout = StringIO() + + self.assertRaises(DatabaseError, sql.tquery, + 'insert into blah values (1)', con=self.db) + + self.assertRaises(DatabaseError, sql.tquery, + 'insert into blah values (1)', con=self.db, + retry=True) + finally: + sys.stdout = sys.__stdout__ + + def test_keyword_as_column_names(self): + ''' + ''' + df = DataFrame({'From':np.ones(5)}) + sql.write_frame(df, con = self.db, name = 'testkeywords') + + def test_onecolumn_of_integer(self): + ''' + GH 3628 + a column_of_integers dataframe should transfer well to sql + ''' + mono_df=DataFrame([1 , 2], columns=['c0']) + sql.write_frame(mono_df, con = self.db, name = 'mono_df') + # computing the sum via sql + con_x=self.db + the_sum=sum([my_c0[0] for my_c0 in con_x.execute("select * from mono_df")]) + # it should not fail, and gives 3 ( Issue #3628 ) + self.assertEqual(the_sum , 3) + + result = sql.read_frame("select * from mono_df",con_x) + tm.assert_frame_equal(result,mono_df) + + +class TestMySQL(unittest.TestCase): + + def setUp(self): + _skip_if_no_MySQLdb() + import MySQLdb + try: + # Try Travis defaults. + # No real user should allow root access with a blank password. + self.db = MySQLdb.connect(host='localhost', user='root', passwd='', + db='pandas_nosetest') + except: + pass + else: + return + try: + self.db = MySQLdb.connect(read_default_group='pandas') + except MySQLdb.ProgrammingError as e: + raise nose.SkipTest( + "Create a group of connection parameters under the heading " + "[pandas] in your system's mysql default file, " + "typically located at ~/.my.cnf or /etc/.my.cnf. ") + except MySQLdb.Error as e: + raise nose.SkipTest( + "Cannot connect to database. " + "Create a group of connection parameters under the heading " + "[pandas] in your system's mysql default file, " + "typically located at ~/.my.cnf or /etc/.my.cnf. ") + + def test_basic(self): + _skip_if_no_MySQLdb() + frame = tm.makeTimeDataFrame() + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "For more robust support.*") + self._check_roundtrip(frame) + + def test_write_row_by_row(self): + _skip_if_no_MySQLdb() + frame = tm.makeTimeDataFrame() + frame.ix[0, 0] = np.nan + drop_sql = "DROP TABLE IF EXISTS test" + create_sql = sql.get_schema(frame, 'test', 'mysql') + cur = self.db.cursor() + cur.execute(drop_sql) + cur.execute(create_sql) + ins = "INSERT INTO test VALUES (%s, %s, %s, %s)" + for idx, row in frame.iterrows(): + fmt_sql = format_query(ins, *row) + sql.tquery(fmt_sql, cur=cur) + + self.db.commit() + + result = sql.read_frame("select * from test", con=self.db) + result.index = frame.index + tm.assert_frame_equal(result, frame) + + def test_execute(self): + _skip_if_no_MySQLdb() + frame = tm.makeTimeDataFrame() + drop_sql = "DROP TABLE IF EXISTS test" + create_sql = sql.get_schema(frame, 'test', 'mysql') + cur = self.db.cursor() + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "Unknown table.*") + cur.execute(drop_sql) + cur.execute(create_sql) + ins = "INSERT INTO test VALUES (%s, %s, %s, %s)" + + row = frame.ix[0] + sql.execute(ins, self.db, params=tuple(row)) + self.db.commit() + + result = sql.read_frame("select * from test", self.db) + result.index = frame.index[:1] + tm.assert_frame_equal(result, frame[:1]) + + def test_schema(self): + _skip_if_no_MySQLdb() + frame = tm.makeTimeDataFrame() + create_sql = sql.get_schema(frame, 'test', 'mysql') + lines = create_sql.splitlines() + for l in lines: + tokens = l.split(' ') + if len(tokens) == 2 and tokens[0] == 'A': + self.assert_(tokens[1] == 'DATETIME') + + frame = tm.makeTimeDataFrame() + drop_sql = "DROP TABLE IF EXISTS test" + create_sql = sql.get_schema(frame, 'test', 'mysql', keys=['A', 'B'],) + lines = create_sql.splitlines() + self.assert_('PRIMARY KEY (A,B)' in create_sql) + cur = self.db.cursor() + cur.execute(drop_sql) + cur.execute(create_sql) + + def test_execute_fail(self): + _skip_if_no_MySQLdb() + drop_sql = "DROP TABLE IF EXISTS test" + create_sql = """ + CREATE TABLE test + ( + a TEXT, + b TEXT, + c REAL, + PRIMARY KEY (a(5), b(5)) + ); + """ + cur = self.db.cursor() + cur.execute(drop_sql) + cur.execute(create_sql) + + sql.execute('INSERT INTO test VALUES("foo", "bar", 1.234)', self.db) + sql.execute('INSERT INTO test VALUES("foo", "baz", 2.567)', self.db) + + try: + sys.stdout = StringIO() + self.assertRaises(Exception, sql.execute, + 'INSERT INTO test VALUES("foo", "bar", 7)', + self.db) + finally: + sys.stdout = sys.__stdout__ + + def test_execute_closed_connection(self): + _skip_if_no_MySQLdb() + drop_sql = "DROP TABLE IF EXISTS test" + create_sql = """ + CREATE TABLE test + ( + a TEXT, + b TEXT, + c REAL, + PRIMARY KEY (a(5), b(5)) + ); + """ + cur = self.db.cursor() + cur.execute(drop_sql) + cur.execute(create_sql) + + sql.execute('INSERT INTO test VALUES("foo", "bar", 1.234)', self.db) + self.db.close() + try: + sys.stdout = StringIO() + self.assertRaises(Exception, sql.tquery, "select * from test", + con=self.db) + finally: + sys.stdout = sys.__stdout__ + + def test_na_roundtrip(self): + _skip_if_no_MySQLdb() + pass + + def _check_roundtrip(self, frame): + _skip_if_no_MySQLdb() + drop_sql = "DROP TABLE IF EXISTS test_table" + cur = self.db.cursor() + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "Unknown table.*") + cur.execute(drop_sql) + sql.write_frame(frame, name='test_table', con=self.db, flavor='mysql') + result = sql.read_frame("select * from test_table", self.db) + + # HACK! Change this once indexes are handled properly. + result.index = frame.index + result.index.name = frame.index.name + + expected = frame + tm.assert_frame_equal(result, expected) + + frame['txt'] = ['a'] * len(frame) + frame2 = frame.copy() + index = Index(lrange(len(frame2))) + 10 + frame2['Idx'] = index + drop_sql = "DROP TABLE IF EXISTS test_table2" + cur = self.db.cursor() + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "Unknown table.*") + cur.execute(drop_sql) + sql.write_frame(frame2, name='test_table2', con=self.db, flavor='mysql') + result = sql.read_frame("select * from test_table2", self.db, + index_col='Idx') + expected = frame.copy() + + # HACK! Change this once indexes are handled properly. + expected.index = index + expected.index.names = result.index.names + tm.assert_frame_equal(expected, result) + + def test_tquery(self): + try: + import MySQLdb + except ImportError: + raise nose.SkipTest + frame = tm.makeTimeDataFrame() + drop_sql = "DROP TABLE IF EXISTS test_table" + cur = self.db.cursor() + cur.execute(drop_sql) + sql.write_frame(frame, name='test_table', con=self.db, flavor='mysql') + result = sql.tquery("select A from test_table", self.db) + expected = frame.A + result = Series(result, frame.index) + tm.assert_series_equal(result, expected) + + try: + sys.stdout = StringIO() + self.assertRaises(DatabaseError, sql.tquery, + 'select * from blah', con=self.db) + + self.assertRaises(DatabaseError, sql.tquery, + 'select * from blah', con=self.db, retry=True) + finally: + sys.stdout = sys.__stdout__ + + def test_uquery(self): + try: + import MySQLdb + except ImportError: + raise nose.SkipTest + frame = tm.makeTimeDataFrame() + drop_sql = "DROP TABLE IF EXISTS test_table" + cur = self.db.cursor() + cur.execute(drop_sql) + sql.write_frame(frame, name='test_table', con=self.db, flavor='mysql') + stmt = 'INSERT INTO test_table VALUES(2.314, -123.1, 1.234, 2.3)' + self.assertEqual(sql.uquery(stmt, con=self.db), 1) + + try: + sys.stdout = StringIO() + + self.assertRaises(DatabaseError, sql.tquery, + 'insert into blah values (1)', con=self.db) + + self.assertRaises(DatabaseError, sql.tquery, + 'insert into blah values (1)', con=self.db, + retry=True) + finally: + sys.stdout = sys.__stdout__ + + def test_keyword_as_column_names(self): + ''' + ''' + _skip_if_no_MySQLdb() + df = DataFrame({'From':np.ones(5)}) + sql.write_frame(df, name='testkeywords', con=self.db, + if_exists='replace', flavor='mysql') + +if __name__ == '__main__': + # unittest.main() + # nose.runmodule(argv=[__file__,'-vvs','-x', '--pdb-failure'], + # exit=False) + nose.runmodule(argv=[__file__, '-vvs', '-x', '--pdb', '--pdb-failure'], + exit=False)