Skip to content

Commit 7a12f40

Browse files
Nested integrals (#914)
Add support for nested integrals over the same space.
2 parents 02c6205 + 1a3328e commit 7a12f40

File tree

7 files changed

+345
-99
lines changed

7 files changed

+345
-99
lines changed

nutils/SI.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@ def r(dispatch_func):
346346
@register(function.jump)
347347
@register(function.kronecker)
348348
@register(function.linearize)
349+
@register(function.swap_spaces)
349350
@register(function.opposite)
350351
@register(function.replace_arguments)
351352
@register(function.scatter)

nutils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
'Numerical Utilities for Finite Element Analysis'
22

3-
__version__ = version = '9a62'
3+
__version__ = version = '9a63'
44
version_name = 'jook-sing'

nutils/function.py

Lines changed: 231 additions & 64 deletions
Large diffs are not rendered by default.

nutils/sample.py

Lines changed: 57 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ def __init__(self, spaces: Tuple[str, ...], ndims: int, nelems: int, npoints: in
9191
The number of points.
9292
'''
9393

94+
for i, space in enumerate(spaces):
95+
if space in spaces[i+1:]:
96+
raise ValueError(f'All spaces in a `Sample` must be unique, but space {space} is repeated.')
9497
self.spaces = spaces
9598
self.ndims = ndims
9699
self.nelems = nelems
@@ -479,6 +482,21 @@ def zip(*samples: 'Sample') -> 'Sample':
479482

480483
return _Zip(*samples)
481484

485+
def rename_spaces(self, map: Mapping[str, str], /) -> 'Sample':
486+
'''Return a :class:`Sample` with spaces renamed according to ``map``.
487+
488+
Args
489+
----
490+
map : mapping of :class:`str` to :class:`str`
491+
A mapping of old space to new space.
492+
493+
Returns
494+
-------
495+
renamed : :class:`Sample`
496+
'''
497+
498+
raise NotImplementedError
499+
482500

483501
class _TransformChainsSample(Sample):
484502

@@ -554,6 +572,9 @@ def get_evaluable_indices(self, ielem: evaluable.Array) -> evaluable.Array:
554572
def _bind(self, func: function.Array) -> function.Array:
555573
return _ConcatenatePoints(func, self)
556574

575+
def rename_spaces(self, map: Mapping[str, str], /) -> Sample:
576+
return _DefaultIndex(map.get(self.space, self.space), self.transforms, self.points)
577+
557578

558579
class _CustomIndex(_TransformChainsSample):
559580

@@ -578,6 +599,9 @@ def tri(self) -> numpy.ndarray:
578599
def hull(self) -> numpy.ndarray:
579600
return numpy.take(self._index, self._parent.hull)
580601

602+
def rename_spaces(self, map: Mapping[str, str], /) -> Sample:
603+
return _CustomIndex(self._parent.rename_spaces(map), self._index)
604+
581605

582606
if os.environ.get('NUTILS_TENSORIAL', None) == 'test': # pragma: nocover
583607

@@ -627,7 +651,7 @@ def get_evaluable_weights(self, __ielem: evaluable.Array) -> evaluable.Array:
627651
return evaluable.Zeros((evaluable.constant(0),) * len(self.spaces), dtype=float)
628652

629653
def get_lower_args(self, __ielem: evaluable.Array) -> function.LowerArgs:
630-
return function.LowerArgs((), {}, {})
654+
return function.LowerArgs.empty()
631655

632656
def get_element_tri(self, ielem: int) -> numpy.ndarray:
633657
raise IndexError('index out of range')
@@ -647,6 +671,9 @@ def _bind(self, func: function.Array) -> function.Array:
647671
def basis(self, interpolation: str = 'none') -> function.Array:
648672
return function.zeros((0,), float)
649673

674+
def rename_spaces(self, map: Mapping[str, str], /) -> Sample:
675+
return _Empty(tuple(map.get(space, space) for space in self.spaces), self.ndims)
676+
650677

651678
class _Add(_TensorialSample):
652679

@@ -694,6 +721,9 @@ def _integral(self, func: function.Array) -> function.Array:
694721
def _bind(self, func: function.Array) -> function.Array:
695722
return numpy.concatenate([self._sample1._bind(func), self._sample2._bind(func)])
696723

724+
def rename_spaces(self, map: Mapping[str, str], /) -> Sample:
725+
return _Add(self._sample1.rename_spaces(map), self._sample2.rename_spaces(map))
726+
697727

698728
def _simplex_strip(strip):
699729
# Helper function that creates simplices for an extruded simplex, with
@@ -807,7 +837,7 @@ def get_evaluable_weights(self, __ielem: evaluable.Array) -> evaluable.Array:
807837

808838
def get_lower_args(self, __ielem: evaluable.Array) -> function.LowerArgs:
809839
ielem1, ielem2 = evaluable.divmod(__ielem, self._sample2.nelems)
810-
return self._sample1.get_lower_args(ielem1) | self._sample2.get_lower_args(ielem2)
840+
return self._sample1.get_lower_args(ielem1) * self._sample2.get_lower_args(ielem2)
811841

812842
@property
813843
def _reversed_factors(self):
@@ -900,6 +930,9 @@ def basis(self, interpolation: str = 'none') -> Sample:
900930
assert basis1.ndim == basis2.ndim == 1
901931
return numpy.ravel(basis1[:, None] * basis2[None, :])
902932

933+
def rename_spaces(self, map: Mapping[str, str], /) -> Sample:
934+
return _Mul(self._sample1.rename_spaces(map), self._sample2.rename_spaces(map))
935+
903936

904937
class _Zip(Sample):
905938

@@ -940,15 +973,16 @@ def _getslice(self, ielem):
940973

941974
def get_lower_args(self, __ielem: evaluable.Array) -> function.LowerArgs:
942975
points_shape = evaluable.Take(evaluable.Constant(self._sizes), __ielem),
943-
coordinates = {}
944-
transform_chains = {}
976+
args = function.LowerArgs.empty(points_shape)
945977
for samplei, ielemsi, ilocalsi in zip(self._samples, self._ielems, self._ilocals):
946-
argsi = samplei.get_lower_args(evaluable.Take(evaluable.Constant(ielemsi), __ielem))
947978
slicei = evaluable.Take(evaluable.Constant(ilocalsi), self._getslice(__ielem))
948-
transform_chains.update(argsi.transform_chains)
949-
for space, coords in argsi.coordinates.items():
950-
coordinates[space] = evaluable.Transpose.to_end(evaluable.Take(evaluable._flat(evaluable.Transpose.from_end(coords, 0), ndim=2), slicei), 0)
951-
return function.LowerArgs(points_shape, transform_chains, coordinates)
979+
args += samplei \
980+
.get_lower_args(evaluable.Take(evaluable.Constant(ielemsi), __ielem)) \
981+
.map_coordinates(
982+
points_shape,
983+
lambda coords: evaluable.Transpose.to_end(evaluable.Take(evaluable._flat(evaluable.Transpose.from_end(coords, 0), ndim=2), slicei), 0),
984+
)
985+
return args
952986

953987
def get_evaluable_indices(self, ielem):
954988
return evaluable.Take(evaluable.Constant(self._indices), self._getslice(ielem))
@@ -959,6 +993,9 @@ def get_evaluable_weights(self, ielem):
959993
weights = self._samples[0].get_evaluable_weights(ielem0)
960994
return evaluable._take(evaluable._flat(weights), slice0, axis=0)
961995

996+
def rename_spaces(self, map: Mapping[str, str], /) -> Sample:
997+
return _Zip(*[smpl.rename_spaces(map) for smpl in self._samples])
998+
962999

9631000
class _TakeElements(_TensorialSample):
9641001

@@ -1014,6 +1051,9 @@ def get_element_hull(self, __ielem: int) -> numpy.ndarray:
10141051
def take_elements(self, __indices: numpy.ndarray) -> Sample:
10151052
return self._parent.take_elements(numpy.take(self._indices, __indices))
10161053

1054+
def rename_spaces(self, map: Mapping[str, str], /) -> Sample:
1055+
return _TakeElements(self._parent.rename_spaces(map), self._indices)
1056+
10171057

10181058
class _Integral(function.Array):
10191059

@@ -1023,9 +1063,9 @@ def __init__(self, integrand: function.Array, sample: Sample) -> None:
10231063
super().__init__(shape=integrand.shape, dtype=float if integrand.dtype in (bool, int) else integrand.dtype, spaces=integrand.spaces - frozenset(sample.spaces), arguments=integrand.arguments)
10241064

10251065
def lower(self, args: function.LowerArgs) -> evaluable.Array:
1026-
ielem = evaluable.loop_index('_sample_' + '_'.join(self._sample.spaces), self._sample.nelems)
1066+
ielem = evaluable.loop_index(f'_sample_{len(args.args)}', self._sample.nelems)
10271067
weights = evaluable.astype(self._sample.get_evaluable_weights(ielem), self.dtype)
1028-
integrand = evaluable.astype(self._integrand.lower(args | self._sample.get_lower_args(ielem)), self.dtype)
1068+
integrand = evaluable.astype(self._integrand.lower(args * self._sample.get_lower_args(ielem)), self.dtype)
10291069
elem_integral = evaluable.einsum('B,ABC->AC', weights, integrand, B=weights.ndim, C=self.ndim)
10301070
return evaluable.loop_sum(elem_integral, ielem)
10311071

@@ -1039,8 +1079,8 @@ def __init__(self, func: function.Array, sample: _TransformChainsSample) -> None
10391079

10401080
def lower(self, args: function.LowerArgs) -> evaluable.Array:
10411081
axis = len(args.points_shape)
1042-
ielem = evaluable.loop_index('_sample_' + '_'.join(self._sample.spaces), self._sample.nelems)
1043-
args |= self._sample.get_lower_args(ielem)
1082+
ielem = evaluable.loop_index(f'_sample_{len(args.args)}', self._sample.nelems)
1083+
args *= self._sample.get_lower_args(ielem)
10441084
func = self._func.lower(args)
10451085
func = evaluable.Transpose.to_end(func, *range(axis, len(args.points_shape)))
10461086
for i in range(len(args.points_shape) - axis - 1):
@@ -1073,7 +1113,8 @@ def __init__(self, sample: _TransformChainsSample, interpolation: str) -> None:
10731113
super().__init__(shape=(sample.npoints,), dtype=float, spaces=frozenset({sample.space}), arguments={})
10741114

10751115
def lower(self, args: function.LowerArgs) -> evaluable.Array:
1076-
aligned_space_coords = args.coordinates[self._sample.space]
1116+
arg = args[self._sample.space]
1117+
aligned_space_coords = arg.coordinates
10771118
assert aligned_space_coords.ndim == len(args.points_shape) + 1
10781119
space_coords, where = evaluable.unalign(aligned_space_coords)
10791120
# Reinsert the coordinate axis, the last axis of `aligned_space_coords`, or
@@ -1085,9 +1126,8 @@ def lower(self, args: function.LowerArgs) -> evaluable.Array:
10851126
space_coords = evaluable.Transpose(space_coords, numpy.argsort(where))
10861127
where = tuple(sorted(where))
10871128

1088-
(chain, *_), tip_index = args.transform_chains[self._sample.space]
1089-
index = evaluable.TransformIndex(self._sample.transforms[0], chain, tip_index)
1090-
coords = evaluable.TransformCoords(self._sample.transforms[0], chain, tip_index, space_coords)
1129+
index = evaluable.TransformIndex(self._sample.transforms[0], arg.transforms, arg.index)
1130+
coords = evaluable.TransformCoords(self._sample.transforms[0], arg.transforms, arg.index, space_coords)
10911131
expect = self._sample.points.get_evaluable_coords(index)
10921132
sampled = evaluable.Sampled(coords, expect, self._interpolation)
10931133
indices = self._sample.get_evaluable_indices(index)

nutils/topology.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1388,7 +1388,7 @@ def _lower_args(self, ielem, point):
13881388
ielem1, ielem2 = evaluable.divmod(ielem, len(self.topo2))
13891389
largs1 = self.topo1._lower_args(ielem1, point[:self.topo1.ndims])
13901390
largs2 = self.topo2._lower_args(ielem2, point[self.topo1.ndims:])
1391-
return largs1 | largs2
1391+
return largs1 * largs2
13921392

13931393
def _sample(self, ielems, coords, weights=None):
13941394
ielems1, ielems2 = divmod(ielems, len(self.topo2))

tests/test_function.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -418,15 +418,16 @@ def _check(name, op, n_op, *args):
418418
class Unlower(TestCase):
419419

420420
def test(self):
421-
e = evaluable.Argument('arg', tuple(map(evaluable.constant, (2, 3, 4, 5))), int)
421+
shape = tuple(map(evaluable.constant, (2, 3, 4, 5)))
422+
e = evaluable.Argument('arg', shape, int)
422423
arguments = types.frozendict({'arg': ((2, 3), int)})
423-
f = function._Unlower(e, frozenset(), arguments, function.LowerArgs((2, 3), {}, {}))
424+
f = function._Unlower(e, frozenset(), arguments, function.LowerArgs.empty(shape[:2]))
424425
self.assertEqual(f.shape, (4, 5))
425426
self.assertEqual(f.dtype, int)
426427
self.assertEqual(f.arguments, arguments)
427-
self.assertEqual(f.lower(function.LowerArgs((2, 3), {}, {})), e)
428+
self.assertEqual(f.lower(function.LowerArgs.empty(shape[:2])), e)
428429
with self.assertRaises(ValueError):
429-
f.lower(function.LowerArgs((3, 4), {}, {}))
430+
f.lower(function.LowerArgs.empty(shape[1:3]))
430431

431432

432433
class Custom(TestCase):
@@ -435,10 +436,10 @@ def assertEvalAlmostEqual(self, factual, fdesired, **args):
435436
with self.subTest('0d-points'):
436437
self.assertAllAlmostEqual(evaluable.eval_once(factual.as_evaluable_array, arguments=args), evaluable.eval_once(fdesired.as_evaluable_array, arguments=args))
437438
with self.subTest('1d-points'):
438-
lower_args = function.LowerArgs((evaluable.asarray(5),), {}, {})
439+
lower_args = function.LowerArgs.empty((evaluable.asarray(5),))
439440
self.assertAllAlmostEqual(evaluable.eval_once(factual.lower(lower_args), arguments=args), evaluable.eval_once(fdesired.lower(lower_args), arguments=args))
440441
with self.subTest('2d-points'):
441-
lower_args = function.LowerArgs((evaluable.asarray(5), evaluable.asarray(6)), {}, {})
442+
lower_args = function.LowerArgs.empty((evaluable.asarray(5), evaluable.asarray(6)))
442443
self.assertAllAlmostEqual(evaluable.eval_once(factual.lower(lower_args), arguments=args), evaluable.eval_once(fdesired.lower(lower_args), arguments=args))
443444

444445
def assertMultipy(self, leftval, rightval):
@@ -1573,7 +1574,7 @@ def test_constant(self):
15731574

15741575
def test_lower_spaces(self):
15751576
topo, geom = mesh.rectilinear([3])
1576-
with self.assertRaisesRegex(ValueError, r'cannot lower function with spaces \(.+\) - did you forget integral or sample?'):
1577+
with self.assertRaisesRegex(KeyError, r'no such space: .* - did you forget integral or sample?'):
15771578
function.factor(geom)
15781579

15791580

@@ -1602,3 +1603,21 @@ def test_conflict(self):
16021603
g = x**2 * y2
16031604
with self.assertRaisesRegex(ValueError, "inconsistent shapes for argument 'y'"):
16041605
function.arguments_for(f, g)
1606+
1607+
1608+
class swap_spaces(TestCase):
1609+
1610+
def test_different(self):
1611+
X, x = mesh.line(2, space='X')
1612+
self.assertEqual(X.f_index.spaces, frozenset('X'))
1613+
f = function.swap_spaces(X.f_index, 'X', 'Y')
1614+
self.assertEqual(f.spaces, frozenset('Y'))
1615+
with self.assertRaisesRegex(KeyError, 'no such space'):
1616+
X.sample('gauss', 0).eval(f)
1617+
self.assertEqual(X.sample('gauss', 0).rename_spaces({'X': 'Y'}).eval(f).tolist(), [0, 1])
1618+
1619+
def test_same(self):
1620+
X, x = mesh.line(2, space='X')
1621+
f = function.swap_spaces(X.f_index, 'X', 'X')
1622+
self.assertEqual(f.spaces, frozenset('X'))
1623+
self.assertEqual(X.sample('gauss', 0).eval(f).tolist(), [0, 1])

tests/test_sample.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,16 +63,14 @@ def test_get_lower_args(self):
6363
self.assertEqual(tuple(n.__index__() for n in actual_shape(dict(ielem=ielem))), desired_shape)
6464
offset = 0
6565
for space, desired_chain, desired_point in zip(self.desired_spaces, desired_chains, desired_points):
66-
(chain, *_), index = args.transform_chains[space]
67-
self.assertEqual(chain[evaluable.eval_once(index, arguments=dict(ielem=ielem)).__index__()], desired_chain)
66+
arg = args[space]
67+
self.assertEqual(arg.transforms[evaluable.eval_once(arg.index, arguments=dict(ielem=ielem)).__index__()], desired_chain)
6868
desired_coords = desired_point.coords
6969
desired_coords = numpy.lib.stride_tricks.as_strided(desired_coords, shape=(*desired_shape, desired_point.ndims,), strides=(0,)*offset+desired_coords.strides[:-1]+(0,)*(len(args.points_shape)-offset-desired_coords.ndim+1)+desired_coords.strides[-1:])
70-
actual_coords = evaluable.eval_once(args.coordinates[space], arguments=dict(ielem=ielem))
70+
actual_coords = evaluable.eval_once(arg.coordinates, arguments=dict(ielem=ielem))
7171
self.assertEqual(actual_coords.shape, desired_coords.shape)
7272
self.assertAllAlmostEqual(actual_coords, desired_coords)
7373
offset += desired_point.coords.ndim - 1
74-
with self.assertRaisesRegex(ValueError, '^Nested'):
75-
args | self.sample.get_lower_args(evaluable.InRange(evaluable.Argument('ielem2', (), int), evaluable.constant(self.desired_nelems)))
7674

7775
@property
7876
def _desired_element_tri(self):
@@ -128,8 +126,8 @@ def test_take_elements_single(self):
128126
self.assertEqual(take.ndims, self.desired_ndims)
129127
args = take.get_lower_args(evaluable.InRange(evaluable.Argument('ielem', (), int), evaluable.constant(1)))
130128
for space, desired_chain in zip(self.desired_spaces, self.desired_transform_chains[ielem]):
131-
(chain, *_), index = args.transform_chains[space]
132-
self.assertEqual(chain[evaluable.eval_once(index, arguments=dict(ielem=0)).__index__()], desired_chain)
129+
arg = args[space]
130+
self.assertEqual(arg.transforms[evaluable.eval_once(arg.index, arguments=dict(ielem=0)).__index__()], desired_chain)
133131

134132
def test_take_elements_empty(self):
135133
take = self.sample.take_elements(numpy.array([], int))
@@ -148,6 +146,11 @@ def test_asfunction(self):
148146
func = self.sample.asfunction(numpy.arange(self.sample.npoints))
149147
self.assertEqual(self.sample.bind(func).eval().tolist(), numpy.arange(self.desired_npoints).tolist())
150148

149+
def test_rename_spaces(self):
150+
assert not any(space.endswith('!') for space in self.desired_spaces)
151+
renamed = self.sample.rename_spaces({space: space + '!' for space in self.desired_spaces})
152+
self.assertEqual(renamed.spaces, tuple(space + '!' for space in self.desired_spaces))
153+
151154

152155
class Empty(TestCase, Common):
153156

@@ -280,8 +283,7 @@ def test_integrate(self):
280283
self.assertAlmostEqual(self.stitched.integrate(function.J(self.geomX)), 5/9) # NOTE: != norm(slope)
281284

282285
def test_nested(self):
283-
with self.assertRaisesRegex(ValueError, 'Nested integrals or samples in the same space: X.*, Y.'):
284-
self.stitched.integral(self.stitched.integral(1)).eval()
286+
self.stitched.integral(self.stitched.integral(1)).eval()
285287
topoZ, geomZ = mesh.line(2, space='Z')
286288
inner = self.stitched.integral((geomZ - self.geomX) * function.J(self.geomY))
287289
outer = topoZ.integral(inner * function.J(geomZ), degree=2)
@@ -295,6 +297,10 @@ def test_triplet(self):
295297
self.assertAllAlmostEqual(geomX, geomY[:, numpy.newaxis] * self.slope)
296298
self.assertAllAlmostEqual(geomY, geomZ / 3)
297299

300+
def test_rename_spaces(self):
301+
renamed = self.stitched.rename_spaces({'X': 'B', 'X0': 'B', 'X1': 'C', 'Y': 'A'})
302+
self.assertEqual(renamed.spaces, ('A', 'B') if 'X' in self.sampleX.spaces else ('A', 'B', 'C'))
303+
298304

299305
Zip(etype='square')
300306
Zip(etype='triangle')
@@ -438,6 +444,10 @@ def test_basis(self):
438444

439445
class Special(TestCase):
440446

447+
def test_init_repeated_spaces(self):
448+
with self.assertRaisesRegex(ValueError, 'space a is repeated.$'):
449+
Sample(('a', 'b', 'a'), 1, 1, 1)
450+
441451
def test_add_different_spaces(self):
442452
class Dummy(Sample):
443453
pass
@@ -576,3 +586,12 @@ def test_empty(self):
576586
empty = function.zeros(shape, float)
577587
array = empty.eval(legacy=False)
578588
self.assertAllEqual(array, numpy.zeros((2, 3)))
589+
590+
def test_reuse_space(self):
591+
X0, x0 = mesh.line(2, space='X')
592+
X1, x1 = mesh.line(3, space='X')
593+
self.assertAllAlmostEqual(
594+
function.eval(X0.integral(X1.integral(function.J(x1), degree=3) * function.J(x0), degree=2)),
595+
6,
596+
places=14,
597+
)

0 commit comments

Comments
 (0)