Skip to content

Commit

Permalink
fix #24: pass on batch_size
Browse files Browse the repository at this point in the history
  • Loading branch information
chfw committed Dec 20, 2016
1 parent 37bbf17 commit 438a87e
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 52 deletions.
67 changes: 27 additions & 40 deletions pyexcel_io/database/django.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
:copyright: (c) 2014-2016 by Onni Software Ltd.
:license: New BSD License, see LICENSE for more details
"""
import logging

from pyexcel_io.book import BookReader, BookWriter
from pyexcel_io.sheet import SheetWriter
from pyexcel_io.utils import is_empty_array, swap_empty_string_for_none
Expand All @@ -15,6 +17,8 @@
from ._common import TableExportAdapter, TableExporter
from ._common import TableImporter, TableImportAdapter

log = logging.getLogger(__name__)


class DjangoModelReader(QuerysetsReader):
"""Read from django model
Expand All @@ -33,66 +37,46 @@ def __init__(self, model, export_columns=None, **keywords):


class DjangoModelWriter(SheetWriter):
def __init__(self, model, batch_size=None):
self.batch_size = batch_size
self.mymodel = None
self.column_names = None
self.mapdict = None
self.initializer = None

self.mymodel, self.column_names, self.mapdict, self.initializer = model
if self.initializer is None:
self.initializer = lambda row: row
if isinstance(self.mapdict, list):
self.column_names = self.mapdict
self.mapdict = None
elif isinstance(self.mapdict, dict):
self.column_names = [self.mapdict[name]
for name in self.column_names]
self.objs = []
def __init__(self, adapter, batch_size=None):
self.__batch_size = batch_size
self.__model = adapter.model
self.__column_names = adapter.column_names
self.__mapdict = adapter.column_name_mapping_dict
self.__initializer = adapter.row_initializer
self.__objs = []

def write_row(self, array):
if is_empty_array(array):
print(constants.MESSAGE_EMPTY_ARRAY)
else:
new_array = swap_empty_string_for_none(array)
model_to_be_created = new_array
if self.initializer is not None:
model_to_be_created = self.initializer(new_array)
if self.__initializer is not None:
model_to_be_created = self.__initializer(new_array)
if model_to_be_created:
self.objs.append(self.mymodel(**dict(
zip(self.column_names, model_to_be_created)
self.__objs.append(self.__model(**dict(
zip(self.__column_names, model_to_be_created)
)))
# else
# skip the row

def close(self):
try:
self.mymodel.objects.bulk_create(self.objs,
batch_size=self.batch_size)
self.__model.objects.bulk_create(self.__objs,
batch_size=self.__batch_size)
except Exception as e:
print(constants.MESSAGE_DB_EXCEPTION)
print(e)
for object in self.objs:
log.info(constants.MESSAGE_DB_EXCEPTION)
log.info(e)
for object in self.__objs:
try:
object.save()
except Exception as e2:
print(constants.MESSAGE_IGNORE_ROW)
print(e2)
print(object)
log.info(constants.MESSAGE_IGNORE_ROW)
log.info(e2)
log.info(object)
continue


class DjangoModelWriterNew(DjangoModelWriter):
def __init__(self, adapter, batch_size=None):
self.batch_size = batch_size
self.mymodel = adapter.model
self.column_names = adapter.column_names
self.mapdict = adapter.column_name_mapping_dict
self.initializer = adapter.row_initializer
self.objs = []


class DjangoModelExportAdapter(TableExportAdapter):
pass

Expand Down Expand Up @@ -133,12 +117,15 @@ class DjangoBookWriter(BookWriter):

def open_content(self, file_content, **keywords):
self.importer = file_content
self._keywords = keywords

def create_sheet(self, sheet_name):
sheet_writer = None
model = self.importer.get(sheet_name)
if model:
sheet_writer = DjangoModelWriterNew(model)
sheet_writer = DjangoModelWriter(
model,
batch_size=self._keywords.get('batch_size', None))
return sheet_writer


Expand Down
43 changes: 31 additions & 12 deletions tests/test_django_book.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,16 @@ def setUp(self):

def test_sheet_save_to_django_model(self):
model = FakeExceptionDjangoModel()
writer = DjangoModelWriter([model, self.data[0], None, None])
adapter = DjangoModelImportAdapter(model)
adapter.column_names = self.data[0]
writer = DjangoModelWriter(adapter)
writer.write_array(self.data[1:])
writer.close()
# now raise excpetion
model = FakeExceptionDjangoModel(raiseException=True)
writer = DjangoModelWriter([model, self.data[0], None, None])
adapter = DjangoModelImportAdapter(model)
adapter.column_names = self.data[0]
writer = DjangoModelWriter(adapter)
writer.write_array(self.data[1:])
writer.close()

Expand All @@ -126,10 +130,12 @@ def setUp(self):

def test_sheet_save_to_django_model(self):
model = FakeDjangoModel()
writer = DjangoModelWriter([model, self.data[0], None, None])
adapter = DjangoModelImportAdapter(model)
adapter.column_names = self.data[0]
writer = DjangoModelWriter(adapter)
writer.write_array(self.data[1:])
writer.close()
assert model.objects.objs == self.result
eq_(model.objects.objs, self.result)

def test_sheet_save_to_django_model_with_empty_array(self):
model = FakeDjangoModel()
Expand All @@ -139,7 +145,9 @@ def test_sheet_save_to_django_model_with_empty_array(self):
[1, 2, 3],
[4, 5, 6]
]
writer = DjangoModelWriter([model, data[0], None, None])
adapter = DjangoModelImportAdapter(model)
adapter.column_names = self.data[0]
writer = DjangoModelWriter(adapter)
writer.write_array(data[1:])
writer.close()
assert model.objects.objs == self.result
Expand All @@ -150,7 +158,10 @@ def test_sheet_save_to_django_model_3(self):
def wrapper(row):
row[0] = row[0] + 1
return row
writer = DjangoModelWriter([model, self.data[0], None, wrapper])
adapter = DjangoModelImportAdapter(model)
adapter.column_names = self.data[0]
adapter.row_initializer = wrapper
writer = DjangoModelWriter(adapter)
writer.write_array(self.data[1:])
writer.close()
assert model.objects.objs == [
Expand All @@ -166,7 +177,10 @@ def wrapper(row):
return None
else:
return row
writer = DjangoModelWriter([model, self.data[0], None, wrapper])
adapter = DjangoModelImportAdapter(model)
adapter.column_names = self.data[0]
adapter.row_initializer = wrapper
writer = DjangoModelWriter(adapter)
writer.write_array(self.data[1:])
writer.close()
assert model.objects.objs == [
Expand Down Expand Up @@ -217,12 +231,14 @@ def test_mapping_array(self):
[1, 2, 3],
[4, 5, 6]
]
mapdict = ["X", "Y", "Z"]
model = FakeDjangoModel()
writer = DjangoModelWriter([model, data2[0], mapdict, None])
adapter = DjangoModelImportAdapter(model)
adapter.column_names = data2[0]
adapter.column_name_mapping_dict = ["X", "Y", "Z"]
writer = DjangoModelWriter(adapter)
writer.write_array(data2[1:])
writer.close()
assert model.objects.objs == self.result
eq_(model.objects.objs, self.result)

def test_mapping_dict(self):
data2 = [
Expand All @@ -236,10 +252,13 @@ def test_mapping_dict(self):
"B": "Y"
}
model = FakeDjangoModel()
writer = DjangoModelWriter([model, data2[0], mapdict, None])
adapter = DjangoModelImportAdapter(model)
adapter.column_names = data2[0]
adapter.column_name_mapping_dict = mapdict
writer = DjangoModelWriter(adapter)
writer.write_array(data2[1:])
writer.close()
assert model.objects.objs == self.result
eq_(model.objects.objs, self.result)

def test_empty_model(self):
model = FakeDjangoModel()
Expand Down

0 comments on commit 438a87e

Please sign in to comment.