Skip to content

Equality of Ops with InnerGraph #1606

@ricardoV94

Description

@ricardoV94

Description

Scan and Composite/ScalarLoop are our only ops with inner graph that support equality /merging. The way equality / hashing is done it's to use the generated c_code or ask what would be the c_module hash key + traverse the graph with equal_computations.

This approach is not great, it conflates C-code generation with regular Op matching, and it's also slow / non-extensible to non-C Ops in the inner graph. It would also be good to have a means of comparing fgraph equality that doesn't require traversing the graph in the worst case scenario that two Ops are indeed identical (which happens in the case of Scan).

One alternative (but we may want to think of others) is to use hash-consing for inner graphs, in that if you ever create the same graph twice you get the same variable, and equality hashing can be done simply in terms of identity. If two fgraph have the same output variables, they are identical, otherwise they're not.

This shouldn't be more expensive for us, as we already clone the graphs everytime we create an inner graph Op. For Composite we actually do it 3 times in the regular lifetime of the Op.

Once at init here:

inputs, outputs = clone(inputs, outputs)

Then once more to run MergeOptimizer in _cleanup_graph:

self.inputs, self.outputs = self._cleanup_graph(inputs, outputs)

def _cleanup_graph(self, inputs, outputs):
# TODO: We could convert to TensorVariable, optimize graph,
# and then convert back to ScalarVariable.
# This would introduce rewrites like `log(1 + x) -> log1p`.
fgraph = FunctionGraph(copy(inputs), copy(outputs))
# Validate node types
for node in fgraph.apply_nodes:
if not isinstance(node.op, ScalarOp):
raise TypeError(
f"The fgraph of {self.__class__.__name__} must be exclusively "
"composed of scalar operations."
)
# Run MergeOptimization to avoid duplicated nodes
MergeOptimizer().rewrite(fgraph)
inputs, outputs = fgraph.inputs, fgraph.outputs

And finally whenever we hash/compare equality and have to create the c_template:

As the fgraph property ALWAYS clones the inner variables:

@property
def fgraph(self):
if hasattr(self, "_fgraph"):
return self._fgraph
# fgraph cannot be a property of the base class because it messes up with C caching.
# We also need a `FunctionGraph(clone=True)` (default) according to an old comment
fgraph = FunctionGraph(self.inputs, self.outputs)
self._fgraph = fgraph
return self._fgraph

@property
def c_code_template(self):
from pytensor.link.c.interface import CLinkerType
if hasattr(self, "_c_code"):
return self._c_code
fg = self.fgraph

def __eq__(self, other):
if self is other:
return True
if (
type(self) is not type(other)
or self.nin != other.nin
or self.nout != other.nout
):
return False
# TODO FIXME: Why this? Shouldn't we expect equivalent inputs to this
# object to generate the same `_c_code`?
return self.c_code_template == other.c_code_template
def __hash__(self):
# Note that in general, the configparser settings at the time
# of code generation (__init__) affect the semantics of this Op.
# This function assumes that all relevant info about the configparser
# is embodied in _c_code. So the _c_code, rather than self.fgraph,
# is the signature of the semantics of this Op.
# _c_code is preserved through unpickling, so the Op will not change
# semantics when it is reloaded with different configparser
# settings.
#
# TODO FIXME: Doesn't the above just mean that we should be including
# the relevant "configparser settings" here? Also, why should we even
# care about the exact form of the generated C code when comparing
# `Op`s? All this smells of leaky concerns and interfaces.
return hash((type(self), self.nin, self.nout, self.c_code_template))

Creating the fgraph once with hash-consing wouldn't be any more expensive. It would automatically achieve the same that MergeOptimizer does (no duplicated nodes are ever created by definition), and would allow O(1) hash/equality that is dissociated from the C-backend (plus the cost of creating the fgraph, but as I mentioned this is already incurred anyway).

This functionality would probably be useful elsewhere, for instance it could be used for pymc-devs/pymc-extras#277 with some tweaks.

It would address #1114 and fix #1601

The hash-consed fgraphs should be immutable, easy to achieve by using Apply that have tuples as inputs/outputs instead of lists. The inner graphs of an Op should not be mutable since Ops are supposed to be frozen and deterministic. This is the issue reported in #1601

If needed, during compilation there should be a specific rewrite that clones the fgraph into a mutable one, rewrites it, and creates a new version of the Op with the new Functiongraph. Right now this is awkwardly done/cached at make_thunk time or in the new backends at dispatch time.

Making this an explicitly rewrite also allows users to control compilation of inner graphs without the awkward mode in Scan or kwargs in OpFromGraph


Again this may not be the best solution, but hash-consing of Apply / Constants can be implemented with weakref dictionaries + Using nominal variables for the inner inputs (which are already hash-consed). There was an exploratory work in Aesara: aesara-devs/aesara#1165

(Note: I think that Aesara PR takes the idea too far, having mutable graph seems rather reasonable during rewrites, and there's no obvious proposal on how to avoid too much copying with immutable nodes)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions