Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Normalize column names when generating SQL #505

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 23 additions & 16 deletions csvkit/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import datetime

import six
from normality import slugify

from sqlalchemy import Column, MetaData, Table, create_engine
from sqlalchemy import BigInteger, Boolean, Date, DateTime, Float, Integer, String, Time
from sqlalchemy import Column, MetaData, Table, create_engine, String, Time
from sqlalchemy import BigInteger, Boolean, Date, DateTime, Float, Integer
from sqlalchemy.schema import CreateTable

NoneType = type(None)
Expand All @@ -27,23 +28,22 @@
SQL_INTEGER_MAX = 2147483647
SQL_INTEGER_MIN = -2147483647

def make_column(column, no_constraints=False):
"""
Creates a sqlalchemy column from a csvkit Column.
"""

def make_column(column, no_constraints=False, normalize_columns=False):
""" Creates a sqlalchemy column from a csvkit Column. """
sql_column_kwargs = {}
sql_type_kwargs = {}

column_types = {
bool: Boolean,
#int: Integer, see special case below
# int: Integer, see special case below
float: Float,
datetime.datetime: DateTime,
datetime.date: Date,
datetime.time: Time,
NoneType: String,
six.text_type: String
}
}

if column.type in column_types:
sql_column_type = column_types[column.type]
Expand All @@ -52,7 +52,7 @@ def make_column(column, no_constraints=False):
column_min = min([v for v in column if v is not None])

if column_max > SQL_INTEGER_MAX or column_min < SQL_INTEGER_MIN:
sql_column_type = BigInteger
sql_column_type = BigInteger
else:
sql_column_type = Integer
else:
Expand All @@ -66,28 +66,36 @@ def make_column(column, no_constraints=False):

sql_column_kwargs['nullable'] = column.has_nulls()

return Column(column.name, sql_column_type(**sql_type_kwargs), **sql_column_kwargs)
name = column.name
if normalize_columns:
name = slugify(name, sep='_')

return Column(name, sql_column_type(**sql_type_kwargs),
**sql_column_kwargs)


def get_connection(connection_string):
engine = create_engine(connection_string)
metadata = MetaData(engine)

return engine, metadata

def make_table(csv_table, name='table_name', no_constraints=False, db_schema=None, metadata=None):
"""
Creates a sqlalchemy table from a csvkit Table.
"""

def make_table(csv_table, name='table_name', no_constraints=False,
db_schema=None, normalize_columns=False, metadata=None):
""" Creates a sqlalchemy table from a csvkit Table. """
if not metadata:
metadata = MetaData()

sql_table = Table(csv_table.name, metadata, schema=db_schema)

for column in csv_table:
sql_table.append_column(make_column(column, no_constraints))
sql_table.append_column(make_column(column, no_constraints,
normalize_columns))

return sql_table


def make_create_table_statement(sql_table, dialect=None):
"""
Generates a CREATE TABLE statement for a sqlalchemy table.
Expand All @@ -99,4 +107,3 @@ def make_create_table_statement(sql_table, dialect=None):
sql_dialect = None

return six.text_type(CreateTable(sql_table).compile(dialect=sql_dialect)).strip() + ';'

52 changes: 33 additions & 19 deletions csvkit/table.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#!/usr/bin/env python

import datetime
import itertools

Expand All @@ -11,22 +10,27 @@
from csvkit.cli import parse_column_identifiers
from csvkit.headers import make_default_headers


class InvalidType(object):
"""
Dummy object type for Column initialization, since None is being used as a valid value.
Dummy object type for Column initialization, since None is being used as a
valid value.
"""
pass


class Column(list):
"""
A normalized data column and inferred annotations (nullable, etc.).
"""
def __init__(self, order, name, l, normal_type=InvalidType, blanks_as_nulls=True, infer_types=True):
def __init__(self, order, name, l, normal_type=InvalidType,
blanks_as_nulls=True, infer_types=True):
"""
Construct a column from a sequence of values.

If normal_type is not InvalidType, inference will be skipped and values assumed to have already been normalized.
If infer_types is False, type inference will be skipped and the type assumed to be unicode.
If normal_type is not InvalidType, inference will be skipped and values
assumed to have already been normalized. If infer_types is False, type
inference will be skipped and the type assumed to be unicode.
"""
if normal_type != InvalidType:
t = normal_type
Expand All @@ -39,7 +43,7 @@ def __init__(self, order, name, l, normal_type=InvalidType, blanks_as_nulls=True

list.__init__(self, data)
self.order = order
self.name = name or '_unnamed' # empty column names don't make sense
self.name = name or '_unnamed' # empty column names don't make sense
self.type = t

def __str__(self):
Expand All @@ -53,7 +57,8 @@ def __unicode__(self):

def __getitem__(self, key):
"""
Return null for keys beyond the range of the column. This allows for columns to be of uneven length and still be merged into rows cleanly.
Return null for keys beyond the range of the column. This allows for
columns to be of uneven length and still be merged into rows cleanly.
"""
l = len(self)

Expand Down Expand Up @@ -88,13 +93,16 @@ def max_length(self):

return l


class Table(list):
"""
A normalized data table and inferred annotations (nullable, etc.).
"""

def __init__(self, columns=[], name='new_table'):
"""
Generic constructor. You should normally use a from_* method to create a Table.
Generic constructor. You should normally use a from_* method to create
a Table.
"""
list.__init__(self, columns)
self.name = name
Expand Down Expand Up @@ -187,14 +195,16 @@ def row(self, i):
return row_data

@classmethod
def from_csv(cls, f, name='from_csv_table', snifflimit=None, column_ids=None, blanks_as_nulls=True, zero_based=False, infer_types=True, no_header_row=False, **kwargs):
"""
Creates a new Table from a file-like object containing CSV data.

Note: the column_ids argument will cause only those columns with a matching identifier
to be parsed, type inferred, etc. However, their order/index property will reflect the
original data (e.g. column 8 will still be "order" 7, even if it's the third column
in the resulting Table.
def from_csv(cls, f, name='from_csv_table', snifflimit=None,
column_ids=None, blanks_as_nulls=True, zero_based=False,
infer_types=True, no_header_row=False, **kwargs):
""" Creates a new Table from a file-like object containing CSV data.

Note: the column_ids argument will cause only those columns with a
matching identifier to be parsed, type inferred, etc. However, their
order/index property will reflect the original data (e.g. column 8
will still be "order" 7, even if it's the third column in the resulting
Table.
"""
# This bit of nonsense is to deal with "files" from stdin,
# which are not seekable and thus must be buffered
Expand All @@ -214,7 +224,8 @@ def from_csv(cls, f, name='from_csv_table', snifflimit=None, column_ids=None, bl
row = next(rows)

headers = make_default_headers(len(row))
column_ids = parse_column_identifiers(column_ids, headers, zero_based)
column_ids = parse_column_identifiers(column_ids, headers,
zero_based)
headers = [headers[c] for c in column_ids]
data_columns = [[] for c in headers]

Expand All @@ -224,7 +235,8 @@ def from_csv(cls, f, name='from_csv_table', snifflimit=None, column_ids=None, bl
headers = next(rows)

if column_ids:
column_ids = parse_column_identifiers(column_ids, headers, zero_based)
column_ids = parse_column_identifiers(column_ids, headers,
zero_based)
headers = [headers[c] for c in column_ids]
else:
column_ids = range(len(headers))
Expand Down Expand Up @@ -254,7 +266,9 @@ def from_csv(cls, f, name='from_csv_table', snifflimit=None, column_ids=None, bl
columns = []

for i, c in enumerate(data_columns):
columns.append(Column(column_ids[i], headers[i], c, blanks_as_nulls=blanks_as_nulls, infer_types=infer_types))
columns.append(Column(column_ids[i], headers[i], c,
blanks_as_nulls=blanks_as_nulls,
infer_types=infer_types))

return Table(columns, name=name)

Expand Down
17 changes: 14 additions & 3 deletions csvkit/utilities/csvsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@
import sys

import agate
from normality import slugify

from csvkit import sql
from csvkit import table
from csvkit.cli import CSVKitUtility


class CSVSQL(CSVKitUtility):
description = 'Generate SQL statements for one or more CSV files, create execute those statements directly on a database, and execute one or more SQL queries.'
description = ("Generate SQL statements for one or more CSV files, create "
"execute those statements directly on a database, and "
"execute one or more SQL queries.")
override_flags = ['l', 'f']

def add_arguments(self):
Expand All @@ -21,6 +25,8 @@ def add_arguments(self):
help='Limit CSV dialect sniffing to the specified number of bytes. Specify "0" to disable sniffing entirely.')
self.argparser.add_argument('-i', '--dialect', dest='dialect', choices=sql.DIALECTS,
help='Dialect of SQL to generate. Only valid when --db is not specified.')
self.argparser.add_argument('-n', '--normalize-columns', dest='normalize_columns', action='store_true',
help='Normalize the headers before generating column names.')
self.argparser.add_argument('--db', dest='connection_string',
help='If present, a sqlalchemy connection string to use to directly execute generated SQL on a database.')
self.argparser.add_argument('--query', default=None,
Expand Down Expand Up @@ -115,7 +121,8 @@ def main(self):
table_name,
self.args.no_constraints,
self.args.db_schema,
metadata
self.args.normalize_columns,
metadata,
)

# Create table
Expand All @@ -126,11 +133,14 @@ def main(self):
if do_insert and csv_table.count_rows() > 0:
insert = sql_table.insert()
headers = csv_table.headers()
if self.args.normalize_columns:
headers = [slugify(h, sep='_') for h in headers]
conn.execute(insert, [dict(zip(headers, row)) for row in csv_table.to_rows()])

# Output SQL statements
else:
sql_table = sql.make_table(csv_table, table_name, self.args.no_constraints)
sql_table = sql.make_table(csv_table, table_name,
self.args.no_constraints)
self.output_file.write('%s\n' % sql.make_create_table_statement(sql_table, dialect=self.args.dialect))

if connection_string:
Expand All @@ -156,6 +166,7 @@ def main(self):
trans.commit()
conn.close()


def launch_new_instance():
utility = CSVSQL()
utility.main()
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
'sqlalchemy>=0.6.6',
'openpyxl==2.2.6',
'six>=1.6.1',
'normality>=0.2.4',
'python-dateutil==2.2',
'dbf>=0.96.005'
]
Expand Down