Skip to content

Commit

Permalink
Add support for instance parameters (#306)
Browse files Browse the repository at this point in the history
  • Loading branch information
philippjfr authored and jlstevens committed Feb 28, 2019
1 parent 78c32e7 commit 4ea353a
Show file tree
Hide file tree
Showing 5 changed files with 481 additions and 194 deletions.
79 changes: 37 additions & 42 deletions param/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,8 @@

from .parameterized import Parameterized, Parameter, String, \
descendents, ParameterizedFunction, ParamOverrides

from .parameterized import depends, output # noqa: api import
from .parameterized import logging_level # noqa: api import
from .parameterized import shared_parameters # noqa: api import
from .parameterized import (depends, output, logging_level, # noqa: api import
shared_parameters, instance_descriptor)

from collections import OrderedDict
from numbers import Real
Expand Down Expand Up @@ -153,7 +151,7 @@ def param_union(*parameterizeds, **kwargs):
kwargs.popitem()[0]))
d = dict()
for o in parameterizeds:
for k, p in o.param.params().items():
for k in o.param:
if k != 'name':
if k in d and warn:
warnings.warn("overwriting parameter {}".format(k))
Expand Down Expand Up @@ -480,7 +478,7 @@ def __call__(self, val=None, time_type=None):
if time_type and val is None:
raise Exception("Please specify a value for the new time_type.")
if time_type:
type_param = self.param.params('time_type')
type_param = self.param.objects('existing').get('time_type')
type_param.constant = False
self.time_type = time_type
type_param.constant = True
Expand Down Expand Up @@ -604,6 +602,7 @@ def __get__(self,obj,objtype):
return self._produce_value(gen)


@instance_descriptor
def __set__(self,obj,val):
"""
Call the superclass's set and keep this parameter's
Expand Down Expand Up @@ -842,18 +841,18 @@ def _checkBounds(self, val):
if vmax is not None:
if incmax is True:
if not val <= vmax:
raise ValueError("Parameter '%s' must be at most %s"%(self._attrib_name,vmax))
raise ValueError("Parameter '%s' must be at most %s"%(self.name,vmax))
else:
if not val < vmax:
raise ValueError("Parameter '%s' must be less than %s"%(self._attrib_name,vmax))
raise ValueError("Parameter '%s' must be less than %s"%(self.name,vmax))

if vmin is not None:
if incmin is True:
if not val >= vmin:
raise ValueError("Parameter '%s' must be at least %s"%(self._attrib_name,vmin))
raise ValueError("Parameter '%s' must be at least %s"%(self.name,vmin))
else:
if not val > vmin:
raise ValueError("Parameter '%s' must be greater than %s"%(self._attrib_name,vmin))
raise ValueError("Parameter '%s' must be greater than %s"%(self.name,vmin))



Expand All @@ -869,7 +868,7 @@ def _validate(self, val):
return

if not _is_number(val):
raise ValueError("Parameter '%s' only takes numeric values"%(self._attrib_name))
raise ValueError("Parameter '%s' only takes numeric values"%(self.name))

self._checkBounds(val)

Expand Down Expand Up @@ -913,7 +912,7 @@ def _validate(self, val):
return

if not isinstance(val,int):
raise ValueError("Parameter '%s' must be an integer."%self._attrib_name)
raise ValueError("Parameter '%s' must be an integer."%self.name)

self._checkBounds(val)

Expand Down Expand Up @@ -941,16 +940,16 @@ def _validate(self, val):
if self.allow_None:
if not isinstance(val,bool) and val is not None:
raise ValueError("Boolean '%s' only takes a Boolean value or None."
%self._attrib_name)
%self.name)

if val is not True and val is not False and val is not None:
raise ValueError("Boolean '%s' must be True, False, or None."%self._attrib_name)
raise ValueError("Boolean '%s' must be True, False, or None."%self.name)
else:
if not isinstance(val,bool):
raise ValueError("Boolean '%s' only takes a Boolean value."%self._attrib_name)
raise ValueError("Boolean '%s' only takes a Boolean value."%self.name)

if val is not True and val is not False:
raise ValueError("Boolean '%s' must be True or False."%self._attrib_name)
raise ValueError("Boolean '%s' must be True or False."%self.name)
super(Boolean, self)._validate(val)


Expand All @@ -972,7 +971,7 @@ def __init__(self,default=(0,0),length=None,**params):
self.length = len(default)
elif length is None and default is None:
raise ValueError("%s: length must be specified if no default is supplied." %
(self._attrib_name))
(self.name))
else:
self.length = length
self._validate(default)
Expand All @@ -983,12 +982,11 @@ def _validate(self, val):
return

if not isinstance(val,tuple):
raise ValueError("Tuple '%s' only takes a tuple value."%self._attrib_name)
raise ValueError("Tuple '%s' only takes a tuple value."%self.name)

if not len(val)==self.length:
raise ValueError("%s: tuple is not of the correct length (%d instead of %d)." %
(self._attrib_name,len(val),self.length))

(self.name,len(val),self.length))



Expand All @@ -1001,7 +999,7 @@ def _validate(self, val):
for n in val:
if not _is_number(n):
raise ValueError("%s: tuple element is not numeric: %s." %
(self._attrib_name,str(n)))
(self.name,str(n)))



Expand All @@ -1025,7 +1023,7 @@ class Callable(Parameter):

def _validate(self, val):
if not (self.allow_None and val is None) and (not callable(val)):
raise ValueError("Callable '%s' only takes a callable object."%self._attrib_name)
raise ValueError("Callable '%s' only takes a callable object."%self.name)
super(Callable, self)._validate(val)


Expand Down Expand Up @@ -1074,9 +1072,6 @@ class Composite(Parameter):
attributes.
"""

# Note: objtype is same as _owner, but objtype left for backwards
# compatibility (I think it's used in places to detect composite
# parameter)
__slots__=['attribs','objtype']

def __init__(self,attribs=None,**kw):
Expand All @@ -1095,7 +1090,7 @@ def __get__(self,obj,objtype):
return [getattr(obj,a) for a in self.attribs]

def _validate(self, val):
assert len(val) == len(self.attribs),"Compound parameter '%s' got the wrong number of values (needed %d, but got %d)." % (self._attrib_name,len(self.attribs),len(val))
assert len(val) == len(self.attribs),"Compound parameter '%s' got the wrong number of values (needed %d, but got %d)." % (self.name,len(self.attribs),len(val))

def _post_setter(self, obj, val):
if obj is None:
Expand Down Expand Up @@ -1202,7 +1197,7 @@ def _validate(self, val):
# CEBALERT: can be called before __init__ has called
# super's __init__, i.e. before attrib_name has been set.
try:
attrib_name = self._attrib_name
attrib_name = self.name
except AttributeError:
attrib_name = ""

Expand Down Expand Up @@ -1297,7 +1292,7 @@ def _validate(self,val):
if not (isinstance(val,self.class_)) and not (val is None and self.allow_None):
raise ValueError(
"Parameter '%s' value must be an instance of %s, not '%s'" %
(self._attrib_name, class_name, val))
(self.name, class_name, val))
else:
if not (val is None and self.allow_None) and not (issubclass(val,self.class_)):
raise ValueError(
Expand Down Expand Up @@ -1350,27 +1345,27 @@ def _validate(self, val):
return

if not isinstance(val, list):
raise ValueError("List '%s' must be a list."%(self._attrib_name))
raise ValueError("List '%s' must be a list."%(self.name))

if self.bounds is not None:
min_length,max_length = self.bounds
l=len(val)
if min_length is not None and max_length is not None:
if not (min_length <= l <= max_length):
raise ValueError("%s: list length must be between %s and %s (inclusive)"%(self._attrib_name,min_length,max_length))
raise ValueError("%s: list length must be between %s and %s (inclusive)"%(self.name,min_length,max_length))
elif min_length is not None:
if not min_length <= l:
raise ValueError("%s: list length must be at least %s."%(self._attrib_name,min_length))
raise ValueError("%s: list length must be at least %s."%(self.name,min_length))
elif max_length is not None:
if not l <= max_length:
raise ValueError("%s: list length must be at most %s."%(self._attrib_name,max_length))
raise ValueError("%s: list length must be at most %s."%(self.name,max_length))

self._check_type(val)

def _check_type(self,val):
if self.class_ is not None:
for v in val:
assert isinstance(v,self.class_),repr(self._attrib_name)+": "+repr(v)+" is not an instance of " + repr(self.class_) + "."
assert isinstance(v,self.class_),repr(self.name)+": "+repr(v)+" is not an instance of " + repr(self.class_) + "."



Expand All @@ -1386,7 +1381,7 @@ class HookList(List):

def _check_type(self,val):
for v in val:
assert callable(v),repr(self._attrib_name)+": "+repr(v)+" is not callable."
assert callable(v),repr(self.name)+": "+repr(v)+" is not callable."



Expand Down Expand Up @@ -1649,12 +1644,12 @@ def _resolve(self, path):
def _validate(self, val):
if val is None:
if not self.allow_None:
Parameterized(name="%s.%s"%(self._owner.name,self._attrib_name)).warning('None is not allowed')
Parameterized(name="%s.%s"%(self.owner.name,self.name)).warning('None is not allowed')
else:
try:
self._resolve(val)
except IOError as e:
Parameterized(name="%s.%s"%(self._owner.name,self._attrib_name)).warning('%s',e.args[0])
Parameterized(name="%s.%s"%(self.owner.name,self.name)).warning('%s',e.args[0])

def __get__(self, obj, objtype):
"""
Expand Down Expand Up @@ -1813,7 +1808,7 @@ def _validate(self, val):
return

if not isinstance(val, dt_types) and not (self.allow_None and val is None):
raise ValueError("Date '%s' only takes datetime types."%self._attrib_name)
raise ValueError("Date '%s' only takes datetime types."%self.name)

self._checkBounds(val)

Expand All @@ -1832,10 +1827,10 @@ def _validate(self, val):
if (self.allow_None and val is None):
return
if not isinstance(val, String.basestring):
raise ValueError("Color '%s' only takes a string value."%self._attrib_name)
raise ValueError("Color '%s' only takes a string value."%self.name)
if not re.match('^#?(([0-9a-fA-F]{2}){3}|([0-9a-fA-F]){3})$', val):
raise ValueError("Color '%s' only accepts valid RGB hex codes."
% self._attrib_name)
% self.name)



Expand Down Expand Up @@ -1907,7 +1902,7 @@ def _checkBounds(self, val):
too_high = (vmax is not None) and (v > vmax if incmax else v >= vmax)
if too_low or too_high:
raise ValueError("Parameter '%s' %s bound must be in range %s"
% (self._attrib_name, bound, self.rangestr()))
% (self.name, bound, self.rangestr()))


class DateRange(Range):
Expand All @@ -1923,11 +1918,11 @@ def _validate(self, val):

for n in val:
if not isinstance(n, dt_types):
raise ValueError("DateRange '%s' only takes datetime types: %s"%(self._attrib_name,val))
raise ValueError("DateRange '%s' only takes datetime types: %s"%(self.name,val))

start, end = val
if not end >= start:
raise ValueError("DateRange '%s': end date %s is before start date %s."%(self._attrib_name,val[1],val[0]))
raise ValueError("DateRange '%s': end date %s is before start date %s."%(self.name,val[1],val[0]))

# Calling super(DateRange, self)._check(val) would also check
# values are numeric, which is redundant, so just call
Expand Down
2 changes: 1 addition & 1 deletion param/ipython.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def get_param_info(self, obj, include_super=True):
True, parameters are also collected from the super classes.
"""

params = dict(obj.param.params())
params = dict(obj.param.objects('existing'))
if isinstance(obj,type):
changed = []
val_dict = dict((k,p.default) for (k,p) in params.items())
Expand Down
Loading

0 comments on commit 4ea353a

Please sign in to comment.