From d3437eae162dc48326c1e4017eaeb9ea7de84179 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Sat, 29 Apr 2023 13:44:35 +0530 Subject: [PATCH] CHange --- pymc/logprob/checks.py | 157 ----------------------------------------- 1 file changed, 157 deletions(-) delete mode 100644 pymc/logprob/checks.py diff --git a/pymc/logprob/checks.py b/pymc/logprob/checks.py deleted file mode 100644 index d72130cc85..0000000000 --- a/pymc/logprob/checks.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright 2023 The PyMC Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# MIT License -# -# Copyright (c) 2021-2022 aesara-devs -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -from typing import List, Optional - -import pytensor.tensor as pt - -from pytensor.graph.rewriting.basic import node_rewriter -from pytensor.raise_op import CheckAndRaise, ExceptionType, assert_op -from pytensor.tensor.shape import SpecifyShape - -from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper -from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db -from pymc.logprob.utils import ignore_logprob - - -class MeasurableSpecifyShape(SpecifyShape): - """A placeholder used to specify a log-likelihood for a specify-shape sub-graph.""" - - -MeasurableVariable.register(MeasurableSpecifyShape) - - -@_logprob.register(MeasurableSpecifyShape) -def logprob_specify_shape(op, values, inner_rv, *shapes, **kwargs): - (value,) = values - # transfer specify_shape from rv to value - value = pt.specify_shape(value, shapes) - return _logprob_helper(inner_rv, value) - - -@node_rewriter([SpecifyShape]) -def find_measurable_specify_shapes(fgraph, node) -> Optional[List[MeasurableSpecifyShape]]: - r"""Finds `SpecifyShapeOp`\s for which a `logprob` can be computed.""" - - if isinstance(node.op, MeasurableSpecifyShape): - return None # pragma: no cover - - rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) - - if rv_map_feature is None: - return None # pragma: no cover - - rv = node.outputs[0] - - base_rv, *shape = node.inputs - - if not ( - base_rv.owner - and isinstance(base_rv.owner.op, MeasurableVariable) - and base_rv not in rv_map_feature.rv_values - ): - return None # pragma: no cover - - new_op = MeasurableSpecifyShape() - # Make base_var unmeasurable - unmeasurable_base_rv = ignore_logprob(base_rv) - new_rv = new_op.make_node(unmeasurable_base_rv, *shape).default_output() - new_rv.name = rv.name - - return [new_rv] - - -measurable_ir_rewrites_db.register( - "find_measurable_specify_shapes", - find_measurable_specify_shapes, - "basic", - "specify_shape", -) - - -class MeasurableCheckAndRaise(CheckAndRaise): - """A placeholder used to specify a log-likelihood for an assert sub-graph.""" - - -MeasurableVariable.register(MeasurableCheckAndRaise) - - -@_logprob.register(MeasurableCheckAndRaise) -def logprob_assert(op, values, inner_rv, *assertion, **kwargs): - (value,) = values - # transfer assertion from rv to value - value = op(assertion, value) - return _logprob_helper(inner_rv, value) - - -@node_rewriter([CheckAndRaise]) -def find_measurable_asserts(fgraph, node) -> Optional[List[MeasurableCheckAndRaise]]: - r"""Finds `AssertOp`\s for which a `logprob` can be computed.""" - - if isinstance(node.op, MeasurableCheckAndRaise): - return None # pragma: no cover - - rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) - - if rv_map_feature is None: - return None # pragma: no cover - - rv = node.outputs[0] - - base_rv, *conds = node.inputs - - if not ( - base_rv.owner - and isinstance(base_rv.owner.op, MeasurableVariable) - and base_rv not in rv_map_feature.rv_values - ): - return None # pragma: no cover - - new_op = MeasurableCheckAndRaise(exc_type = node.op.exc_type) - # Make base_var unmeasurable - unmeasurable_base_rv = ignore_logprob(base_rv) - new_rv = new_op.make_node(unmeasurable_base_rv, *conds).default_output() - new_rv.name = rv.name - - return [new_rv] - - -measurable_ir_rewrites_db.register( - "find_measurable_asserts", - find_measurable_asserts, - "basic", - "assert", -)