diff --git a/postgres_copy/__init__.py b/postgres_copy/__init__.py index 76520af..41d14cd 100644 --- a/postgres_copy/__init__.py +++ b/postgres_copy/__init__.py @@ -89,10 +89,84 @@ def validate_mapping(self): if not self.get_field(static_field): raise ValueError("Model does not include %s field" % static_field) + def create(self, cursor): + """ + Generate and run create sql for the temp table. + Runs a DROP on same prior to CREATE to avoid collisions. + + cursor: + A cursor object on the db + """ + self.drop(cursor) + create_sql = self.prep_create() + cursor.execute(create_sql) + + def pre_copy(self, cursor): + pass + + def copy(self, cursor): + """ + Generate and run the COPY command to copy data from csv to temp table. + Calls `self.pre_copy(cursor)` and `self.post_copy(cursor)` respectively + before and after running copy + + cursor: + A cursor object on the db + """ + self.pre_copy(cursor) + copy_sql = self.prep_copy() + fp = open(self.csv_path, 'r') + cursor.copy_expert(copy_sql, fp) + self.post_copy(cursor) + + def post_copy(self, cursor): + pass + + def pre_insert(self, cursor): + pass + + def insert(self, cursor): + """ + Generate and run the INSERT command to move data from the temp table + to the concrete table. + Calls `self.pre_copy(cursor)` and `self.post_copy(cursor)` respectively + before and after running copy + + returns: the count of rows inserted + + cursor: + A cursor object on the db + """ + self.pre_insert(cursor) + insert_sql = self.prep_insert() + cursor.execute(insert_sql) + insert_count = cursor.rowcount + self.post_insert(cursor) + + return insert_count + + def post_insert(self, cursor): + pass + + def drop(self, cursor): + """ + Generate and run the DROP command for the temp table. + + cursor: + A cursor object on the db + """ + drop_sql = self.prep_drop() + cursor.execute(drop_sql) + def save(self, silent=False, stream=sys.stdout): """ Saves the contents of the CSV file to the database. + Override this method and use 'self.create(cursor)`, + `self.copy(cursor)`, `self.insert(cursor)`, and `self.drop(cursor)` + if you need functionality other than the default create/copy/insert/drop + workflow. + silent: By default, non-fatal error notifications are printed to stdout, but this keyword may be set to disable these notifications. @@ -106,22 +180,11 @@ def save(self, silent=False, stream=sys.stdout): stream.write("Loading CSV to %s\n" % self.model.__name__) # Connect to the database - cursor = self.conn.cursor() - - # Create all of the raw SQL - drop_sql = self.prep_drop() - create_sql = self.prep_create() - copy_sql = self.prep_copy() - insert_sql = self.prep_insert() - - # Run all of the raw SQL - cursor.execute(drop_sql) - cursor.execute(create_sql) - fp = open(self.csv_path, 'r') - cursor.copy_expert(copy_sql, fp) - cursor.execute(insert_sql) - insert_count = cursor.rowcount - cursor.execute(drop_sql) + with self.conn.cursor() as c: + self.create(c) + self.copy(c) + insert_count = self.insert(c) + self.drop(c) if not silent: stream.write( diff --git a/tests/models.py b/tests/models.py index 627b656..19dd8d1 100644 --- a/tests/models.py +++ b/tests/models.py @@ -1,5 +1,6 @@ from django.db import models from .fields import MyIntegerField +from postgres_copy import CopyMapping class MockObject(models.Model): @@ -60,3 +61,17 @@ def copy_upper_name_template(self): def copy_lower_name_template(self): return 'lower("%(name)s")' copy_lower_name_template.copy_type = 'text' + + +class HookedCopyMapping(CopyMapping): + def pre_copy(self, cursor): + self.ran_pre_copy = True + + def post_copy(self, cursor): + self.ran_post_copy = True + + def pre_insert(self, cursor): + self.ran_pre_insert = True + + def post_insert(self, cursor): + self.ran_post_insert = True diff --git a/tests/tests.py b/tests/tests.py index 8881b22..ebcfdbf 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -1,11 +1,12 @@ import os from datetime import date + from .models import ( MockObject, ExtendedMockObject, LimitedMockObject, - OverloadMockObject -) + OverloadMockObject, + HookedCopyMapping) from postgres_copy import CopyMapping from django.test import TestCase @@ -269,3 +270,60 @@ def test_missing_overload_field(self): self.name_path, dict(name='NAME', number='NUMBER', dt='DATE', missing='NAME'), ) + + + def test_save_steps(self): + c = CopyMapping( + MockObject, + self.name_path, + dict(name='NAME', number='NUMBER', dt='DATE'), + ) + cursor = c.conn.cursor() + + c.create(cursor) + cursor.execute("""SELECT count(*) FROM %s;""" % c.temp_table_name) + self.assertEquals(cursor.fetchone()[0], 0) + cursor.execute("""SELECT count(*) FROM %s;""" % c.model._meta.db_table) + self.assertEquals(cursor.fetchone()[0], 0) + + c.copy(cursor) + cursor.execute("""SELECT count(*) FROM %s;""" % c.temp_table_name) + self.assertEquals(cursor.fetchone()[0], 3) + cursor.execute("""SELECT count(*) FROM %s;""" % c.model._meta.db_table) + self.assertEquals(cursor.fetchone()[0], 0) + + c.insert(cursor) + cursor.execute("""SELECT count(*) FROM %s;""" % c.model._meta.db_table) + self.assertEquals(cursor.fetchone()[0], 3) + + c.drop(cursor) + self.assertEquals(cursor.statusmessage, 'DROP TABLE') + cursor.close() + + def test_hooks(self): + c = HookedCopyMapping( + MockObject, + self.name_path, + dict(name='NAME', number='NUMBER', dt='DATE'), + ) + cursor = c.conn.cursor() + + c.create(cursor) + self.assertRaises(AttributeError, lambda: c.ran_pre_copy) + self.assertRaises(AttributeError, lambda: c.ran_post_copy) + self.assertRaises(AttributeError, lambda: c.ran_pre_insert) + self.assertRaises(AttributeError, lambda: c.ran_post_insert) + c.copy(cursor) + self.assertTrue(c.ran_pre_copy) + self.assertTrue(c.ran_post_copy) + self.assertRaises(AttributeError, lambda: c.ran_pre_insert) + self.assertRaises(AttributeError, lambda: c.ran_post_insert) + + c.insert(cursor) + self.assertTrue(c.ran_pre_copy) + self.assertTrue(c.ran_post_copy) + self.assertTrue(c.ran_pre_insert) + self.assertTrue(c.ran_post_insert) + + c.drop(cursor) + cursor.close()