-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Modularise atan2, floorDiv, mod and mul #3407
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: complete! 1 of 1 approvals obtained (waiting on @annxingyuan, @lina128, and @tafsiri)
tfjs-core/src/gradients/Atan2_grad.ts, line 42 at r1 (raw file):
const reduceAxes = getReductionAxes(a.shape, outShape); if (reduceAxes.length > 0) { res = res.sum(reduceAxes);
avoid chained API
tfjs-core/src/gradients/Atan2_grad.ts, line 44 at r1 (raw file):
res = res.sum(reduceAxes); } return res.reshape(a.shape);
avoid chained API
tfjs-core/src/gradients/FloorDiv_grad.ts, line 31 at r1 (raw file):
const derA = () => { const res = dy.div(b.toFloat());
avoid chained API
tfjs-core/src/gradients/FloorDiv_grad.ts, line 34 at r1 (raw file):
const reduceAxes = getReductionAxes(a.shape, outShape); if (reduceAxes.length > 0) { return res.sum(reduceAxes).reshape(a.shape);
avoid chained API
tfjs-core/src/gradients/FloorDiv_grad.ts, line 39 at r1 (raw file):
}; const derB = () => { let res = dy.mul(a.toFloat());
avoid chained API
tfjs-core/src/gradients/FloorDiv_grad.ts, line 42 at r1 (raw file):
const reduceAxes = getReductionAxes(b.shape, outShape); if (reduceAxes.length > 0) { res = res.sum(reduceAxes).reshape(b.shape);
avoid chained API
tfjs-core/src/gradients/FloorDiv_grad.ts, line 44 at r1 (raw file):
res = res.sum(reduceAxes).reshape(b.shape); } const tmp = b.square();
avoid chained API
tfjs-core/src/gradients/FloorDiv_grad.ts, line 45 at r1 (raw file):
} const tmp = b.square(); return res.div(tmp.toFloat()).neg();
avoid chained API
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 to Ann's comments.
Reviewable status: complete! 2 of 1 approvals obtained (waiting on @lina128 and @tafsiri)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: complete! 2 of 1 approvals obtained (waiting on @annxingyuan)
tfjs-core/src/gradients/Atan2_grad.ts, line 42 at r1 (raw file):
Previously, annxingyuan (Ann Yuan) wrote…
avoid chained API
Done
tfjs-core/src/gradients/Atan2_grad.ts, line 44 at r1 (raw file):
Previously, annxingyuan (Ann Yuan) wrote…
avoid chained API
Done
tfjs-core/src/gradients/FloorDiv_grad.ts, line 31 at r1 (raw file):
Previously, annxingyuan (Ann Yuan) wrote…
avoid chained API
See note in PR description. No tests for floorDiv grad, so i'm not going to modify this code in this PR.
tfjs-core/src/gradients/FloorDiv_grad.ts, line 34 at r1 (raw file):
Previously, annxingyuan (Ann Yuan) wrote…
avoid chained API
Same as above
tfjs-core/src/gradients/FloorDiv_grad.ts, line 39 at r1 (raw file):
Previously, annxingyuan (Ann Yuan) wrote…
avoid chained API
Same as above
tfjs-core/src/gradients/FloorDiv_grad.ts, line 42 at r1 (raw file):
Previously, annxingyuan (Ann Yuan) wrote…
avoid chained API
Same as above
tfjs-core/src/gradients/FloorDiv_grad.ts, line 44 at r1 (raw file):
Previously, annxingyuan (Ann Yuan) wrote…
avoid chained API
Same as above
tfjs-core/src/gradients/FloorDiv_grad.ts, line 45 at r1 (raw file):
Previously, annxingyuan (Ann Yuan) wrote…
avoid chained API
Same as above
@tafsiri oops - sorry I missed your note in the PR description :) |
@annxingyuan i did have to rename Mul.[ts|cc] to Multiply.[ts|cc], however this is a good thing as we did set the convention that the filename should match the kernel name. |
This modularizes the remaining binary ops, atan2, floorDiv, mod and mul.
One note about floorDiv: in doing this I noticed we had lost test coverage, i restored an old test for forward pass, however gradient calc didn't have a test. So I didn't remove the chaining api from FloorDiv grad. A followup PR will add tests for the gradient and then we can more safely make edits to that code.
To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.
This change is