Skip to content

Commit

Permalink
Merge pull request #8083 from artemyk/to_sql_create_indexes
Browse files Browse the repository at this point in the history
BUG: When creating table, db indexes should be created from DataFrame indexes
  • Loading branch information
jorisvandenbossche committed Sep 11, 2014
2 parents 77d5f04 + df48524 commit 54678dd
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 18 deletions.
49 changes: 32 additions & 17 deletions pandas/io/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,17 +566,17 @@ def __init__(self, name, pandas_sql_engine, frame=None, index=True,
raise ValueError("Table '%s' already exists." % name)
elif if_exists == 'replace':
self.pd_sql.drop_table(self.name, self.schema)
self.table = self._create_table_statement()
self.table = self._create_table_setup()
self.create()
elif if_exists == 'append':
self.table = self.pd_sql.get_table(self.name, self.schema)
if self.table is None:
self.table = self._create_table_statement()
self.table = self._create_table_setup()
else:
raise ValueError(
"'{0}' is not valid for if_exists".format(if_exists))
else:
self.table = self._create_table_statement()
self.table = self._create_table_setup()
self.create()
else:
# no data provided, read-only mode
Expand Down Expand Up @@ -703,23 +703,25 @@ def _get_column_names_and_types(self, dtype_mapper):
for i, idx_label in enumerate(self.index):
idx_type = dtype_mapper(
self.frame.index.get_level_values(i))
column_names_and_types.append((idx_label, idx_type))
column_names_and_types.append((idx_label, idx_type, True))

column_names_and_types += [
(str(self.frame.columns[i]),
dtype_mapper(self.frame.iloc[:,i]))
dtype_mapper(self.frame.iloc[:,i]),
False)
for i in range(len(self.frame.columns))
]

return column_names_and_types

def _create_table_statement(self):
def _create_table_setup(self):
from sqlalchemy import Table, Column

column_names_and_types = \
self._get_column_names_and_types(self._sqlalchemy_type)

columns = [Column(name, typ)
for name, typ in column_names_and_types]
columns = [Column(name, typ, index=is_index)
for name, typ, is_index in column_names_and_types]

return Table(self.name, self.pd_sql.meta, *columns, schema=self.schema)

Expand Down Expand Up @@ -979,10 +981,12 @@ class PandasSQLTableLegacy(PandasSQLTable):
Instead of a table variable just use the Create Table
statement"""
def sql_schema(self):
return str(self.table)
return str(";\n".join(self.table))

def create(self):
self.pd_sql.execute(self.table)
with self.pd_sql.con:
for stmt in self.table:
self.pd_sql.execute(stmt)

def insert_statement(self):
names = list(map(str, self.frame.columns))
Expand Down Expand Up @@ -1026,14 +1030,17 @@ def insert(self, chunksize=None):
cur.executemany(ins, data_list)
cur.close()

def _create_table_statement(self):
"Return a CREATE TABLE statement to suit the contents of a DataFrame."
def _create_table_setup(self):
"""Return a list of SQL statement that create a table reflecting the
structure of a DataFrame. The first entry will be a CREATE TABLE
statement while the rest will be CREATE INDEX statements
"""

column_names_and_types = \
self._get_column_names_and_types(self._sql_type_name)

pat = re.compile('\s+')
column_names = [col_name for col_name, _ in column_names_and_types]
column_names = [col_name for col_name, _, _ in column_names_and_types]
if any(map(pat.search, column_names)):
warnings.warn(_SAFE_NAMES_WARNING)

Expand All @@ -1044,13 +1051,21 @@ def _create_table_statement(self):

col_template = br_l + '%s' + br_r + ' %s'

columns = ',\n '.join(col_template %
x for x in column_names_and_types)
columns = ',\n '.join(col_template % (cname, ctype)
for cname, ctype, _ in column_names_and_types)
template = """CREATE TABLE %(name)s (
%(columns)s
)"""
create_statement = template % {'name': self.name, 'columns': columns}
return create_statement
create_stmts = [template % {'name': self.name, 'columns': columns}, ]

ix_tpl = "CREATE INDEX ix_{tbl}_{col} ON {tbl} ({br_l}{col}{br_r})"
for cname, _, is_index in column_names_and_types:
if not is_index:
continue
create_stmts.append(ix_tpl.format(tbl=self.name, col=cname,
br_l=br_l, br_r=br_r))

return create_stmts

def _sql_type_name(self, col):
pytype = col.dtype.type
Expand Down
53 changes: 52 additions & 1 deletion pandas/io/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def _load_test2_data(self):
E=['1990-11-22', '1991-10-26', '1993-11-26', '1995-12-12']))
df['E'] = to_datetime(df['E'])

self.test_frame3 = df
self.test_frame2 = df

def _load_test3_data(self):
columns = ['index', 'A', 'B']
Expand Down Expand Up @@ -324,6 +324,13 @@ def _execute_sql(self):
row = iris_results.fetchone()
tm.equalContents(row, [5.1, 3.5, 1.4, 0.2, 'Iris-setosa'])

def _to_sql_save_index(self):
df = DataFrame.from_records([(1,2.1,'line1'), (2,1.5,'line2')],
columns=['A','B','C'], index=['A'])
self.pandasSQL.to_sql(df, 'test_to_sql_saves_index')
ix_cols = self._get_index_columns('test_to_sql_saves_index')
self.assertEqual(ix_cols, [['A',],])


#------------------------------------------------------------------------------
#--- Testing the public API
Expand Down Expand Up @@ -694,6 +701,13 @@ def test_warning_case_insensitive_table_name(self):
# Verify some things
self.assertEqual(len(w), 0, "Warning triggered for writing a table")

def _get_index_columns(self, tbl_name):
from sqlalchemy.engine import reflection
insp = reflection.Inspector.from_engine(self.conn)
ixs = insp.get_indexes('test_index_saved')
ixs = [i['column_names'] for i in ixs]
return ixs


class TestSQLLegacyApi(_TestSQLApi):
"""
Expand Down Expand Up @@ -1074,6 +1088,16 @@ def test_nan_string(self):
result = sql.read_sql_query('SELECT * FROM test_nan', self.conn)
tm.assert_frame_equal(result, df)

def _get_index_columns(self, tbl_name):
from sqlalchemy.engine import reflection
insp = reflection.Inspector.from_engine(self.conn)
ixs = insp.get_indexes(tbl_name)
ixs = [i['column_names'] for i in ixs]
return ixs

def test_to_sql_save_index(self):
self._to_sql_save_index()


class TestSQLiteAlchemy(_TestSQLAlchemy):
"""
Expand Down Expand Up @@ -1368,6 +1392,20 @@ def test_datetime_time(self):
# test support for datetime.time
raise nose.SkipTest("datetime.time not supported for sqlite fallback")

def _get_index_columns(self, tbl_name):
ixs = sql.read_sql_query(
"SELECT * FROM sqlite_master WHERE type = 'index' " +
"AND tbl_name = '%s'" % tbl_name, self.conn)
ix_cols = []
for ix_name in ixs.name:
ix_info = sql.read_sql_query(
"PRAGMA index_info(%s)" % ix_name, self.conn)
ix_cols.append(ix_info.name.tolist())
return ix_cols

def test_to_sql_save_index(self):
self._to_sql_save_index()


class TestMySQLLegacy(TestSQLiteLegacy):
"""
Expand Down Expand Up @@ -1424,6 +1462,19 @@ def test_a_deprecation(self):
sql.has_table('test_frame1', self.conn, flavor='mysql'),
'Table not written to DB')

def _get_index_columns(self, tbl_name):
ixs = sql.read_sql_query(
"SHOW INDEX IN %s" % tbl_name, self.conn)
ix_cols = {}
for ix_name, ix_col in zip(ixs.Key_name, ixs.Column_name):
if ix_name not in ix_cols:
ix_cols[ix_name] = []
ix_cols[ix_name].append(ix_col)
return list(ix_cols.values())

def test_to_sql_save_index(self):
self._to_sql_save_index()


#------------------------------------------------------------------------------
#--- Old tests from 0.13.1 (before refactor using sqlalchemy)
Expand Down

0 comments on commit 54678dd

Please sign in to comment.