Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding meta-information for MeasurableOps #6754

Closed

Conversation

Dhruvanshu-Joshi
Copy link
Member

@Dhruvanshu-Joshi Dhruvanshu-Joshi commented Jun 3, 2023

What is this PR about?
This PR aims to solve issue #6360 and is a cotinuation of the PR #6685 by incorporating RV meta information in intermediate MeasurableVariables. The Measurable ops covered are MeasurableComparison, MeasurableClip, MeasurableRound, MeasurableSpecifyShape, MeasurableCheckAndRaise, MeasurableIfElse, MeasurableScan, MeasurableMakeVector, MeasurableJoin, MeasurableDimShuffle, MeasurableTransforms and DiracDelta.
I understand that this PR does not align with the commendable changes made in PR #6746. However, I just want a review from the maintainers on the changes made here and if it agrees with what they had in mind. I'll make all the necessary changes required for this PR to align with #6746 once it is merged.

Checklist

Major / Breaking Changes

  • ...

New features

  • ...

Bugfixes

  • ...

Documentation

  • ...

Maintenance

  • ...

📚 Documentation preview 📚: https://pymc--6754.org.readthedocs.build/en/6754/

@codecov
Copy link

codecov bot commented Jun 3, 2023

Codecov Report

Merging #6754 (7d2fc53) into main (f67ff8b) will increase coverage by 0.01%.
The diff coverage is 98.63%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6754      +/-   ##
==========================================
+ Coverage   92.02%   92.03%   +0.01%     
==========================================
  Files          95       95              
  Lines       16261    16302      +41     
==========================================
+ Hits        14964    15004      +40     
- Misses       1297     1298       +1     
Impacted Files Coverage Δ
pymc/logprob/abstract.py 93.84% <95.23%> (+0.22%) ⬆️
pymc/logprob/binary.py 96.47% <100.00%> (+0.08%) ⬆️
pymc/logprob/censoring.py 98.90% <100.00%> (+0.02%) ⬆️
pymc/logprob/checks.py 98.00% <100.00%> (+0.08%) ⬆️
pymc/logprob/cumsum.py 100.00% <100.00%> (ø)
pymc/logprob/mixture.py 96.92% <100.00%> (+0.02%) ⬆️
pymc/logprob/scan.py 98.57% <100.00%> (+0.07%) ⬆️
pymc/logprob/tensor.py 79.83% <100.00%> (+0.34%) ⬆️
pymc/logprob/transforms.py 94.69% <100.00%> (+0.01%) ⬆️

@Dhruvanshu-Joshi
Copy link
Member Author

Though all the individual tests in tests/logprob run successfully (except the broken test_checks), the mypy tests fail.
One way to minimize them is updating get_default_measurable_metainfo as:

def get_default_measurable_metainfo(base_op: MeasurableVariable, base_dtype) -> Tuple[Any, Union[Tuple[Any, ...], Any], Union[MeasureType, Any]]:
    if not isinstance(base_op, MeasurableVariable):
        raise TypeError("base_op must be a RandomVariable or MeasurableVariable")

    ndim_supp = base_op.ndim_supp

    supp_axes = getattr(base_op, "supp_axes", None)
    if supp_axes is None:
        supp_axes = tuple(range(-base_op.ndim_supp, 0))

    measure_type = getattr(base_op, "measure_type", None)
    if measure_type is None:
        measure_type = (
            MeasureType.Discrete if base_dtype.dtype.startswith("int") else MeasureType.Continuous
        )

    return ndim_supp, supp_axes, measure_type

However, we still face pymc/logprob/abstract.py:271: error: "MeasurableVariable" has no attribute "ndim_supp".
Is this error valid and if yes, how should I tackle this?
If invaid, the run_mypy.py can be easily changed to ignore this.

@Dhruvanshu-Joshi Dhruvanshu-Joshi changed the title Passed meta-information for MeasurableOps Adding meta-information for MeasurableOps Jun 3, 2023
@ricardoV94
Copy link
Member

ricardoV94 commented Jun 5, 2023

@Dhruvanshu-Joshi I pushed a commit that uses multiple inheritance to incorporate the metainfo in the MeasurableVariable subclasses. This avoids having to re-define the __init__ method for each subclass.

I haven't tested at all, and each rewrite should be checked manually to make sure we are doing the right thing. Special attention should be given to Ops with multiple measurable outputs (Scan, IfElse), as we need to preserve the meta-info for each output.

After checking and cleaning up the code, a good test case would be to remove the current limitation on Dimshuffles of non-pure RVs as mentioned in #6360. This should be done in a separate commit! :)

Feel free to ask any questions about the code I pushed (you can leave comments directly on the changed lines here on Github)

@Dhruvanshu-Joshi
Copy link
Member Author

@ricardoV94 I have made some changes so that all the logprob tests pass.
However, mypy tests still fail. They are related to the type-hints specified in the function get_measurable_meta_info. The input to this function is specified to be an Op but in some cases (for cumsum and checks), an instance of Measurable Variable is passed as input. Also the return type also does not match the expected type-hint.

@Dhruvanshu-Joshi
Copy link
Member Author

Also I am working on solving the merge conflicts by referring the PR #6746 and will update if doing so solves the mypy errors.

@ricardoV94
Copy link
Member

The input to this function is specified to be an Op but in some cases (for cumsum and checks), an instance of Measurable Variable is passed as input.

That means we just need to pass the Op instead in those cases, no?

@Dhruvanshu-Joshi
Copy link
Member Author

The input to this function is specified to be an Op but in some cases (for cumsum and checks), an instance of Measurable Variable is passed as input.

That means we just need to pass the Op instead in those cases, no?

Are we not doing that by passing base_rv.owner.op? I checked it and it is dirichlet_rv{1, (1,), floatX, False} in case of SpecifyShape where it is giving an error while running the mypy tests. Also , in case of cumsum.py , no error is raised after updating the code in accordance with the current format. base_rv.owner.op in this case is bernoulli_rv{0, (0,), int64, False} where the mypy tests pass.

Also, for the error in the return type of get_measurable_meta_info, can we change it from the expected Tuple[int, Tuple[int], MeasureType] to actual Tuple[Union[int, Tuple[int]], Tuple[Union[int, Tuple[int]]], Union[MeasureType, Tuple[MeasureType]]]?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants