Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Derive logprob for hyperbolic and error transformations #6664

Merged
merged 15 commits into from
Apr 28, 2023

Conversation

LukeLB
Copy link
Contributor

@LukeLB LukeLB commented Apr 10, 2023

What is this PR about?
I have implemented additional Elemwise transformations as suggested in issue #6631. Specifically, this pull request adds cosh, sinh, tanh, erf, erfc, and erfcx functions. I plan to address the other suggested transformations in a separate pull request, as they require a more significant rewrite of existing functions. However, if it is preferred to include them all in one pull request, I'm happy to do so.

Please note that this is still a work in progress, and I have not yet written any tests for the new Transforms. I would appreciate some guidance on how to design these tests as its not clear to me what I should be testing them against.

Also for the erfcx transform it would be great double check my math is correct, for the backward I have rewrote a matlab function and for the log jacobian determinant I used wolfram alpha to get the derivative of erfcx.
...

Checklist

Major / Breaking Changes

  • New elemwise transforms
  • Cleaned up the if block in find_measureable_transforms() as it was getting quite large

New features

Transforms for:

  • cosh
  • sinh
  • tanh
  • erf
  • erfc
  • erfcx

Bugfixes

  • NA

Documentation

  • Haven't added any in this PR

Maintenance

  • NA

📚 Documentation preview 📚: https://pymc--6664.org.readthedocs.build/en/6664/

@codecov
Copy link

codecov bot commented Apr 10, 2023

Codecov Report

Merging #6664 (11d41db) into main (b7764dd) will increase coverage by 0.03%.
The diff coverage is 88.73%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6664      +/-   ##
==========================================
+ Coverage   91.96%   92.00%   +0.03%     
==========================================
  Files          94       95       +1     
  Lines       15927    16101     +174     
==========================================
+ Hits        14647    14813     +166     
- Misses       1280     1288       +8     
Impacted Files Coverage Δ
pymc/logprob/transforms.py 94.95% <88.73%> (-0.93%) ⬇️

... and 12 files with indirect coverage changes

@ricardoV94
Copy link
Member

ricardoV94 commented Apr 11, 2023

Please note that this is still a work in progress, and I have not yet written any tests for the new Transforms. I would appreciate some guidance on how to design these tests as its not clear to me what I should be testing them against.

Before I have just tested the logp matches against equivalent RVs forms such as abs(normal) == halfnormal, but that won't work here AFAICT :)

So then it boils down to:

  1. testing the new transforms have the right log_jac_det, which you can do with resource to

    def check_jacobian_det(

  2. Test the logprob derivation is working, something like:

def test_erf_logp():
  base_rv = pt.random.normal(0.5, 1, name="base_rv")  # Something not centered around 0 is usually better
  rv = pt.erf(base_rv)
  vv = rv.clone()

  rv_logp = logp(rv, vv)
  assert_no_rvs(rv_logp)

  transform = ErfTransform
  expected_logp = logp(rv, transform.backward(vv)) + transform.log_jac_det(vv)

  vv_test = np.array(0.25)  # Arbitrary test value
  np.testing.assert_almost_equal(
    rv_logp.eval({vv: vv_test}),
    expected_logp.eval({vv: vv_test}),
  )

You can probably parametrize and test all new functions with the same test.


Alternatively you can try to hijack

def test_transformed_logprob(at_dist, dist_params, sp_dist, size):

The test now assumes you are testing only _default_transform, but you could make it accept a non default transform. Everything else should work the same?

@LukeLB
Copy link
Contributor Author

LukeLB commented Apr 15, 2023

@ricardoV94 I've been working on 2. and the test now runs however its throwing an assertion error. The test is:

@pytest.mark.parametrize("transform", [ErfTransform])
def test_erf_logp(transform):
    base_rv = pt.random.normal(0.5, 1, name="base_rv")  # Something not centered around 0 is usually better
    rv = pt.erf(base_rv)
    vv = rv.clone()
    rv_logp = joint_logprob({rv: vv})

    transform = transform()
    expected_logp = joint_logprob({rv: transform.backward(vv)}) + transform.log_jac_det(vv)

    vv_test = np.array(0.25)  # Arbitrary test value
    np.testing.assert_almost_equal(
    rv_logp.eval({vv: vv_test}),
    expected_logp.eval({vv: vv_test}),
    )

This gives the assertion error:

E       AssertionError: 
E       Arrays are not almost equal to 7 decimals
E       
E       Mismatched elements: 1 / 1 (100%)
E       Max absolute difference: 0.06346299
E       Max relative difference: 0.07601085
E        x: array(-0.898383)
E        y: array(-0.83492)

They're close but still quite a difference. I'm not sure if this is the way I've written the logp in the test or that the internal transform functions are wrong. Any ideas?

@ricardoV94
Copy link
Member

Oh my example was wrong. You want to compare with the base_rv + jacobian, not rv + jacobian:

This passes locally:

import numpy as np
import pytensor.tensor as pt
from pymc.logprob.basic import logp
from pymc.logprob.transforms import ErfTransform

base_rv = pt.random.normal(0.5, 1, name="base_rv")  # Something not centered around 0 is usually better
rv = pt.erf(base_rv)

vv = rv.clone()
rv_logp = logp(rv, vv)

transform = ErfTransform()
expected_logp = logp(base_rv, transform.backward(vv)) + transform.log_jac_det(vv)

vv_test = np.array(0.25)  # Arbitrary test value
np.testing.assert_almost_equal(
    rv_logp.eval({vv: vv_test}),
    expected_logp.eval({vv: vv_test}),
)

@LukeLB
Copy link
Contributor Author

LukeLB commented Apr 18, 2023

Cheers will make the change!

@LukeLB
Copy link
Contributor Author

LukeLB commented Apr 19, 2023

Oh my example was wrong. You want to compare with the base_rv + jacobian, not rv + jacobian:

This passes locally:

import numpy as np
import pytensor.tensor as pt
from pymc.logprob.basic import logp
from pymc.logprob.transforms import ErfTransform

base_rv = pt.random.normal(0.5, 1, name="base_rv")  # Something not centered around 0 is usually better
rv = pt.erf(base_rv)

vv = rv.clone()
rv_logp = logp(rv, vv)

transform = ErfTransform()
expected_logp = logp(base_rv, transform.backward(vv)) + transform.log_jac_det(vv)

vv_test = np.array(0.25)  # Arbitrary test value
np.testing.assert_almost_equal(
    rv_logp.eval({vv: vv_test}),
    expected_logp.eval({vv: vv_test}),
)

@LukeLB LukeLB closed this Apr 19, 2023
@LukeLB LukeLB reopened this Apr 19, 2023
@LukeLB
Copy link
Contributor Author

LukeLB commented Apr 19, 2023

Woops clicked the wrong button didn't mean to close!

So test 2. now works for all Transforms however I'm having an issue with the check_jacobian_det which gives an assertion error for all transforms. Does this suggest that the math is wrong?

Note for test 2. I had to make changes to the test by adding a switch statement and editing the switch statement on line 416 in transforms.py to take input_logprob AND jacobian because of the descrepency of returning nans vs. -infs as if input_logprob is nan then this also returns nan and not -inf.

@ricardoV94
Copy link
Member

ricardoV94 commented Apr 20, 2023

Woops clicked the wrong button didn't mean to close!

No problem :)

So test 2. now works for all Transforms however I'm having an issue with the check_jacobian_det which gives an assertion error for all transforms. Does this suggest that the math is wrong?

I think there must have been an error in your log_jac_det expression. I tweaked the default implementation in RVTransform so that it works for both elemwise and vector transforms, and (after allowing the test to accept nan) it passes, whereas with your hand-written implementation it did not.

I think it's fine to use the default implementation (for the cases where it works). I didn't try to find what was the error.

@ricardoV94 ricardoV94 marked this pull request as draft April 20, 2023 08:26
@LukeLB
Copy link
Contributor Author

LukeLB commented Apr 20, 2023

Great, I'll take a look at what you did and see if I can try and implement it with the other transforms.

@LukeLB
Copy link
Contributor Author

LukeLB commented Apr 20, 2023

OK all done, all tests pass.

@LukeLB LukeLB marked this pull request as ready for review April 27, 2023 07:48
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. I don't know why the coverage shows some of the new transforms not being covered, just a fluke?

I just have a question about a change below.

@@ -391,7 +419,7 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa
jacobian = jacobian.sum(axis=tuple(range(-ndim_supp, 0)))

# The jacobian is used to ensure a value in the supported domain was provided
return pt.switch(pt.isnan(jacobian), -np.inf, input_logprob + jacobian)
return pt.switch(pt.isnan(input_logprob + jacobian), -np.inf, input_logprob + jacobian)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we revert this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potentially let me check. The reason I did that was because it meant that we return -np.inf consistently when input_logprob = nan, which is the case for some of the transforms.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay seems that isn't the case anymore and tests pass with the reverted change

@ricardoV94 ricardoV94 changed the title Additional Elemwise Transformations WIP Derive logprob for hyperbolic and error transformations Apr 28, 2023
LukeLB and others added 2 commits April 28, 2023 09:36
Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
@ricardoV94 ricardoV94 merged commit d4bb701 into pymc-devs:main Apr 28, 2023
21 checks passed
@ricardoV94
Copy link
Member

Awesome work @LukeLB! Looking forward to your next PR :)

@LukeLB
Copy link
Contributor Author

LukeLB commented Apr 28, 2023

Thanks @ricardoV94 it's been a pleasure! Thanks for reviewing :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants