Skip to content

Commit

Permalink
Refactored the application so it is forgiving of CSV headers excluded…
Browse files Browse the repository at this point in the history
… from the model and the CopyMapping map dict. This is an alternative solution to @ckirby's pull request #30 that does not require adding a new kwarg but instead tries to have the code 'just work.'
  • Loading branch information
palewire committed Oct 30, 2016
1 parent d76c469 commit db3d3b2
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 3 deletions.
21 changes: 19 additions & 2 deletions postgres_copy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,24 @@ def __init__(
else:
self.static_mapping = {}

# Make sure all of the headers in the mapping actually exist in the CSV file
headers = self.get_headers()
for map_header in self.mapping.values():
if map_header not in headers:
raise ValueError("Header '%s' in the mapping not found in the CSV file" % map_header)

# Connect the headers from the CSV with the fields on the model
self.field_header_crosswalk = []
self.excluded_headers = []
inverse_mapping = {v: k for k, v in self.mapping.items()}
for h in self.get_headers():
for h in headers:
try:
f_name = inverse_mapping[h]
except KeyError:
raise ValueError("Map does not include %s field" % h)
# If the CSV field is not included on the model map, that's okay, it just means
# the user has decided not to load that column.
self.excluded_headers.append(h)
pass
try:
f = [f for f in self.model._meta.fields if f.name == f_name][0]
except IndexError:
Expand Down Expand Up @@ -204,6 +214,9 @@ def prep_insert(self):
model_fields = []

for field, header in self.field_header_crosswalk:
# Any of the headers excluded from the mapping need to be skipped
if header in self.excluded_headers:
continue
model_fields.append('"%s"' % field.get_attname_column()[1])

for k in self.static_mapping.keys():
Expand All @@ -213,6 +226,10 @@ def prep_insert(self):

temp_fields = []
for field, header in self.field_header_crosswalk:
# Any of the headers excluded from the mapping need to be skipped
if header in self.excluded_headers:
continue
# Otherwise go ahead and format the SQL
string = '"%s"' % header
if hasattr(field, 'copy_template'):
string = field.copy_template % dict(name=header)
Expand Down
12 changes: 12 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,15 @@ class Meta:
def copy_name_template(self):
return 'upper("%(name)s")'
copy_name_template.copy_type = 'text'


class LimitedMockObject(models.Model):
name = models.CharField(max_length=500)
dt = models.DateField(null=True)

class Meta:
app_label = 'tests'

def copy_name_template(self):
return 'upper("%(name)s")'
copy_name_template.copy_type = 'text'
23 changes: 22 additions & 1 deletion tests/tests.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from datetime import date
from .models import MockObject, ExtendedMockObject
from .models import MockObject, ExtendedMockObject, LimitedMockObject
from postgres_copy import CopyMapping
from django.test import TestCase

Expand All @@ -18,6 +18,7 @@ def setUp(self):
def tearDown(self):
MockObject.objects.all().delete()
ExtendedMockObject.objects.all().delete()
LimitedMockObject.objects.all().delete()

def test_bad_call(self):
with self.assertRaises(TypeError):
Expand Down Expand Up @@ -57,6 +58,13 @@ def test_bad_field(self):
dict(name1='NAME', number='NUMBER', dt='DATE'),
)

def test_limited_fields(self):
CopyMapping(
LimitedMockObject,
self.name_path,
dict(name='NAME', dt='DATE'),
)

def test_simple_save(self):
c = CopyMapping(
MockObject,
Expand All @@ -71,6 +79,19 @@ def test_simple_save(self):
date(2012, 1, 1)
)

def test_limited_save(self):
c = CopyMapping(
LimitedMockObject,
self.name_path,
dict(name='NAME', dt='DATE')
)
c.save()
self.assertEqual(LimitedMockObject.objects.count(), 3)
self.assertEqual(
LimitedMockObject.objects.get(name='BEN').dt,
date(2012, 1, 1)
)

def test_save_foreign_key(self):
c = CopyMapping(
MockObject,
Expand Down

0 comments on commit db3d3b2

Please sign in to comment.