Skip to content

Commit

Permalink
Merge pull request #361 from scopatz/i360
Browse files Browse the repository at this point in the history
Fixes #360
  • Loading branch information
Elliott Biondo authored and Elliott Biondo committed Mar 8, 2014
2 parents 71ba3ab + e9d9fe2 commit 6ed1949
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 29 deletions.
69 changes: 43 additions & 26 deletions pyne/mesh.py
Expand Up @@ -293,14 +293,17 @@ class IMeshTag(Tag):
features to this process.
"""

def __init__(self, size=1, dtype='f8', mesh=None, name=None, doc=None):
def __init__(self, size=1, dtype='f8', default=0.0, mesh=None, name=None, doc=None):
"""Parameters
----------
size : int, optional
The number of elements of type dtype that this tag stores.
dtype : np.dtype or similar, optional
The data type of this tag from int, float, and byte. See PyTAPS
tags for more details.
default : dtype or None, optional
The default value to fill this tag with upon creation. If None, then
the tag is created empty.
mesh : Mesh, optional
The PyNE mesh to tag.
name : str, optional
Expand All @@ -313,11 +316,17 @@ def __init__(self, size=1, dtype='f8', mesh=None, name=None, doc=None):
if mesh is None or name is None:
self._lazy_args['size'] = size
self._lazy_args['dtype'] = dtype
return
self._lazy_args['default'] = default
return
self.size = size
self.dtype = dtype
self.default = default
try:
self.tag = self.mesh.mesh.getTagHandle(self.name)
except iBase.TagNotFoundError:
self.tag = self.mesh.mesh.createTag(self.name, size, dtype)
if default is not None:
self[:] = default

def __delete__(self, mesh):
super(IMeshTag, self).__delete__(mesh)
Expand All @@ -327,7 +336,7 @@ def __getitem__(self, key):
m = self.mesh.mesh
size = len(self.mesh)
mtag = self.tag
miter = m.iterate(iBase.Type.region, iMesh.Topology.all)
miter = self.mesh.iter_ve()
if isinstance(key, _INTEGRAL_TYPES):
if key >= size:
raise IndexError("key index {0} greater than the size of the "
Expand All @@ -349,32 +358,37 @@ def __getitem__(self, key):
"or fancy index.".format(key))

def __setitem__(self, key, value):
# get value into canonical form
tsize = self.size
value = np.asarray(value, self.tag.type)
value = np.atleast_1d(value) if tsize == 1 else np.atleast_2d(value)
# set up mesh to be iterated over
m = self.mesh.mesh
size = len(self.mesh)
msize = len(self.mesh)
mtag = self.tag
miter = m.iterate(iBase.Type.region, iMesh.Topology.all)
miter = self.mesh.iter_ve()
if isinstance(key, _INTEGRAL_TYPES):
if key >= size:
if key >= msize:
raise IndexError("key index {0} greater than the size of the "
"mesh {1}".format(key, size))
"mesh {1}".format(key, msize))
for i_ve in zip(range(key+1), miter):
pass
mtag[i_ve[1]] = value
elif isinstance(key, slice):
idx = range(*key.indices(size))
if not (isinstance(value, Sequence) and len(value) == len(idx)):
value = [value] * len(idx)
mtag[list(miter)[key]] = value
key = list(miter)[key]
v = np.empty(len(key), self.tag.type) if tsize == 1 else \
np.empty((len(key), tsize), self.tag.type)
v[...] = value
mtag[key] = v
elif isinstance(key, np.ndarray) and key.dtype == np.bool:
if len(key) != size:
if len(key) != msize:
raise KeyError("boolean mask must match the length of the mesh.")
ntrues = key.sum()
if not (isinstance(value, Sequence) and len(value) == ntrues):
value = [value] * ntrues
mtag[[ve for b, ve in zip(key, miter) if b]] = value
key = [ve for b, ve in zip(key, miter) if b]
v = np.empty(len(key), self.tag.type) if tsize == 1 else \
np.empty((len(key), tsize), self.tag.type)
v[...] = value
mtag[key] = v
elif isinstance(key, Iterable):
if not (isinstance(value, Sequence) and len(value) == len(key)):
value = [value] * len(key)
ves = list(miter)
mtag[[ves[i] for i in key]] = value
else:
Expand All @@ -385,7 +399,7 @@ def __delitem__(self, key):
m = self.mesh.mesh
size = len(self.mesh)
mtag = self.tag
miter = m.iterate(iBase.Type.region, iMesh.Topology.all)
miter = self.mesh.iter_ve()
if isinstance(key, _INTEGRAL_TYPES):
if key >= size:
raise IndexError("key index {0} greater than the size of the "
Expand Down Expand Up @@ -510,8 +524,7 @@ class Mesh(object):
def __init__(self, mesh=None, mesh_file=None, structured=False, \
structured_coords=None, structured_set=None,
structured_ordering='xyz', mats=None):
"""
Parameters
"""Parameters
----------
mesh : iMesh instance, optional
mesh_file : str, optional
Expand Down Expand Up @@ -649,11 +662,7 @@ def __init__(self, mesh=None, mesh_file=None, structured=False, \
self.mats = mats

# tag with volume id and ensure mats exist.
if self.structured:
ves = list(self.structured_iterate_hex(self.structured_ordering))
else:
ves = list(self.mesh.iterate(iBase.Type.region, iMesh.Topology.all))

ves = list(self.iter_ve())
tags = self.mesh.getAllTags(ves[0])
tags = set(tag.name for tag in tags)
if 'idx' in tags:
Expand Down Expand Up @@ -713,6 +722,14 @@ def __iter__(self):
iMesh.Topology.all)):
yield i, mats[i], ve

def iter_ve(self):
"""Returns an iterator that yields on the volume elements.
"""
if self.structured:
return self.structured_iterate_hex(self.structured_ordering)
else:
return self.mesh.iterate(iBase.Type.region, iMesh.Topology.all)

def __contains__(self, i):
return i < len(self)

Expand Down
25 changes: 22 additions & 3 deletions tests/test_mesh.py
Expand Up @@ -661,9 +661,7 @@ def test_imeshtag():
}
m = gen_mesh(mats=mats)
m.f = IMeshTag(mesh=m, name='f')
ftag = m.mesh.getTagHandle('f')
ftag[list(m.mesh.iterate(iBase.Type.region, iMesh.Topology.all))] = \
[1.0, 2.0, 3.0, 4.0]
m.f[:] = [1.0, 2.0, 3.0, 4.0]

# Getting tags
assert_equal(m.f[0], 1.0)
Expand Down Expand Up @@ -733,6 +731,16 @@ def test_lazytaginit():
assert_in('cactus', m.tags)
assert_array_equal(m.cactus[0], [42, 43, 44])

x = np.arange(len(m))[:,np.newaxis] * np.array([42, 43, 44])
m.cactus[:] = x
assert_array_equal(m.cactus[2], x[2])

def test_issue360():
a = Mesh(structured=True, structured_coords=[[0,1,2],[0,1],[0,1]])
a.cat = IMeshTag(3, float)
a.cat[:] = [[0.11, 0.22, 0.33],[0.44, 0.55, 0.66]]
a.cat[:] = np.array([[0.11, 0.22, 0.33],[0.44, 0.55, 0.66]])

def test_iter():
mats = {
0: Material({'H1': 1.0, 'K39': 1.0}, density=42.0),
Expand All @@ -748,6 +756,17 @@ def test_iter():
assert_is(mats[i], mat)
assert_equal(j, idx_tag[ve])
j += 1

def test_iter_ve():
mats = {
0: Material({'H1': 1.0, 'K39': 1.0}, density=42.0),
1: Material({'H1': 0.1, 'O16': 1.0}, density=43.0),
2: Material({'He4': 42.0}, density=44.0),
3: Material({'Tm171': 171.0}, density=45.0),
}
m = gen_mesh(mats=mats)
ves1 = set(ve for _, _, ve in m)
ves2 = set(m.iter_ve())


def test_contains():
Expand Down

0 comments on commit 6ed1949

Please sign in to comment.