Skip to content

Commit

Permalink
Add eager option to watch calls (#351)
Browse files Browse the repository at this point in the history
* Add eager option to watch calls

* Renamed to queued, changed default and added test
  • Loading branch information
philippjfr committed Oct 1, 2019
1 parent f6542cc commit bd4bd1a
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 24 deletions.
43 changes: 22 additions & 21 deletions param/parameterized.py
Expand Up @@ -88,7 +88,7 @@ def logging_level(level):


@contextmanager
def batch_watch(parameterized, run=True):
def batch_watch(parameterized, enable=True, run=True):
"""
Context manager to batch watcher events on a parameterized object.
The context manager will queue any events triggered by setting a
Expand All @@ -97,7 +97,7 @@ def batch_watch(parameterized, run=True):
queued events are not dispatched and should be processed manually.
"""
BATCH_WATCH = parameterized.param._BATCH_WATCH
parameterized.param._BATCH_WATCH = True
parameterized.param._BATCH_WATCH = enable or parameterized.param._BATCH_WATCH
try:
yield
finally:
Expand All @@ -112,7 +112,7 @@ def edit_constant(parameterized):
Temporarily set parameters on Parameterized object to constant=False
to allow editing them.
"""
params = parameterized.objects('existing').values()
params = parameterized.param.objects('existing').values()
constants = [p.constant for p in params]
for p in params:
p.constant = False
Expand Down Expand Up @@ -479,7 +479,7 @@ def _m_caller(self,n):
PInfo = namedtuple("PInfo","inst cls name pobj what")
MInfo = namedtuple("MInfo","inst cls name method")
Event = namedtuple("Event","what name obj cls old new type")
Watcher = namedtuple("Watcher","inst cls fn mode onlychanged parameter_names what")
Watcher = namedtuple("Watcher","inst cls fn mode onlychanged parameter_names what queued")

class ParameterMetaclass(type):
"""
Expand Down Expand Up @@ -1449,13 +1449,13 @@ def _call_watcher(self_, watcher, event):
self_._events.append(event)
if watcher not in self_._watchers:
self_._watchers.append(watcher)
elif watcher.mode == 'args':
with batch_watch(self_.self_or_cls, run=False):
watcher.fn(self_._update_event_type(watcher, event, self_.self_or_cls.param._TRIGGER))
else:
with batch_watch(self_.self_or_cls, run=False):
event = self_._update_event_type(watcher, event, self_.self_or_cls.param._TRIGGER)
watcher.fn(**{event.name: event.new})
event = self_._update_event_type(watcher, event, self_.self_or_cls.param._TRIGGER)
with batch_watch(self_.self_or_cls, enable=watcher.queued, run=False):
if watcher.mode == 'args':
watcher.fn(event)
else:
watcher.fn(**{event.name: event.new})


def _batch_call_watchers(self_):
Expand All @@ -1475,7 +1475,7 @@ def _batch_call_watchers(self_):
self_.self_or_cls.param._TRIGGER)
for name in watcher.parameter_names
if (name, watcher.what) in event_dict]
with batch_watch(self_.self_or_cls, run=False):
with batch_watch(self_.self_or_cls, enable=watcher.queued, run=False):
if watcher.mode == 'args':
watcher.fn(*events)
else:
Expand Down Expand Up @@ -1718,11 +1718,11 @@ def _watch(self_, action, watcher, what='value', operation='add'): #'add' | 'rem
watchers[what] = []
getattr(watchers[what], action)(watcher)

def watch(self_,fn,parameter_names, what='value', onlychanged=True):
def watch(self_,fn,parameter_names, what='value', onlychanged=True, queued=False):
parameter_names = tuple(parameter_names) if isinstance(parameter_names, list) else (parameter_names,)
watcher = Watcher(inst=self_.self, cls=self_.cls, fn=fn, mode='args',
onlychanged=onlychanged, parameter_names=parameter_names,
what=what)
what=what, queued=queued)
self_._watch('append', watcher, what)
return watcher

Expand All @@ -1736,16 +1736,16 @@ def unwatch(self_,watcher):
self_.warning('No such watcher {watcher} to remove.'.format(watcher=watcher))


def watch_values(self_,fn,parameter_names,what='value', onlychanged=True):
def watch_values(self_, fn, parameter_names, what='value', onlychanged=True, queued=False):
parameter_names = tuple(parameter_names) if isinstance(parameter_names, list) else (parameter_names,)
watcher = Watcher(inst=self_.self, cls=self_.cls, fn=fn,
mode='kwargs', onlychanged=onlychanged,
parameter_names=parameter_names, what='value')
parameter_names=parameter_names, what='value',
queued=queued)
self_._watch('append', watcher, what)
return watcher



# Instance methods


Expand Down Expand Up @@ -1903,10 +1903,11 @@ def __init__(mcs,name,bases,dict_):
# everything else access from here rather than from method
# object
for n,dinfo in dependers:
if dinfo.get('watch', False):
_watch.append(n)
watch = dinfo.get('watch', False)
if watch:
_watch.append((n, watch == 'queued'))

mcs.param._depends = {'watch':_watch}
mcs.param._depends = {'watch': _watch}

if docstring_signature:
mcs.__class_docstring_signature()
Expand Down Expand Up @@ -2343,14 +2344,14 @@ def __init__(self,**params):
for cls in classlist(self.__class__):
if not issubclass(cls, Parameterized):
continue
for n in cls.param._depends['watch']:
for n, queued in cls.param._depends['watch']:
# TODO: should improve this - will happen for every
# instantiation of Parameterized with watched deps. Will
# probably store expanded deps on class - see metaclass
# 'dependers'.
for p in self.param.params_depended_on(n):
# TODO: can't remember why not just pass m (rather than _m_caller) here
(p.inst or p.cls).param.watch(_m_caller(self,n),p.name,p.what)
(p.inst or p.cls).param.watch(_m_caller(self, n), p.name, p.what, queued=queued)

self.initialized=True

Expand Down
23 changes: 20 additions & 3 deletions tests/API1/testwatch.py
Expand Up @@ -44,11 +44,16 @@ class SimpleWatchSubclass(SimpleWatchExample):

class WatchMethodExample(SimpleWatchSubclass):

@param.depends('a', watch=True)
@param.depends('a', watch='queued')
def _clip_a(self):
if self.a > 3:
self.a = 3

@param.depends('b', watch=True)
def _clip_b(self):
if self.b > 10:
self.b = 10

@param.depends('b', watch=True)
def _set_c(self):
self.c = self.b*2
Expand Down Expand Up @@ -451,18 +456,30 @@ def test_dependent_params(self):
obj.b = 3
self.assertEqual(obj.c, 6)

def test_multiple_watcher_dispatch(self):
def test_multiple_watcher_dispatch_queued(self):
obj = WatchMethodExample()
obj2 = SimpleWatchExample()

def link(event):
obj2.a = event.new

obj.param.watch(link, 'a')
obj.param.watch(link, 'a', queued=True)
obj.a = 4
self.assertEqual(obj.a, 3)
self.assertEqual(obj2.a, 3)

def test_multiple_watcher_dispatch(self):
obj = WatchMethodExample()
obj2 = SimpleWatchExample()

def link(event):
obj2.b = event.new

obj.param.watch(link, 'b')
obj.b = 11
self.assertEqual(obj.b, 10)
self.assertEqual(obj2.b, 11)

def test_multiple_watcher_dispatch_on_param_attribute(self):
obj = WatchMethodExample()
accumulator = Accumulator()
Expand Down

0 comments on commit bd4bd1a

Please sign in to comment.