Skip to content

Commit

Permalink
Derive logprob of less and greater than comparisons (#6662)
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyas3156 committed Apr 19, 2023
1 parent f2bb88b commit 9b712bf
Show file tree
Hide file tree
Showing 7 changed files with 237 additions and 14 deletions.
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ jobs:
tests/distributions/test_truncated.py
tests/logprob/test_abstract.py
tests/logprob/test_basic.py
tests/logprob/test_binary.py
tests/logprob/test_censoring.py
tests/logprob/test_composite_logprob.py
tests/logprob/test_cumsum.py
Expand Down
1 change: 1 addition & 0 deletions pymc/logprob/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

# isort: off
# Add rewrites to the DBs
import pymc.logprob.binary
import pymc.logprob.censoring
import pymc.logprob.cumsum
import pymc.logprob.checks
Expand Down
111 changes: 111 additions & 0 deletions pymc/logprob/binary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright 2023 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional

import numpy as np
import pytensor.tensor as pt

from pytensor.graph.basic import Node
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.scalar.basic import GT, LT
from pytensor.tensor.math import gt, lt

from pymc.logprob.abstract import (
MeasurableElemwise,
MeasurableVariable,
_logcdf_helper,
_logprob,
_logprob_helper,
)
from pymc.logprob.rewriting import measurable_ir_rewrites_db
from pymc.logprob.utils import check_potential_measurability, ignore_logprob


class MeasurableComparison(MeasurableElemwise):
"""A placeholder used to specify a log-likelihood for a binary comparison RV sub-graph."""

valid_scalar_types = (GT, LT)


@node_rewriter(tracks=[gt, lt])
def find_measurable_comparisons(
fgraph: FunctionGraph, node: Node
) -> Optional[List[MeasurableComparison]]:
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
if rv_map_feature is None:
return None # pragma: no cover

if isinstance(node.op, MeasurableComparison):
return None # pragma: no cover

(compared_var,) = node.outputs
base_var, const = node.inputs

if not (
base_var.owner
and isinstance(base_var.owner.op, MeasurableVariable)
and base_var not in rv_map_feature.rv_values
):
return None

# check for potential measurability of const
if not check_potential_measurability((const,), rv_map_feature):
return None

# Make base_var unmeasurable
unmeasurable_base_var = ignore_logprob(base_var)

compared_op = MeasurableComparison(node.op.scalar_op)
compared_rv = compared_op.make_node(unmeasurable_base_var, const).default_output()
compared_rv.name = compared_var.name
return [compared_rv]


measurable_ir_rewrites_db.register(
"find_measurable_comparisons",
find_measurable_comparisons,
"basic",
"comparison",
)


@_logprob.register(MeasurableComparison)
def comparison_logprob(op, values, base_rv, operand, **kwargs):
(value,) = values

base_rv_op = base_rv.owner.op

logcdf = _logcdf_helper(base_rv, operand, **kwargs)
logccdf = pt.log1mexp(logcdf)

condn_exp = pt.eq(value, np.array(True))

if isinstance(op.scalar_op, GT):
logprob = pt.switch(condn_exp, logccdf, logcdf)
elif isinstance(op.scalar_op, LT):
if base_rv.dtype.startswith("int"):
logpmf = _logprob_helper(base_rv, operand, **kwargs)
logcdf_lt_true = _logcdf_helper(base_rv, operand - 1, **kwargs)
logprob = pt.switch(condn_exp, logcdf_lt_true, pt.logaddexp(logccdf, logpmf))
else:
logprob = pt.switch(condn_exp, logcdf, logccdf)
else:
raise TypeError(f"Unsupported scalar_op {op.scalar_op}")

if base_rv_op.name:
logprob.name = f"{base_rv_op}_logprob"
logcdf.name = f"{base_rv_op}_logcdf"

return logprob
17 changes: 3 additions & 14 deletions pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
_logprob_helper,
)
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db
from pymc.logprob.utils import ignore_logprob, walk_model
from pymc.logprob.utils import check_potential_measurability, ignore_logprob


class TransformedVariable(Op):
Expand Down Expand Up @@ -573,19 +573,8 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
# Check that other inputs are not potentially measurable, in which case this rewrite
# would be invalid
other_inputs = tuple(inp for inp in node.inputs if inp is not measurable_input)
if any(
ancestor_node
for ancestor_node in walk_model(
other_inputs,
walk_past_rvs=False,
stop_at_vars=set(rv_map_feature.rv_values),
)
if (
ancestor_node.owner
and isinstance(ancestor_node.owner.op, MeasurableVariable)
and ancestor_node not in rv_map_feature.rv_values
)
):

if not check_potential_measurability(other_inputs, rv_map_feature):
return None

# Make base_measure outputs unmeasurable
Expand Down
18 changes: 18 additions & 0 deletions pymc/logprob/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,24 @@ def indices_from_subtensor(idx_list, indices):
)


def check_potential_measurability(inputs: Tuple[TensorVariable], rv_map_feature):
if any(
ancestor_node
for ancestor_node in walk_model(
inputs,
walk_past_rvs=False,
stop_at_vars=set(rv_map_feature.rv_values),
)
if (
ancestor_node.owner
and isinstance(ancestor_node.owner.op, MeasurableVariable)
and ancestor_node not in rv_map_feature.rv_values
)
):
return None
return True


class ParameterValueError(ValueError):
"""Exception for invalid parameters values in logprob graphs"""

Expand Down
1 change: 1 addition & 0 deletions scripts/run_mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
pymc/distributions/timeseries.py
pymc/distributions/truncated.py
pymc/initial_point.py
pymc/logprob/binary.py
pymc/logprob/censoring.py
pymc/logprob/basic.py
pymc/logprob/mixture.py
Expand Down
102 changes: 102 additions & 0 deletions tests/logprob/test_binary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright 2023 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytensor
import pytensor.tensor as pt
import pytest
import scipy.stats as st

from pytensor import function

from pymc import logp
from pymc.logprob import factorized_joint_logprob
from pymc.testing import assert_no_rvs


@pytest.mark.parametrize(
"comparison_op, exp_logp_true, exp_logp_false",
[
(pt.lt, st.norm(0, 1).logcdf, st.norm(0, 1).logsf),
(pt.gt, st.norm(0, 1).logsf, st.norm(0, 1).logcdf),
],
)
def test_continuous_rv_comparison(comparison_op, exp_logp_true, exp_logp_false):
x_rv = pt.random.normal(0, 1)
comp_x_rv = comparison_op(x_rv, 0.5)

comp_x_vv = comp_x_rv.clone()

logprob = logp(comp_x_rv, comp_x_vv)
assert_no_rvs(logprob)

logp_fn = pytensor.function([comp_x_vv], logprob)

assert np.isclose(logp_fn(0), exp_logp_false(0.5))
assert np.isclose(logp_fn(1), exp_logp_true(0.5))


@pytest.mark.parametrize(
"comparison_op, exp_logp_true, exp_logp_false",
[
(
pt.lt,
lambda x: st.poisson(2).logcdf(x - 1),
lambda x: np.logaddexp(st.poisson(2).logsf(x), st.poisson(2).logpmf(x)),
),
(
pt.gt,
st.poisson(2).logsf,
st.poisson(2).logcdf,
),
],
)
def test_discrete_rv_comparison(comparison_op, exp_logp_true, exp_logp_false):
x_rv = pt.random.poisson(2)
cens_x_rv = comparison_op(x_rv, 3)

cens_x_vv = cens_x_rv.clone()

logprob = logp(cens_x_rv, cens_x_vv)
assert_no_rvs(logprob)

logp_fn = pytensor.function([cens_x_vv], logprob)

assert np.isclose(logp_fn(1), exp_logp_true(3))
assert np.isclose(logp_fn(0), exp_logp_false(3))


def test_potentially_measurable_operand():
x_rv = pt.random.normal(2)
z_rv = pt.random.normal(x_rv)
y_rv = pt.lt(x_rv, z_rv)

y_vv = y_rv.clone()
z_vv = z_rv.clone()

logprob = factorized_joint_logprob({z_rv: z_vv, y_rv: y_vv})[y_vv]
assert_no_rvs(logprob)

fn = function([z_vv, y_vv], logprob)
z_vv_test = 0.5
y_vv_test = True
np.testing.assert_array_almost_equal(
fn(z_vv_test, y_vv_test),
st.norm(2, 1).logcdf(z_vv_test),
)

with pytest.raises(
NotImplementedError,
match="Logprob method not implemented",
):
logp(y_rv, y_vv).eval({y_vv: y_vv_test})

0 comments on commit 9b712bf

Please sign in to comment.