diff --git a/postgres_copy/copy_from.py b/postgres_copy/copy_from.py index 7b9847f..3524bf3 100644 --- a/postgres_copy/copy_from.py +++ b/postgres_copy/copy_from.py @@ -21,7 +21,7 @@ class CopyMapping(object): def __init__( self, model, - csv_path, + csv, mapping, using=None, delimiter=',', @@ -34,9 +34,14 @@ def __init__( ): # Set the required arguments self.model = model - self.csv_path = csv_path - if not os.path.exists(self.csv_path): - raise ValueError("csv_path does not exist") + self.csv_path = None + self.csv_file = None + if hasattr(csv, 'read'): + self.csv_file = csv + else: + self.csv_path = csv + if not os.path.exists(self.csv_path): + raise ValueError("csv_path does not exist") # Hook in the other optional settings self.quote_character = quote_character @@ -129,10 +134,15 @@ def get_headers(self): """ Returns the column headers from the csv as a list. """ - logger.debug("Retrieving headers from {}".format(self.csv_path)) - with open(self.csv_path, 'rU') as infile: - csv_reader = csv.reader(infile, delimiter=self.delimiter) + if self.csv_path: + logger.debug("Retrieving headers from {}".format(self.csv_path)) + with open(self.csv_path, 'rU') as infile: + csv_reader = csv.reader(infile, delimiter=self.delimiter) + headers = next(csv_reader) + else: + csv_reader = csv.reader(self.csv_file, delimiter=self.delimiter) headers = next(csv_reader) + self.csv_file.seek(0) return headers def validate_mapping(self): @@ -254,7 +264,7 @@ def copy(self, cursor): logger.debug("Running COPY command") copy_sql = self.prep_copy() logger.debug(copy_sql) - fp = open(self.csv_path, 'r') + fp = open(self.csv_path, 'r') if self.csv_path else self.csv_file cursor.copy_expert(copy_sql, fp) # Run post-copy hook diff --git a/tests/tests.py b/tests/tests.py index 7c07850..b194faf 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -234,6 +234,20 @@ def test_limited_fields(self): dict(name='NAME', dt='DATE'), ) + def test_simple_save_with_fileobject(self): + f = open(self.name_path, 'r') + MockObject.objects.from_csv( + f, + dict(name='NAME', number='NUMBER', dt='DATE') + ) + self.assertEqual(MockObject.objects.count(), 3) + self.assertEqual(MockObject.objects.get(name='BEN').number, 1) + self.assertEqual( + MockObject.objects.get(name='BEN').dt, + date(2012, 1, 1) + ) + print(MockObject.objects.all()) + def test_simple_save(self): insert_count = MockObject.objects.from_csv( self.name_path,