diff --git a/docs/index.rst b/docs/index.rst index 996d2c1..cb5e8ed 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -143,31 +143,36 @@ Argument Description string field names for the CSV header. ================= ========================================================= -===================== ===================================================== +======================= ===================================================== Keyword Arguments -===================== ===================================================== -``delimiter`` The character that separates values in the data file. - By default it is ",". This must be a single one-byte - character. - -``null`` Specifies the string that represents a null value. - The default is an unquoted empty string. This must - be a single one-byte character. - -``encoding`` Specifies the character set encoding of the strings - in the CSV data source. For example, ``'latin-1'``, - ``'utf-8'``, and ``'cp437'`` are all valid encoding - parameters. - -``using`` Sets the database to use when importing data. - Default is None, which will use the ``'default'`` - database. - -``static_mapping`` Set model attributes not in the CSV the same - for every row in the database by providing a dictionary - with the name of the columns as keys and the static - inputs as values. -===================== ===================================================== +======================= ===================================================== +``delimiter`` The character that separates values in the data file. + By default it is ",". This must be a single one-byte + character. + +``null`` Specifies the string that represents a null value. + The default is an unquoted empty string. This must + be a single one-byte character. + +``encoding`` Specifies the character set encoding of the strings + in the CSV data source. For example, ``'latin-1'``, + ``'utf-8'``, and ``'cp437'`` are all valid encoding + parameters. + +``using`` Sets the database to use when importing data. + Default is None, which will use the ``'default'`` + database. + +``static_mapping`` Set model attributes not in the CSV the same + for every row in the database by providing a dictionary + with the name of the columns as keys and the static + inputs as values. + +``overloaded_mapping`` Reuse a mapped column for a different model field. + This is useful when you want to have both the + original value as well as a modified form, generally + using a `copy_template` to transform the second value +======================= ===================================================== ``save()`` keyword arguments diff --git a/postgres_copy/__init__.py b/postgres_copy/__init__.py index 2dbb025..c187cb1 100644 --- a/postgres_copy/__init__.py +++ b/postgres_copy/__init__.py @@ -1,10 +1,11 @@ +import csv import os import sys -import csv -from django.db import connections, router -from django.contrib.humanize.templatetags.humanize import intcomma from collections import OrderedDict +from django.contrib.humanize.templatetags.humanize import intcomma +from django.db import connections, router + class CopyMapping(object): """ @@ -12,16 +13,18 @@ class CopyMapping(object): and loads it into PostgreSQL databases using its COPY command. """ + def __init__( - self, - model, - csv_path, - mapping, - using=None, - delimiter=',', - null=None, - encoding=None, - static_mapping=None + self, + model, + csv_path, + mapping, + using=None, + delimiter=',', + null=None, + encoding=None, + static_mapping=None, + overloaded_mapping=None ): self.model = model self.mapping = mapping @@ -44,6 +47,10 @@ def __init__( self.static_mapping = OrderedDict(static_mapping) else: self.static_mapping = {} + if overloaded_mapping is not None: + self.overloaded_mapping = overloaded_mapping + else: + self.overloaded_mapping = {} # Connect the headers from the CSV with the fields on the model self.field_header_crosswalk = [] @@ -58,14 +65,24 @@ def __init__( except IndexError: raise ValueError("Model does not include %s field" % f_name) self.field_header_crosswalk.append((f, h)) - # Validate that the static mapping columns exist for f_name in self.static_mapping.keys(): try: [s for s in self.model._meta.fields if s.name == f_name][0] except IndexError: raise ValueError("Model does not include %s field" % f_name) - + # Validate Overloaded headers and fields + clear_overload_keys = [] + for k, v in self.overloaded_mapping.items(): + try: + o = [o for o in self.model._meta.fields if o.name == k][0] + self.overloaded_mapping[o] = v + clear_overload_keys.append(k) + except IndexError: + raise ValueError("Model does not include overload %s field" + % v) + for key in clear_overload_keys: + del self.overloaded_mapping[key] self.temp_table_name = "temp_%s" % self.model._meta.db_table def save(self, silent=False, stream=sys.stdout): @@ -173,7 +190,7 @@ def prep_copy(self): 'extra_options': '', 'header_list': ", ".join([ '"%s"' % h for f, h in self.field_header_crosswalk - ]) + ]) } if self.delimiter: options['extra_options'] += " DELIMITER '%s'" % self.delimiter @@ -209,19 +226,34 @@ def prep_insert(self): for k in self.static_mapping.keys(): model_fields.append('"%s"' % k) + for k in self.overloaded_mapping.keys(): + model_fields.append('"%s"' % k.get_attname_column()[1]) + options['model_fields'] = ", ".join(model_fields) temp_fields = [] for field, header in self.field_header_crosswalk: - string = '"%s"' % header - if hasattr(field, 'copy_template'): - string = field.copy_template % dict(name=header) - template_method = 'copy_%s_template' % field.name - if hasattr(self.model, template_method): - template = getattr(self.model(), template_method)() - string = template % dict(name=header) - temp_fields.append(string) + temp_fields.append(self._generate_insert_temp_fields( + field, header) + ) + for v in self.static_mapping.values(): temp_fields.append("'%s'" % v) + + for k, v in self.overloaded_mapping.items(): + temp_fields.append(self._generate_insert_temp_fields( + k, v) + ) options['temp_fields'] = ", ".join(temp_fields) + return sql % options + + def _generate_insert_temp_fields(self, concrete, column): + string = '"%s"' % column + if hasattr(concrete, 'copy_template'): + string = concrete.copy_template % dict(name=column) + template_method = 'copy_%s_template' % concrete.name + if hasattr(self.model, template_method): + template = getattr(self.model(), template_method)() + string = template % dict(name=column) + return string diff --git a/tests/models.py b/tests/models.py index 8cd5065..06956b3 100644 --- a/tests/models.py +++ b/tests/models.py @@ -29,3 +29,23 @@ class Meta: def copy_name_template(self): return 'upper("%(name)s")' copy_name_template.copy_type = 'text' + + +class OverloadMockObject(models.Model): + name = models.CharField(max_length=500) + lower_name = models.CharField(max_length=500) + number = MyIntegerField(null=True, db_column='num') + dt = models.DateField(null=True) + parent = models.ForeignKey('MockObject', null=True, default=None) + + class Meta: + app_label = 'tests' + + def copy_name_template(self): + return 'upper("%(name)s")' + copy_name_template.copy_type = 'text' + + def copy_lower_name_template(self): + return 'lower("%(name)s")' + copy_name_template.copy_type = 'text' + diff --git a/tests/tests.py b/tests/tests.py index 2e2a800..9f36a93 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -1,6 +1,6 @@ import os from datetime import date -from .models import MockObject, ExtendedMockObject +from .models import MockObject, ExtendedMockObject, OverloadMockObject from postgres_copy import CopyMapping from django.test import TestCase @@ -18,6 +18,7 @@ def setUp(self): def tearDown(self): MockObject.objects.all().delete() ExtendedMockObject.objects.all().delete() + OverloadMockObject.objects.all().delete() def test_bad_call(self): with self.assertRaises(TypeError): @@ -217,3 +218,30 @@ def test_save_foreign_key(self): MockObject.objects.get(name='BEN').dt, date(2012, 1, 1) ) + + def test_overload_save(self): + c = CopyMapping( + OverloadMockObject, + self.name_path, + dict(name='NAME', number='NUMBER', dt='DATE'), + overloaded_mapping=dict(lower_name='NAME') + ) + c.save() + self.assertEqual(OverloadMockObject.objects.count(), 3) + self.assertEqual(OverloadMockObject.objects.get(name='BEN').number, 1) + self.assertEqual(OverloadMockObject.objects.get(lower_name='ben').number, 1) + self.assertEqual( + OverloadMockObject.objects.get(name='BEN').dt, + date(2012, 1, 1) + ) + omo = OverloadMockObject.objects.first() + self.assertEqual(omo.name.lower(), omo.lower_name) + + def test_missing_overload_field(self): + with self.assertRaises(ValueError): + c = CopyMapping( + OverloadMockObject, + self.name_path, + dict(name='NAME', number='NUMBER', dt='DATE'), + overloaded_mapping=dict(missing='NAME') + )