Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Overload column #34

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 29 additions & 24 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -143,31 +143,36 @@ Argument Description
string field names for the CSV header.
================= =========================================================

===================== =====================================================
======================= =====================================================
Keyword Arguments
===================== =====================================================
``delimiter`` The character that separates values in the data file.
By default it is ",". This must be a single one-byte
character.

``null`` Specifies the string that represents a null value.
The default is an unquoted empty string. This must
be a single one-byte character.

``encoding`` Specifies the character set encoding of the strings
in the CSV data source. For example, ``'latin-1'``,
``'utf-8'``, and ``'cp437'`` are all valid encoding
parameters.

``using`` Sets the database to use when importing data.
Default is None, which will use the ``'default'``
database.

``static_mapping`` Set model attributes not in the CSV the same
for every row in the database by providing a dictionary
with the name of the columns as keys and the static
inputs as values.
===================== =====================================================
======================= =====================================================
``delimiter`` The character that separates values in the data file.
By default it is ",". This must be a single one-byte
character.

``null`` Specifies the string that represents a null value.
The default is an unquoted empty string. This must
be a single one-byte character.

``encoding`` Specifies the character set encoding of the strings
in the CSV data source. For example, ``'latin-1'``,
``'utf-8'``, and ``'cp437'`` are all valid encoding
parameters.

``using`` Sets the database to use when importing data.
Default is None, which will use the ``'default'``
database.

``static_mapping`` Set model attributes not in the CSV the same
for every row in the database by providing a dictionary
with the name of the columns as keys and the static
inputs as values.

``overloaded_mapping`` Reuse a mapped column for a different model field.
This is useful when you want to have both the
original value as well as a modified form, generally
using a `copy_template` to transform the second value
======================= =====================================================


``save()`` keyword arguments
Expand Down
74 changes: 52 additions & 22 deletions postgres_copy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,30 @@
import csv
import os
import sys
import csv
from django.db import connections, router
from django.contrib.humanize.templatetags.humanize import intcomma
from collections import OrderedDict

from django.contrib.humanize.templatetags.humanize import intcomma
from django.db import connections, router


class CopyMapping(object):
"""
Maps comma-delimited data file to a Django model
and loads it into PostgreSQL databases using its
COPY command.
"""

def __init__(
self,
model,
csv_path,
mapping,
using=None,
delimiter=',',
null=None,
encoding=None,
static_mapping=None
self,
model,
csv_path,
mapping,
using=None,
delimiter=',',
null=None,
encoding=None,
static_mapping=None,
overloaded_mapping=None
):
self.model = model
self.mapping = mapping
Expand All @@ -44,6 +47,10 @@ def __init__(
self.static_mapping = OrderedDict(static_mapping)
else:
self.static_mapping = {}
if overloaded_mapping is not None:
self.overloaded_mapping = overloaded_mapping
else:
self.overloaded_mapping = {}

# Connect the headers from the CSV with the fields on the model
self.field_header_crosswalk = []
Expand All @@ -58,13 +65,21 @@ def __init__(
except IndexError:
raise ValueError("Model does not include %s field" % f_name)
self.field_header_crosswalk.append((f, h))

# Validate that the static mapping columns exist
for f_name in self.static_mapping.keys():
try:
[s for s in self.model._meta.fields if s.name == f_name][0]
except IndexError:
raise ValueError("Model does not include %s field" % f_name)
# Validate Overloaded headers and fields
self.overloaded_crosswalk = []
for k, v in self.overloaded_mapping.items():
try:
o = [o for o in self.model._meta.fields if o.name == k][0]
self.overloaded_crosswalk.append((o, v))
except IndexError:
raise ValueError("Model does not include overload %s field"
% v)

self.temp_table_name = "temp_%s" % self.model._meta.db_table

Expand Down Expand Up @@ -173,7 +188,7 @@ def prep_copy(self):
'extra_options': '',
'header_list': ", ".join([
'"%s"' % h for f, h in self.field_header_crosswalk
])
])
}
if self.delimiter:
options['extra_options'] += " DELIMITER '%s'" % self.delimiter
Expand Down Expand Up @@ -209,19 +224,34 @@ def prep_insert(self):
for k in self.static_mapping.keys():
model_fields.append('"%s"' % k)

for field, value in self.overloaded_crosswalk:
model_fields.append('"%s"' % field.get_attname_column()[1])

options['model_fields'] = ", ".join(model_fields)

temp_fields = []
for field, header in self.field_header_crosswalk:
string = '"%s"' % header
if hasattr(field, 'copy_template'):
string = field.copy_template % dict(name=header)
template_method = 'copy_%s_template' % field.name
if hasattr(self.model, template_method):
template = getattr(self.model(), template_method)()
string = template % dict(name=header)
temp_fields.append(string)
temp_fields.append(self._generate_insert_temp_fields(
field, header)
)

for v in self.static_mapping.values():
temp_fields.append("'%s'" % v)

for field, value in self.overloaded_crosswalk:
temp_fields.append(self._generate_insert_temp_fields(
field, value)
)
options['temp_fields'] = ", ".join(temp_fields)

return sql % options

def _generate_insert_temp_fields(self, concrete, column):
string = '"%s"' % column
if hasattr(concrete, 'copy_template'):
string = concrete.copy_template % dict(name=column)
template_method = 'copy_%s_template' % concrete.name
if hasattr(self.model, template_method):
template = getattr(self.model(), template_method)()
string = template % dict(name=column)
return string
20 changes: 20 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,23 @@ class Meta:
def copy_name_template(self):
return 'upper("%(name)s")'
copy_name_template.copy_type = 'text'


class OverloadMockObject(models.Model):
name = models.CharField(max_length=500)
lower_name = models.CharField(max_length=500)
number = MyIntegerField(null=True, db_column='num')
dt = models.DateField(null=True)
parent = models.ForeignKey('MockObject', null=True, default=None)

class Meta:
app_label = 'tests'

def copy_name_template(self):
return 'upper("%(name)s")'
copy_name_template.copy_type = 'text'

def copy_lower_name_template(self):
return 'lower("%(name)s")'
copy_name_template.copy_type = 'text'

30 changes: 29 additions & 1 deletion tests/tests.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from datetime import date
from .models import MockObject, ExtendedMockObject
from .models import MockObject, ExtendedMockObject, OverloadMockObject
from postgres_copy import CopyMapping
from django.test import TestCase

Expand All @@ -18,6 +18,7 @@ def setUp(self):
def tearDown(self):
MockObject.objects.all().delete()
ExtendedMockObject.objects.all().delete()
OverloadMockObject.objects.all().delete()

def test_bad_call(self):
with self.assertRaises(TypeError):
Expand Down Expand Up @@ -217,3 +218,30 @@ def test_save_foreign_key(self):
MockObject.objects.get(name='BEN').dt,
date(2012, 1, 1)
)

def test_overload_save(self):
c = CopyMapping(
OverloadMockObject,
self.name_path,
dict(name='NAME', number='NUMBER', dt='DATE'),
overloaded_mapping=dict(lower_name='NAME')
)
c.save()
self.assertEqual(OverloadMockObject.objects.count(), 3)
self.assertEqual(OverloadMockObject.objects.get(name='BEN').number, 1)
self.assertEqual(OverloadMockObject.objects.get(lower_name='ben').number, 1)
self.assertEqual(
OverloadMockObject.objects.get(name='BEN').dt,
date(2012, 1, 1)
)
omo = OverloadMockObject.objects.first()
self.assertEqual(omo.name.lower(), omo.lower_name)

def test_missing_overload_field(self):
with self.assertRaises(ValueError):
c = CopyMapping(
OverloadMockObject,
self.name_path,
dict(name='NAME', number='NUMBER', dt='DATE'),
overloaded_mapping=dict(missing='NAME')
)