/
contract.py
637 lines (536 loc) · 25 KB
/
contract.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
from __future__ import absolute_import, division, print_function
import itertools
from abc import ABCMeta, abstractmethod
from collections import OrderedDict, defaultdict
import opt_einsum
import torch
from opt_einsum import shared_intermediates
from six import add_metaclass
from six.moves import map
from pyro.distributions.util import broadcast_shape
from pyro.ops.einsum import contract
from pyro.ops.sumproduct import logsumproductexp
@add_metaclass(ABCMeta)
class TensorRing(object):
"""
Abstract tensor ring class.
Each tensor ring class has a notion of ``dims`` that can be sum-contracted
out, and a notion of ``ordinal`` that represents a set of batch dimensions
that can be broadcasted-up or product-contracted out.
Implementations should cache intermediate results to be compatible with
:func:`~opt_einsum.shared_intermediates`.
:param dict cache: an optional :func:`~opt_einsum.shared_intermediates`
cache.
"""
def __init__(self, cache=None):
self._cache = {} if cache is None else cache
def _save_tensor(self, tensor):
"""
Saves a tensor in the cache so that ``id(tensor)`` can be used as a
key in the cache without risk if the id being recycled.
"""
self._cache['tensor', id(tensor)] = tensor
@abstractmethod
def dims(self, term):
"""
Returns an iterable of nontrivial dims associted with this term.
Derived classes may use any hashable type for dims.
"""
raise NotImplementedError
@abstractmethod
def sumproduct(self, terms, dims):
"""
Multiply all ``terms`` together, then sum-contract out all ``dims``
from the result.
:param list terms: a list of tensors
:dims: an iterable of dims
"""
raise NotImplementedError
@abstractmethod
def product(self, term, ordinal):
"""
Product-contract the given ``term`` along any batch dimensions
present in given ``ordinal``.
:param torch.Tensor term: the term to contract
:param frozenset ordinal: an ordinal specifying batch context
"""
raise NotImplementedError
@abstractmethod
def broadcast(self, tensor, ordinal):
"""
Broadcast the given ``term`` by expanding along any batch dimensions
present in ``ordinal`` but not ``term``.
:param torch.Tensor term: the term to expand
:param frozenset ordinal: an ordinal specifying batch context
"""
raise NotImplementedError
class UnpackedLogRing(TensorRing):
"""
Tensor Ring defined by high-dimensional unpacked tensors in log space.
Tensor values are in log units, so ``sum`` is implemented as ``logsumexp``,
and ``product`` is implemented as ``sum``.
Tensor shapes are typically wide with only a few nontrivial dimensions::
torch.Size((7, 1, 1, 1, 1, 1, 3, 1, 1, 2))
Dims are negative integers indexing into tensors shapes from the right.
Ordinals are frozensets of ``CondIndepStackFrame``s.
"""
def dims(self, term):
return [d for d in range(-term.dim(), 0) if term.size(d) > 1]
def sumproduct(self, terms, dims):
key = 'sumproduct', frozenset(id(x) for x in terms), frozenset(dims)
if key in self._cache:
return self._cache[key]
if dims:
assert all(dim < 0 for dim in dims)
shape = list(broadcast_shape(*set(x.shape for x in terms)))
for dim in dims:
shape[dim] = 1
term = logsumproductexp(terms, tuple(shape))
else:
term = sum(terms)
# Aggressively squeeze to improve sharing.
while term.dim() and term.size(0) == 1:
term = term.squeeze(0)
self._save_tensor(term)
self._cache[key] = term
return term
def product(self, term, ordinal):
for frame in sorted(ordinal, key=lambda f: -f.dim):
if -frame.dim <= term.dim() and term.size(frame.dim) != 1:
key = 'product', id(term), frame.dim
if key in self._cache:
term = self._cache[key]
else:
self._save_tensor(term)
term = term.sum(frame.dim, keepdim=True)
self._cache[key] = term
return term
def broadcast(self, term, ordinal):
shape = list(term.shape)
for frame in ordinal:
shape = [1] * (-frame.dim - len(shape)) + shape
shape[frame.dim] = frame.size
shape = torch.Size(shape)
if term.shape == shape:
return term
key = 'broadcast', id(term), shape
if key in self._cache:
return self._cache[key]
self._save_tensor(term)
term = term.expand(shape)
self._cache[key] = term
return term
class PackedLogRing(TensorRing):
"""
Tensor Ring of packed tensors with named dimensions in log space.
Tensor values are in log units, so ``sum`` is implemented as ``logsumexp``,
and ``product`` is implemented as ``sum``.
Tensor dimensions are packed; to read the name of a tensor, call
:meth:`dims`, which returns a string of dimension names aligned with the
tensor's shape.
Dims are characters (string or unicode).
Ordinals are frozensets of characters.
"""
def __init__(self, inputs, operands, cache=None):
super(PackedLogRing, self).__init__(cache=cache)
self._batch_size = {}
for dims, term in zip(inputs, operands):
self._save_tensor(term)
self._cache['dims', id(term)] = dims
for dim, size in zip(dims, term.shape):
old = self._batch_size.setdefault(dim, size)
if old != size:
raise ValueError(u"Dimension size mismatch at dim '{}': {} vs {}"
.format(dim, size, old))
def dims(self, term):
return self._cache['dims', id(term)]
def sumproduct(self, terms, dims):
inputs = [self.dims(term) for term in terms]
output = ''.join(sorted(set(''.join(inputs)) - set(dims)))
equation = ','.join(inputs) + '->' + output
term = contract(equation, *terms, backend='pyro.ops.einsum.torch_log')
self._save_tensor(term)
self._cache['dims', id(term)] = output
return term
def product(self, term, ordinal):
dims = self.dims(term)
for dim in sorted(ordinal, reverse=True):
pos = dims.find(dim)
if pos != -1:
key = 'product', id(term), dim
if key in self._cache:
term = self._cache[key]
else:
self._save_tensor(term)
term = term.sum(pos)
dims = dims.replace(dim, '')
self._cache[key] = term
self._cache['dims', id(term)] = dims
return term
def broadcast(self, term, ordinal):
dims = self.dims(term)
missing_dims = ''.join(sorted(set(ordinal) - set(dims)))
if missing_dims:
key = 'broadcast', id(term), missing_dims
if key in self._cache:
term = self._cache[key]
else:
missing_shape = tuple(self._batch_size[dim] for dim in missing_dims)
term = term.expand(missing_shape + term.shape)
dims = missing_dims + dims
self._cache[key] = term
self._cache['dims', id(term)] = dims
return term
def _partition_terms(ring, terms, dims):
"""
Given a list of terms and a set of contraction dims, partitions the terms
up into sets that must be contracted together. By separating these
components we avoid broadcasting. This function should be deterministic.
This function should be deterministic and free of side effects.
"""
# Construct a bipartite graph between terms and the dims in which they
# are enumerated. This conflates terms and dims (tensors and ints).
neighbors = OrderedDict([(t, []) for t in terms] + [(d, []) for d in sorted(dims)])
for term in terms:
for dim in ring.dims(term):
if dim in dims:
neighbors[term].append(dim)
neighbors[dim].append(term)
# Partition the bipartite graph into connected components for contraction.
while neighbors:
v, pending = neighbors.popitem()
component = OrderedDict([(v, None)]) # used as an OrderedSet
for v in pending:
component[v] = None
while pending:
v = pending.pop()
for v in neighbors.pop(v):
if v not in component:
component[v] = None
pending.append(v)
# Split this connected component into tensors and dims.
component_terms = [v for v in component if isinstance(v, torch.Tensor)]
if component_terms:
component_dims = set(v for v in component if not isinstance(v, torch.Tensor))
yield component_terms, component_dims
def _contract_component(ring, tensor_tree, sum_dims):
"""
Contract out ``sum_dims`` in a tree of tensors in-place, via message
passing. This reduces all tensors down to a single tensor in the greatest
lower bound iarange context.
This function should be deterministic.
This function has side-effects: it modifies ``tensor_tree`` and
``sum_dims`` in-place.
:param TensorRing ring: an algebraic ring defining tensor operations.
:param OrderedDict tensor_tree: a dictionary mapping ordinals to lists of
tensors. An ordinal is a frozenset of ``CondIndepStack`` frames.
:param dict sum_dims: a dictionary mapping tensors to sets of dimensions
(indexed from the right) that should be summed out.
"""
# First close the set of ordinals under intersection (greatest lower bound),
# ensuring that the ordinals are arranged in a tree structure.
target_ordinal = frozenset.intersection(*tensor_tree)
tensor_tree.setdefault(target_ordinal, [])
pending = list(tensor_tree)
while pending:
t = pending.pop()
for u in list(tensor_tree):
tu = t & u
if tu not in tensor_tree:
tensor_tree[tu] = []
pending.append(tu)
# Collect contraction dimensions by ordinal.
dim_to_ordinal = {}
for t, terms in tensor_tree.items():
for term in terms:
for dim in sum_dims[term]:
dim_to_ordinal[dim] = dim_to_ordinal.get(dim, t) & t
dims_tree = defaultdict(set)
for dim, t in dim_to_ordinal.items():
dims_tree[t].add(dim)
# Recursively combine terms in different iarange contexts.
while any(dims_tree.values()):
leaf = max(tensor_tree, key=len)
leaf_terms = tensor_tree.pop(leaf)
leaf_dims = dims_tree.pop(leaf, set())
# Split terms at the current ordinal into connected components.
for terms, dims in _partition_terms(ring, leaf_terms, leaf_dims):
# Eliminate any enumeration dims via a sumproduct contraction.
term = ring.sumproduct(terms, dims)
remaining_dims = set.union(*map(sum_dims.pop, terms)) - dims
# Eliminate extra iarange dims via product contractions.
if leaf == target_ordinal:
parent = leaf
else:
parent = frozenset.union(*(t for t, d in dims_tree.items() if d & remaining_dims))
if parent == leaf:
raise NotImplementedError(
"Expected tree-structured iarange nesting, but found "
"dependencies on independent iaranges [{}]. "
"Try converting one of the iaranges to an irange (but beware "
"exponential cost in the size of that irange)"
.format(', '.join(getattr(f, 'name', str(f)) for f in leaf)))
contract_frames = leaf - parent
term = ring.product(term, contract_frames)
tensor_tree.setdefault(parent, []).append(term)
sum_dims[term] = remaining_dims
def contract_tensor_tree(tensor_tree, sum_dims, ring=None, cache=None):
"""
Contract out ``sum_dims`` in a tree of tensors via message passing.
This function should be deterministic and free of side effects.
:param OrderedDict tensor_tree: a dictionary mapping ordinals to lists of
tensors. An ordinal is a frozenset of ``CondIndepStack`` frames.
:param dict sum_dims: a dictionary mapping tensors to sets of dimensions
(indexed from the right) that should be summed out.
:param TensorRing ring: an algebraic ring defining tensor operations.
:param dict cache: an optional :func:`~opt_einsum.shared_intermediates`
cache.
:returns: A contracted version of ``tensor_tree``
:rtype: OrderedDict
"""
if ring is None:
ring = UnpackedLogRing(cache=cache)
assert isinstance(tensor_tree, OrderedDict)
assert isinstance(sum_dims, dict)
assert isinstance(ring, TensorRing)
ordinals = {term: t for t, terms in tensor_tree.items() for term in terms}
all_terms = [term for terms in tensor_tree.values() for term in terms]
all_dims = set.union(*sum_dims.values())
contracted_tree = OrderedDict()
# Split this tensor tree into connected components.
for terms, dims in _partition_terms(ring, all_terms, all_dims):
component = OrderedDict()
component_dims = {}
for term in terms:
component.setdefault(ordinals[term], []).append(term)
component_dims[term] = sum_dims[term]
# Contract this connected component down to a single tensor.
_contract_component(ring, component, component_dims)
assert len(component) == 1
t, terms = component.popitem()
assert len(terms) == 1
contracted_tree.setdefault(t, []).extend(terms)
return contracted_tree
def contract_to_tensor(tensor_tree, sum_dims, target_ordinal, ring=None, cache=None):
"""
Contract out ``sum_dims`` in a tree of tensors, via message
passing. This reduces all terms down to a single tensor in the iarange
context specified by ``target_ordinal``.
This function should be deterministic and free of side effects.
:param OrderedDict tensor_tree: a dictionary mapping ordinals to lists of
tensors. An ordinal is a frozenset of ``CondIndepStack`` frames.
:param dict sum_dims: a dictionary mapping tensors to sets of dimensions
(indexed from the right) that should be summed out.
:param frozendset target_ordinal: An optional ordinal to which results will
be contracted or broadcasted.
:param TensorRing ring: an algebraic ring defining tensor operations.
:param dict cache: an optional :func:`~opt_einsum.shared_intermediates`
cache.
:returns: a single tensor
:rtype: torch.Tensor
"""
if ring is None:
ring = UnpackedLogRing(cache=cache)
assert isinstance(tensor_tree, OrderedDict)
assert isinstance(sum_dims, dict)
assert isinstance(target_ordinal, frozenset)
assert isinstance(ring, TensorRing)
# Contract out all sum dims via sumproduct contractions.
tensor_tree = contract_tensor_tree(tensor_tree, sum_dims, ring=ring)
# Eliminate extra iarange dims via product contractions.
lower_terms = []
lower_ordinal = frozenset()
for ordinal, terms in tensor_tree.items():
contract_frames = ordinal - target_ordinal
if contract_frames:
ordinal = ordinal & target_ordinal
terms = [ring.product(term, contract_frames) for term in terms]
lower_terms.extend(terms)
lower_ordinal = lower_ordinal | ordinal
assert lower_ordinal <= target_ordinal
# Combine and broadcast terms.
lower_term = ring.sumproduct(lower_terms, set())
return ring.broadcast(lower_term, target_ordinal)
def ubersum(equation, *operands, **kwargs):
"""
Generalized batched sum-product algorithm via tensor message passing.
This generalizes :func:`~pyro.ops.einsum.contract` in two ways:
1. Multiple outputs are allowed, and intermediate results can be shared.
2. Inputs and outputs can be batched along symbols given in ``batch_dims``;
reductions along ``batch_dims`` are product reductions.
The best way to understand this function is to try the examples below,
which show how :func:`ubersum` calls can be implemented as multiple calls
to :func:`~pyro.ops.einsum.contract` (which is generally more expensive).
To illustrate multiple outputs, note that the following are equivalent::
z1, z2, z3 = ubersum('ab,bc->a,b,c', x, y) # multiple outputs
backend = 'pyro.ops.einsum.torch_log'
z1 = contract('ab,bc->a', x, y, backend=backend)
z2 = contract('ab,bc->b', x, y, backend=backend)
z3 = contract('ab,bc->c', x, y, backend=backend)
To illustrate batched inputs, note that the following are equivalent::
assert len(x) == 3 and len(y) == 3
z = ubersum('c,abc,acd->bd', w, x, y, batch_dims='a')
z = contract('c,bc,bc,bc,cd,cd,cd->bd', w, *x, *y, backend=backend)
When a non-batch dimension `i` always appears with a batch dimension `a`,
then `i` corresponds to a distinct symbol for each slice of `a`. Thus
the following are equivalent::
assert len(x) == 3 and len(y) == 3
z = ubersum('abi,abi->', x, y, batch_dims='a')
z = contract('bi,bj,bk,bi,bj,bk->', *x, *y, backend=backend)
When such a non-batched dimension appears in the output, it must be
accompanied by all of its batch dimensions, e.g. the following are
equivalent::
assert len(x) == 3 and len(y) == 3
z = ubersum('abi,abi->ai', x, y, batch_dims='a')
z0 = contract('bi,bj,bk,bi,bj,bk->i', *x, *y, backend=backend)
z1 = contract('bi,bj,bk,bi,bj,bk->j', *x, *y, backend=backend)
z2 = contract('bi,bj,bk,bi,bj,bk->k', *x, *y, backend=backend)
z = torch.stack([z0, z1, z2])
Among all valid inputs, some computations are polynomial in the sizes of
the input tensors and other computations are exponential in the sizes of
the input tensors. This function raises :py:class:`NotImplementedError`
whenever the computation is exponential.
:param str equation: an einsum equation, optionally with multiple outputs.
:param torch.Tensor operands: a collection of tensors
:param str batch_dims: an optional string of batch dims.
:param dict cache: an optional :func:`~opt_einsum.shared_intermediates`
cache.
:return: a tuple of tensors of requested shape, one entry per output.
:rtype: tuple
:raises ValueError: if tensor sizes mismatch or an output requests a
batched dim without that dim's batch dims.
:raises NotImplementedError: if contraction would have cost exponential in
the size of any input tensor.
"""
# Extract kwargs.
cache = kwargs.pop('cache', None)
batch_dims = kwargs.pop('batch_dims', '')
backend = kwargs.pop('backend', 'pyro.ops.einsum.torch_log')
if backend != 'pyro.ops.einsum.torch_log':
raise NotImplementedError
# Parse generalized einsum equation.
if '.' in equation:
raise NotImplementedError('ubsersum does not yet support ellipsis notation')
inputs, outputs = equation.split('->')
inputs = inputs.split(',')
outputs = outputs.split(',')
assert len(inputs) == len(operands)
assert all(isinstance(x, torch.Tensor) for x in operands)
if len(operands) != len(set(operands)):
operands = [x[...] for x in operands] # ensure tensors are unique
# Construct a tensor tree shared by all outputs.
tensor_tree = OrderedDict()
max_ordinal = frozenset(batch_dims)
for dims, term in zip(inputs, operands):
assert len(dims) == term.dim()
ordinal = frozenset(dims) & max_ordinal
tensor_tree.setdefault(ordinal, []).append(term)
# Compute outputs, sharing intermediate computations.
results = []
with shared_intermediates(cache) as cache:
ring = PackedLogRing(inputs, operands, cache=cache)
for output in outputs:
nosum_dims = set(batch_dims + output)
sum_dims = {term: set(dims) - nosum_dims for dims, term in zip(inputs, operands)}
target_ordinal = frozenset(output) & max_ordinal
term = contract_to_tensor(tensor_tree, sum_dims, target_ordinal, ring=ring)
dims = ring.dims(term)
if dims != output:
term = term.permute(*map(dims.index, output))
results.append(term)
return tuple(results)
def _select(tensor, dims, indices):
for dim, index in zip(dims, indices):
tensor = tensor.select(dim, index)
return tensor
class _DimFlattener(object):
"""
Object to map batched dims to batches of flat dims.
:param dict dim_to_ordinal: a mapping from contraction dim to the set of
batch dims over which the contraction dim is batched.
"""
def __init__(self, dim_to_ordinal):
self._plates = {d: tuple(sorted(ordinal)) for d, ordinal in dim_to_ordinal.items()}
self._symbols = map(opt_einsum.get_symbol, itertools.count())
self._map = {}
def __call__(self, dim, indices):
"""
Converts a batched dim + batch indices to a flattened dim.
:param str dim: a batched dimension to flatten
:param dict indices: a mapping from batch dimension to int
:return: a flattened dim
:rtype: str
"""
plate = self._plates.get(dim, ())
index = tuple(indices[d] for d in plate)
key = dim, index
if key in self._map:
return self._map[key]
normal_dim = next(self._symbols)
self._map[key] = normal_dim
return normal_dim
def naive_ubersum(equation, *operands, **kwargs):
"""
Naive reference implementation of :func:`ubersum`.
This implementation should never raise ``NotImplementedError``.
This implementation should agree with :func:`ubersum` whenver
:func:`ubersum` does not raise ``NotImplementedError``.
"""
# Parse equation, without loss of generality assuming a single output.
inputs, outputs = equation.split('->')
outputs = outputs.split(',')
if len(outputs) > 1:
return tuple(naive_ubersum(inputs + '->' + output, *operands, **kwargs)[0]
for output in outputs)
output, = outputs
inputs = inputs.split(',')
# Split dims into batch dims, contraction dims, and dims to keep.
batch_dims = set(kwargs.pop('batch_dims', ''))
if not batch_dims:
result = opt_einsum.contract(equation, *operands, backend='pyro.ops.einsum.torch_log')
return (result,)
output_dims = set(output)
# Collect sizes of all dimensions.
sizes = {}
for input_, operand in zip(inputs, operands):
for dim, size in zip(input_, operand.shape):
old = sizes.setdefault(dim, size)
if old != size:
raise ValueError(u"Dimension size mismatch at dim '{}': {} vs {}"
.format(dim, size, old))
# Compute batch context for each non-batch dim, by convention the
# intersection over all batch contexts of tensors in which the dim appears.
dim_to_ordinal = {}
for dims in map(set, inputs):
ordinal = dims & batch_dims
for dim in dims - batch_dims:
dim_to_ordinal[dim] = dim_to_ordinal.get(dim, ordinal) & ordinal
for dim in output_dims - batch_dims:
missing_dims = dim_to_ordinal[dim] - output_dims
if missing_dims:
raise ValueError(u"It is nonsensical to preserve a batched dim without preserving "
u"all of that dim's batch dims, but found '{}' without '{}' in '{}'"
.format(dim, ','.join(missing_dims), equation))
# Flatten by replicating along batch dimensions.
flatten_dim = _DimFlattener(dim_to_ordinal)
flat_inputs = []
flat_operands = []
for input_, operand in zip(inputs, operands):
local_dims = [d for d in input_ if d in batch_dims]
offsets = [input_.index(d) - len(input_) for d in local_dims]
for index in itertools.product(*(range(sizes[d]) for d in local_dims)):
flat_inputs.append(''.join(flatten_dim(d, dict(zip(local_dims, index)))
for d in input_ if d not in batch_dims))
flat_operands.append(_select(operand, offsets, index))
# Defer to unbatched einsum.
result = operands[0].new_empty(torch.Size(sizes[d] for d in output))
local_dims = [d for d in output if d in batch_dims]
offsets = [output.index(d) - len(output) for d in local_dims]
for index in itertools.product(*(range(sizes[d]) for d in local_dims)):
flat_output = ''.join(flatten_dim(d, dict(zip(local_dims, index)))
for d in output if d not in batch_dims)
flat_equation = ','.join(flat_inputs) + '->' + flat_output
flat_result = opt_einsum.contract(flat_equation, *flat_operands,
backend='pyro.ops.einsum.torch_log')
_select(result, offsets, index).copy_(flat_result)
return (result,)