Skip to content

Commit

Permalink
Allow defining an OpFromGraph from constant and shared inputs.
Browse files Browse the repository at this point in the history
Also adds a strict flag
  • Loading branch information
ricardoV94 committed Mar 22, 2024
1 parent 339aab4 commit 97317a5
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 37 deletions.
61 changes: 31 additions & 30 deletions pytensor/compile/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,38 +92,29 @@ def construct_nominal_fgraph(
dict[Variable, Variable],
]:
"""Construct an inner-`FunctionGraph` with ordered nominal inputs."""
dummy_inputs = []
for n, inp in enumerate(inputs):
if (
not isinstance(inp, Variable)
or isinstance(inp, Constant)
or isinstance(inp, SharedVariable)
):
raise TypeError(
f"Inputs and outputs must be non-Constant/shared Variable instances; got {inp}"
)

dummy_inputs.append(inp.type())
implicit_shared_inputs = []

dummy_shared_inputs = []
shared_inputs = []
dummy_inputs = [inp.type() for inp in inputs]
dummy_implicit_shared_inputs = []
for var in graph_inputs(outputs, inputs):
if var in inputs:
continue
if isinstance(var, SharedVariable):
# To correctly support shared variables the inner-graph should
# not see them; otherwise, there will be problems with
# gradients.
# That's why we collect the shared variables and replace them
# with dummies.
shared_inputs.append(var)
dummy_shared_inputs.append(var.type())
elif var not in inputs and not isinstance(var, Constant):
raise MissingInputError(f"OpFromGraph is missing an input: {var}")

replacements = dict(zip(inputs + shared_inputs, dummy_inputs + dummy_shared_inputs))
# We allow shared inputs to be added automatically to the graph
implicit_shared_inputs.append(var)
dummy_implicit_shared_inputs.append(var.type())
elif not isinstance(var, Constant):
raise MissingInputError(f"NominalGraph is missing an input: {var}")

replacements = dict(
zip(
inputs + implicit_shared_inputs, dummy_inputs + dummy_implicit_shared_inputs
)
)

new = rebuild_collect_shared(
cast(Sequence[Variable], outputs),
inputs=inputs + shared_inputs,
inputs=inputs + implicit_shared_inputs,
replace=replacements,
copy_inputs_over=False,
)
Expand All @@ -133,7 +124,7 @@ def construct_nominal_fgraph(
(clone_d, update_d, update_expr, new_shared_inputs),
) = new

assert len(local_inputs) == len(inputs) + len(shared_inputs)
assert len(local_inputs) == len(inputs) + len(implicit_shared_inputs)
assert len(local_outputs) == len(outputs)
assert not update_d
assert not update_expr
Expand All @@ -155,7 +146,7 @@ def construct_nominal_fgraph(
fgraph.clients.pop(inp, None)
fgraph.add_input(nom_inp)

return fgraph, shared_inputs, update_d, update_expr
return fgraph, implicit_shared_inputs, update_d, update_expr


class OpFromGraph(Op, HasInnerGraph):
Expand All @@ -177,8 +168,6 @@ class OpFromGraph(Op, HasInnerGraph):
- grad() make it support DisconnectedType and the new interface
- add support for NullType and DisconnectedType when R_op supports them
- check how it works with updates.
- add test with constant as input or inside the inner graph.
- Add support for the GPU? Probably just need an opt to remove transfer
- Add support to pickle this Op.
- Add support/test with random generator
- Add optimization to removing unused inputs/outputs
Expand Down Expand Up @@ -310,11 +299,13 @@ def __init__(
self,
inputs: list[Variable],
outputs: list[Variable],
*,
inline: bool = False,
lop_overrides: str = "default",
grad_overrides: str = "default",
rop_overrides: str = "default",
connection_pattern: Optional[list[list[bool]]] = None,
strict: bool = False,
name: Optional[str] = None,
**kwargs,
):
Expand Down Expand Up @@ -399,6 +390,10 @@ def __init__(
must be equal to number of outputs. connection_pattern If not
``None``, this will be used as the connection_pattern for this
:class:`Op`.
strict: bool, default False
If true, it raises when any variables needed to compute the inner graph
are not provided as explici inputs. This can only happen for graphs with
shared variables.
name
A name for debugging purposes.
kwargs
Expand All @@ -424,6 +419,12 @@ def __init__(
inputs, outputs
)

if strict and self.shared_inputs:
raise ValueError(
"All variables needed to compute inner-graph must be provided as inputs under strict=True. "
f"The inner-graph implicitly depends on the following shared variables {self.shared_inputs}"
)

self.kwargs = kwargs
self.input_types = [inp.type for inp in inputs]
self.output_types = [out.type for out in outputs]
Expand Down
33 changes: 26 additions & 7 deletions tests/compile/test_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.graph.utils import MissingInputError
from pytensor.printing import debugprint
from pytensor.tensor.basic import as_tensor
from pytensor.tensor.basic import constant
from pytensor.tensor.math import dot, exp, sigmoid
from pytensor.tensor.math import round as pt_round
from pytensor.tensor.math import sum as pt_sum
Expand Down Expand Up @@ -43,12 +43,6 @@ def test_valid_input(self):
with pytest.raises(TypeError):
OpFromGraph([1], [1])

with pytest.raises(TypeError):
OpFromGraph([x, as_tensor(1)], [x])

with pytest.raises(TypeError):
OpFromGraph([shared(1)], [1])

with pytest.raises(NotImplementedError):
OpFromGraph([x], [x], updates={})

Expand Down Expand Up @@ -559,6 +553,31 @@ def test_outputs_consistency(self):
# The original `op.fgraph` outputs should stay the same, though
assert equal_computations(op.inner_outputs, [x**2 / x], op.inner_inputs, [x])

def test_explicit_input_from_constant(self):
x = pt.dscalar("x")
y = constant(1.0, name="y")
test_ofg = OpFromGraph([x, y], [x + y])

out = test_ofg(x, y)
assert out.eval({x: 5}) == 6

def test_explicit_input_from_shared(self):
x = pt.dscalar("x")
y = shared(1.0, name="y")

with pytest.raises(
ValueError,
match=r"The inner-graph implicitly depends on the following shared variables \[y\]",
):
OpFromGraph([x], [x + y], strict=True)

test_ofg = OpFromGraph([x, y], [x + y], strict=True)

out = test_ofg(x, y)
assert out.eval({x: 5}) == 6
y.set_value(2.0)
assert out.eval({x: 6})


@config.change_flags(floatX="float64")
def test_debugprint():
Expand Down

0 comments on commit 97317a5

Please sign in to comment.