Skip to content

Commit

Permalink
Fix merge errors
Browse files Browse the repository at this point in the history
  • Loading branch information
ckirby committed Oct 25, 2016
2 parents 35b96f4 + 2be775d commit 73dc5f3
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 43 deletions.
61 changes: 33 additions & 28 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -143,35 +143,40 @@ Argument Description
string field names for the CSV header.
================= =========================================================

===================== =====================================================
======================= =====================================================
Keyword Arguments
===================== =====================================================
``delimiter`` The character that separates values in the data file.
By default it is ",". This must be a single one-byte
character.

``null`` Specifies the string that represents a null value.
The default is an unquoted empty string. This must
be a single one-byte character.

``encoding`` Specifies the character set encoding of the strings
in the CSV data source. For example, ``'latin-1'``,
``'utf-8'``, and ``'cp437'`` are all valid encoding
parameters.

``using`` Sets the database to use when importing data.
Default is None, which will use the ``'default'``
database.

``static_mapping`` Set model attributes not in the CSV the same
for every row in the database by providing a dictionary
with the name of the columns as keys and the static
inputs as values.

``ignore_headers`` A list of headers from your csv that don't have
equivalent fields in your model. These columns will
be ignored.
===================== =====================================================
======================= =====================================================
``delimiter`` The character that separates values in the data file.
By default it is ",". This must be a single one-byte
character.

``null`` Specifies the string that represents a null value.
The default is an unquoted empty string. This must
be a single one-byte character.

``encoding`` Specifies the character set encoding of the strings
in the CSV data source. For example, ``'latin-1'``,
``'utf-8'``, and ``'cp437'`` are all valid encoding
parameters.

``using`` Sets the database to use when importing data.
Default is None, which will use the ``'default'``
database.

``static_mapping`` Set model attributes not in the CSV the same
for every row in the database by providing a dictionary
with the name of the columns as keys and the static
inputs as values.

``ignore_headers`` A list of headers from your csv that don't have
equivalent fields in your model. These columns will
be ignored.

``overloaded_mapping`` Reuse a mapped column for a different model field.
This is useful when you want to have both the
original value as well as a modified form, generally
using a `copy_template` to transform the second value
======================= =====================================================


``save()`` keyword arguments
Expand Down
58 changes: 44 additions & 14 deletions postgres_copy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import csv
import os
import sys
import csv
from django.db import connections, router
from django.contrib.humanize.templatetags.humanize import intcomma
from collections import OrderedDict

from django.contrib.humanize.templatetags.humanize import intcomma
from django.db import connections, router


class CopyMapping(object):
"""
Maps comma-delimited data file to a Django model
and loads it into PostgreSQL databases using its
COPY command.
"""

def __init__(
self,
model,
Expand All @@ -22,7 +24,8 @@ def __init__(
delimiter=',',
null=None,
encoding=None,
static_mapping=None
static_mapping=None,
overloaded_mapping=None
):
self.model = model
self.mapping = mapping
Expand All @@ -49,6 +52,10 @@ def __init__(
self.static_mapping = OrderedDict(static_mapping)
else:
self.static_mapping = {}
if overloaded_mapping is not None:
self.overloaded_mapping = overloaded_mapping
else:
self.overloaded_mapping = {}

# Connect the headers from the CSV with the fields on the model
self.field_header_crosswalk = []
Expand All @@ -67,13 +74,21 @@ def __init__(
except IndexError:
raise ValueError("Model does not include %s field" % f_name)
self.field_header_crosswalk.append((f, h))

# Validate that the static mapping columns exist
for f_name in self.static_mapping.keys():
try:
[s for s in self.model._meta.fields if s.name == f_name][0]
except IndexError:
raise ValueError("Model does not include %s field" % f_name)
# Validate Overloaded headers and fields
self.overloaded_crosswalk = []
for k, v in self.overloaded_mapping.items():
try:
o = [o for o in self.model._meta.fields if o.name == k][0]
self.overloaded_crosswalk.append((o, v))
except IndexError:
raise ValueError("Model does not include overload %s field"
% v)

self.temp_table_name = "temp_%s" % self.model._meta.db_table

Expand Down Expand Up @@ -182,7 +197,7 @@ def prep_copy(self):
'extra_options': '',
'header_list': ", ".join([
'"%s"' % h for f, h in self.field_header_crosswalk
])
])
}
if self.delimiter:
options['extra_options'] += " DELIMITER '%s'" % self.delimiter
Expand Down Expand Up @@ -220,21 +235,36 @@ def prep_insert(self):
for k in self.static_mapping.keys():
model_fields.append('"%s"' % k)

for field, value in self.overloaded_crosswalk:
model_fields.append('"%s"' % field.get_attname_column()[1])

options['model_fields'] = ", ".join(model_fields)

temp_fields = []
for field, header in self.field_header_crosswalk:
if header in self.ignore_headers:
continue
string = '"%s"' % header
if hasattr(field, 'copy_template'):
string = field.copy_template % dict(name=header)
template_method = 'copy_%s_template' % field.name
if hasattr(self.model, template_method):
template = getattr(self.model(), template_method)()
string = template % dict(name=header)
temp_fields.append(string)
temp_fields.append(self._generate_insert_temp_fields(
field, header)
)

for v in self.static_mapping.values():
temp_fields.append("'%s'" % v)

for field, value in self.overloaded_crosswalk:
temp_fields.append(self._generate_insert_temp_fields(
field, value)
)
options['temp_fields'] = ", ".join(temp_fields)

return sql % options

def _generate_insert_temp_fields(self, concrete, column):
string = '"%s"' % column
if hasattr(concrete, 'copy_template'):
string = concrete.copy_template % dict(name=column)
template_method = 'copy_%s_template' % concrete.name
if hasattr(self.model, template_method):
template = getattr(self.model(), template_method)()
string = template % dict(name=column)
return string
19 changes: 19 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,22 @@ def copy_name_template(self):
return 'upper("%(name)s")'
copy_name_template.copy_type = 'text'


class OverloadMockObject(models.Model):
name = models.CharField(max_length=500)
lower_name = models.CharField(max_length=500)
number = MyIntegerField(null=True, db_column='num')
dt = models.DateField(null=True)
parent = models.ForeignKey('MockObject', null=True, default=None)

class Meta:
app_label = 'tests'

def copy_name_template(self):
return 'upper("%(name)s")'
copy_name_template.copy_type = 'text'

def copy_lower_name_template(self):
return 'lower("%(name)s")'
copy_name_template.copy_type = 'text'

31 changes: 30 additions & 1 deletion tests/tests.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from datetime import date
from .models import MockObject, ExtendedMockObject, LimitedMockObject
from .models import MockObject, ExtendedMockObject, LimitedMockObject,\
OverloadMockObject
from postgres_copy import CopyMapping
from django.test import TestCase

Expand All @@ -19,6 +20,7 @@ def tearDown(self):
MockObject.objects.all().delete()
ExtendedMockObject.objects.all().delete()
LimitedMockObject.objects.all().delete()
OverloadMockObject.objects.all().delete()

def test_bad_call(self):
with self.assertRaises(TypeError):
Expand Down Expand Up @@ -243,3 +245,30 @@ def test_save_foreign_key(self):
MockObject.objects.get(name='BEN').dt,
date(2012, 1, 1)
)

def test_overload_save(self):
c = CopyMapping(
OverloadMockObject,
self.name_path,
dict(name='NAME', number='NUMBER', dt='DATE'),
overloaded_mapping=dict(lower_name='NAME')
)
c.save()
self.assertEqual(OverloadMockObject.objects.count(), 3)
self.assertEqual(OverloadMockObject.objects.get(name='BEN').number, 1)
self.assertEqual(OverloadMockObject.objects.get(lower_name='ben').number, 1)
self.assertEqual(
OverloadMockObject.objects.get(name='BEN').dt,
date(2012, 1, 1)
)
omo = OverloadMockObject.objects.first()
self.assertEqual(omo.name.lower(), omo.lower_name)

def test_missing_overload_field(self):
with self.assertRaises(ValueError):
c = CopyMapping(
OverloadMockObject,
self.name_path,
dict(name='NAME', number='NUMBER', dt='DATE'),
overloaded_mapping=dict(missing='NAME')
)

0 comments on commit 73dc5f3

Please sign in to comment.