Skip to content

Commit

Permalink
Merge pull request #8973 from artemyk/notnulldtype_sql
Browse files Browse the repository at this point in the history
ENH: Infer dtype from non-nulls when pushing to SQL
  • Loading branch information
jorisvandenbossche committed Dec 8, 2014
2 parents 406c84d + ffc5097 commit 67ec0a8
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 54 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v0.15.2.txt
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ Enhancements
- ``to_datetime`` gains an ``exact`` keyword to allow for a format to not require an exact match for a provided format string (if its ``False). ``exact`` defaults to ``True`` (meaning that exact matching is still the default) (:issue:`8904`)
- Added ``axvlines`` boolean option to parallel_coordinates plot function, determines whether vertical lines will be printed, default is True
- Added ability to read table footers to read_html (:issue:`8552`)
- ``to_sql`` now infers datatypes of non-NA values for columns that contain NA values and have dtype ``object`` (:issue:`8778`).

.. _whatsnew_0152.performance:

Expand Down
91 changes: 54 additions & 37 deletions pandas/io/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,37 +885,56 @@ def _harmonize_columns(self, parse_dates=None):
except KeyError:
pass # this column not in results

def _get_notnull_col_dtype(self, col):
"""
Infer datatype of the Series col. In case the dtype of col is 'object'
and it contains NA values, this infers the datatype of the not-NA
values. Needed for inserting typed data containing NULLs, GH8778.
"""
col_for_inference = col
if col.dtype == 'object':
notnulldata = col[~isnull(col)]
if len(notnulldata):
col_for_inference = notnulldata

return lib.infer_dtype(col_for_inference)

def _sqlalchemy_type(self, col):
from sqlalchemy.types import (BigInteger, Float, Text, Boolean,
DateTime, Date, Time)

dtype = self.dtype or {}
if col.name in dtype:
return self.dtype[col.name]

if com.is_datetime64_dtype(col):
col_type = self._get_notnull_col_dtype(col)

from sqlalchemy.types import (BigInteger, Float, Text, Boolean,
DateTime, Date, Time)

if col_type == 'datetime64':
try:
tz = col.tzinfo
return DateTime(timezone=True)
except:
return DateTime
if com.is_timedelta64_dtype(col):
if col_type == 'timedelta64':
warnings.warn("the 'timedelta' type is not supported, and will be "
"written as integer values (ns frequency) to the "
"database.", UserWarning)
return BigInteger
elif com.is_float_dtype(col):
elif col_type == 'floating':
return Float
elif com.is_integer_dtype(col):
elif col_type == 'integer':
# TODO: Refine integer size.
return BigInteger
elif com.is_bool_dtype(col):
elif col_type == 'boolean':
return Boolean
inferred = lib.infer_dtype(com._ensure_object(col))
if inferred == 'date':
elif col_type == 'date':
return Date
if inferred == 'time':
elif col_type == 'time':
return Time
elif col_type == 'complex':
raise ValueError('Complex datatypes not supported')

return Text

def _numpy_type(self, sqltype):
Expand Down Expand Up @@ -1187,15 +1206,15 @@ def _create_sql_schema(self, frame, table_name, keys=None):
# SQLAlchemy installed
# SQL type convertions for each DB
_SQL_TYPES = {
'text': {
'string': {
'mysql': 'VARCHAR (63)',
'sqlite': 'TEXT',
},
'float': {
'floating': {
'mysql': 'FLOAT',
'sqlite': 'REAL',
},
'int': {
'integer': {
'mysql': 'BIGINT',
'sqlite': 'INTEGER',
},
Expand All @@ -1211,12 +1230,13 @@ def _create_sql_schema(self, frame, table_name, keys=None):
'mysql': 'TIME',
'sqlite': 'TIME',
},
'bool': {
'boolean': {
'mysql': 'BOOLEAN',
'sqlite': 'INTEGER',
}
}


# SQL enquote and wildcard symbols
_SQL_SYMB = {
'mysql': {
Expand Down Expand Up @@ -1291,8 +1311,8 @@ def _create_table_setup(self):
br_l = _SQL_SYMB[flv]['br_l'] # left val quote char
br_r = _SQL_SYMB[flv]['br_r'] # right val quote char

create_tbl_stmts = [(br_l + '%s' + br_r + ' %s') % (cname, ctype)
for cname, ctype, _ in column_names_and_types]
create_tbl_stmts = [(br_l + '%s' + br_r + ' %s') % (cname, col_type)
for cname, col_type, _ in column_names_and_types]
if self.keys is not None and len(self.keys):
cnames_br = ",".join([br_l + c + br_r for c in self.keys])
create_tbl_stmts.append(
Expand All @@ -1317,30 +1337,27 @@ def _sql_type_name(self, col):
dtype = self.dtype or {}
if col.name in dtype:
return dtype[col.name]
pytype = col.dtype.type
pytype_name = "text"
if issubclass(pytype, np.floating):
pytype_name = "float"
elif com.is_timedelta64_dtype(pytype):

col_type = self._get_notnull_col_dtype(col)
if col_type == 'timedelta64':
warnings.warn("the 'timedelta' type is not supported, and will be "
"written as integer values (ns frequency) to the "
"database.", UserWarning)
pytype_name = "int"
elif issubclass(pytype, np.integer):
pytype_name = "int"
elif issubclass(pytype, np.datetime64) or pytype is datetime:
# Caution: np.datetime64 is also a subclass of np.number.
pytype_name = "datetime"
elif issubclass(pytype, np.bool_):
pytype_name = "bool"
elif issubclass(pytype, np.object):
pytype = lib.infer_dtype(com._ensure_object(col))
if pytype == "date":
pytype_name = "date"
elif pytype == "time":
pytype_name = "time"

return _SQL_TYPES[pytype_name][self.pd_sql.flavor]
col_type = "integer"

elif col_type == "datetime64":
col_type = "datetime"

elif col_type == "empty":
col_type = "string"

elif col_type == "complex":
raise ValueError('Complex datatypes not supported')

if col_type not in _SQL_TYPES:
col_type = "string"

return _SQL_TYPES[col_type][self.pd_sql.flavor]


class SQLiteDatabase(PandasSQL):
Expand Down
80 changes: 63 additions & 17 deletions pandas/io/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,11 @@ def test_timedelta(self):
result = sql.read_sql_query('SELECT * FROM test_timedelta', self.conn)
tm.assert_series_equal(result['foo'], df['foo'].astype('int64'))

def test_complex(self):
df = DataFrame({'a':[1+1j, 2j]})
# Complex data type should raise error
self.assertRaises(ValueError, df.to_sql, 'test_complex', self.conn)

def test_to_sql_index_label(self):
temp_frame = DataFrame({'col1': range(4)})

Expand Down Expand Up @@ -1175,19 +1180,38 @@ def test_dtype(self):
(0.9, None)]
df = DataFrame(data, columns=cols)
df.to_sql('dtype_test', self.conn)
df.to_sql('dtype_test2', self.conn, dtype={'B': sqlalchemy.Boolean})
df.to_sql('dtype_test2', self.conn, dtype={'B': sqlalchemy.TEXT})
meta = sqlalchemy.schema.MetaData(bind=self.conn)
meta.reflect()
sqltype = meta.tables['dtype_test2'].columns['B'].type
self.assertTrue(isinstance(sqltype, sqlalchemy.TEXT))
self.assertRaises(ValueError, df.to_sql,
'error', self.conn, dtype={'B': str})

def test_notnull_dtype(self):
cols = {'Bool': Series([True,None]),
'Date': Series([datetime(2012, 5, 1), None]),
'Int' : Series([1, None], dtype='object'),
'Float': Series([1.1, None])
}
df = DataFrame(cols)

tbl = 'notnull_dtype_test'
df.to_sql(tbl, self.conn)
returned_df = sql.read_sql_table(tbl, self.conn)
meta = sqlalchemy.schema.MetaData(bind=self.conn)
meta.reflect()
self.assertTrue(isinstance(meta.tables['dtype_test'].columns['B'].type,
sqltypes.TEXT))
if self.flavor == 'mysql':
my_type = sqltypes.Integer
else:
my_type = sqltypes.Boolean
self.assertTrue(isinstance(meta.tables['dtype_test2'].columns['B'].type,
my_type))
self.assertRaises(ValueError, df.to_sql,
'error', self.conn, dtype={'B': bool})

col_dict = meta.tables[tbl].columns

self.assertTrue(isinstance(col_dict['Bool'].type, my_type))
self.assertTrue(isinstance(col_dict['Date'].type, sqltypes.DateTime))
self.assertTrue(isinstance(col_dict['Int'].type, sqltypes.Integer))
self.assertTrue(isinstance(col_dict['Float'].type, sqltypes.Float))


class TestSQLiteAlchemy(_TestSQLAlchemy):
Expand Down Expand Up @@ -1507,6 +1531,13 @@ def test_to_sql_save_index(self):
def test_transactions(self):
self._transaction_test()

def _get_sqlite_column_type(self, table, column):
recs = self.conn.execute('PRAGMA table_info(%s)' % table)
for cid, name, ctype, not_null, default, pk in recs:
if name == column:
return ctype
raise ValueError('Table %s, column %s not found' % (table, column))

def test_dtype(self):
if self.flavor == 'mysql':
raise nose.SkipTest('Not applicable to MySQL legacy')
Expand All @@ -1515,20 +1546,35 @@ def test_dtype(self):
(0.9, None)]
df = DataFrame(data, columns=cols)
df.to_sql('dtype_test', self.conn)
df.to_sql('dtype_test2', self.conn, dtype={'B': 'bool'})
df.to_sql('dtype_test2', self.conn, dtype={'B': 'STRING'})

def get_column_type(table, column):
recs = self.conn.execute('PRAGMA table_info(%s)' % table)
for cid, name, ctype, not_null, default, pk in recs:
if name == column:
return ctype
raise ValueError('Table %s, column %s not found' % (table, column))

self.assertEqual(get_column_type('dtype_test', 'B'), 'TEXT')
self.assertEqual(get_column_type('dtype_test2', 'B'), 'bool')
# sqlite stores Boolean values as INTEGER
self.assertEqual(self._get_sqlite_column_type('dtype_test', 'B'), 'INTEGER')

self.assertEqual(self._get_sqlite_column_type('dtype_test2', 'B'), 'STRING')
self.assertRaises(ValueError, df.to_sql,
'error', self.conn, dtype={'B': bool})

def test_notnull_dtype(self):
if self.flavor == 'mysql':
raise nose.SkipTest('Not applicable to MySQL legacy')

cols = {'Bool': Series([True,None]),
'Date': Series([datetime(2012, 5, 1), None]),
'Int' : Series([1, None], dtype='object'),
'Float': Series([1.1, None])
}
df = DataFrame(cols)

tbl = 'notnull_dtype_test'
df.to_sql(tbl, self.conn)

self.assertEqual(self._get_sqlite_column_type(tbl, 'Bool'), 'INTEGER')
self.assertEqual(self._get_sqlite_column_type(tbl, 'Date'), 'TIMESTAMP')
self.assertEqual(self._get_sqlite_column_type(tbl, 'Int'), 'INTEGER')
self.assertEqual(self._get_sqlite_column_type(tbl, 'Float'), 'REAL')


class TestMySQLLegacy(TestSQLiteFallback):
"""
Test the legacy mode against a MySQL database.
Expand Down

0 comments on commit 67ec0a8

Please sign in to comment.