From d186c03e8d95fcba36971d8e3d04292263202ac3 Mon Sep 17 00:00:00 2001 From: Jonathan Sundqvist Date: Sat, 3 Feb 2018 22:18:33 +0100 Subject: [PATCH] Support a file like object as a csv_path --- postgres_copy/copy_from.py | 26 ++++++++++++++++++-------- tests/tests.py | 14 ++++++++++++++ 2 files changed, 32 insertions(+), 8 deletions(-) diff --git a/postgres_copy/copy_from.py b/postgres_copy/copy_from.py index 5a1712a..798cd7c 100644 --- a/postgres_copy/copy_from.py +++ b/postgres_copy/copy_from.py @@ -21,7 +21,7 @@ class CopyMapping(object): def __init__( self, model, - csv_path, + csv, mapping, using=None, delimiter=',', @@ -34,9 +34,14 @@ def __init__( ): # Set the required arguments self.model = model - self.csv_path = csv_path - if not os.path.exists(self.csv_path): - raise ValueError("csv_path does not exist") + self.csv_path = None + self.csv_file = None + if hasattr(csv, 'read'): + self.csv_file = csv + else: + self.csv_path = csv + if not os.path.exists(self.csv_path): + raise ValueError("csv_path does not exist") # Hook in the other optional settings self.quote_character = quote_character @@ -127,10 +132,15 @@ def get_headers(self): """ Returns the column headers from the csv as a list. """ - logger.debug("Retrieving headers from {}".format(self.csv_path)) - with open(self.csv_path, 'rU') as infile: - csv_reader = csv.reader(infile, delimiter=self.delimiter) + if self.csv_path: + logger.debug("Retrieving headers from {}".format(self.csv_path)) + with open(self.csv_path, 'rU') as infile: + csv_reader = csv.reader(infile, delimiter=self.delimiter) + headers = next(csv_reader) + else: + csv_reader = csv.reader(self.csv_file, delimiter=self.delimiter) headers = next(csv_reader) + self.csv_file.seek(0) return headers def validate_mapping(self): @@ -252,7 +262,7 @@ def copy(self, cursor): logger.debug("Running COPY command") copy_sql = self.prep_copy() logger.debug(copy_sql) - fp = open(self.csv_path, 'r') + fp = open(self.csv_path, 'r') if self.csv_path else self.csv_file cursor.copy_expert(copy_sql, fp) # Run post-copy hook diff --git a/tests/tests.py b/tests/tests.py index dc8d680..4bfc58f 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -234,6 +234,20 @@ def test_limited_fields(self): dict(name='NAME', dt='DATE'), ) + def test_simple_save_with_fileobject(self): + f = open(self.name_path, 'r') + MockObject.objects.from_csv( + f, + dict(name='NAME', number='NUMBER', dt='DATE') + ) + self.assertEqual(MockObject.objects.count(), 3) + self.assertEqual(MockObject.objects.get(name='BEN').number, 1) + self.assertEqual( + MockObject.objects.get(name='BEN').dt, + date(2012, 1, 1) + ) + print(MockObject.objects.all()) + def test_simple_save(self): MockObject.objects.from_csv( self.name_path,