Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent local_sum_make_vector from introducing forbidden float64 #659

Merged
merged 9 commits into from
Mar 8, 2024

Conversation

tvwenger
Copy link
Contributor

@tvwenger tvwenger commented Mar 4, 2024

Description

This PR changes the behavior of local_sum_make_vector to skip rewriting whenever the internal accumulator is more precise than both the input/output data and config.floatX. Otherwise, the internal accumulator is added to the graph, which can introduce forbidden float64 to the graph when the user requests config.floatX="float32".

Replaces PR #655 and #656

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@codecov-commenter
Copy link

codecov-commenter commented Mar 4, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 80.82%. Comparing base (d175203) to head (cce1ab0).

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main     #659   +/-   ##
=======================================
  Coverage   80.82%   80.82%           
=======================================
  Files         162      162           
  Lines       46820    46822    +2     
  Branches    11438    11439    +1     
=======================================
+ Hits        37844    37846    +2     
  Misses       6725     6725           
  Partials     2251     2251           
Files Coverage Δ
pytensor/tensor/rewriting/basic.py 94.07% <100.00%> (+0.02%) ⬆️

@tvwenger
Copy link
Contributor Author

tvwenger commented Mar 4, 2024

@ricardoV94 I implemented your less-restrictive approach and updated the tests. This PR is ready to review/merge

@tvwenger
Copy link
Contributor Author

tvwenger commented Mar 5, 2024

@ricardoV94 new tests implemented and ready for review/merge

@tvwenger
Copy link
Contributor Author

tvwenger commented Mar 6, 2024

@ricardoV94 Just a reminder that this PR is ready for review! Thanks for your help.

@tvwenger tvwenger requested a review from ricardoV94 March 6, 2024 17:17
tests/tensor/rewriting/test_basic.py Outdated Show resolved Hide resolved
@tvwenger tvwenger requested a review from ricardoV94 March 7, 2024 21:11
@tvwenger tvwenger requested a review from ricardoV94 March 7, 2024 21:23
tests/tensor/rewriting/test_basic.py Outdated Show resolved Hide resolved
tests/tensor/rewriting/test_basic.py Outdated Show resolved Hide resolved
@tvwenger tvwenger requested a review from ricardoV94 March 7, 2024 21:45
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for not only fixing the issue but sticking around to also improve the tests!

@ricardoV94 ricardoV94 added bug Something isn't working graph rewriting labels Mar 7, 2024
@ricardoV94 ricardoV94 merged commit 4ee3588 into pymc-devs:main Mar 8, 2024
53 checks passed
@tvwenger tvwenger deleted the rewrite_acc_dtype branch March 8, 2024 19:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working graph rewriting
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants