From e864c4d639a21c1ef9685969bf68547a536462e7 Mon Sep 17 00:00:00 2001 From: Chaim Kirby Date: Tue, 25 Oct 2016 12:00:52 +0300 Subject: [PATCH 1/7] Code for overloaded mappings --- postgres_copy/__init__.py | 33 +++++++++++++++++++++++++++++--- tests/models.py | 20 ++++++++++++++++++++ tests/tests.py | 40 ++++++++++++++++++++++++++++++++++++++- 3 files changed, 89 insertions(+), 4 deletions(-) diff --git a/postgres_copy/__init__.py b/postgres_copy/__init__.py index 2dbb025..99c0797 100644 --- a/postgres_copy/__init__.py +++ b/postgres_copy/__init__.py @@ -21,7 +21,8 @@ def __init__( delimiter=',', null=None, encoding=None, - static_mapping=None + static_mapping=None, + overloaded_mapping=None ): self.model = model self.mapping = mapping @@ -44,6 +45,10 @@ def __init__( self.static_mapping = OrderedDict(static_mapping) else: self.static_mapping = {} + if overloaded_mapping is not None: + self.overloaded_mapping = overloaded_mapping + else: + self.overloaded_mapping = {} # Connect the headers from the CSV with the fields on the model self.field_header_crosswalk = [] @@ -58,14 +63,21 @@ def __init__( except IndexError: raise ValueError("Model does not include %s field" % f_name) self.field_header_crosswalk.append((f, h)) - # Validate that the static mapping columns exist for f_name in self.static_mapping.keys(): try: [s for s in self.model._meta.fields if s.name == f_name][0] except IndexError: raise ValueError("Model does not include %s field" % f_name) - + # Validate Overloaded headers and fields + for k, v in self.overloaded_mapping.items(): + if v not in inverse_mapping.keys(): + raise ValueError("Overloaded %s field is not in mapping" % k) + try: + f = [f for f in self.model._meta.fields if f.name == k][0] + except IndexError: + raise ValueError("Model does not include overload %s field" + % v) self.temp_table_name = "temp_%s" % self.model._meta.db_table def save(self, silent=False, stream=sys.stdout): @@ -209,6 +221,9 @@ def prep_insert(self): for k in self.static_mapping.keys(): model_fields.append('"%s"' % k) + for k in self.overloaded_mapping.keys(): + model_fields.append('"%s"' % k) + options['model_fields'] = ", ".join(model_fields) temp_fields = [] @@ -224,4 +239,16 @@ def prep_insert(self): for v in self.static_mapping.values(): temp_fields.append("'%s'" % v) options['temp_fields'] = ", ".join(temp_fields) + + for k, v in self.overloaded_mapping.items(): + string = '"%s"' % v + if hasattr(k, 'copy_template'): + string = field.copy_template % dict(name=v) + template_method = 'copy_%s_template' % field.name + if hasattr(self.model, template_method): + template = getattr(self.model(), template_method)() + string = template % dict(name=v) + temp_fields.append(string) + options['temp_fields'] = ", ".join(temp_fields) + return sql % options diff --git a/tests/models.py b/tests/models.py index 8cd5065..06956b3 100644 --- a/tests/models.py +++ b/tests/models.py @@ -29,3 +29,23 @@ class Meta: def copy_name_template(self): return 'upper("%(name)s")' copy_name_template.copy_type = 'text' + + +class OverloadMockObject(models.Model): + name = models.CharField(max_length=500) + lower_name = models.CharField(max_length=500) + number = MyIntegerField(null=True, db_column='num') + dt = models.DateField(null=True) + parent = models.ForeignKey('MockObject', null=True, default=None) + + class Meta: + app_label = 'tests' + + def copy_name_template(self): + return 'upper("%(name)s")' + copy_name_template.copy_type = 'text' + + def copy_lower_name_template(self): + return 'lower("%(name)s")' + copy_name_template.copy_type = 'text' + diff --git a/tests/tests.py b/tests/tests.py index 2e2a800..dea2bb8 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -1,6 +1,6 @@ import os from datetime import date -from .models import MockObject, ExtendedMockObject +from .models import MockObject, ExtendedMockObject, OverloadMockObject from postgres_copy import CopyMapping from django.test import TestCase @@ -18,6 +18,7 @@ def setUp(self): def tearDown(self): MockObject.objects.all().delete() ExtendedMockObject.objects.all().delete() + OverloadMockObject.objects.all().delete() def test_bad_call(self): with self.assertRaises(TypeError): @@ -217,3 +218,40 @@ def test_save_foreign_key(self): MockObject.objects.get(name='BEN').dt, date(2012, 1, 1) ) + + def test_overload_save(self): + c = CopyMapping( + OverloadMockObject, + self.name_path, + dict(name='NAME', number='NUMBER', dt='DATE'), + overloaded_mapping=dict(lower_name='NAME') + ) + c.save() + self.assertEqual(OverloadMockObject.objects.count(), 3) + self.assertEqual(OverloadMockObject.objects.get(name='BEN').number, 1) + self.assertEqual(OverloadMockObject.objects.get(lower_name='ben').number, 1) + self.assertEqual( + OverloadMockObject.objects.get(name='BEN').dt, + date(2012, 1, 1) + ) + omo = OverloadMockObject.objects.first() + self.assertEqual(omo.name.lower(), omo.lower_name) + + def test_bad_non_overload(self): + with self.assertRaises(ValueError): + c = CopyMapping( + OverloadMockObject, + self.name_path, + dict(name='NAME', dt='DATE'), + static_mapping=dict(number=12), + overloaded_mapping=dict(number='NUMBER') + ) + + with self.assertRaises(ValueError): + c = CopyMapping( + OverloadMockObject, + self.name_path, + dict(name='NAME', number='NUMBER', dt='DATE'), + overloaded_mapping=dict(missing='NAME') + ) + From e49a49b8d0e40ef1b30911aeb775856dcbf9c455 Mon Sep 17 00:00:00 2001 From: Chaim Kirby Date: Tue, 25 Oct 2016 12:30:08 +0300 Subject: [PATCH 2/7] docs --- docs/index.rst | 53 +++++++++++++++++++++++++++----------------------- 1 file changed, 29 insertions(+), 24 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index 996d2c1..cb5e8ed 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -143,31 +143,36 @@ Argument Description string field names for the CSV header. ================= ========================================================= -===================== ===================================================== +======================= ===================================================== Keyword Arguments -===================== ===================================================== -``delimiter`` The character that separates values in the data file. - By default it is ",". This must be a single one-byte - character. - -``null`` Specifies the string that represents a null value. - The default is an unquoted empty string. This must - be a single one-byte character. - -``encoding`` Specifies the character set encoding of the strings - in the CSV data source. For example, ``'latin-1'``, - ``'utf-8'``, and ``'cp437'`` are all valid encoding - parameters. - -``using`` Sets the database to use when importing data. - Default is None, which will use the ``'default'`` - database. - -``static_mapping`` Set model attributes not in the CSV the same - for every row in the database by providing a dictionary - with the name of the columns as keys and the static - inputs as values. -===================== ===================================================== +======================= ===================================================== +``delimiter`` The character that separates values in the data file. + By default it is ",". This must be a single one-byte + character. + +``null`` Specifies the string that represents a null value. + The default is an unquoted empty string. This must + be a single one-byte character. + +``encoding`` Specifies the character set encoding of the strings + in the CSV data source. For example, ``'latin-1'``, + ``'utf-8'``, and ``'cp437'`` are all valid encoding + parameters. + +``using`` Sets the database to use when importing data. + Default is None, which will use the ``'default'`` + database. + +``static_mapping`` Set model attributes not in the CSV the same + for every row in the database by providing a dictionary + with the name of the columns as keys and the static + inputs as values. + +``overloaded_mapping`` Reuse a mapped column for a different model field. + This is useful when you want to have both the + original value as well as a modified form, generally + using a `copy_template` to transform the second value +======================= ===================================================== ``save()`` keyword arguments From 35729bf3c375b71ca93d68abd05c8540b1cc6eda Mon Sep 17 00:00:00 2001 From: Chaim Kirby Date: Tue, 25 Oct 2016 12:51:20 +0300 Subject: [PATCH 3/7] Fix pyflakes error --- postgres_copy/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/postgres_copy/__init__.py b/postgres_copy/__init__.py index 99c0797..cdc474e 100644 --- a/postgres_copy/__init__.py +++ b/postgres_copy/__init__.py @@ -74,7 +74,7 @@ def __init__( if v not in inverse_mapping.keys(): raise ValueError("Overloaded %s field is not in mapping" % k) try: - f = [f for f in self.model._meta.fields if f.name == k][0] + o = [o for o in self.model._meta.fields if o.name == k][0] except IndexError: raise ValueError("Model does not include overload %s field" % v) From 04de918a8a9245f66d8fea57e6574fea9178b0c1 Mon Sep 17 00:00:00 2001 From: Chaim Kirby Date: Tue, 25 Oct 2016 12:59:35 +0300 Subject: [PATCH 4/7] Fix pyflakes error --- postgres_copy/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/postgres_copy/__init__.py b/postgres_copy/__init__.py index cdc474e..ab36c46 100644 --- a/postgres_copy/__init__.py +++ b/postgres_copy/__init__.py @@ -74,7 +74,7 @@ def __init__( if v not in inverse_mapping.keys(): raise ValueError("Overloaded %s field is not in mapping" % k) try: - o = [o for o in self.model._meta.fields if o.name == k][0] + [o for o in self.model._meta.fields if o.name == k][0] except IndexError: raise ValueError("Model does not include overload %s field" % v) From ed8e0d8afb7daf5730de92654bb8fc68f16617d3 Mon Sep 17 00:00:00 2001 From: Chaim Kirby Date: Tue, 25 Oct 2016 14:34:00 +0300 Subject: [PATCH 5/7] Remove unreachable try/except code and partner test --- postgres_copy/__init__.py | 2 -- tests/tests.py | 12 +----------- 2 files changed, 1 insertion(+), 13 deletions(-) diff --git a/postgres_copy/__init__.py b/postgres_copy/__init__.py index ab36c46..b2ac31a 100644 --- a/postgres_copy/__init__.py +++ b/postgres_copy/__init__.py @@ -71,8 +71,6 @@ def __init__( raise ValueError("Model does not include %s field" % f_name) # Validate Overloaded headers and fields for k, v in self.overloaded_mapping.items(): - if v not in inverse_mapping.keys(): - raise ValueError("Overloaded %s field is not in mapping" % k) try: [o for o in self.model._meta.fields if o.name == k][0] except IndexError: diff --git a/tests/tests.py b/tests/tests.py index dea2bb8..9f36a93 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -237,16 +237,7 @@ def test_overload_save(self): omo = OverloadMockObject.objects.first() self.assertEqual(omo.name.lower(), omo.lower_name) - def test_bad_non_overload(self): - with self.assertRaises(ValueError): - c = CopyMapping( - OverloadMockObject, - self.name_path, - dict(name='NAME', dt='DATE'), - static_mapping=dict(number=12), - overloaded_mapping=dict(number='NUMBER') - ) - + def test_missing_overload_field(self): with self.assertRaises(ValueError): c = CopyMapping( OverloadMockObject, @@ -254,4 +245,3 @@ def test_bad_non_overload(self): dict(name='NAME', number='NUMBER', dt='DATE'), overloaded_mapping=dict(missing='NAME') ) - From 06496d748a617cc8093b490955e0606a478d8efc Mon Sep 17 00:00:00 2001 From: Chaim Kirby Date: Tue, 25 Oct 2016 14:55:40 +0300 Subject: [PATCH 6/7] Refactor prep_insert with new method to generate temp fields The `copy_template` logic was used for both standard and overloaded mapping fields. Refactor the code into it's own method. --- postgres_copy/__init__.py | 75 +++++++++++++++++++++------------------ 1 file changed, 41 insertions(+), 34 deletions(-) diff --git a/postgres_copy/__init__.py b/postgres_copy/__init__.py index b2ac31a..c84a17d 100644 --- a/postgres_copy/__init__.py +++ b/postgres_copy/__init__.py @@ -1,10 +1,11 @@ +import csv import os import sys -import csv -from django.db import connections, router -from django.contrib.humanize.templatetags.humanize import intcomma from collections import OrderedDict +from django.contrib.humanize.templatetags.humanize import intcomma +from django.db import connections, router + class CopyMapping(object): """ @@ -12,17 +13,18 @@ class CopyMapping(object): and loads it into PostgreSQL databases using its COPY command. """ + def __init__( - self, - model, - csv_path, - mapping, - using=None, - delimiter=',', - null=None, - encoding=None, - static_mapping=None, - overloaded_mapping=None + self, + model, + csv_path, + mapping, + using=None, + delimiter=',', + null=None, + encoding=None, + static_mapping=None, + overloaded_mapping=None ): self.model = model self.mapping = mapping @@ -70,12 +72,17 @@ def __init__( except IndexError: raise ValueError("Model does not include %s field" % f_name) # Validate Overloaded headers and fields + clear_overload_keys = [] for k, v in self.overloaded_mapping.items(): try: - [o for o in self.model._meta.fields if o.name == k][0] + o = [o for o in self.model._meta.fields if o.name == k][0] + self.overloaded_mapping[o] = v + clear_overload_keys.append(k) except IndexError: raise ValueError("Model does not include overload %s field" % v) + for key in clear_overload_keys: + del self.overloaded_mapping[key] self.temp_table_name = "temp_%s" % self.model._meta.db_table def save(self, silent=False, stream=sys.stdout): @@ -182,8 +189,8 @@ def prep_copy(self): 'db_table': self.temp_table_name, 'extra_options': '', 'header_list': ", ".join([ - '"%s"' % h for f, h in self.field_header_crosswalk - ]) + '"%s"' % h for f, h in self.field_header_crosswalk + ]) } if self.delimiter: options['extra_options'] += " DELIMITER '%s'" % self.delimiter @@ -220,33 +227,33 @@ def prep_insert(self): model_fields.append('"%s"' % k) for k in self.overloaded_mapping.keys(): - model_fields.append('"%s"' % k) + model_fields.append('"%s"' % k.get_attname_column()[1]) options['model_fields'] = ", ".join(model_fields) temp_fields = [] for field, header in self.field_header_crosswalk: - string = '"%s"' % header - if hasattr(field, 'copy_template'): - string = field.copy_template % dict(name=header) - template_method = 'copy_%s_template' % field.name - if hasattr(self.model, template_method): - template = getattr(self.model(), template_method)() - string = template % dict(name=header) - temp_fields.append(string) + temp_fields.append(self._generate_insert_temp_fields( + field, header) + ) + for v in self.static_mapping.values(): temp_fields.append("'%s'" % v) - options['temp_fields'] = ", ".join(temp_fields) for k, v in self.overloaded_mapping.items(): - string = '"%s"' % v - if hasattr(k, 'copy_template'): - string = field.copy_template % dict(name=v) - template_method = 'copy_%s_template' % field.name - if hasattr(self.model, template_method): - template = getattr(self.model(), template_method)() - string = template % dict(name=v) - temp_fields.append(string) + temp_fields.append(self._generate_insert_temp_fields( + k, v) + ) options['temp_fields'] = ", ".join(temp_fields) return sql % options + + def _generate_insert_temp_fields(self, concrete, column): + string = '"%s"' % column + if hasattr(concrete, 'copy_template'): + string = concrete.copy_template % dict(name=column) + template_method = 'copy_%s_template' % concrete.name + if hasattr(self.model, template_method): + template = getattr(self.model(), template_method)() + string = template % dict(name=column) + return string From 8cdc934e02719e55aa53f75b89647119c4968c03 Mon Sep 17 00:00:00 2001 From: Chaim Kirby Date: Tue, 25 Oct 2016 15:07:23 +0300 Subject: [PATCH 7/7] Fix pep8 errors --- postgres_copy/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/postgres_copy/__init__.py b/postgres_copy/__init__.py index c84a17d..c187cb1 100644 --- a/postgres_copy/__init__.py +++ b/postgres_copy/__init__.py @@ -189,8 +189,8 @@ def prep_copy(self): 'db_table': self.temp_table_name, 'extra_options': '', 'header_list': ", ".join([ - '"%s"' % h for f, h in self.field_header_crosswalk - ]) + '"%s"' % h for f, h in self.field_header_crosswalk + ]) } if self.delimiter: options['extra_options'] += " DELIMITER '%s'" % self.delimiter