Skip to content

Commit

Permalink
Completed sweeping refactor of the library so that excluded headers a…
Browse files Browse the repository at this point in the history
…nd overloaded fields 'just work' and everything is more legible and straight forward. For #34.
  • Loading branch information
palewire committed Nov 12, 2016
1 parent f924679 commit 29afb60
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 101 deletions.
114 changes: 40 additions & 74 deletions postgres_copy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import sys
import csv
from collections import OrderedDict
from django.db import connections, router
from django.contrib.humanize.templatetags.humanize import intcomma
__version__ = '0.1.0'


class CopyMapping(object):
Expand All @@ -25,14 +28,19 @@ def __init__(
):
# Set the required arguments
self.model = model
self.mapping = mapping
self.mapping = OrderedDict(mapping)
self.csv_path = csv_path
if not os.path.exists(self.csv_path):
raise ValueError("csv_path does not exist")

# Line up the CSV file
if os.path.exists(csv_path):
self.csv_path = csv_path
# Hook in the other optional settings
self.delimiter = delimiter
self.null = null
self.encoding = encoding
if static_mapping is not None:
self.static_mapping = OrderedDict(static_mapping)
else:
raise ValueError("csv_path does not exist")
self.headers = self.get_headers()
self.static_mapping = {}

# Line up the database connection
if using is not None:
Expand All @@ -45,48 +53,18 @@ def __init__(
self.backend = self.conn.ops
self.temp_table_name = "temp_%s" % self.model._meta.db_table

# Hook in the other optional settings
self.delimiter = delimiter
self.null = null
self.encoding = encoding
if static_mapping is not None:
self.static_mapping = OrderedDict(static_mapping)
else:
self.static_mapping = {}
# Pull the CSV headers
self.headers = self.get_headers()

# Make sure the mapping is legit
# Make sure the everything is legit
self.validate_mapping()

# Identify any CSV headers that have been excluded from the mapping
self.excluded_headers = [h for h in self.headers if h not in self.mapping.keys()]

#
# Connect the headers from the CSV with the fields on the model
#

# self.field_header_crosswalk = []
# # Flip around the mapping so the CSV headers are the keys and database model fields the values
# inverse_mapping = {v: k for k, v in self.mapping.items()}
#
# # Loop through the CSV headers ...
# for h in headers:
# # Check if they are in the mapping
# try:
# f_name = inverse_mapping[h]
# except KeyError:
# # If the CSV field is not included on the model map,
# # that's okay, it just means the user has decided not
# # to load that column.
# self.excluded_headers.append(h)
# pass
# self.field_header_crosswalk.append((f, h))

def get_field(self, name):
"""
Returns any fields on the database model matching the provided name.
"""
try:
[f for f in self.model._meta.fields if f.name == map_field][0]
return [f for f in self.model._meta.fields if f.name == name][0]
except IndexError:
return None

Expand All @@ -104,12 +82,12 @@ def validate_mapping(self):
# Make sure all the model fields in the mapping actually exist
for map_field in self.mapping.keys():
if not self.get_field(map_field):
raise ValueError("Model does not include %s field" % f_name)
raise ValueError("Model does not include %s field" % map_field)

# Make sure any static mapping columns exist
for static_field in self.static_mapping.keys():
if not self.get_field(static_field):
raise ValueError("Model does not include %s field" % f_name)
raise ValueError("Model does not include %s field" % static_field)

def save(self, silent=False, stream=sys.stdout):
"""
Expand Down Expand Up @@ -177,24 +155,10 @@ def prep_create(self):
field_list = []

# Loop through all the fields and CSV headers together
for field_name, header in self.mapping.items():

# Pull the field object from the model
field = self.get_field(field_name)
for header in self.headers:

# Format the SQL create statement
string = '"%s" %s' % (header, field.db_type(self.conn))

# If the field has an override, use that
if hasattr(field, 'copy_template'):
string = '"%s" %s' % (header, field.copy_type)

# If the model has a more-specific override, use that
template_method = 'copy_%s_template' % field.name
if hasattr(self.model, template_method):
method = getattr(self.model(), template_method)
if hasattr(method, 'copy_type'):
string = '"%s" %s' % (header, method.copy_type)
string = '"%s" text' % header

# Add the string to the list
field_list.append(string)
Expand All @@ -220,7 +184,7 @@ def prep_copy(self):
'db_table': self.temp_table_name,
'extra_options': '',
'header_list': ", ".join([
'"%s"' % h for f, h in self.field_header_crosswalk
'"%s"' % h for h in self.headers
])
}
if self.delimiter:
Expand Down Expand Up @@ -249,36 +213,38 @@ def prep_insert(self):
temp_table=self.temp_table_name,
)

#
# The model fields to be inserted into
#

model_fields = []
for field_name, header in self.mapping.items():
# Any of the headers excluded from the mapping need to be skipped
if header in self.excluded_headers:
continue
field = self.get_field(field_name)
model_fields.append('"%s"' % field.get_attname_column()[1])

for k in self.static_mapping.keys():
model_fields.append('"%s"' % k)

print self.field_header_crosswalk
print model_fields
options['model_fields'] = ", ".join(model_fields)

temp_fields = []
for field, header in self.field_header_crosswalk:
#
# The temp fields to SELECT from
#

temp_fields = []
for field_name, header in self.mapping.items():
# Pull the field object from the model
field = self.get_field(field_name)
field_type = field.db_type(self.conn)

# Any of the headers excluded from the mapping need to be skipped
if header in self.excluded_headers:
continue
# Format the SQL
string = 'cast("%s" as %s)' % (header, field_type)

# Otherwise go ahead and format the SQL
string = '"%s"' % header

# The template overrides too
# Apply a datatype template override, if it exists
if hasattr(field, 'copy_template'):
string = field.copy_template % dict(name=header)

# Apply a field specific template override, if it exists
template_method = 'copy_%s_template' % field.name
if hasattr(self.model, template_method):
template = getattr(self.model(), template_method)()
Expand All @@ -294,5 +260,5 @@ def prep_insert(self):
# Join it all together
options['temp_fields'] = ", ".join(temp_fields)

print sql % options
# Pass it out
return sql % options
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def run(self):

setup(
name='django-postgres-copy',
version='0.0.8',
version='0.1.0',
description="Quickly load comma-delimited data into a Django model \
using PostgreSQL's COPY command",
author='Ben Welsh',
Expand Down
51 changes: 25 additions & 26 deletions tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,6 @@ def test_static_values(self):
ExtendedMockObject,
self.name_path,
dict(name='NAME', number='NUMBER', dt='DATE'),
encoding='UTF-8',
static_mapping={'static_val':1,'static_string':'test'}
)
c.save()
Expand Down Expand Up @@ -245,28 +244,28 @@ def test_save_foreign_key(self):
date(2012, 1, 1)
)

# def test_overload_save(self):
# c = CopyMapping(
# OverloadMockObject,
# self.name_path,
# dict(name='NAME', lower_name='NAME', upper_name='NAME', number='NUMBER', dt='DATE'),
# )
# c.save()
# self.assertEqual(OverloadMockObject.objects.count(), 3)
# self.assertEqual(OverloadMockObject.objects.get(name='ben').number, 1)
# self.assertEqual(OverloadMockObject.objects.get(lower_name='ben').number, 1)
# self.assertEqual(OverloadMockObject.objects.get(upper_name='BEN').number, 1)
# self.assertEqual(
# OverloadMockObject.objects.get(name='ben').dt,
# date(2012, 1, 1)
# )
# omo = OverloadMockObject.objects.first()
# self.assertEqual(omo.name.lower(), omo.lower_name)
#
# def test_missing_overload_field(self):
# with self.assertRaises(ValueError):
# c = CopyMapping(
# OverloadMockObject,
# self.name_path,
# dict(name='NAME', number='NUMBER', dt='DATE', missing='NAME'),
# )
def test_overload_save(self):
c = CopyMapping(
OverloadMockObject,
self.name_path,
dict(name='NAME', lower_name='NAME', upper_name='NAME', number='NUMBER', dt='DATE'),
)
c.save()
self.assertEqual(OverloadMockObject.objects.count(), 3)
self.assertEqual(OverloadMockObject.objects.get(name='ben').number, 1)
self.assertEqual(OverloadMockObject.objects.get(lower_name='ben').number, 1)
self.assertEqual(OverloadMockObject.objects.get(upper_name='BEN').number, 1)
self.assertEqual(
OverloadMockObject.objects.get(name='ben').dt,
date(2012, 1, 1)
)
omo = OverloadMockObject.objects.first()
self.assertEqual(omo.name.lower(), omo.lower_name)

def test_missing_overload_field(self):
with self.assertRaises(ValueError):
c = CopyMapping(
OverloadMockObject,
self.name_path,
dict(name='NAME', number='NUMBER', dt='DATE', missing='NAME'),
)

0 comments on commit 29afb60

Please sign in to comment.