Skip to content

Commit

Permalink
Fix bug when broadcasting branches in local_useless_switch rewrite
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Mar 28, 2024
1 parent 5a47550 commit ef22377
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 11 deletions.
19 changes: 8 additions & 11 deletions pytensor/tensor/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,18 +1023,15 @@ def local_useless_switch(fgraph, node):

# if left is right -> left
if equivalent_up_to_constant_casting(left, right):
if left.type.broadcastable == out_bcast:
out_dtype = node.outputs[0].type.dtype
if left.type.dtype != out_dtype:
left = cast(left, out_dtype)
copy_stack_trace(node.outputs + left, left)
# When not casting, the other inputs of the switch aren't needed in the traceback
return [left]
if left.type.broadcastable != out_bcast:
left, _ = broadcast_arrays(left, cond)

else:
ret = broadcast_arrays(left, cond)[0]
copy_stack_trace(node.outputs + left, ret)
return [ret]
out_dtype = node.outputs[0].type.dtype
if left.type.dtype != out_dtype:
left = cast(left, out_dtype)

copy_stack_trace(node.outputs + node.inputs, left)
return [left]

# This case happens with scan.
# Elemwise{switch}(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)
Expand Down
19 changes: 19 additions & 0 deletions tests/tensor/rewriting/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,6 +1089,25 @@ def test_broadcasting_3(self):
assert isinstance(f.maker.fgraph.outputs[0].owner.op, Alloc)
assert not any(node.op == pt.switch for node in f.maker.fgraph.toposort())

def test_broadcasting_different_dtype(self):
cond = vector("x", dtype="bool")
float32_branch = as_tensor(np.array([0], dtype="float32"))
float64_branch = as_tensor(np.array([0], dtype="float64"))

out = pt.switch(cond, float32_branch, float64_branch)
expected_out = pt.alloc(float64_branch, cond.shape)

rewritten_out = rewrite_graph(
out, include=("canonicalize", "stabilize", "specialize")
)
assert equal_computations([rewritten_out], [expected_out])

out = pt.switch(cond, float64_branch, float32_branch)
rewritten_out = rewrite_graph(
out, include=("canonicalize", "stabilize", "specialize")
)
assert equal_computations([rewritten_out], [expected_out])


class TestLocalMergeSwitchSameCond:
@pytest.mark.parametrize(
Expand Down

0 comments on commit ef22377

Please sign in to comment.