Skip to content

Commit

Permalink
Convert boolean indices to integers when determining broadcast pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Oct 14, 2020
1 parent e96b285 commit ba12674
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
3 changes: 3 additions & 0 deletions tests/tensor/test_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2236,6 +2236,9 @@ def test_AdvancedBooleanSubtensor(self):
check_topo=False,
)

abs_res = n[~tensor.isinf(n)]
assert abs_res.broadcastable == (False,)


@change_flags(compute_test_value="raise")
def test_basic_shape():
Expand Down
21 changes: 19 additions & 2 deletions theano/tensor/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2360,7 +2360,7 @@ class BaseAdvancedSubtensor(Op):

__props__ = ()

def make_node(self, x, *index):
def make_node(self, x, *index, is_boolean=False):
x = theano.tensor.as_tensor_variable(x)
index = tuple(map(as_index_variable, index))

Expand All @@ -2372,9 +2372,23 @@ def make_node(self, x, *index):
theano.tensor.tensor(dtype="int64", broadcastable=()) if not bcast else 1
for bcast in x.broadcastable
)

bcast_index = index
if is_boolean:
bcast_index = tuple(
chain.from_iterable(
theano.tensor.basic.nonzero(idx)
if getattr(idx, "ndim", 0) > 0
else (idx,)
for idx in bcast_index
)
)

bcast = [
getattr(i, "value", i) == 1 for i in indexed_result_shape(fake_shape, index)
getattr(i, "value", i) == 1
for i in indexed_result_shape(fake_shape, bcast_index)
]

return gof.Apply(
self,
(x,) + index,
Expand Down Expand Up @@ -2465,6 +2479,9 @@ class AdvancedBooleanSubtensor(BaseAdvancedSubtensor):
"""

def make_node(self, x, *index):
return super().make_node(x, *index, is_boolean=True)

def grad(self, inputs, grads):
(gz,) = grads
x = inputs[0]
Expand Down

0 comments on commit ba12674

Please sign in to comment.