Skip to content

Commit

Permalink
Fix vectorize_graph bug when replacements were provided only some o…
Browse files Browse the repository at this point in the history
…utputs of a node

The provided output could be silently ignored and replaced by the new output of the vectorized node.

The changes also avoid vectorizing multiple-output nodes when none of the unreplaced outputs are needed.
  • Loading branch information
ricardoV94 committed Jan 7, 2024
1 parent c4ae6e3 commit 0ebc83b
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 8 deletions.
7 changes: 4 additions & 3 deletions pytensor/graph/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1439,15 +1439,16 @@ def io_toposort(
order = []
while todo:
cur = todo.pop()
# We suppose that all outputs are always computed
if cur.outputs[0] in computed:
if all(out in computed for out in cur.outputs):
continue
if all(i in computed or i.owner is None for i in cur.inputs):
computed.update(cur.outputs)
order.append(cur)
else:
todo.append(cur)
todo.extend(i.owner for i in cur.inputs if i.owner)
todo.extend(
i.owner for i in cur.inputs if (i.owner and i not in computed)
)
return order

compute_deps = None
Expand Down
5 changes: 5 additions & 0 deletions pytensor/graph/replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,11 @@ def vectorize_graph(
vect_inputs = [vect_vars.get(inp, inp) for inp in node.inputs]
vect_node = vectorize_node(node, *vect_inputs)
for output, vect_output in zip(node.outputs, vect_node.outputs):
if output in vect_vars:
# This can happen when some outputs of a multi-output node are given a replacement,
# while some of the remaining outputs are still needed in the graph.
# We make sure we don't overwrite the provided replacement with the newly vectorized output
continue
vect_vars[output] = vect_output

seq_vect_outputs = [vect_vars[out] for out in seq_outputs]
Expand Down
41 changes: 40 additions & 1 deletion tests/graph/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from pytensor.tensor.type import TensorType, iscalars, matrix, scalars, vector
from pytensor.tensor.type_other import NoneConst
from pytensor.tensor.variable import TensorVariable
from tests.graph.utils import MyInnerGraphOp
from tests.graph.utils import MyInnerGraphOp, op_multiple_outputs


class MyType(Type):
Expand Down Expand Up @@ -287,6 +287,45 @@ def test_outputs_clients(self):
all = io_toposort([], o0.outputs)
assert all == [o0]

def test_multi_output_nodes(self):
l0, r0 = op_multiple_outputs(shared(0.0))
l1, r1 = op_multiple_outputs(shared(0.0))

v0 = r0 + 1
v1 = pt.exp(v0)
out = r1 * v1

# When either r0 or r1 is provided as an input, the respective node shouldn't be part of the toposort
assert set(io_toposort([], [out])) == {
r0.owner,
r1.owner,
v0.owner,
v1.owner,
out.owner,
}
assert set(io_toposort([r0], [out])) == {
r1.owner,
v0.owner,
v1.owner,
out.owner,
}
assert set(io_toposort([r1], [out])) == {
r0.owner,
v0.owner,
v1.owner,
out.owner,
}
assert set(io_toposort([r0, r1], [out])) == {v0.owner, v1.owner, out.owner}

# When l0 and/or l1 are provided, we still need to compute the respective nodes
assert set(io_toposort([l0, l1], [out])) == {
r0.owner,
r1.owner,
v0.owner,
v1.owner,
out.owner,
}


class TestEval:
def setup_method(self):
Expand Down
71 changes: 67 additions & 4 deletions tests/graph/test_replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,15 @@
import pytensor.tensor as pt
from pytensor import config, function, shared
from pytensor.graph.basic import equal_computations, graph_inputs
from pytensor.graph.replace import clone_replace, graph_replace, vectorize_graph
from pytensor.graph.replace import (
clone_replace,
graph_replace,
vectorize_graph,
vectorize_node,
)
from pytensor.tensor import dvector, fvector, vector
from tests import unittest_tools as utt
from tests.graph.utils import MyOp, MyVariable
from tests.graph.utils import MyOp, MyVariable, op_multiple_outputs


class TestCloneReplace:
Expand Down Expand Up @@ -227,8 +232,6 @@ def test_graph_replace_disconnected(self):


class TestVectorizeGraph:
# TODO: Add tests with multiple outputs, constants, and other singleton types

def test_basic(self):
x = pt.vector("x")
y = pt.exp(x) / pt.sum(pt.exp(x))
Expand Down Expand Up @@ -260,3 +263,63 @@ def test_multiple_outputs(self):
new_y1_res, new_y2_res = fn(new_x_test)
np.testing.assert_allclose(new_y1_res, [0, 3, 6])
np.testing.assert_allclose(new_y2_res, [2, 5, 8])

def test_multi_output_node(self):
x = pt.scalar("x")
node = op_multiple_outputs.make_node(x)
y1, y2 = node.outputs
out = pt.add(y1, y2)

new_x = pt.vector("new_x")
new_y1 = pt.vector("new_y1")
new_y2 = pt.vector("new_y2")

# Cases where either x or both of y1 and y2 are given replacements
new_out = vectorize_graph(out, {x: new_x})
expected_new_out = pt.add(*vectorize_node(node, new_x).outputs)
assert equal_computations([new_out], [expected_new_out])

new_out = vectorize_graph(out, {y1: new_y1, y2: new_y2})
expected_new_out = pt.add(new_y1, new_y2)
assert equal_computations([new_out], [expected_new_out])

new_out = vectorize_graph(out, {x: new_x, y1: new_y1, y2: new_y2})
expected_new_out = pt.add(new_y1, new_y2)
assert equal_computations([new_out], [expected_new_out])

# Special case where x is given a replacement as well as only one of y1 and y2
# The graph combines the replaced variable with the other vectorized output
new_out = vectorize_graph(out, {x: new_x, y1: new_y1})
expected_new_out = pt.add(new_y1, vectorize_node(node, new_x).outputs[1])
assert equal_computations([new_out], [expected_new_out])

def test_multi_output_node_random_variable(self):
"""This is a regression test for #569.
Functionally, it covers the same case as `test_multiple_output_node`
"""

# RandomVariables have two outputs, a hidden RNG and the visible draws
beta0 = pt.random.normal(name="beta0")
beta1 = pt.random.normal(name="beta1")

out1 = beta0 + 1
out2 = beta1 * pt.exp(out1)

# We replace the second output of each RandomVariable
new_beta0 = pt.tensor("new_beta0", shape=(3,))
new_beta1 = pt.tensor("new_beta1", shape=(3,))

new_outs = vectorize_graph(
[out1, out2],
replace={
beta0: new_beta0,
beta1: new_beta1,
},
)

expected_new_outs = [
new_beta0 + 1,
new_beta1 * pt.exp(new_beta0 + 1),
]
assert equal_computations(new_outs, expected_new_outs)

0 comments on commit 0ebc83b

Please sign in to comment.