Skip to content

Commit

Permalink
update iterobj class
Browse files Browse the repository at this point in the history
  • Loading branch information
cliffckerr committed Nov 1, 2023
1 parent f8d6ab2 commit 426d012
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 59 deletions.
150 changes: 102 additions & 48 deletions sciris/sc_nested.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
##############################################################################


__all__ = ['getnested', 'setnested', 'makenested', 'iternested', 'iterobj',
__all__ = ['getnested', 'setnested', 'makenested', 'iternested', 'IterObj', 'iterobj',
'mergenested', 'flattendict', 'nestedloop']


Expand Down Expand Up @@ -99,20 +99,26 @@ def makenested(nesteddict, keylist=None, value=None, overwrite=False, generator=

def check_iter_type(obj, check_array=False, known=None, known_to_none=True, custom=None):
''' Helper function to determine if an object is a dict, list, or neither -- not for the user '''
if known is not None and isinstance(obj, known):
out = '' if known_to_none else 'known' # Choose how known objects are handled
elif isinstance(obj, dict):
out = 'dict'
elif isinstance(obj, list):
out = 'list'
elif hasattr(obj, '__dict__'):
out = 'object'
elif check_array and isinstance(obj, np.ndarray):
out = 'array'
elif custom is not None:
out = custom(obj)
else:
out = '' # Evaluates to false
out = None
if custom is not None: # Handle custom first, to allow overrides
if custom and not callable(custom): # Ensure custom_type is callable
custom_func = (lambda obj: 'custom' if isinstance(obj, custom) else None)
else:
custom_func = custom
out = custom_func(obj)
if out is None:
if known is not None and isinstance(obj, known):
out = '' if known_to_none else 'known' # Choose how known objects are handled
elif isinstance(obj, dict):
out = 'dict'
elif isinstance(obj, list):
out = 'list'
elif hasattr(obj, '__dict__'):
out = 'object'
elif check_array and isinstance(obj, np.ndarray):
out = 'array'
else:
out = '' # Evaluates to false
return out


Expand Down Expand Up @@ -217,7 +223,7 @@ class IterObj(object):
Object iteration manager
For arguments and usage documentation, see :func:`sc.iterobj() <iterobj>`.
Use this class if you want more control over how the object is iterated over.
Use this class only if you want more control over how the object is iterated over.
Class-specific args:
custom_type (func): a custom function for returning a string for a specific object type (should return '' by default)
Expand All @@ -229,16 +235,41 @@ class IterObj(object):
import sciris as sc
def slowfunc(i):
sc.randsleep(seed=i)
return i**2
# Standard usage
P = sc.Parallel(slowfunc, iterarg=range(10), parallelizer='multiprocess-async')
P.run_async()
P.monitor()
P.finalize()
print(P.times)
# Create a simple class for storing data
class DataObj(sc.prettyobj):
def __init__(self, **kwargs):
self.keys = tuple(kwargs.keys())
self.values = tuple(kwargs.values())
# Create the data
obj1 = DataObj(a=[1,2,3], b=[4,5,6])
obj2 = DataObj(c=[7,8,9], d=[10])
obj = DataObj(obj1=obj1, obj2=obj2)
# Define custom methods for iterating over tuples and the DataObj
def custom_iter(obj):
if isinstance(obj, tuple):
return enumerate(obj)
if isinstance(obj, DataObj):
return [(k,v) for k,v in zip(obj.keys, obj.values)]
# Define custom method for getting data from each
def custom_get(obj, key):
if isinstance(obj, tuple):
return obj[key]
elif isinstance(obj, DataObj):
return obj.values[obj.keys.index(key)]
# Gather all data into one list
all_data = []
def gather_data(obj, all_data=all_data):
if isinstance(obj, list):
all_data += obj
# Run the iteration
io = sc.IterObj(obj, func=gather_data, custom_type=(tuple, DataObj), custom_iter=custom_iter, custom_get=custom_get)
io.iterate()
print(all_data)
| *New in version 3.1.2.*
'''
Expand Down Expand Up @@ -284,33 +315,51 @@ def __init__(self, obj, func=None, inplace=False, copy=False, leaf=False, atomic
self.itertype = check_iter_type(self.obj, known=self.atomic, custom=self.custom_type)

return


def indent(self, string='', space=' '):
''' Print, with output indented successively '''
if self.verbose:
print(space*len(self._trace) + string)
return

def iteritems(self):
''' Return an iterator over items in this object '''
if self.itertype == 'dict':
return self.obj.items()
elif self.itertype == 'list':
return enumerate(self.obj)
elif self.itertype == 'object':
return self.obj.__dict__.items()
elif self.custom_iter:
return self.custom_iter(self.obj)
else:
return {}.items() # Return nothing if not recognized
self.indent(f'Iterating with type "{self.itertype}"')
out = None
if self.custom_iter:
out = self.custom_iter(self.obj)
if out is None:
if self.itertype == 'dict':
out = self.obj.items()
elif self.itertype == 'list':
out = enumerate(self.obj)
elif self.itertype == 'object':
out = self.obj.__dict__.items()
else:
out = {}.items() # Return nothing if not recognized
return out

def getitem(self, key):
''' Get the value for the item '''
self.indent(f'Getting key "{key}"')
if self.itertype in ['dict', 'list']:
return self.obj[key]
elif self.itertype == 'object':
return self.obj.__dict__[key]
elif self.custom_get:
return self.custom_get(self.obj, key)
else:
return None

def setitem(self, key, value):
''' Set the value for the item '''
self.indent(f'Setting key "{key}"')
if self.itertype in ['dict', 'list']:
self.obj[key] = value
elif self.itertype == 'object':
self.obj.__dict__[key] = value
elif self.custom_set:
self.custom_set(self.obj, key, value)
return

def iterate(self):
Expand All @@ -321,16 +370,18 @@ def iterate(self):
trace = self._trace + [key]
newobj = subobj
subitertype = check_iter_type(subobj)
if self.verbose: # pragma: no cover
print(f'Working on {trace}, {self.leaf}, {subitertype}')
self.indent(f'Working on {trace}, leaf={self.leaf}, type={str(subitertype)}')
if not (self.leaf and subitertype):
newobj = self.func(subobj, *self.func_args, **self.func_kw)
if self.inplace:
self.setitem(self.obj, key, newobj)
self.setitem(key, newobj)
else:
self._output[tuple(trace)] = newobj
iterobj(self.getitem(key), self.func, inplace=self.inplace, leaf=self.leaf, atomic=self.atomic, # Run recursively
verbose=self.verbose, _trace=trace, _output=self._output, *self.func_args, **self.func_kw)
io = IterObj(self.getitem(key), self.func, inplace=self.inplace, leaf=self.leaf, # Create a new instance
atomic=self.atomic, verbose=self.verbose, _trace=trace, _output=self._output,
custom_type=self.custom_type, custom_iter=self.custom_iter, custom_get=self.custom_get, custom_set=self.custom_set,
*self.func_args, **self.func_kw)
io.iterate() # Run recursively

if self.inplace:
newobj = self.func(self.obj, *self.func_args, **self.func_kw) # Set at the root level
Expand All @@ -349,6 +400,9 @@ def iterobj(obj, func=None, inplace=False, copy=False, leaf=False, atomic='defau
Can modify an object in-place, or return a value. See also :func:`sc.search() <search>`
for a function to search through complex objects.
By default, lists, dictionaries, and objects are iterated over. For custom iteration
options, see :class:`sc.IterObj() <IterObj>`.
Note: there are three different output possibilities, depending on the keywords:
- ``inplace=False``, ``copy=False`` (default): collate the output of the function into a flat dictionary, with keys corresponding to each node of the project
Expand All @@ -373,23 +427,23 @@ def iterobj(obj, func=None, inplace=False, copy=False, leaf=False, atomic='defau
data = dict(a=dict(x=[1,2,3], y=[4,5,6]), b=dict(foo='string', bar='other_string'))
# Search through an object
def check_type(obj, which):
return isinstance(obj, which)
def check_int(obj):
return isinstance(obj, int)
out = sc.iterobj(data, check_type, which=int)
out = sc.iterobj(data, check_type)
print(out)
# Modify in place -- collapse mutliple short lines into one
def collapse(obj):
def collapse(obj, maxlen):
string = str(obj)
if len(string) < 10:
if len(string) < maxlen:
return string
else:
return obj
sc.printjson(data)
sc.iterobj(data, collapse, inplace=True)
sc.iterobj(data, collapse, inplace=True, maxlen=10) # Note passing of keyword argument to function
sc.printjson(data)
| *New in version 3.0.0.*
Expand Down
68 changes: 57 additions & 11 deletions tests/test_nested.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,13 @@ def test_iterobj():
o.a.i1 = [1,2,3]
o.b.i2 = dict(cat=[4,5,6])
data = dict(
a=dict(
x=[1,2,3],
y=[4,5,6]),
b=dict(
foo='string',
bar='other_string'),
c=o,
a = dict(
x = [1,2,3],
y = [4,5,6]),
b = dict(
foo = 'string',
bar = 'other_string'),
c = o,
)

# Search through an object
Expand All @@ -166,20 +166,65 @@ def check_type(obj, which):
print(out)

# Modify in place -- collapse mutliple short lines into one
def collapse(obj):
def collapse(obj, maxlen):
string = str(obj)
if len(string) < 10:
if len(string) < maxlen:
return string
else:
return obj

sc.printjson(data)
sc.iterobj(data, collapse, inplace=True)
sc.iterobj(data, collapse, inplace=True, maxlen=10) # Note passing of keyword argument to function
sc.iterobj(data, collapse, inplace=True, maxlen=10) # Note passing of keyword argument to function
sc.printjson(data)
assert data['a']['x'] == '[1, 2, 3]'

return out


def test_iterobj_class():
sc.heading('Testing iterobj class')

# Create a simple class for storing data
class DataObj(sc.prettyobj):
def __init__(self, **kwargs):
self.keys = tuple(kwargs.keys())
self.values = tuple(kwargs.values())

# Create the data
obj1 = DataObj(a=[0,1,2], b=[3,4,5])
obj2 = DataObj(c=[6,7,8], d=[9])
obj = DataObj(obj1=obj1, obj2=obj2)

# Define custom methods for iterating over tuples and the DataObj
def custom_iter(obj):
if isinstance(obj, tuple):
return enumerate(obj)
if isinstance(obj, DataObj):
return [(k,v) for k,v in zip(obj.keys, obj.values)]

# Define custom method for getting data from each
def custom_get(obj, key):
if isinstance(obj, tuple):
return obj[key]
elif isinstance(obj, DataObj):
return obj.values[obj.keys.index(key)]

# Gather all data into one list
all_data = []
def gather_data(obj, all_data=all_data):
if isinstance(obj, list):
all_data += obj

# Run the iteration
io = sc.IterObj(obj, func=gather_data, custom_type=(tuple, DataObj), custom_iter=custom_iter, custom_get=custom_get)
io.iterate()
print(all_data)
assert all_data == list(range(10))

return io


def test_equal():
sc.heading('Testing equal')
out = sc.objdict()
Expand Down Expand Up @@ -245,6 +290,7 @@ def test_equal():
dicts = test_dicts()
search = test_search()
iterobj = test_iterobj()
io_obj = test_iterobj_class()
equal = test_equal()

T.toc('Done.')

0 comments on commit 426d012

Please sign in to comment.