-
Notifications
You must be signed in to change notification settings - Fork 354
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Elementwise pow op #1133
Elementwise pow op #1133
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See my comment bellow
burn-tch/src/ops/base.rs
Outdated
tensor, | ||
exponent, | ||
|lhs, rhs| lhs.f_pow_tensor_(rhs).unwrap(), | ||
|lhs, rhs| rhs.f_pow_tensor_(lhs).unwrap(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it's equivalent, we should use the readonly version in this case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I resolved it, but just to clarify, you mean switching to f_pow
(which takes &self
rather than &mut self
) right?
@@ -1589,6 +1589,50 @@ where | |||
/// For calculating abs of the elements of a tensor, users should prefer the [Tensor::abs](Tensor::abs) function, | |||
/// which is more high-level and designed for public use. | |||
fn abs<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<D>; | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this file, you should actually add the user API operation implementations where the static dispatch is done. This isn't inside a trait, but the implementation uses the Kind to dispatch to the right implementation.
Example with powf_scalar:
/// Applies element wise powf operation with a scalar.
///
/// `y = x + s`
pub fn powf_scalar<E: ElementConversion>(self, other: E) -> Self {
Self::new(K::powf_scalar(self.primitive, other))
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for this PR. Just some minor comments, but should be good to go pretty soon. Tagging @louisfd to review as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I double-checked the backward pass and it's well done!
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## main #1133 +/- ##
==========================================
- Coverage 86.01% 85.96% -0.06%
==========================================
Files 522 522
Lines 58674 59179 +505
==========================================
+ Hits 50469 50873 +404
- Misses 8205 8306 +101 ☔ View full report in Codecov by Sentry. |
This is still in a draft state, I'll fill out the checklist once it's further along
Pull Request Template
Checklist
run-checks all
script has been executed.Related Issues/PRs
Provide links to relevant issues and dependent PRs.
Changes
In progress addition of pow function. Assumes that the type conversion when rhs is differently typed than lhs/output, happens to rhs prior to the operation
Testing
I've created the python files, but I haven't yet generated the ONNX models. If someone beats me to it, their is a slight difference in spec between pytorch and onnx regarding pow. for pytorch, it seems it always outputs a float tensor, so the pow_int needs to convert the output post op