Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add method to remove circular references in data objects and add test #54930

Merged
merged 18 commits into from May 4, 2020
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
57 changes: 57 additions & 0 deletions salt/utils/data.py
Expand Up @@ -170,6 +170,45 @@ def compare_lists(old=None, new=None):
return ret


def _remove_circular_refs(ob, _seen=None):
'''
Generic method to remove circular references from objects.
This has been taken from author Martijn Pieters
https://stackoverflow.com/questions/44777369/
remove-circular-references-in-dicts-lists-tuples/44777477#44777477
:param ob: dict, list, typle, set, and frozenset
Standard python object
:param object _seen:
Object that has circular reference
:returns:
Cleaned Python object
:rtype:
type(ob)
'''
if _seen is None:
_seen = set()
if id(ob) in _seen:
# Here we caught a circular reference.
# Alert user and cleanup to continue.
log.exception(
'Caught a circular reference in data structure below.'
'Cleaning and continuing execution.\n%r\n',
ob,
)
return None
_seen.add(id(ob))
res = ob
if isinstance(ob, dict):
res = {
_remove_circular_refs(k, _seen): _remove_circular_refs(v, _seen)
for k, v in ob.items()}
elif isinstance(ob, (list, tuple, set, frozenset)):
res = type(ob)(_remove_circular_refs(v, _seen) for v in ob)
# remove id again; only *nested* references count
_seen.remove(id(ob))
return res


def decode(data, encoding=None, errors='strict', keep=False,
normalize=False, preserve_dict_class=False, preserve_tuples=False,
to_str=False):
Expand Down Expand Up @@ -200,6 +239,9 @@ def decode(data, encoding=None, errors='strict', keep=False,
for the base character, and one for the breve mark). Normalizing allows for
a more reliable test case.
'''
# Clean data object before decoding to avoid circular references
data = _remove_circular_refs(data)

_decode_func = salt.utils.stringutils.to_unicode \
if not to_str \
else salt.utils.stringutils.to_str
Expand Down Expand Up @@ -235,6 +277,9 @@ def decode_dict(data, encoding=None, errors='strict', keep=False,
Decode all string values to Unicode. Optionally use to_str=True to ensure
strings are str types and not unicode on Python 2.
'''
# Clean data object before decoding to avoid circular references
data = _remove_circular_refs(data)

_decode_func = salt.utils.stringutils.to_unicode \
if not to_str \
else salt.utils.stringutils.to_str
Expand Down Expand Up @@ -294,6 +339,9 @@ def decode_list(data, encoding=None, errors='strict', keep=False,
Decode all string values to Unicode. Optionally use to_str=True to ensure
strings are str types and not unicode on Python 2.
'''
# Clean data object before decoding to avoid circular references
data = _remove_circular_refs(data)

_decode_func = salt.utils.stringutils.to_unicode \
if not to_str \
else salt.utils.stringutils.to_str
Expand Down Expand Up @@ -350,6 +398,9 @@ def encode(data, encoding=None, errors='strict', keep=False,
can be useful for cases where the data passed to this function is likely to
contain binary blobs.
'''
# Clean data object before encoding to avoid circular references
data = _remove_circular_refs(data)

if isinstance(data, Mapping):
return encode_dict(data, encoding, errors, keep,
preserve_dict_class, preserve_tuples)
Expand Down Expand Up @@ -381,6 +432,9 @@ def encode_dict(data, encoding=None, errors='strict', keep=False,
'''
Encode all string values to bytes
'''
# Clean data object before encoding to avoid circular references
data = _remove_circular_refs(data)

ret = data.__class__() if preserve_dict_class else {}
for key, value in six.iteritems(data):
if isinstance(key, tuple):
Expand Down Expand Up @@ -434,6 +488,9 @@ def encode_list(data, encoding=None, errors='strict', keep=False,
'''
Encode all string values to bytes
'''
# Clean data object before encoding to avoid circular references
data = _remove_circular_refs(data)

ret = []
for item in data:
if isinstance(item, list):
Expand Down
19 changes: 19 additions & 0 deletions tests/unit/utils/test_data.py
Expand Up @@ -353,6 +353,25 @@ def test_decode(self):
BYTES,
keep=False)

def test_circular_refs_dicts(self):
test_dict = {"key": "value", "type": "test1"}
test_dict["self"] = test_dict
ret = salt.utils.data._remove_circular_refs(ob=test_dict)
self.assertDictEqual(ret, {'key': 'value', 'type': 'test1', 'self': None})

def test_circular_refs_lists(self):
test_list = {'foo': [], }
test_list['foo'].append((test_list,))
ret = salt.utils.data._remove_circular_refs(ob=test_list)
self.assertDictEqual(ret, {'foo': [(None,)]})

def test_circular_refs_tuple(self):
test_dup = {
'foo': 'string 1', 'bar': 'string 1',
'ham': 1, 'spam': 1}
ret = salt.utils.data._remove_circular_refs(ob=test_dup)
self.assertDictEqual(ret, {'foo': 'string 1', 'bar': 'string 1', 'ham': 1, 'spam': 1})

def test_decode_to_str(self):
'''
Companion to test_decode, they should both be kept up-to-date with one
Expand Down