Skip to content

Commit

Permalink
ENH #4163 Fix bug in index + parse date interaction, added test case …
Browse files Browse the repository at this point in the history
…for problem
  • Loading branch information
mangecoeur committed Feb 6, 2014
1 parent f156e81 commit 32b493a
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 80 deletions.
159 changes: 87 additions & 72 deletions pandas/io/sql.py
Expand Up @@ -23,7 +23,7 @@ class DatabaseError(IOError):


#------------------------------------------------------------------------------
# Helper execution functions
# Helper functions

def _convert_params(sql, params):
"""convert sql and params args to DBAPI2.0 compliant format"""
Expand All @@ -33,6 +33,47 @@ def _convert_params(sql, params):
return args


def _safe_col_name(col_name):
#TODO: probably want to forbid database reserved names, such as "database"
return col_name.strip().replace(' ', '_')


def _handle_date_column(col, format=None):
if isinstance(format, dict):
return to_datetime(col, **format)
else:
if format in ['D', 's', 'ms', 'us', 'ns']:
return to_datetime(col, coerce=True, unit=format)
elif issubclass(col.dtype.type, np.floating) or issubclass(col.dtype.type, np.integer):
# parse dates as timestamp
format = 's' if format is None else format
return to_datetime(col, coerce=True, unit=format)
else:
return to_datetime(col, coerce=True, format=format)


def _parse_date_columns(data_frame, parse_dates):
""" Force non-datetime columns to be read as such.
Supports both string formatted and integer timestamp columns
"""
# handle non-list entries for parse_dates gracefully
if parse_dates is True or parse_dates is None or parse_dates is False:
parse_dates = []

if not hasattr(parse_dates, '__iter__'):
parse_dates = [parse_dates]

for col_name in parse_dates:
df_col = data_frame[col_name]
try:
fmt = parse_dates[col_name]
except TypeError:
fmt = None
data_frame[col_name] = _handle_date_column(df_col, format=fmt)

return data_frame


def execute(sql, con, cur=None, params=None, flavor='sqlite'):
"""
Execute the given SQL query using the provided connection object.
Expand All @@ -44,7 +85,7 @@ def execute(sql, con, cur=None, params=None, flavor='sqlite'):
con: SQLAlchemy engine or DBAPI2 connection (legacy mode)
Using SQLAlchemy makes it possible to use any DB supported by that
library.
If a DBAPI2 object is given, a supported SQL flavor must also be provided
If a DBAPI2 object, a supported SQL flavor must also be provided
cur: depreciated, cursor is obtained from connection
params: list or tuple, optional
List of parameters to pass to execute method.
Expand Down Expand Up @@ -283,9 +324,11 @@ def pandasSQL_builder(con, flavor=None, meta=None):
return PandasSQLAlchemy(con, meta=meta)
else:
warnings.warn(
"Not an SQLAlchemy engine, attempting to use as legacy DBAPI connection")
"""Not an SQLAlchemy engine,
attempting to use as legacy DBAPI connection""")
if flavor is None:
raise ValueError("""PandasSQL must be created with an SQLAlchemy engine
raise ValueError(
"""PandasSQL must be created with an SQLAlchemy engine
or a DBAPI2 connection and SQL flavour""")
else:
return PandasSQLLegacy(con, flavor)
Expand All @@ -298,36 +341,16 @@ def pandasSQL_builder(con, flavor=None, meta=None):
return PandasSQLLegacy(con, flavor)


def _safe_col_name(col_name):
return col_name.strip().replace(' ', '_')


def _parse_date_column(col, format=None):
if isinstance(format, dict):
return to_datetime(col, **format)
else:
if format in ['D', 's', 'ms', 'us', 'ns']:
return to_datetime(col, coerce=True, unit=format)
elif issubclass(col.dtype.type, np.floating) or issubclass(col.dtype.type, np.integer):
# parse dates as timestamp
format = 's' if format is None else format
return to_datetime(col, coerce=True, unit=format)
else:
return to_datetime(col, coerce=True, format=format)


def _frame_from_data_and_columns(data, columns, index_col=None,
coerce_float=True):
df = DataFrame.from_records(
data, columns=columns, coerce_float=coerce_float)
if index_col is not None:
df.set_index(index_col, inplace=True)
return df


class PandasSQLTable(PandasObject):

def __init__(self, name, pandas_sql_engine, frame=None, index=True, if_exists='fail', prefix='pandas'):
""" For mapping Pandas tables to SQL tables.
Uses fact that table is reflected by SQLAlchemy to
do better type convertions.
Also holds various flags needed to avoid having to
pass them between functions all the time.
"""
# TODO: support for multiIndex
def __init__(self, name, pandas_sql_engine, frame=None, index=True,
if_exists='fail', prefix='pandas'):
self.name = name
self.pd_sql = pandas_sql_engine
self.prefix = prefix
Expand Down Expand Up @@ -400,13 +423,15 @@ def read(self, coerce_float=True, parse_dates=None, columns=None):
data = result.fetchall()
column_names = result.keys()

self.frame = _frame_from_data_and_columns(data, column_names,
index_col=self.index,
coerce_float=coerce_float)
self.frame = DataFrame.from_records(
data, columns=column_names, coerce_float=coerce_float)

self._harmonize_columns(parse_dates=parse_dates)

# Assume that if the index was in prefix_index format, we gave it a name
if self.index is not None:
self.frame.set_index(self.index, inplace=True)

# Assume if the index in prefix_index format, we gave it a name
# and should return it nameless
if self.index == self.prefix + '_index':
self.frame.index.name = None
Expand Down Expand Up @@ -442,13 +467,14 @@ def _create_table_statement(self):
return Table(self.name, self.pd_sql.meta, *columns)

def _harmonize_columns(self, parse_dates=None):
""" Make a data_frame's column type align with an sql_table column types
""" Make a data_frame's column type align with an sql_table
column types
Need to work around limited NA value support.
Floats are always fine, ints must always
be floats if there are Null values.
Booleans are hard because converting bool column with None replaces
all Nones with false. Therefore only convert bool if there are no NA
values.
all Nones with false. Therefore only convert bool if there are no
NA values.
Datetimes should already be converted
to np.datetime if supported, but here we also force conversion
if required
Expand All @@ -469,7 +495,7 @@ def _harmonize_columns(self, parse_dates=None):

if col_type is datetime or col_type is date:
if not issubclass(df_col.dtype.type, np.datetime64):
self.frame[col_name] = _parse_date_column(df_col)
self.frame[col_name] = _handle_date_column(df_col)

elif col_type is float:
# floats support NA, can always convert!
Expand All @@ -486,7 +512,7 @@ def _harmonize_columns(self, parse_dates=None):
fmt = parse_dates[col_name]
except TypeError:
fmt = None
self.frame[col_name] = _parse_date_column(
self.frame[col_name] = _handle_date_column(
df_col, format=fmt)

except KeyError:
Expand Down Expand Up @@ -543,27 +569,6 @@ def to_sql(self, *args, **kwargs):
raise ValueError(
"PandasSQL must be created with an SQLAlchemy engine or connection+sql flavor")

def _parse_date_columns(self, data_frame, parse_dates):
""" Force non-datetime columns to be read as such.
Supports both string formatted and integer timestamp columns
"""
# handle non-list entries for parse_dates gracefully
if parse_dates is True or parse_dates is None or parse_dates is False:
parse_dates = []

if not hasattr(parse_dates, '__iter__'):
parse_dates = [parse_dates]

for col_name in parse_dates:
df_col = data_frame[col_name]
try:
fmt = parse_dates[col_name]
except TypeError:
fmt = None
data_frame[col_name] = _parse_date_column(df_col, format=fmt)

return data_frame


class PandasSQLAlchemy(PandasSQL):

Expand Down Expand Up @@ -593,17 +598,23 @@ def uquery(self, *args, **kwargs):
result = self.execute(*args, **kwargs)
return result.rowcount

def read_sql(self, sql, index_col=None, coerce_float=True, parse_dates=None, params=None):
def read_sql(self, sql, index_col=None, coerce_float=True,
parse_dates=None, params=None):
args = _convert_params(sql, params)

result = self.execute(*args)
data = result.fetchall()
columns = result.keys()

data_frame = _frame_from_data_and_columns(data, columns,
index_col=index_col,
coerce_float=coerce_float)
data_frame = DataFrame.from_records(
data, columns=columns, coerce_float=coerce_float)

_parse_date_columns(data_frame, parse_dates)

if index_col is not None:
data_frame.set_index(index_col, inplace=True)

return self._parse_date_columns(data_frame, parse_dates)
return data_frame

def to_sql(self, frame, name, if_exists='fail', index=True):
table = PandasSQLTable(
Expand Down Expand Up @@ -818,10 +829,14 @@ def read_sql(self, sql, index_col=None, coerce_float=True, params=None,
data = self._fetchall_as_list(cursor)
cursor.close()

data_frame = _frame_from_data_and_columns(data, columns,
index_col=index_col,
coerce_float=coerce_float)
return self._parse_date_columns(data_frame, parse_dates=parse_dates)
data_frame = DataFrame.from_records(
data, columns=columns, coerce_float=coerce_float)

_parse_date_columns(data_frame, parse_dates)

if index_col is not None:
data_frame.set_index(index_col, inplace=True)
return data_frame

def _fetchall_as_list(self, cur):
result = cur.fetchall()
Expand Down
41 changes: 33 additions & 8 deletions pandas/io/tests/test_sql.py
Expand Up @@ -215,7 +215,7 @@ def _roundtrip(self):
result = self.pandasSQL.read_sql('SELECT * FROM test_frame_roundtrip')

result.set_index('pandas_index', inplace=True)
#result.index.astype(int)
# result.index.astype(int)

result.index.name = None

Expand Down Expand Up @@ -327,7 +327,9 @@ def test_roundtrip(self):
sql.to_sql(self.test_frame1, 'test_frame_roundtrip',
con=self.conn, flavor='sqlite')
result = sql.read_sql(
'SELECT * FROM test_frame_roundtrip', con=self.conn, flavor='sqlite')
'SELECT * FROM test_frame_roundtrip',
con=self.conn,
flavor='sqlite')

# HACK!
result.index = self.test_frame1.index
Expand Down Expand Up @@ -355,28 +357,51 @@ def test_date_parsing(self):
df = sql.read_sql(
"SELECT * FROM types_test_data", self.conn, flavor='sqlite')
self.assertFalse(
issubclass(df.DateCol.dtype.type, np.datetime64), "DateCol loaded with incorrect type")
issubclass(df.DateCol.dtype.type, np.datetime64),
"DateCol loaded with incorrect type")

df = sql.read_sql("SELECT * FROM types_test_data",
self.conn, flavor='sqlite', parse_dates=['DateCol'])
self.assertTrue(
issubclass(df.DateCol.dtype.type, np.datetime64), "DateCol loaded with incorrect type")
issubclass(df.DateCol.dtype.type, np.datetime64),
"DateCol loaded with incorrect type")

df = sql.read_sql("SELECT * FROM types_test_data", self.conn,
flavor='sqlite', parse_dates={'DateCol': '%Y-%m-%d %H:%M:%S'})
flavor='sqlite',
parse_dates={'DateCol': '%Y-%m-%d %H:%M:%S'})
self.assertTrue(
issubclass(df.DateCol.dtype.type, np.datetime64), "DateCol loaded with incorrect type")
issubclass(df.DateCol.dtype.type, np.datetime64),
"DateCol loaded with incorrect type")

df = sql.read_sql("SELECT * FROM types_test_data",
self.conn, flavor='sqlite', parse_dates=['IntDateCol'])
self.conn, flavor='sqlite',
parse_dates=['IntDateCol'])

self.assertTrue(issubclass(df.IntDateCol.dtype.type, np.datetime64),
"IntDateCol loaded with incorrect type")

df = sql.read_sql("SELECT * FROM types_test_data",
self.conn, flavor='sqlite', parse_dates={'IntDateCol': 's'})
self.conn, flavor='sqlite',
parse_dates={'IntDateCol': 's'})

self.assertTrue(issubclass(df.IntDateCol.dtype.type, np.datetime64),
"IntDateCol loaded with incorrect type")

def test_date_and_index(self):
""" Test case where same column appears in parse_date and index_col"""

df = sql.read_sql("SELECT * FROM types_test_data",
self.conn, flavor='sqlite',
parse_dates=['DateCol', 'IntDateCol'],
index_col='DateCol')
self.assertTrue(
issubclass(df.index.dtype.type, np.datetime64),
"DateCol loaded with incorrect type")

self.assertTrue(
issubclass(df.IntDateCol.dtype.type, np.datetime64),
"IntDateCol loaded with incorrect type")


class TestSQLAlchemy(PandasSQLTest):

Expand Down

0 comments on commit 32b493a

Please sign in to comment.