diff --git a/docs/index.rst b/docs/index.rst index 996d2c1..205ffba 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -167,6 +167,10 @@ Keyword Arguments for every row in the database by providing a dictionary with the name of the columns as keys and the static inputs as values. + +``ignore_headers`` A list of headers from your csv that don't have + equivalent fields in your model. These columns will + be ignored. ===================== ===================================================== diff --git a/postgres_copy/__init__.py b/postgres_copy/__init__.py index 2dbb025..afff43c 100644 --- a/postgres_copy/__init__.py +++ b/postgres_copy/__init__.py @@ -17,6 +17,7 @@ def __init__( model, csv_path, mapping, + ignore_headers=None, using=None, delimiter=',', null=None, @@ -37,6 +38,10 @@ def __init__( if self.conn.vendor != 'postgresql': raise TypeError("Only PostgreSQL backends supported") self.backend = self.conn.ops + if ignore_headers is None: + self.ignore_headers = [] + else: + self.ignore_headers = ignore_headers self.delimiter = delimiter self.null = null self.encoding = encoding @@ -48,13 +53,17 @@ def __init__( # Connect the headers from the CSV with the fields on the model self.field_header_crosswalk = [] inverse_mapping = {v: k for k, v in self.mapping.items()} + for ignore in self.ignore_headers: + inverse_mapping[ignore] = ignore.lower() for h in self.get_headers(): try: f_name = inverse_mapping[h] except KeyError: raise ValueError("Map does not include %s field" % h) try: - f = [f for f in self.model._meta.fields if f.name == f_name][0] + if f_name not in [ih.lower() for ih in self.ignore_headers]: + f = [f for f in self.model._meta.fields + if f.name == f_name][0] except IndexError: raise ValueError("Model does not include %s field" % f_name) self.field_header_crosswalk.append((f, h)) @@ -204,6 +213,8 @@ def prep_insert(self): model_fields = [] for field, header in self.field_header_crosswalk: + if header in self.ignore_headers: + continue model_fields.append('"%s"' % field.get_attname_column()[1]) for k in self.static_mapping.keys(): @@ -213,6 +224,8 @@ def prep_insert(self): temp_fields = [] for field, header in self.field_header_crosswalk: + if header in self.ignore_headers: + continue string = '"%s"' % header if hasattr(field, 'copy_template'): string = field.copy_template % dict(name=header) diff --git a/tests/models.py b/tests/models.py index 8cd5065..a7e5d65 100644 --- a/tests/models.py +++ b/tests/models.py @@ -29,3 +29,16 @@ class Meta: def copy_name_template(self): return 'upper("%(name)s")' copy_name_template.copy_type = 'text' + + +class LimitedMockObject(models.Model): + name = models.CharField(max_length=500) + dt = models.DateField(null=True) + + class Meta: + app_label = 'tests' + + def copy_name_template(self): + return 'upper("%(name)s")' + copy_name_template.copy_type = 'text' + diff --git a/tests/tests.py b/tests/tests.py index 2e2a800..6c2afba 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -1,6 +1,6 @@ import os from datetime import date -from .models import MockObject, ExtendedMockObject +from .models import MockObject, ExtendedMockObject, LimitedMockObject from postgres_copy import CopyMapping from django.test import TestCase @@ -18,6 +18,7 @@ def setUp(self): def tearDown(self): MockObject.objects.all().delete() ExtendedMockObject.objects.all().delete() + LimitedMockObject.objects.all().delete() def test_bad_call(self): with self.assertRaises(TypeError): @@ -57,6 +58,17 @@ def test_bad_field(self): dict(name1='NAME', number='NUMBER', dt='DATE'), ) + def test_limited_fields(self): + try: + CopyMapping( + LimitedMockObject, + self.name_path, + dict(name='NAME', dt='DATE'), + ignore_headers=['NUMBER'] + ) + except ValueError: + self.fail("Failed trying to ignore fields") + def test_simple_save(self): c = CopyMapping( MockObject, @@ -71,6 +83,20 @@ def test_simple_save(self): date(2012, 1, 1) ) + def test_limited_save(self): + c = CopyMapping( + LimitedMockObject, + self.name_path, + dict(name='NAME', dt='DATE'), + ignore_headers=['NUMBER'] + ) + c.save() + self.assertEqual(LimitedMockObject.objects.count(), 3) + self.assertEqual( + LimitedMockObject.objects.get(name='BEN').dt, + date(2012, 1, 1) + ) + def test_save_foreign_key(self): c = CopyMapping( MockObject,