Skip to content

Commit

Permalink
Prevent local_sum_make_vector from introducing internal float64 (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
tvwenger committed Mar 8, 2024
1 parent d175203 commit 4ee3588
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 26 deletions.
7 changes: 6 additions & 1 deletion pytensor/tensor/rewriting/basic.py
Expand Up @@ -28,7 +28,7 @@
import numpy as np

import pytensor.scalar.basic as ps
from pytensor import compile
from pytensor import compile, config
from pytensor.compile.ops import ViewOp
from pytensor.graph import FunctionGraph
from pytensor.graph.basic import Constant, Variable
Expand Down Expand Up @@ -941,6 +941,11 @@ def local_sum_make_vector(fgraph, node):
elements = array.owner.inputs
acc_dtype = node.op.acc_dtype
out_dtype = node.op.dtype

# Skip rewrite if it would add unnecessary float64 to the graph
if acc_dtype == "float64" and out_dtype != "float64" and config.floatX != "float64":
return

if len(elements) == 0:
element_sum = zeros(dtype=out_dtype, shape=())
elif len(elements) == 1:
Expand Down
57 changes: 32 additions & 25 deletions tests/tensor/rewriting/test_basic.py
Expand Up @@ -12,7 +12,7 @@
from pytensor.compile.mode import get_default_mode, get_mode
from pytensor.compile.ops import DeepCopyOp, deep_copy_op
from pytensor.configdefaults import config
from pytensor.graph.basic import equal_computations, vars_between
from pytensor.graph.basic import equal_computations
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import check_stack_trace, out2in
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
Expand All @@ -26,12 +26,12 @@
ScalarFromTensor,
Split,
TensorFromScalar,
cast,
join,
tile,
)
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import (
Sum,
add,
bitwise_and,
bitwise_or,
Expand Down Expand Up @@ -1298,41 +1298,48 @@ def test_local_join_make_vector():


def test_local_sum_make_vector():
# To check that rewrite is applied, we must enforce dtype to
# allow rewrite to occur even if floatX != "float64"
a, b, c = scalars("abc")
mv = MakeVector(config.floatX)
output = mv(a, b, c).sum()

output = rewrite_graph(output)
between = vars_between([a, b, c], [output])
for var in between:
assert (var.owner is None) or (not isinstance(var.owner.op, MakeVector))
output = mv(a, b, c).sum(dtype="float64")
rewrite_output = rewrite_graph(output)
expected_output = cast(
add(*[cast(value, "float64") for value in [a, b, c]]), dtype="float64"
)
assert equal_computations([expected_output], [rewrite_output])

# Check for empty sum
# Empty axes should return input vector since no sum is applied
a, b, c = scalars("abc")
mv = MakeVector(config.floatX)
output = mv(a, b, c).sum(axis=[])
rewrite_output = rewrite_graph(output)
expected_output = mv(a, b, c)
assert equal_computations([expected_output], [rewrite_output])

output = rewrite_graph(output)
between = vars_between([a, b, c], [output])
for var in between:
assert (var.owner is None) or (not isinstance(var.owner.op, Sum))

# Check empty MakeVector
# Empty input should return 0
mv = MakeVector(config.floatX)
output = mv().sum()
rewrite_output = rewrite_graph(output)
expected_output = pt.as_tensor(0, dtype=config.floatX)
assert equal_computations([expected_output], [rewrite_output])

output = rewrite_graph(output)
between = vars_between([a, b, c], [output])
for var in between:
assert (var.owner is None) or (not isinstance(var.owner.op, Sum))

# Single element input should return element value
a = scalars("a")
mv = MakeVector(config.floatX)
output = mv(a).sum()

output = rewrite_graph(output)
between = vars_between([a, b, c], [output])
for var in between:
assert (var.owner is None) or (not isinstance(var.owner.op, Sum))
rewrite_output = rewrite_graph(output)
expected_output = cast(a, config.floatX)
assert equal_computations([expected_output], [rewrite_output])

# This is a regression test for #653. Ensure that rewrite is NOT
# applied when user requests float32
with config.change_flags(floatX="float32", warn_float64="raise"):
a, b, c = scalars("abc")
mv = MakeVector(config.floatX)
output = mv(a, b, c).sum()
rewrite_output = rewrite_graph(output)
assert equal_computations([output], [rewrite_output])


@pytest.mark.parametrize(
Expand Down

0 comments on commit 4ee3588

Please sign in to comment.