Skip to content

Commit

Permalink
Merge pull request #92 from sciris/flatten-dict
Browse files Browse the repository at this point in the history
Update flattendict function
  • Loading branch information
cliffckerr committed Mar 25, 2020
2 parents 5d2959f + 02b8df8 commit 7146f29
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 38 deletions.
71 changes: 34 additions & 37 deletions sciris/sc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1879,48 +1879,45 @@ def mergenested(dict1, dict2, die=False, verbose=False, _path=None):
func.__doc__ = docstring


def flattendict(inputdict=None, basekey=None, subkeys=None, complist=None, keylist=None, limit=100):
'''
A function for flattening out a recursive dictionary, with an optional list of sub-keys (ignored if non-existent).
The flattened out structure is returned as complist. Values can be an object or a list of objects.
All keys (including basekey) within the recursion are returned as keylist.
Specifically, this function is intended for dictionaries of the form...
inputdict[key1][sub_key[0]] = [a, key2, b]
inputdict[key1][sub_key[1]] = [c, d]
inputdict[key2][sub_key[0]] = e
inputdict[key2][sub_key[1]] = [e, f, g]
...which, for this specific example, will output list...
[a, e, e, f, g, h, b, c, d]
There is a max-depth of limit for the recursion.
'''
def flattendict(input_dict: dict, sep: str = None, _prefix=None) -> dict:
"""
Flatten nested dictionary
if limit<1:
errormsg = 'ERROR: A recursion limit has been reached when flattening a dictionary, stopping at key "%s".' % basekey
raise Exception(errormsg)
Example:
if complist is None: complist = []
if keylist is None: keylist = []
keylist.append(basekey)
>>> flattendict({'a':{'b':1,'c':{'d':2,'e':3}}})
{('a', 'b'): 1, ('a', 'c', 'd'): 2, ('a', 'c', 'e'): 3}
>>> flattendict({'a':{'b':1,'c':{'d':2,'e':3}}}, sep='_')
{'a_b': 1, 'a_c_d': 2, 'a_c_e': 3}
if subkeys is None: inputlist = inputdict[basekey]
else:
inputlist = []
for sub_key in subkeys:
if sub_key in inputdict[basekey]:
val = inputdict[basekey][sub_key]
if isinstance(val, list):
inputlist += val
else:
inputlist.append(val) # Handle unlisted objects.
Args:
d: Input dictionary potentially containing dicts as values
sep: Concatenate keys using string separator. If ``None`` the returned dictionary will have tuples as keys
_prefix: Internal argument for recursively accumulating the nested keys
Returns:
A flat dictionary where no values are dicts
"""

output_dict = {}
for k, v in input_dict.items():
if sep is None:
if _prefix is None:
k2 = (k,)
else:
k2 = _prefix + (k,)
else:
if _prefix is None:
k2 = k
else:
k2 = _prefix + sep + k

for comp in inputlist:
if comp in inputdict.keys():
flattendict(inputdict=inputdict, basekey=comp, subkeys=subkeys, complist=complist, keylist=keylist, limit=limit-1)
if isinstance(v, dict):
output_dict.update(flattendict(input_dict[k], sep=sep, _prefix=k2))
else:
complist.append(comp)
return complist, keylist
output_dict[k2] = v

return output_dict


##############################################################################
Expand Down
2 changes: 1 addition & 1 deletion sciris/sc_version.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
__all__ = ['__version__', '__versiondate__', '__license__']

__version__ = '0.16.2'
__version__ = '0.16.3'
__versiondate__ = '2020-03-24'
__license__ = 'Sciris %s (%s) -- (c) Sciris.org' % (__version__, __versiondate__)
8 changes: 8 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ def test_printing():
return string


def test_flattendict():
# Simple integration test to make sure the function runs without raising an error
sc.flattendict({'a': {'b': 1, 'c': {'d': 2, 'e': 3}}})
flat = sc.flattendict({'a': {'b': 1, 'c': {'d': 2, 'e': 3}}}, sep='_')
return flat


def test_profile():
sc.heading('Test profiling functions')

Expand Down Expand Up @@ -160,6 +167,7 @@ def test_readdate():

test_colorize()
string = test_printing()
flat = test_flattendict()
foo = test_profile()
myobj = test_prepr()
uid = test_uuid()
Expand Down

0 comments on commit 7146f29

Please sign in to comment.