diff --git a/.travis.yml b/.travis.yml index 0293824..3be998f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -26,7 +26,7 @@ script: < data/test-input.csv | docker-compose run psql_client - coverage run - -a --source=csv2sql -m csv2sql data tbl < data/test-input.csv + -a --source=csv2sql -m csv2sql data -r tbl < data/test-input.csv | docker-compose run psql_client after_success: - coveralls diff --git a/csv2sql/main.py b/csv2sql/main.py index 709b71b..06ee9b5 100644 --- a/csv2sql/main.py +++ b/csv2sql/main.py @@ -79,9 +79,11 @@ def _dump_schema(args, in_file=None): ) -def _dump_data(args, in_file=None): +def _dump_data(args, in_file=None, rebuild=None): if not in_file: in_file = args.in_file + if rebuild is None: + rebuild = args.rebuild # Skip the header. reader = csv.reader(in_file, delimiter=args.delimiter) @@ -92,6 +94,7 @@ def _dump_data(args, in_file=None): args.table_name, reader, args.null, + rebuild, ) @@ -100,7 +103,7 @@ def _dump_all(args): _dump_schema(args, in_file=file_iterator) file_iterator.rewind() frozen_file_iterator = file_iterator.freeze() - _dump_data(args, in_file=frozen_file_iterator) + _dump_data(args, in_file=frozen_file_iterator, rebuild=False) def parse_args(arguments): @@ -146,7 +149,7 @@ def parse_args(arguments): schema_factory = argparse.ArgumentParser(add_help=False) schema_factory.add_argument( '-r', '--rebuild', action='store_true', - help='Rebuild the table by "DROP TABLE IF EXISTS".') + help='Rebuild the table by a query such as "DROP TABLE IF EXISTS".') schema_factory.add_argument( '--lines-for-inference', metavar='NUM', help=('Num lines to identify column types.' @@ -154,6 +157,12 @@ def parse_args(arguments): ' used to identify them. [default: 0]'), type=int, default=0) + # insertion_factory. + insertion_factory = argparse.ArgumentParser(add_help=False) + insertion_factory.add_argument( + '-r', '--rebuild', action='store_true', + help='Rebuild the table by a query such as "TRUNCATE TABLE".') + # pattern_readable. pattern_readable = argparse.ArgumentParser(add_help=False) pattern_readable.add_argument( @@ -164,8 +173,9 @@ def parse_args(arguments): schema_dumper = [ readable, writable, query_engine_dependent, csv_readable, query_factory, schema_factory, pattern_readable] - query_dumper = [ - readable, writable, query_engine_dependent, csv_readable, query_factory] + insertion_dumper = [ + readable, writable, query_engine_dependent, csv_readable, + query_factory, insertion_factory] pattern_dumper = [writable, query_engine_dependent, pattern_readable] # Main. @@ -184,7 +194,7 @@ def parse_args(arguments): 'schema', help='Schema queries.', parents=schema_dumper, ).set_defaults(command=_dump_schema) subparsers.add_parser( - 'data', help='Data-insertion queries.', parents=query_dumper, + 'data', help='Data-insertion queries.', parents=insertion_dumper, ).set_defaults(command=_dump_data) subparsers.add_parser( 'pattern', help='Type-inference patterns.', parents=pattern_dumper, diff --git a/csv2sql/queryengines/postgresql.py b/csv2sql/queryengines/postgresql.py index e3f7a5b..c218215 100644 --- a/csv2sql/queryengines/postgresql.py +++ b/csv2sql/queryengines/postgresql.py @@ -98,8 +98,16 @@ def write_schema_statement(out_stream, table_name, column_types, rebuild=False): out_stream.write(_LINE_TERMINATOR) -def write_insert_statement(out_stream, table_name, reader, null_value): - """Write the insert query into `out_stream`.""" +def write_insert_statement( + out_stream, table_name, reader, null_value, rebuild=False): + """Write the insert query into `out_stream`. + When `rebuild` is true, it prepends the query + 'TRUNCATE TABLE `table_name`. + """ + if rebuild: + out_stream.write('TRUNCATE TABLE {0};'.format(table_name)) + out_stream.write(_LINE_TERMINATOR) + out_stream.write( 'COPY {0} FROM STDIN WITH NULL \'{1}\' CSV;'.format( table_name, diff --git a/csv2sql/tests/test_main.py b/csv2sql/tests/test_main.py index 46d7d77..535c235 100644 --- a/csv2sql/tests/test_main.py +++ b/csv2sql/tests/test_main.py @@ -28,7 +28,7 @@ def test_schema_dumper(self, command_name): @parameterized.expand([ ('data',), ]) - def test_query_dumper(self, command_name): + def test_insertion_dumper(self, command_name): arguments = [command_name, 'table-name'] actual = parse_args(arguments) eq_(actual.table_name, 'table-name') @@ -36,5 +36,6 @@ def test_query_dumper(self, command_name): ok_(hasattr(actual, 'out_file')) ok_(hasattr(actual, 'null')) ok_(hasattr(actual, 'delimiter')) + ok_(hasattr(actual, 'rebuild')) ok_(hasattr(actual, 'command')) ok_(hasattr(actual, 'query_engine'))