diff --git a/postgres_copy/__init__.py b/postgres_copy/__init__.py index 8d46bfd..20bc904 100644 --- a/postgres_copy/__init__.py +++ b/postgres_copy/__init__.py @@ -202,13 +202,13 @@ def prep_insert(self): ) model_fields = [] + for field, header in self.field_header_crosswalk: - if field.db_column: - model_fields.append('"%s"' % field.db_column) - else: - model_fields.append('"%s"' % field.name) + model_fields.append('"%s"' % field.get_attname_column()[1]) + for k in self.static_mapping.keys(): model_fields.append('"%s"' % k) + options['model_fields'] = ", ".join(model_fields) temp_fields = [] diff --git a/tests/data/foreignkeys.csv b/tests/data/foreignkeys.csv new file mode 100644 index 0000000..b3dc8ad --- /dev/null +++ b/tests/data/foreignkeys.csv @@ -0,0 +1,4 @@ +NAME,NUMBER,DATE,PARENT +ben,1,2012-01-01,4 +joe,2,2012-01-02,5 +jane,3,2012-01-03,6 diff --git a/tests/models.py b/tests/models.py index 20faba6..cecface 100644 --- a/tests/models.py +++ b/tests/models.py @@ -6,6 +6,7 @@ class MockObject(models.Model): name = models.CharField(max_length=500) number = MyIntegerField(null=True, db_column='num') dt = models.DateField(null=True) + parent = models.ForeignKey('MockObject', null=True) class Meta: app_label = 'tests' diff --git a/tests/tests.py b/tests/tests.py index 969ae92..2e2a800 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -10,6 +10,7 @@ class PostgresCopyTest(TestCase): def setUp(self): self.data_dir = os.path.join(os.path.dirname(__file__), 'data') self.name_path = os.path.join(self.data_dir, 'names.csv') + self.foreign_path = os.path.join(self.data_dir, 'foreignkeys.csv') self.pipe_path = os.path.join(self.data_dir, 'pipes.csv') self.null_path = os.path.join(self.data_dir, 'nulls.csv') self.backwards_path = os.path.join(self.data_dir, 'backwards.csv') @@ -70,6 +71,21 @@ def test_simple_save(self): date(2012, 1, 1) ) + def test_save_foreign_key(self): + c = CopyMapping( + MockObject, + self.foreign_path, + dict(name='NAME', number='NUMBER', dt='DATE', parent='PARENT') + ) + + c.save() + self.assertEqual(MockObject.objects.count(), 3) + self.assertEqual(MockObject.objects.get(name='BEN').parent_id, 4) + self.assertEqual( + MockObject.objects.get(name='BEN').dt, + date(2012, 1, 1) + ) + def test_silent_save(self): c = CopyMapping( MockObject, @@ -186,3 +202,18 @@ def test_bad_static_values(self): static_mapping={'static_bad':1,} ) c.save() + + def test_save_foreign_key(self): + c = CopyMapping( + MockObject, + self.foreign_path, + dict(name='NAME', number='NUMBER', dt='DATE', parent='PARENT') + ) + + c.save() + self.assertEqual(MockObject.objects.count(), 3) + self.assertEqual(MockObject.objects.get(name='BEN').parent_id, 4) + self.assertEqual( + MockObject.objects.get(name='BEN').dt, + date(2012, 1, 1) + )