Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix tensordot implementation #607

Merged
merged 1 commit into from
Jan 20, 2024
Merged

Conversation

lucianopaz
Copy link
Contributor

@lucianopaz lucianopaz commented Jan 18, 2024

Description

This PR discards the past implementation of tensordot and uses the one from numpy. The past implementation had many problems:

  • It had two completely independent branches of execution so code was not thoroughly testes.
  • It was recursive. If you passed in axes as a sequence of sequences, it tried to do transpose the dimensions and then call tensordot again using axes as an integer. This made things harder to maintain.
  • It had bugs as shown by BUG: tensordot is broken #606.
  • It lost the static shape information.

Taking the implementation from numpy works well. It handles all the cases of axes, it has a single execution branch, and preserves the static shape information. The only downside is that it does not have the logic to implement batched_tensordot. From my point of view, this is a good thing. batched_dot and batched_tensordot were poor-men implementations of what we can now do using Blockwise and better vectorization. I think that those two could be deprecated altogether and eventually removed.

I still haven't added tests for this PR because I want to see the current test suite coverage report. The coverage report showed that there was a test that needed to be adapted, but most of tensordot was already covered by the existing test suite. I added special tests for #606, static shape and runtime shape validation.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@lucianopaz lucianopaz added the bug Something isn't working label Jan 18, 2024
@lucianopaz lucianopaz marked this pull request as ready for review January 18, 2024 14:58
@lucianopaz lucianopaz force-pushed the tensordot branch 2 times, most recently from 1e2e7ae to 15bc7ce Compare January 18, 2024 21:10
@codecov-commenter
Copy link

codecov-commenter commented Jan 18, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (f799219) 80.87% compared to head (b7402d6) 80.86%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #607      +/-   ##
==========================================
- Coverage   80.87%   80.86%   -0.01%     
==========================================
  Files         162      162              
  Lines       46680    46743      +63     
  Branches    11408    11419      +11     
==========================================
+ Hits        37751    37800      +49     
- Misses       6699     6705       +6     
- Partials     2230     2238       +8     
Files Coverage Δ
pytensor/tensor/math.py 89.49% <100.00%> (-0.57%) ⬇️

... and 1 file with indirect coverage changes

pytensor/tensor/math.py Outdated Show resolved Hide resolved
tests/tensor/test_math.py Outdated Show resolved Hide resolved
tests/tensor/test_math.py Outdated Show resolved Hide resolved
tests/tensor/test_math.py Outdated Show resolved Hide resolved
tests/tensor/test_math.py Outdated Show resolved Hide resolved
tests/tensor/test_math.py Outdated Show resolved Hide resolved
pytensor/tensor/math.py Outdated Show resolved Hide resolved
pytensor/tensor/math.py Outdated Show resolved Hide resolved
pytensor/tensor/math.py Outdated Show resolved Hide resolved
pytensor/tensor/math.py Outdated Show resolved Hide resolved
@lucianopaz lucianopaz force-pushed the tensordot branch 2 times, most recently from 437be34 to 5f1fc1d Compare January 19, 2024 10:36
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw this looks much better than the previous implementation!

We should open an issue to deprecate the old tensordot as dot utility

pytensor/tensor/math.py Outdated Show resolved Hide resolved
@lucianopaz lucianopaz force-pushed the tensordot branch 2 times, most recently from 1f7321f to 1fe21db Compare January 19, 2024 11:13
@lucianopaz
Copy link
Contributor Author

Btw this looks much better than the previous implementation!

We should open an issue to deprecate the old tensordot as dot utility

I'll open an issue for it

pytensor/tensor/math.py Outdated Show resolved Hide resolved
pytensor/tensor/math.py Outdated Show resolved Hide resolved
@ricardoV94
Copy link
Member

Some small docstring tweak and a possible nitpick in the test (feel free to ignore) and this is good from my side

pytensor/tensor/math.py Outdated Show resolved Hide resolved
@ricardoV94 ricardoV94 merged commit e3fb498 into pymc-devs:main Jan 20, 2024
52 checks passed
@ricardoV94
Copy link
Member

Thanks @lucianopaz !

@lucianopaz lucianopaz deleted the tensordot branch January 20, 2024 07:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

BUG: tensordot is broken
4 participants