diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 6b7f667ae8..117f398fb0 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -1023,18 +1023,15 @@ def local_useless_switch(fgraph, node): # if left is right -> left if equivalent_up_to_constant_casting(left, right): - if left.type.broadcastable == out_bcast: - out_dtype = node.outputs[0].type.dtype - if left.type.dtype != out_dtype: - left = cast(left, out_dtype) - copy_stack_trace(node.outputs + left, left) - # When not casting, the other inputs of the switch aren't needed in the traceback - return [left] + if left.type.broadcastable != out_bcast: + left, _ = broadcast_arrays(left, cond) - else: - ret = broadcast_arrays(left, cond)[0] - copy_stack_trace(node.outputs + left, ret) - return [ret] + out_dtype = node.outputs[0].type.dtype + if left.type.dtype != out_dtype: + left = cast(left, out_dtype) + + copy_stack_trace(node.outputs + node.inputs, left) + return [left] # This case happens with scan. # Elemwise{switch}(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X) diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index 5e0366aa4f..bacbc540c5 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -1089,6 +1089,25 @@ def test_broadcasting_3(self): assert isinstance(f.maker.fgraph.outputs[0].owner.op, Alloc) assert not any(node.op == pt.switch for node in f.maker.fgraph.toposort()) + def test_broadcasting_different_dtype(self): + cond = vector("x", dtype="bool") + float32_branch = as_tensor(np.array([0], dtype="float32")) + float64_branch = as_tensor(np.array([0], dtype="float64")) + + out = pt.switch(cond, float32_branch, float64_branch) + expected_out = pt.alloc(float64_branch, cond.shape) + + rewritten_out = rewrite_graph( + out, include=("canonicalize", "stabilize", "specialize") + ) + assert equal_computations([rewritten_out], [expected_out]) + + out = pt.switch(cond, float64_branch, float32_branch) + rewritten_out = rewrite_graph( + out, include=("canonicalize", "stabilize", "specialize") + ) + assert equal_computations([rewritten_out], [expected_out]) + class TestLocalMergeSwitchSameCond: @pytest.mark.parametrize(