From 26d06643f3493ab1ec12f1c37fff2c98defca555 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Sat, 29 Apr 2023 13:45:51 +0530 Subject: [PATCH] Rectification in MeasurableCheckAndRaise --- pymc/logprob/checks.py | 157 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 157 insertions(+) create mode 100644 pymc/logprob/checks.py diff --git a/pymc/logprob/checks.py b/pymc/logprob/checks.py new file mode 100644 index 0000000000..90c6a44b8e --- /dev/null +++ b/pymc/logprob/checks.py @@ -0,0 +1,157 @@ +# Copyright 2023 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# MIT License +# +# Copyright (c) 2021-2022 aesara-devs +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from typing import List, Optional + +import pytensor.tensor as pt + +from pytensor.graph.rewriting.basic import node_rewriter +from pytensor.raise_op import CheckAndRaise +from pytensor.tensor.shape import SpecifyShape + +from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper +from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db +from pymc.logprob.utils import ignore_logprob + + +class MeasurableSpecifyShape(SpecifyShape): + """A placeholder used to specify a log-likelihood for a specify-shape sub-graph.""" + + +MeasurableVariable.register(MeasurableSpecifyShape) + + +@_logprob.register(MeasurableSpecifyShape) +def logprob_specify_shape(op, values, inner_rv, *shapes, **kwargs): + (value,) = values + # transfer specify_shape from rv to value + value = pt.specify_shape(value, shapes) + return _logprob_helper(inner_rv, value) + + +@node_rewriter([SpecifyShape]) +def find_measurable_specify_shapes(fgraph, node) -> Optional[List[MeasurableSpecifyShape]]: + r"""Finds `SpecifyShapeOp`\s for which a `logprob` can be computed.""" + + if isinstance(node.op, MeasurableSpecifyShape): + return None # pragma: no cover + + rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) + + if rv_map_feature is None: + return None # pragma: no cover + + rv = node.outputs[0] + + base_rv, *shape = node.inputs + + if not ( + base_rv.owner + and isinstance(base_rv.owner.op, MeasurableVariable) + and base_rv not in rv_map_feature.rv_values + ): + return None # pragma: no cover + + new_op = MeasurableSpecifyShape() + # Make base_var unmeasurable + unmeasurable_base_rv = ignore_logprob(base_rv) + new_rv = new_op.make_node(unmeasurable_base_rv, *shape).default_output() + new_rv.name = rv.name + + return [new_rv] + + +measurable_ir_rewrites_db.register( + "find_measurable_specify_shapes", + find_measurable_specify_shapes, + "basic", + "specify_shape", +) + + +class MeasurableCheckAndRaise(CheckAndRaise): + """A placeholder used to specify a log-likelihood for an assert sub-graph.""" + + +MeasurableVariable.register(MeasurableCheckAndRaise) + + +@_logprob.register(MeasurableCheckAndRaise) +def logprob_assert(op, values, inner_rv, *assertion, **kwargs): + (value,) = values + # transfer assertion from rv to value + value = op(assertion, value) + return _logprob_helper(inner_rv, value) + + +@node_rewriter([CheckAndRaise]) +def find_measurable_asserts(fgraph, node) -> Optional[List[MeasurableCheckAndRaise]]: + r"""Finds `AssertOp`\s for which a `logprob` can be computed.""" + + if isinstance(node.op, MeasurableCheckAndRaise): + return None # pragma: no cover + + rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) + + if rv_map_feature is None: + return None # pragma: no cover + + rv = node.outputs[0] + + base_rv, *conds = node.inputs + + if not ( + base_rv.owner + and isinstance(base_rv.owner.op, MeasurableVariable) + and base_rv not in rv_map_feature.rv_values + ): + return None # pragma: no cover + + new_op = MeasurableCheckAndRaise(exc_type=node.op.exc_type) + # Make base_var unmeasurable + unmeasurable_base_rv = ignore_logprob(base_rv) + new_rv = new_op.make_node(unmeasurable_base_rv, *conds).default_output() + new_rv.name = rv.name + + return [new_rv] + + +measurable_ir_rewrites_db.register( + "find_measurable_asserts", + find_measurable_asserts, + "basic", + "assert", +)