New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement MarginRankingLoss as native function and add reduce=True arg to it #5346
Conversation
why doesn't this also have a backward function? |
I used #5080 as a reference for this PR, and a backwards wasn't implemented there as well. Should I also add a backwards? |
The fact that it is moved from python to cpp might be enough to justify ignoring backward. You should either implement it or do a speed benchmark as #5080 . |
|
@ezyang Yeah I have an update, was going to run benchmarks first. |
benchmarks run with inputs of size (1000), 5000 times: forward (old) [0.8390149101614952, 0.8477208158001304, 0.859797858633101] forward (new) [0.7037791842594743, 0.7032447448000312, 0.6909415749832988] |
f43af50
to
81e173c
Compare
@pytorchbot retest this please |
Interesting that double backward slows down, but this is fine given the single backward improvements. |
@pytorchbot retest this please |
9a46f6f
to
ed36fcf
Compare
@pytorchbot retest this please |
1 similar comment
@pytorchbot retest this please |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good in general. But please fix the conflict :)
@@ -38,6 +38,9 @@ | |||
|
|||
- func: chunk(Tensor self, int64_t chunks, int64_t dim=0) -> TensorList | |||
|
|||
- func: margin_ranking_loss(Tensor input1, Tensor input2, Tensor target, double margin, bool size_average, bool reduce) -> Tensor |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Test Plan
test/run_test.sh
Added unit tests for MarginRankingLoss.