Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Errors in checks and test_checks #6684

Closed
Dhruvanshu-Joshi opened this issue Apr 22, 2023 · 9 comments
Closed

Errors in checks and test_checks #6684

Dhruvanshu-Joshi opened this issue Apr 22, 2023 · 9 comments

Comments

@Dhruvanshu-Joshi
Copy link
Contributor

Description

In the PR 6599, logprob/joint_logp was renamed to logprob/basic for most of the tests that used factorized_joint_logprob. This seems like a simple one line pr but there's more to it. There are some errors in pymc/logprob/checks.py and tests/logprob/test_checks.py.

pymc/pymc/logprob/checks.py

Lines 105 to 150 in a59c9cd

class MeasurableCheckAndRaise(CheckAndRaise):
"""A placeholder used to specify a log-likelihood for an assert sub-graph."""
MeasurableVariable.register(MeasurableCheckAndRaise)
@_logprob.register(MeasurableCheckAndRaise)
def logprob_assert(op, values, inner_rv, *assertion, **kwargs):
(value,) = values
# transfer assertion from rv to value
value = op(assertion, value)
return _logprob_helper(inner_rv, value)
@node_rewriter([CheckAndRaise])
def find_measurable_asserts(fgraph, node) -> Optional[List[MeasurableCheckAndRaise]]:
r"""Finds `AssertOp`\s for which a `logprob` can be computed."""
if isinstance(node.op, MeasurableCheckAndRaise):
return None # pragma: no cover
rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)
if rv_map_feature is None:
return None # pragma: no cover
rv = node.outputs[0]
base_rv, *conds = node.inputs
if not (
base_rv.owner
and isinstance(base_rv.owner.op, MeasurableVariable)
and base_rv not in rv_map_feature.rv_values
):
return None # pragma: no cover
exception_type = ExceptionType()
new_op = MeasurableCheckAndRaise(exc_type=exception_type)
# Make base_var unmeasurable
unmeasurable_base_rv = ignore_logprob(base_rv)
new_rv = new_op.make_node(unmeasurable_base_rv, *conds).default_output()
new_rv.name = rv.name
return [new_rv]

Here , instead of a class, an instance of the class ExceptionType is passed as parameter exc_type in a call to MeasurableCheckAndRaise in line 144.

This can be easily solved by using a class CustomException which is inspired from the way pytensor tests handle it.

class CustomException(ValueError):
    """A custom user-created exception to throw."""

class MeasurableCheckAndRaise(CheckAndRaise):
    """A placeholder used to specify a log-likelihood for an assert sub-graph."""


MeasurableVariable.register(MeasurableCheckAndRaise)


@_logprob.register(MeasurableCheckAndRaise)
def logprob_assert(op, values, inner_rv, *assertion, **kwargs):
    (value,) = values
    # transfer assertion from rv to value
    value = op(assertion, value)
    return _logprob_helper(inner_rv, value)


@node_rewriter([CheckAndRaise])
def find_measurable_asserts(fgraph, node) -> Optional[List[MeasurableCheckAndRaise]]:
    r"""Finds `AssertOp`\s for which a `logprob` can be computed."""

    if isinstance(node.op, MeasurableCheckAndRaise):
        return None  # pragma: no cover

    rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)

    if rv_map_feature is None:
        return None  # pragma: no cover

    rv = node.outputs[0]

    base_rv, *conds = node.inputs

    if not (
        base_rv.owner
        and isinstance(base_rv.owner.op, MeasurableVariable)
        and base_rv not in rv_map_feature.rv_values
    ):
        return None  # pragma: no cover

    exception_type = ExceptionType()
    new_op = MeasurableCheckAndRaise(exc_type = CustomException)
    # Make base_var unmeasurable
    unmeasurable_base_rv = ignore_logprob(base_rv)
    new_rv = new_op.make_node(unmeasurable_base_rv, *conds).default_output()
    new_rv.name = rv.name

    return [new_rv]

This is something which I can handle and link a PR for the same. However, the issue in the test/logprob/test_checks.py
is where I need help.

@welcome
Copy link

welcome bot commented Apr 22, 2023

Welcome Banner
🎉 Welcome to PyMC! 🎉 We're really excited to have your input into the project! 💖

If you haven't done so already, please make sure you check out our Contributing Guidelines and Code of Conduct.

@Dhruvanshu-Joshi
Copy link
Contributor Author

def test_assert_logprob():
rv = pt.random.normal()
assert_op = Assert("Test assert")
# Example: Add assert that rv must be positive
assert_rv = assert_op(rv > 0, rv)
assert_rv.name = "assert_rv"
assert_vv = assert_rv.clone()
assert_logp = factorized_joint_logprob({assert_rv: assert_vv})[assert_vv]
# Check valid value is correct and doesn't raise
# Since here the value to the rv satisfies the condition, no error is raised.
valid_value = 3.0
with pytest.raises(AssertionError, match="Test assert"):
assert_logp.eval({assert_vv: valid_value})
# Check invalid value
# Since here the value to the rv is negative, an exception is raised as the condition is not met
with pytest.raises(AssertionError, match="Test assert"):
assert_logp.eval({assert_vv: -5.0})

Here, an obvious error is in the assert_op where the boolean condition preceeds the rv.

On running the pytest test with the rectified code as the following, the test fails to raise an assertion error.

def test_assert_logprob():
    rv = pt.random.normal()
    assert_op = Assert("Test assert")
    # Example: Add assert that rv must be positive
    assert_rv = assert_op(rv, rv>0)
    assert_rv.name = "assert_rv"

    assert_vv = assert_rv.clone()
    assert_logp = factorized_joint_logprob({assert_rv: assert_vv})[assert_vv]

    # Check valid value is correct and doesn't raise
    # Since here the value to the rv satisfies the condition, no error is raised.
    valid_value = 3.0
    assert_logp.eval({assert_vv: valid_value})

    # Check invalid value
    # Since here the value to the rv is negative, an exception is raised as the condition is not met
    with pytest.raises(AssertionError, match="Test assert"):
        assert_logp.eval({assert_vv: -5.0})
============================= test session starts ==============================
platform linux -- Python 3.10.9, pytest-7.3.1, pluggy-1.0.0 -- /home/deimos/pymc/pymc-dev/bin/python
cachedir: .pytest_cache
rootdir: /home/deimos/Desktop/pymc
configfile: pyproject.toml
plugins: cov-4.0.0
collected 2 items                                                              

tests/logprob/test_checks.py::test_specify_shape_logprob PASSED          [ 50%]
tests/logprob/test_checks.py::test_assert_logprob FAILED                 [100%]

=================================== FAILURES ===================================
_____________________________ test_assert_logprob ______________________________

    def test_assert_logprob():
        rv = pt.random.normal()
        assert_op = Assert("Test assert")
        # Example: Add assert that rv must be positive
        assert_rv = assert_op(rv, rv>0)
        assert_rv.name = "assert_rv"
    
        assert_vv = assert_rv.clone()
        assert_logp = factorized_joint_logprob({assert_rv: assert_vv})[assert_vv]
    
        # Check valid value is correct and doesn't raise
        # Since here the value to the rv satisfies the condition, no error is raised.
        valid_value = 3.0
        assert_logp.eval({assert_vv: valid_value})
    
        # Check invalid value
        # Since here the value to the rv is negative, an exception is raised as the condition is not met
>       with pytest.raises(AssertionError, match="Test assert"):
E       Failed: DID NOT RAISE <class 'AssertionError'>

tests/logprob/test_checks.py:97: Failed
=============================== warnings summary ===============================
../../pymc/pymc-dev/lib/python3.10/site-packages/pkg_resources/__init__.py:121
  /home/deimos/pymc/pymc-dev/lib/python3.10/site-packages/pkg_resources/__init__.py:121: DeprecationWarning: pkg_resources is deprecated as an API
    warnings.warn("pkg_resources is deprecated as an API", DeprecationWarning)

../../pymc/pymc-dev/lib/python3.10/site-packages/pkg_resources/__init__.py:2870
  /home/deimos/pymc/pymc-dev/lib/python3.10/site-packages/pkg_resources/__init__.py:2870: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('mpl_toolkits')`.
  Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
    declare_namespace(pkg)

../../pymc/pymc-dev/lib/python3.10/site-packages/pkg_resources/__init__.py:2870
../../pymc/pymc-dev/lib/python3.10/site-packages/pkg_resources/__init__.py:2870
../../pymc/pymc-dev/lib/python3.10/site-packages/pkg_resources/__init__.py:2870
../../pymc/pymc-dev/lib/python3.10/site-packages/pkg_resources/__init__.py:2870
  /home/deimos/pymc/pymc-dev/lib/python3.10/site-packages/pkg_resources/__init__.py:2870: DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('sphinxcontrib')`.
  Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
    declare_namespace(pkg)

tests/logprob/test_checks.py::test_assert_logprob
  /home/deimos/Desktop/pymc/pymc/logprob/basic.py:255: UserWarning: Found a random variable that was neither among the observations nor the conditioned variables: [normal_rv{0, (0, 0), floatX, False}.out].
  This variables is a clone and does not match the original one on identity.
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=========================== short test summary info ============================
FAILED tests/logprob/test_checks.py::test_assert_logprob - Failed: DID NOT RAISE <class 'AssertionError'>
=================== 1 failed, 1 passed, 7 warnings in 2.17s ====================

@ricardoV94
Copy link
Member

ricardoV94 commented Apr 25, 2023

If I understand correctly there's a bug in the MeasurableCheckAndRaise, but this is not due to #6599? It seems to be present in the original #6538 PR no?

Is there something you couldn't quite figure out about the solution yet? The Exception type should be the same as the one in the original AssertOp. It should be easy to retrieve it from the Op: https://github.com/pymc-devs/pytensor/blob/9fd3af71a9cf455f408dbb32afe4c175d9128432/pytensor/raise_op.py#L46

@Dhruvanshu-Joshi
Copy link
Contributor Author

Yes you are correct. #6599 does not rectify the import for factorized_joint_logprob which causes the test_specify_shape_logprob to fail. In test_assert_logprob, AsertionError is not raised for negative values which causes the test to fail. Will give it another look again and get back to you.

@Dhruvanshu-Joshi
Copy link
Contributor Author

There are some changes I would like to suggest in the MeasurableCheckAndRaise that are different from #6538.

class CustomException(AssertionError):
    """A custom user-created exception to throw."""

class MeasurableCheckAndRaise(CheckAndRaise):
    """A placeholder used to specify a log-likelihood for an assert sub-graph."""


MeasurableVariable.register(MeasurableCheckAndRaise)


@_logprob.register(MeasurableCheckAndRaise)
def logprob_assert(op, values, inner_rv, *assertion, **kwargs):
    (value,) = values
    # transfer assertion from rv to value
    value = op(assertion, value)
    return _logprob_helper(inner_rv, value)


@node_rewriter([CheckAndRaise])
def find_measurable_asserts(fgraph, node) -> Optional[List[MeasurableCheckAndRaise]]:
    r"""Finds `AssertOp`\s for which a `logprob` can be computed."""

    if isinstance(node.op, MeasurableCheckAndRaise):
        return None  # pragma: no cover

    rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)

    if rv_map_feature is None:
        return None  # pragma: no cover

    rv = node.outputs[0]

    base_rv, *conds = node.inputs

    if not (
        base_rv.owner
        and isinstance(base_rv.owner.op, MeasurableVariable)
        and base_rv not in rv_map_feature.rv_values
    ):
        return None  # pragma: no cover

    exception_type = ExceptionType()
    new_op = MeasurableCheckAndRaise(exc_type = CustomException)
    # Make base_var unmeasurable
    unmeasurable_base_rv = ignore_logprob(base_rv)
    new_rv = new_op.make_node(unmeasurable_base_rv, *conds).default_output()
    new_rv.name = rv.name

    return [new_rv]


measurable_ir_rewrites_db.register(
    "find_measurable_asserts",
    find_measurable_asserts,
    "basic",
    "assert",
)

But even these changes does not produces the expected AssertionError for negative rv values in test_checks for test_assert_logprob. However, they do produce the AssertionError on passing 0 instead of -5 here.

with pytest.raises(AssertionError, match="Test assert"):
assert_logp.eval({assert_vv: -5.0})

My guess is there is a problem with the assert_op itself but I am unable to figure out exactly what it is.

@ricardoV94
Copy link
Member

You don't need the custom exception, the Measurable Op should use the same exception already present in the original Op. That ExceptionType() call still there doesn't make sense.

@Dhruvanshu-Joshi
Copy link
Contributor Author

Yeah the call to ExceptionType() is actually redundant and does nothing. I tried replacing CustomException with node.op.exc_type and it works. What concerns me here is that using assert_logp.eval({assert_vv: 0}) raises the expected assertion error but using assert_logp.eval({assert_vv: -5.0}) does not when we specify assert_rv = assert_op(rv, rv>0). Does this mean there is something wrong with the logprob_assert itself?

@ricardoV94
Copy link
Member

That's strange. Do you want to open a PR with the relevant tests / changes so that I can have a look?

@Dhruvanshu-Joshi
Copy link
Contributor Author

yeah sure 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants