From 29afb60193a823e493a775ddbb44777b9f19e8a8 Mon Sep 17 00:00:00 2001 From: palewire Date: Sat, 12 Nov 2016 09:09:23 -0800 Subject: [PATCH] Completed sweeping refactor of the library so that excluded headers and overloaded fields 'just work' and everything is more legible and straight forward. For #34. --- postgres_copy/__init__.py | 114 +++++++++++++------------------------- setup.py | 2 +- tests/tests.py | 51 +++++++++-------- 3 files changed, 66 insertions(+), 101 deletions(-) diff --git a/postgres_copy/__init__.py b/postgres_copy/__init__.py index af2f19e..3f0d07f 100644 --- a/postgres_copy/__init__.py +++ b/postgres_copy/__init__.py @@ -1,9 +1,12 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- import os import sys import csv from collections import OrderedDict from django.db import connections, router from django.contrib.humanize.templatetags.humanize import intcomma +__version__ = '0.1.0' class CopyMapping(object): @@ -25,14 +28,19 @@ def __init__( ): # Set the required arguments self.model = model - self.mapping = mapping + self.mapping = OrderedDict(mapping) + self.csv_path = csv_path + if not os.path.exists(self.csv_path): + raise ValueError("csv_path does not exist") - # Line up the CSV file - if os.path.exists(csv_path): - self.csv_path = csv_path + # Hook in the other optional settings + self.delimiter = delimiter + self.null = null + self.encoding = encoding + if static_mapping is not None: + self.static_mapping = OrderedDict(static_mapping) else: - raise ValueError("csv_path does not exist") - self.headers = self.get_headers() + self.static_mapping = {} # Line up the database connection if using is not None: @@ -45,48 +53,18 @@ def __init__( self.backend = self.conn.ops self.temp_table_name = "temp_%s" % self.model._meta.db_table - # Hook in the other optional settings - self.delimiter = delimiter - self.null = null - self.encoding = encoding - if static_mapping is not None: - self.static_mapping = OrderedDict(static_mapping) - else: - self.static_mapping = {} + # Pull the CSV headers + self.headers = self.get_headers() - # Make sure the mapping is legit + # Make sure the everything is legit self.validate_mapping() - # Identify any CSV headers that have been excluded from the mapping - self.excluded_headers = [h for h in self.headers if h not in self.mapping.keys()] - - # - # Connect the headers from the CSV with the fields on the model - # - - # self.field_header_crosswalk = [] - # # Flip around the mapping so the CSV headers are the keys and database model fields the values - # inverse_mapping = {v: k for k, v in self.mapping.items()} - # - # # Loop through the CSV headers ... - # for h in headers: - # # Check if they are in the mapping - # try: - # f_name = inverse_mapping[h] - # except KeyError: - # # If the CSV field is not included on the model map, - # # that's okay, it just means the user has decided not - # # to load that column. - # self.excluded_headers.append(h) - # pass - # self.field_header_crosswalk.append((f, h)) - def get_field(self, name): """ Returns any fields on the database model matching the provided name. """ try: - [f for f in self.model._meta.fields if f.name == map_field][0] + return [f for f in self.model._meta.fields if f.name == name][0] except IndexError: return None @@ -104,12 +82,12 @@ def validate_mapping(self): # Make sure all the model fields in the mapping actually exist for map_field in self.mapping.keys(): if not self.get_field(map_field): - raise ValueError("Model does not include %s field" % f_name) + raise ValueError("Model does not include %s field" % map_field) # Make sure any static mapping columns exist for static_field in self.static_mapping.keys(): if not self.get_field(static_field): - raise ValueError("Model does not include %s field" % f_name) + raise ValueError("Model does not include %s field" % static_field) def save(self, silent=False, stream=sys.stdout): """ @@ -177,24 +155,10 @@ def prep_create(self): field_list = [] # Loop through all the fields and CSV headers together - for field_name, header in self.mapping.items(): - - # Pull the field object from the model - field = self.get_field(field_name) + for header in self.headers: # Format the SQL create statement - string = '"%s" %s' % (header, field.db_type(self.conn)) - - # If the field has an override, use that - if hasattr(field, 'copy_template'): - string = '"%s" %s' % (header, field.copy_type) - - # If the model has a more-specific override, use that - template_method = 'copy_%s_template' % field.name - if hasattr(self.model, template_method): - method = getattr(self.model(), template_method) - if hasattr(method, 'copy_type'): - string = '"%s" %s' % (header, method.copy_type) + string = '"%s" text' % header # Add the string to the list field_list.append(string) @@ -220,7 +184,7 @@ def prep_copy(self): 'db_table': self.temp_table_name, 'extra_options': '', 'header_list': ", ".join([ - '"%s"' % h for f, h in self.field_header_crosswalk + '"%s"' % h for h in self.headers ]) } if self.delimiter: @@ -249,36 +213,38 @@ def prep_insert(self): temp_table=self.temp_table_name, ) + # + # The model fields to be inserted into + # + model_fields = [] for field_name, header in self.mapping.items(): - # Any of the headers excluded from the mapping need to be skipped - if header in self.excluded_headers: - continue + field = self.get_field(field_name) model_fields.append('"%s"' % field.get_attname_column()[1]) for k in self.static_mapping.keys(): model_fields.append('"%s"' % k) - print self.field_header_crosswalk - print model_fields options['model_fields'] = ", ".join(model_fields) - temp_fields = [] - for field, header in self.field_header_crosswalk: + # + # The temp fields to SELECT from + # + temp_fields = [] + for field_name, header in self.mapping.items(): # Pull the field object from the model field = self.get_field(field_name) + field_type = field.db_type(self.conn) - # Any of the headers excluded from the mapping need to be skipped - if header in self.excluded_headers: - continue + # Format the SQL + string = 'cast("%s" as %s)' % (header, field_type) - # Otherwise go ahead and format the SQL - string = '"%s"' % header - - # The template overrides too + # Apply a datatype template override, if it exists if hasattr(field, 'copy_template'): string = field.copy_template % dict(name=header) + + # Apply a field specific template override, if it exists template_method = 'copy_%s_template' % field.name if hasattr(self.model, template_method): template = getattr(self.model(), template_method)() @@ -294,5 +260,5 @@ def prep_insert(self): # Join it all together options['temp_fields'] = ", ".join(temp_fields) - print sql % options + # Pass it out return sql % options diff --git a/setup.py b/setup.py index d3399f0..dc1633e 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ def run(self): setup( name='django-postgres-copy', - version='0.0.8', + version='0.1.0', description="Quickly load comma-delimited data into a Django model \ using PostgreSQL's COPY command", author='Ben Welsh', diff --git a/tests/tests.py b/tests/tests.py index 94bf89f..8881b22 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -206,7 +206,6 @@ def test_static_values(self): ExtendedMockObject, self.name_path, dict(name='NAME', number='NUMBER', dt='DATE'), - encoding='UTF-8', static_mapping={'static_val':1,'static_string':'test'} ) c.save() @@ -245,28 +244,28 @@ def test_save_foreign_key(self): date(2012, 1, 1) ) - # def test_overload_save(self): - # c = CopyMapping( - # OverloadMockObject, - # self.name_path, - # dict(name='NAME', lower_name='NAME', upper_name='NAME', number='NUMBER', dt='DATE'), - # ) - # c.save() - # self.assertEqual(OverloadMockObject.objects.count(), 3) - # self.assertEqual(OverloadMockObject.objects.get(name='ben').number, 1) - # self.assertEqual(OverloadMockObject.objects.get(lower_name='ben').number, 1) - # self.assertEqual(OverloadMockObject.objects.get(upper_name='BEN').number, 1) - # self.assertEqual( - # OverloadMockObject.objects.get(name='ben').dt, - # date(2012, 1, 1) - # ) - # omo = OverloadMockObject.objects.first() - # self.assertEqual(omo.name.lower(), omo.lower_name) - # - # def test_missing_overload_field(self): - # with self.assertRaises(ValueError): - # c = CopyMapping( - # OverloadMockObject, - # self.name_path, - # dict(name='NAME', number='NUMBER', dt='DATE', missing='NAME'), - # ) + def test_overload_save(self): + c = CopyMapping( + OverloadMockObject, + self.name_path, + dict(name='NAME', lower_name='NAME', upper_name='NAME', number='NUMBER', dt='DATE'), + ) + c.save() + self.assertEqual(OverloadMockObject.objects.count(), 3) + self.assertEqual(OverloadMockObject.objects.get(name='ben').number, 1) + self.assertEqual(OverloadMockObject.objects.get(lower_name='ben').number, 1) + self.assertEqual(OverloadMockObject.objects.get(upper_name='BEN').number, 1) + self.assertEqual( + OverloadMockObject.objects.get(name='ben').dt, + date(2012, 1, 1) + ) + omo = OverloadMockObject.objects.first() + self.assertEqual(omo.name.lower(), omo.lower_name) + + def test_missing_overload_field(self): + with self.assertRaises(ValueError): + c = CopyMapping( + OverloadMockObject, + self.name_path, + dict(name='NAME', number='NUMBER', dt='DATE', missing='NAME'), + )