Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keep reference of copied items for cross referencing #579

Closed
wants to merge 23 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 143 additions & 84 deletions model_clone/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,9 @@ def pre_save_duplicate(self, instance): # pylint: disable=R0201
return instance

@transaction.atomic
def make_clone(self, attrs=None, sub_clone=False, using=None):
def make_clone(
self, attrs=None, sub_clone=False, using=None, cloned_references=None
):
"""Creates a clone of the django model instance.

:param attrs: Dictionary of attributes to be replaced on the cloned object.
Expand All @@ -212,6 +214,8 @@ def make_clone(self, attrs=None, sub_clone=False, using=None):
:type using: str
:return: The model instance that has been cloned.
"""
cloned_references = cloned_references or {}

using = using or self._state.db or self.__class__._default_manager.db
attrs = attrs or {}
if not self.pk:
Expand All @@ -224,18 +228,30 @@ def make_clone(self, attrs=None, sub_clone=False, using=None):
duplicate = self # pragma: no cover
duplicate.pk = None # pragma: no cover
else:
duplicate = self._create_copy_of_instance(self, using=using)
duplicate = self._create_copy_of_instance(
self, using=using, cloned_references=cloned_references
)

for name, value in attrs.items():
setattr(duplicate, name, value)

duplicate = self.pre_save_duplicate(duplicate)
duplicate.save(using=using)

duplicate = self.__duplicate_m2o_fields(duplicate, using=using)
duplicate = self.__duplicate_o2o_fields(duplicate, using=using)
duplicate = self.__duplicate_o2m_fields(duplicate, using=using)
duplicate = self.__duplicate_m2m_fields(duplicate, using=using)
cloned_references[self] = duplicate

duplicate = self.__duplicate_m2o_fields(
duplicate, using=using, cloned_references=cloned_references
)
duplicate = self.__duplicate_o2o_fields(
duplicate, using=using, cloned_references=cloned_references
)
duplicate = self.__duplicate_o2m_fields(
duplicate, using=using, cloned_references=cloned_references
)
duplicate = self.__duplicate_m2m_fields(
duplicate, using=using, cloned_references=cloned_references
)

return duplicate

Expand Down Expand Up @@ -283,7 +299,9 @@ def parallel_clone(self, count, attrs=None, batch_size=None, auto_commit=False):
pass

@staticmethod
def _create_copy_of_instance(instance, using=None, force=False, sub_clone=False):
def _create_copy_of_instance(
instance, using=None, force=False, sub_clone=False, cloned_references=None
):
"""Create a copy of a model instance.

:param instance: The instance to be duplicated.
Expand All @@ -298,6 +316,8 @@ def _create_copy_of_instance(instance, using=None, force=False, sub_clone=False)
:rtype: `django.db.models.Model`
"""
cls = instance.__class__
cloned_references = cloned_references or {}

clone_fields = getattr(cls, "_clone_fields", CloneMixin._clone_fields)
clone_excluded_fields = getattr(
cls, "_clone_excluded_fields", CloneMixin._clone_excluded_fields
Expand Down Expand Up @@ -398,14 +418,20 @@ def _create_copy_of_instance(instance, using=None, force=False, sub_clone=False)
elif isinstance(f, models.OneToOneField) and not sub_clone:
sub_instance = getattr(instance, f.name, None) or f.get_default()

if sub_instance is not None:
if sub_instance is not None and not cloned_references.get(
sub_instance
):
sub_instance = CloneMixin._create_copy_of_instance(
sub_instance,
force=True,
sub_clone=True,
cloned_references=cloned_references,
)
sub_instance.save(using=using)
value = sub_instance.pk
elif cloned_references.get(sub_instance):
value = cloned_references.get(sub_instance)

elif all(
[
use_duplicate_suffix_for_non_unique_fields,
Expand All @@ -429,38 +455,48 @@ def _create_copy_of_instance(instance, using=None, force=False, sub_clone=False)

return new_instance

def __duplicate_o2o_fields(self, duplicate, using=None):
def __duplicate_o2o_fields(self, duplicate, using=None, cloned_references=None):
"""Duplicate one to one fields.
:param duplicate: The transient instance that should be duplicated.
:type duplicate: `django.db.models.Model`
:param using: The database alias used to save the created instances.
:type using: str
:return: The duplicate instance with all the one to one fields duplicated.
"""
cloned_references = cloned_references or {}
for f in self._meta.related_objects:
if f.one_to_one:
if any(
[
f.name in self._clone_o2o_fields
and f not in self._meta.concrete_fields,
self._clone_excluded_o2o_fields
and f.name not in self._clone_excluded_o2o_fields
and f not in self._meta.concrete_fields,
]
):
rel_object = getattr(self, f.name, None)
if rel_object:
if f.one_to_one and any(
[
f.name in self._clone_o2o_fields
and f not in self._meta.concrete_fields,
self._clone_excluded_o2o_fields
and f.name not in self._clone_excluded_o2o_fields
and f not in self._meta.concrete_fields,
]
):
rel_object = getattr(self, f.name, None)
if rel_object:
if cloned_references.get(rel_object):
new_rel_object = cloned_references[rel_object]
elif hasattr(rel_object, "make_clone"):
new_rel_object = rel_object.make_clone(
{f.field.name: duplicate},
using=using,
cloned_references=cloned_references,
)
else:
new_rel_object = CloneMixin._create_copy_of_instance(
rel_object,
force=True,
sub_clone=True,
cloned_references=cloned_references,
)
setattr(new_rel_object, f.remote_field.name, duplicate)
new_rel_object.save(using=using)
setattr(new_rel_object, f.remote_field.name, duplicate)
new_rel_object.save(using=using)

return duplicate

def __duplicate_o2m_fields(self, duplicate, using=None):
def __duplicate_o2m_fields(self, duplicate, using=None, cloned_references=None):
"""Duplicate one to many fields.

:param duplicate: The transient instance that should be duplicated.
Expand All @@ -469,46 +505,56 @@ def __duplicate_o2m_fields(self, duplicate, using=None):
:type using: str
:return: The duplicate instance with all the transient one to many duplicated instances.
"""
cloned_references = cloned_references or {}

for f in itertools.chain(
self._meta.related_objects, self._meta.concrete_fields
):
if f.one_to_many:
if any(
[
f.get_accessor_name() in self._clone_m2o_or_o2m_fields,
self._clone_excluded_m2o_or_o2m_fields
and f.get_accessor_name()
not in self._clone_excluded_m2o_or_o2m_fields,
]
):
for item in getattr(self, f.get_accessor_name()).all():
if hasattr(item, "make_clone"):
try:
item.make_clone(
attrs={f.remote_field.name: duplicate},
using=using,
)
except IntegrityError:
item.make_clone(
attrs={f.remote_field.name: duplicate},
sub_clone=True,
using=using,
)
else:
new_item = CloneMixin._create_copy_of_instance(
item,
force=True,
if f.one_to_many and any(
[
f.get_accessor_name() in self._clone_m2o_or_o2m_fields,
self._clone_excluded_m2o_or_o2m_fields
and f.get_accessor_name()
not in self._clone_excluded_m2o_or_o2m_fields,
]
):

for item in getattr(self, f.get_accessor_name()).all():
cloned_reference = cloned_references.get(item)

if cloned_reference:
setattr(cloned_reference, f.remote_field.name, duplicate)
cloned_reference.save()
elif hasattr(item, "make_clone"):
try:
item.make_clone(
attrs={f.remote_field.name: duplicate},
using=using,
cloned_references=cloned_references,
)

except IntegrityError:
item.make_clone(
attrs={f.remote_field.name: duplicate},
sub_clone=True,
using=using,
cloned_references=cloned_references,
)
setattr(new_item, f.remote_field.name, duplicate)
else:
new_item = CloneMixin._create_copy_of_instance(
item,
force=True,
sub_clone=True,
using=using,
cloned_references=cloned_references,
)
setattr(new_item, f.remote_field.name, duplicate)

new_item.save(using=using)
new_item.save(using=using)

return duplicate

def __duplicate_m2o_fields(self, duplicate, using=None):
def __duplicate_m2o_fields(self, duplicate, using=None, cloned_references=None):
"""Duplicate many to one fields.

:param duplicate: The transient instance that should be duplicated.
Expand All @@ -517,30 +563,40 @@ def __duplicate_m2o_fields(self, duplicate, using=None):
:type using: str
:return: The duplicate instance with all the many to one fields duplicated.
"""
cloned_references = cloned_references or {}
for f in self._meta.concrete_fields:
if f.many_to_one:
if any(
if (
f.many_to_one
and any(
[
f.name in self._clone_m2o_or_o2m_fields,
self._clone_excluded_m2o_or_o2m_fields
and f.name not in self._clone_excluded_m2o_or_o2m_fields,
]
):
item = getattr(self, f.name)
if hasattr(item, "make_clone"):
try:
item_clone = item.make_clone(using=using)
except IntegrityError:
item_clone = item.make_clone(sub_clone=True)
else:
item.pk = None # pragma: no cover
item_clone = item.save(using=using) # pragma: no cover
)
and getattr(self, f.name)
):
item = getattr(self, f.name)
if cloned_references.get(item):
item_clone = cloned_references.get(item)
elif hasattr(item, "make_clone"):
try:
item_clone = item.make_clone(
using=using, cloned_references=cloned_references
)
except IntegrityError:
item_clone = item.make_clone(
sub_clone=True, cloned_references=cloned_references
)
else:
item.pk = None # pragma: no cover
item_clone = item.save(using=using) # pragma: no cover

setattr(duplicate, f.name, item_clone)
setattr(duplicate, f.name, item_clone)

return duplicate

def __duplicate_m2m_fields(self, duplicate, using=None):
def __duplicate_m2m_fields(self, duplicate, using=None, cloned_references=None):
"""Duplicate many to many fields.

:param duplicate: The transient instance that should be duplicated.
Expand All @@ -549,29 +605,28 @@ def __duplicate_m2m_fields(self, duplicate, using=None):
:type using: str
:return: The duplicate instance with all the many to many fields duplicated.
"""
fields = set()

for f in self._meta.many_to_many:
cloned_references = cloned_references or {}
fields = {
f
for f in self._meta.many_to_many
if any(
[
f.name in self._clone_m2m_fields,
self._clone_excluded_m2m_fields
and f.name not in self._clone_excluded_m2m_fields,
]
):
fields.add(f)
)
}

for f in self._meta.related_objects:
if f.many_to_many:
if any(
[
f.get_accessor_name() in self._clone_m2m_fields,
self._clone_excluded_m2m_fields
and f.get_accessor_name()
not in self._clone_excluded_m2m_fields,
]
):
fields.add(f)
if f.many_to_many and any(
[
f.get_accessor_name() in self._clone_m2m_fields,
self._clone_excluded_m2m_fields
and f.get_accessor_name() not in self._clone_excluded_m2m_fields,
]
):
fields.add(f)

# Clone many to many fields
for field in fields:
Expand All @@ -586,6 +641,7 @@ def __duplicate_m2m_fields(self, duplicate, using=None):
field_name = field.m2m_field_name()
source = getattr(self, field.attname)
destination = getattr(duplicate, field.attname)

if all(
[
through,
Expand All @@ -594,23 +650,26 @@ def __duplicate_m2m_fields(self, duplicate, using=None):
):
objs = through.objects.filter(**{field_name: self.pk})
for item in objs:
if hasattr(through, "make_clone"):
# TODO: add logic for cross references
if hasattr(item, "make_clone"):
try:
item.make_clone(
attrs={field_name: duplicate},
using=using,
cloned_references=cloned_references,
)
except IntegrityError:
item.make_clone(
attrs={field_name: duplicate},
sub_clone=True,
using=using,
cloned_references=cloned_references,
)
else:
item.pk = None
setattr(item, field_name, duplicate)
item.save(using=using)
else:
destination.set(source.all())
destination.set([cloned_references.get(s) or s for s in source.all()])

return duplicate
Loading