Skip to content

Commit

Permalink
Initial work on defaults.
Browse files Browse the repository at this point in the history
  • Loading branch information
Rowan Cockett committed Feb 15, 2016
1 parent 463b9b6 commit 4776544
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 13 deletions.
63 changes: 50 additions & 13 deletions SimPEG/PropMaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,18 @@ def fset(self, val):
setattr(self, '_%sMap'%prop.name, val)
return property(fget=fget, fset=fset, doc=prop.doc)

def _getDefaultProperty(self):
prop = self
def fget(self):
return getattr(self, '_%sDefault'%prop.name, None)
def fset(self, val):
if prop.propertyLink is not None:
linkName, linkMap = prop.propertyLink
assert getattr(self, '%sDefault'%linkName, None) is None, 'Cannot set both sides of a linked property.'
assert isinstance(val, np.ndarray) or np.isscalar(val), 'Default must be a scalar or a numpy array.'
setattr(self, '_%sDefault'%prop.name, val)
return property(fget=fget, fset=fset, doc=prop.doc)

def _getIndexProperty(self):
prop = self
def fget(self):
Expand All @@ -47,12 +59,17 @@ def _getProperty(self):
def fget(self):
mapping = getattr(self, '%sMap'%prop.name)
if mapping is None and prop.propertyLink is None:
return prop.defaultVal
return getattr(self, '%sDefault'%prop.name)

if mapping is None and prop.propertyLink is not None:
linkName, linkMapClass = prop.propertyLink
linkMap = linkMapClass(None)
if getattr(self, '%sMap'%linkName, None) is None:
# *
print linkName, getattr(self.propMap, '_%sDefault'%linkName, None)
if getattr(self, '%sMap'%linkName, None) is None and getattr(self.propMap, '_%sDefault'%linkName, None) is not None:
# We have a default
return linkMap * getattr(self, '%sDefault'%linkName, None)
elif getattr(self, '%sMap'%linkName, None) is None:
return prop.defaultVal
m = getattr(self, '%s'%linkName)
return linkMap * m
Expand Down Expand Up @@ -110,6 +127,12 @@ def fget(self):
return getattr(self.propMap, '_%sMap'%prop.name, None)
return property(fget=fget)

def _getModelDefaultProperty(self):
prop = self
def fget(self):
return getattr(self.propMap, '_%sDefault'%prop.name, prop.defaultVal)
return property(fget=fget)



class PropModel(object):
Expand Down Expand Up @@ -150,8 +173,9 @@ def __new__(cls, name, bases, attrs):
for attr in keys:
if isinstance(attrs[attr], Property):
attrs[attr].name = attr
attrs[attr + 'Map' ] = attrs[attr]._getMapProperty()
attrs[attr + 'Index'] = attrs[attr]._getIndexProperty()
attrs[attr + 'Map' ] = attrs[attr]._getMapProperty()
attrs[attr + 'Default'] = attrs[attr]._getDefaultProperty()
attrs[attr + 'Index' ] = attrs[attr]._getIndexProperty()
_properties[attr] = attrs[attr]
attrs.pop(attr)

Expand Down Expand Up @@ -181,11 +205,12 @@ def createPropModelClass(self, name, _properties):
for attr in _properties:
prop = _properties[attr]

attrs[attr ] = prop._getProperty()
attrs[attr + 'Map' ] = prop._getModelMapProperty()
attrs[attr + 'Proj' ] = prop._getModelProjProperty()
attrs[attr + 'Model'] = prop._getModelProperty()
attrs[attr + 'Deriv'] = prop._getModelDerivProperty()
attrs[attr ] = prop._getProperty()
attrs[attr + 'Map' ] = prop._getModelMapProperty()
attrs[attr + 'Default'] = prop._getModelDefaultProperty()
attrs[attr + 'Proj' ] = prop._getModelProjProperty()
attrs[attr + 'Model' ] = prop._getModelProperty()
attrs[attr + 'Deriv' ] = prop._getModelDerivProperty()

return type(name.replace('PropMap', 'PropModel'), (PropModel, ), attrs)

Expand All @@ -198,8 +223,8 @@ def __init__(self, mappings):
PropMap takes a multi parameter model and maps it to the equivalent PropModel
"""
if type(mappings) is dict:
assert np.all([k in ['maps', 'slices'] for k in mappings]), 'Dict must only have properties "maps" and "slices"'
self.setup(mappings['maps'], slices=mappings['slices'])
assert np.all([k in ['maps', 'slices', 'defaults'] for k in mappings]), 'Dict must only have properties "maps", "slices" and "defaults"'
self.setup(mappings['maps'], slices=mappings.get('slices',{}), defaults=mappings.get('defaults',{}))
elif type(mappings) is list:
self.setup(mappings)
elif isinstance(mappings, Maps.IdentityMap):
Expand All @@ -208,7 +233,7 @@ def __init__(self, mappings):
raise Exception('mappings must be a dict, a mapping, or a list of tuples.')


def setup(self, maps, slices=None):
def setup(self, maps, slices=None, defaults=None):
"""
Sets up the maps and slices for the PropertyMap
Expand All @@ -231,6 +256,13 @@ def setup(self, maps, slices=None):
s in self._properties and
(type(slices[s]) in [slice, list] or isinstance(slices[s], np.ndarray))
for s in slices]), 'Slices must be for each property'
if defaults is None:
defaults = dict()
else:
assert np.all([
s in self._properties and
(np.isscalar(defaults[s]) or isinstance(defaults[s], np.ndarray))
for s in defaults]), 'Defaults must be for each property'

self.clearMaps()

Expand All @@ -239,7 +271,12 @@ def setup(self, maps, slices=None):
setattr(self, '%sMap'%name, mapping)
setattr(self, '%sIndex'%name, slices.get(name, slice(nP, nP + mapping.nP)))
nP += mapping.nP
self.nP = nP
self.nP = nP

for key in defaults:
setattr(self, '%sDefault'%key, defaults[key])



@property
def defaultInvProp(self):
Expand Down
14 changes: 14 additions & 0 deletions tests/base/test_PropMaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,22 @@ def test_setup(self):
assert np.all(m.sigma == np.exp(np.r_[1.,2,3]))
assert m.sigmaDeriv is not None

assert m.mu == mu_0

assert m.nP == 3

def test_defaultOverride(self):
expMap = Maps.ExpMap(Mesh.TensorMesh((3,)))
PM = MyReciprocalPropMap({'maps':[('sigma', expMap)], 'defaults':{'mu':mu_0*2}})
self.assertRaises(Exception, MyReciprocalPropMap, {'maps':[('sigma', expMap)], 'defaults':{'mu':mu_0*2, 'mui':5}}) # Cannot set both sides of the default

m = PM(np.r_[1.,2,3])
assert np.all(m.sigmaModel == np.r_[1,2,3])

self.assertEqual(m.mu, mu_0 * 2)
# self.assertEqual(m.mui, 1/(mu_0 * 2))


def test_slices(self):
expMap = Maps.ExpMap(Mesh.TensorMesh((3,)))
PM = MyPropMap({'maps':[('sigma', expMap)], 'slices':{'sigma':[2,1,0]}})
Expand Down

0 comments on commit 4776544

Please sign in to comment.