Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/squeeze dims #1779

Merged
merged 11 commits into from
May 22, 2024
Merged

Feat/squeeze dims #1779

merged 11 commits into from
May 22, 2024

Conversation

agelas
Copy link
Contributor

@agelas agelas commented May 19, 2024

Pull Request Template

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

#1780

Changes

Adds a new squeeze_dims function that's more spec compliant with ONNX and PyTorch. Specifically, users can pass in multiple dimensions, negative indices, or no dimensions at all, in which case the function will squeeze all singleton dimensions. The burn-import crate has also been updated to use this.

Testing

Added unit tests to burn-tensor

Copy link

codecov bot commented May 19, 2024

Codecov Report

Attention: Patch coverage is 95.76271% with 5 lines in your changes are missing coverage. Please review.

Project coverage is 86.41%. Comparing base (9c5b07c) to head (9dbb294).
Report is 2 commits behind head on main.

Files Patch % Lines
crates/burn-tensor/src/tests/ops/squeeze.rs 89.13% 5 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1779      +/-   ##
==========================================
+ Coverage   84.91%   86.41%   +1.49%     
==========================================
  Files         756      737      -19     
  Lines       87403    85977    -1426     
==========================================
+ Hits        74218    74295      +77     
+ Misses      13185    11682    -1503     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

if dim_indices.contains(&index) && dim_size == 1 {
check!(TensorCheck::squeeze::<D2>(index, &current_dims));
continue;
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nathanielsimard this check doesn't do much of anything right now since the ==1 basically means there's no chance of it erroring out. The squeeze() function will panic if you ask it to squeeze a dimension that isn't 1. For this, I'm not sure if you want to keep the same behavior as squeeze() or not. It looks like the ONNX spec says it will raise an error if you try squeezing a non-1 dimension. But PyTorch would just leave that dimension unchanged.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my opinion it should behave like our squeeze, so panic on non-1 dimensions

Copy link
Member

@louisfd louisfd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @agelas
Thanks a lot, it's very well documented and tested. I think we should panic on non-1 dimensions; that would be the only thing to change.

if dim_indices.contains(&index) && dim_size == 1 {
check!(TensorCheck::squeeze::<D2>(index, &current_dims));
continue;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my opinion it should behave like our squeeze, so panic on non-1 dimensions

}

/// Test to make sure the function doesn't do anything if a non-singleton dimension is squeezed
#[test]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following my other comment, this test should panic

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I'll make the change.

@agelas agelas requested a review from louisfd May 21, 2024 19:53
Copy link
Member

@louisfd louisfd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@louisfd louisfd merged commit 81ecd14 into tracel-ai:main May 22, 2024
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants