diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d0bd850ca8..1a1d7a50fc 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -104,6 +104,7 @@ jobs: tests/distributions/test_truncated.py tests/logprob/test_abstract.py tests/logprob/test_basic.py + tests/logprob/test_binary.py tests/logprob/test_censoring.py tests/logprob/test_composite_logprob.py tests/logprob/test_cumsum.py diff --git a/pymc/logprob/__init__.py b/pymc/logprob/__init__.py index f6ae51408c..7992efac8b 100644 --- a/pymc/logprob/__init__.py +++ b/pymc/logprob/__init__.py @@ -38,6 +38,7 @@ # isort: off # Add rewrites to the DBs +import pymc.logprob.binary import pymc.logprob.censoring import pymc.logprob.cumsum import pymc.logprob.checks diff --git a/pymc/logprob/binary.py b/pymc/logprob/binary.py new file mode 100644 index 0000000000..a35673b454 --- /dev/null +++ b/pymc/logprob/binary.py @@ -0,0 +1,111 @@ +# Copyright 2023 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional + +import numpy as np +import pytensor.tensor as pt + +from pytensor.graph.basic import Node +from pytensor.graph.fg import FunctionGraph +from pytensor.graph.rewriting.basic import node_rewriter +from pytensor.scalar.basic import GT, LT +from pytensor.tensor.math import gt, lt + +from pymc.logprob.abstract import ( + MeasurableElemwise, + MeasurableVariable, + _logcdf_helper, + _logprob, + _logprob_helper, +) +from pymc.logprob.rewriting import measurable_ir_rewrites_db +from pymc.logprob.utils import check_potential_measurability, ignore_logprob + + +class MeasurableComparison(MeasurableElemwise): + """A placeholder used to specify a log-likelihood for a binary comparison RV sub-graph.""" + + valid_scalar_types = (GT, LT) + + +@node_rewriter(tracks=[gt, lt]) +def find_measurable_comparisons( + fgraph: FunctionGraph, node: Node +) -> Optional[List[MeasurableComparison]]: + rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None) + if rv_map_feature is None: + return None # pragma: no cover + + if isinstance(node.op, MeasurableComparison): + return None # pragma: no cover + + (compared_var,) = node.outputs + base_var, const = node.inputs + + if not ( + base_var.owner + and isinstance(base_var.owner.op, MeasurableVariable) + and base_var not in rv_map_feature.rv_values + ): + return None + + # check for potential measurability of const + if not check_potential_measurability((const,), rv_map_feature): + return None + + # Make base_var unmeasurable + unmeasurable_base_var = ignore_logprob(base_var) + + compared_op = MeasurableComparison(node.op.scalar_op) + compared_rv = compared_op.make_node(unmeasurable_base_var, const).default_output() + compared_rv.name = compared_var.name + return [compared_rv] + + +measurable_ir_rewrites_db.register( + "find_measurable_comparisons", + find_measurable_comparisons, + "basic", + "comparison", +) + + +@_logprob.register(MeasurableComparison) +def comparison_logprob(op, values, base_rv, operand, **kwargs): + (value,) = values + + base_rv_op = base_rv.owner.op + + logcdf = _logcdf_helper(base_rv, operand, **kwargs) + logccdf = pt.log1mexp(logcdf) + + condn_exp = pt.eq(value, np.array(True)) + + if isinstance(op.scalar_op, GT): + logprob = pt.switch(condn_exp, logccdf, logcdf) + elif isinstance(op.scalar_op, LT): + if base_rv.dtype.startswith("int"): + logpmf = _logprob_helper(base_rv, operand, **kwargs) + logcdf_lt_true = _logcdf_helper(base_rv, operand - 1, **kwargs) + logprob = pt.switch(condn_exp, logcdf_lt_true, pt.logaddexp(logccdf, logpmf)) + else: + logprob = pt.switch(condn_exp, logcdf, logccdf) + else: + raise TypeError(f"Unsupported scalar_op {op.scalar_op}") + + if base_rv_op.name: + logprob.name = f"{base_rv_op}_logprob" + logcdf.name = f"{base_rv_op}_logcdf" + + return logprob diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 008be0b731..13532c1d06 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -85,7 +85,7 @@ _logprob_helper, ) from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db -from pymc.logprob.utils import ignore_logprob, walk_model +from pymc.logprob.utils import check_potential_measurability, ignore_logprob class TransformedVariable(Op): @@ -573,19 +573,8 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li # Check that other inputs are not potentially measurable, in which case this rewrite # would be invalid other_inputs = tuple(inp for inp in node.inputs if inp is not measurable_input) - if any( - ancestor_node - for ancestor_node in walk_model( - other_inputs, - walk_past_rvs=False, - stop_at_vars=set(rv_map_feature.rv_values), - ) - if ( - ancestor_node.owner - and isinstance(ancestor_node.owner.op, MeasurableVariable) - and ancestor_node not in rv_map_feature.rv_values - ) - ): + + if not check_potential_measurability(other_inputs, rv_map_feature): return None # Make base_measure outputs unmeasurable diff --git a/pymc/logprob/utils.py b/pymc/logprob/utils.py index 18f9b803e7..c44e88a500 100644 --- a/pymc/logprob/utils.py +++ b/pymc/logprob/utils.py @@ -210,6 +210,24 @@ def indices_from_subtensor(idx_list, indices): ) +def check_potential_measurability(inputs: Tuple[TensorVariable], rv_map_feature): + if any( + ancestor_node + for ancestor_node in walk_model( + inputs, + walk_past_rvs=False, + stop_at_vars=set(rv_map_feature.rv_values), + ) + if ( + ancestor_node.owner + and isinstance(ancestor_node.owner.op, MeasurableVariable) + and ancestor_node not in rv_map_feature.rv_values + ) + ): + return None + return True + + class ParameterValueError(ValueError): """Exception for invalid parameters values in logprob graphs""" diff --git a/scripts/run_mypy.py b/scripts/run_mypy.py index 2ad9c8a6f9..72f7013007 100644 --- a/scripts/run_mypy.py +++ b/scripts/run_mypy.py @@ -29,6 +29,7 @@ pymc/distributions/timeseries.py pymc/distributions/truncated.py pymc/initial_point.py +pymc/logprob/binary.py pymc/logprob/censoring.py pymc/logprob/basic.py pymc/logprob/mixture.py diff --git a/tests/logprob/test_binary.py b/tests/logprob/test_binary.py new file mode 100644 index 0000000000..c2dc692659 --- /dev/null +++ b/tests/logprob/test_binary.py @@ -0,0 +1,102 @@ +# Copyright 2023 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytensor +import pytensor.tensor as pt +import pytest +import scipy.stats as st + +from pytensor import function + +from pymc import logp +from pymc.logprob import factorized_joint_logprob +from pymc.testing import assert_no_rvs + + +@pytest.mark.parametrize( + "comparison_op, exp_logp_true, exp_logp_false", + [ + (pt.lt, st.norm(0, 1).logcdf, st.norm(0, 1).logsf), + (pt.gt, st.norm(0, 1).logsf, st.norm(0, 1).logcdf), + ], +) +def test_continuous_rv_comparison(comparison_op, exp_logp_true, exp_logp_false): + x_rv = pt.random.normal(0, 1) + comp_x_rv = comparison_op(x_rv, 0.5) + + comp_x_vv = comp_x_rv.clone() + + logprob = logp(comp_x_rv, comp_x_vv) + assert_no_rvs(logprob) + + logp_fn = pytensor.function([comp_x_vv], logprob) + + assert np.isclose(logp_fn(0), exp_logp_false(0.5)) + assert np.isclose(logp_fn(1), exp_logp_true(0.5)) + + +@pytest.mark.parametrize( + "comparison_op, exp_logp_true, exp_logp_false", + [ + ( + pt.lt, + lambda x: st.poisson(2).logcdf(x - 1), + lambda x: np.logaddexp(st.poisson(2).logsf(x), st.poisson(2).logpmf(x)), + ), + ( + pt.gt, + st.poisson(2).logsf, + st.poisson(2).logcdf, + ), + ], +) +def test_discrete_rv_comparison(comparison_op, exp_logp_true, exp_logp_false): + x_rv = pt.random.poisson(2) + cens_x_rv = comparison_op(x_rv, 3) + + cens_x_vv = cens_x_rv.clone() + + logprob = logp(cens_x_rv, cens_x_vv) + assert_no_rvs(logprob) + + logp_fn = pytensor.function([cens_x_vv], logprob) + + assert np.isclose(logp_fn(1), exp_logp_true(3)) + assert np.isclose(logp_fn(0), exp_logp_false(3)) + + +def test_potentially_measurable_operand(): + x_rv = pt.random.normal(2) + z_rv = pt.random.normal(x_rv) + y_rv = pt.lt(x_rv, z_rv) + + y_vv = y_rv.clone() + z_vv = z_rv.clone() + + logprob = factorized_joint_logprob({z_rv: z_vv, y_rv: y_vv})[y_vv] + assert_no_rvs(logprob) + + fn = function([z_vv, y_vv], logprob) + z_vv_test = 0.5 + y_vv_test = True + np.testing.assert_array_almost_equal( + fn(z_vv_test, y_vv_test), + st.norm(2, 1).logcdf(z_vv_test), + ) + + with pytest.raises( + NotImplementedError, + match="Logprob method not implemented", + ): + logp(y_rv, y_vv).eval({y_vv: y_vv_test})