Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make coords and data always mutable #7047

Merged
merged 3 commits into from
Apr 3, 2024

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Dec 4, 2023

Closes #6972

This PR provides a new model transform that freezes RV dims that depend on coords as well as mutable data, for those worried about performance issues or incompatibilities with JAX dynamic shape limitations. I expect most users won't need this. The default C-backend doesn't really exploit static shapes. I believe the simpler API for users is a net win.

Note that JAX dynamic shape stuff is not relevant when using JAX samplers because we already replace any shared variables by constants anyways. It's only relevant when compiling PyTensor functions with mode="JAX"

The big picture here is that we define the most general model first, and later specialize if needed. Going from a model with constant shapes to another with different constant shapes is generally not possible because PyTensor eagerly computes static shape outputs for intermediate nodes, and rebuilding with different constant types is not always supported.

Starting with more general models could be quite helpful for producing predictive models automatically.

Note: If there's resistance, this PR can be narrowed down in scope to just remove the distinction between coords_mutable and coords, but still leave MutableData vs ConstantData

@ricardoV94 ricardoV94 added request discussion major Include in major changes release notes section labels Dec 4, 2023
Copy link

codecov bot commented Dec 4, 2023

Codecov Report

Attention: Patch coverage is 31.74603% with 43 lines in your changes are missing coverage. Please review.

Project coverage is 39.54%. Comparing base (81d31c8) to head (28ded31).

❗ Current head 28ded31 differs from pull request most recent head 55693f4. Consider uploading reports for the commit 55693f4 to get more accurate results

Additional details and impacted files

Impacted file tree graph

@@             Coverage Diff             @@
##             main    #7047       +/-   ##
===========================================
- Coverage   92.30%   39.54%   -52.77%     
===========================================
  Files         100      101        +1     
  Lines       16895    16835       -60     
===========================================
- Hits        15595     6657     -8938     
- Misses       1300    10178     +8878     
Files Coverage Δ
pymc/gp/hsgp_approx.py 23.80% <ø> (-71.82%) ⬇️
pymc/sampling/forward.py 14.47% <ø> (-81.43%) ⬇️
pymc/data.py 40.62% <33.33%> (-48.85%) ⬇️
pymc/model/fgraph.py 25.12% <0.00%> (-72.27%) ⬇️
pymc/model_graph.py 14.28% <20.00%> (-63.11%) ⬇️
pymc/model/core.py 57.62% <63.15%> (-34.61%) ⬇️
pymc/model/transform/conditioning.py 15.83% <17.24%> (-79.92%) ⬇️

... and 85 files with indirect coverage changes

@ricardoV94 ricardoV94 force-pushed the coords_always_mutable branch 2 times, most recently from 01deacd to 28ded31 Compare December 4, 2023 13:43
@ricardoV94 ricardoV94 force-pushed the coords_always_mutable branch 2 times, most recently from 4445317 to 546d59e Compare March 27, 2024 14:43
@twiecki
Copy link
Member

twiecki commented Mar 27, 2024

Love it.

@ricardoV94 ricardoV94 merged commit 207821d into pymc-devs:main Apr 3, 2024
21 checks passed
@ricardoV94 ricardoV94 deleted the coords_always_mutable branch April 3, 2024 11:26
@twiecki
Copy link
Member

twiecki commented Apr 3, 2024

Woohoo!

@twiecki
Copy link
Member

twiecki commented Apr 3, 2024

We need to adapt the pymc-examples (and maybe the NBs here too?).

@ricardoV94
Copy link
Member Author

We need to adapt the pymc-examples (and maybe the NBs here too?).

Possibly

@ricardoV94
Copy link
Member Author

It shouldn't fail immediately, just issue a warning.

Only when we remove the kwarg will it fail hard

@lhelleckes
Copy link
Contributor

Hi @ricardoV94!
@michaelosthege and I found out that the latest changes create issues when using pm.ConstantData (or pm.Data) and setting a dtype explicitly. We don't understand why because pytensor.shared has no problem with the dtype argument.

Here is an example:

with pm.Model():
    pm.Data("b", [True, False], dtype=bool)
Traceback (most recent call last):
  File "<stdin>", line 2, in <module>
  File "C:\Users\osthege\AppData\Local\mambaforge\envs\dibecs_6.13.0\lib\site-packages\pymc\data.py", line 420, in Data
    x = pytensor.shared(arr, name, **kwargs)
  File "C:\Users\osthege\AppData\Local\mambaforge\envs\dibecs_6.13.0\lib\site-packages\pytensor\compile\sharedvalue.py", line 202, in shared
    var = shared_constructor(
  File "C:\Users\osthege\AppData\Local\mambaforge\envs\dibecs_6.13.0\lib\functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
TypeError: tensor_constructor() got an unexpected keyword argument 'dtype'

Do you have any idea why this is happening? Thanks in advance for your help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
major Include in major changes release notes section request discussion
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Reassess MutableData and coords_mutable
3 participants