Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modularise atan2, floorDiv, mod and mul #3407

Merged
merged 15 commits into from
Jun 9, 2020
Merged

Modularise atan2, floorDiv, mod and mul #3407

merged 15 commits into from
Jun 9, 2020

Conversation

tafsiri
Copy link
Contributor

@tafsiri tafsiri commented Jun 8, 2020

This modularizes the remaining binary ops, atan2, floorDiv, mod and mul.

One note about floorDiv: in doing this I noticed we had lost test coverage, i restored an old test for forward pass, however gradient calc didn't have a test. So I didn't remove the chaining api from FloorDiv grad. A followup PR will add tests for the gradient and then we can more safely make edits to that code.

To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.


This change is Reviewable

Copy link
Collaborator

@annxingyuan annxingyuan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: :shipit: complete! 1 of 1 approvals obtained (waiting on @annxingyuan, @lina128, and @tafsiri)


tfjs-core/src/gradients/Atan2_grad.ts, line 42 at r1 (raw file):

      const reduceAxes = getReductionAxes(a.shape, outShape);
      if (reduceAxes.length > 0) {
        res = res.sum(reduceAxes);

avoid chained API


tfjs-core/src/gradients/Atan2_grad.ts, line 44 at r1 (raw file):

        res = res.sum(reduceAxes);
      }
      return res.reshape(a.shape);

avoid chained API


tfjs-core/src/gradients/FloorDiv_grad.ts, line 31 at r1 (raw file):

    const derA = () => {
      const res = dy.div(b.toFloat());

avoid chained API


tfjs-core/src/gradients/FloorDiv_grad.ts, line 34 at r1 (raw file):

      const reduceAxes = getReductionAxes(a.shape, outShape);
      if (reduceAxes.length > 0) {
        return res.sum(reduceAxes).reshape(a.shape);

avoid chained API


tfjs-core/src/gradients/FloorDiv_grad.ts, line 39 at r1 (raw file):

    };
    const derB = () => {
      let res = dy.mul(a.toFloat());

avoid chained API


tfjs-core/src/gradients/FloorDiv_grad.ts, line 42 at r1 (raw file):

      const reduceAxes = getReductionAxes(b.shape, outShape);
      if (reduceAxes.length > 0) {
        res = res.sum(reduceAxes).reshape(b.shape);

avoid chained API


tfjs-core/src/gradients/FloorDiv_grad.ts, line 44 at r1 (raw file):

        res = res.sum(reduceAxes).reshape(b.shape);
      }
      const tmp = b.square();

avoid chained API


tfjs-core/src/gradients/FloorDiv_grad.ts, line 45 at r1 (raw file):

      }
      const tmp = b.square();
      return res.div(tmp.toFloat()).neg();

avoid chained API

Copy link
Collaborator

@lina128 lina128 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to Ann's comments.

Reviewable status: :shipit: complete! 2 of 1 approvals obtained (waiting on @lina128 and @tafsiri)

Copy link
Contributor Author

@tafsiri tafsiri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: :shipit: complete! 2 of 1 approvals obtained (waiting on @annxingyuan)


tfjs-core/src/gradients/Atan2_grad.ts, line 42 at r1 (raw file):

Previously, annxingyuan (Ann Yuan) wrote…

avoid chained API

Done


tfjs-core/src/gradients/Atan2_grad.ts, line 44 at r1 (raw file):

Previously, annxingyuan (Ann Yuan) wrote…

avoid chained API

Done


tfjs-core/src/gradients/FloorDiv_grad.ts, line 31 at r1 (raw file):

Previously, annxingyuan (Ann Yuan) wrote…

avoid chained API

See note in PR description. No tests for floorDiv grad, so i'm not going to modify this code in this PR.


tfjs-core/src/gradients/FloorDiv_grad.ts, line 34 at r1 (raw file):

Previously, annxingyuan (Ann Yuan) wrote…

avoid chained API

Same as above


tfjs-core/src/gradients/FloorDiv_grad.ts, line 39 at r1 (raw file):

Previously, annxingyuan (Ann Yuan) wrote…

avoid chained API

Same as above


tfjs-core/src/gradients/FloorDiv_grad.ts, line 42 at r1 (raw file):

Previously, annxingyuan (Ann Yuan) wrote…

avoid chained API

Same as above


tfjs-core/src/gradients/FloorDiv_grad.ts, line 44 at r1 (raw file):

Previously, annxingyuan (Ann Yuan) wrote…

avoid chained API

Same as above


tfjs-core/src/gradients/FloorDiv_grad.ts, line 45 at r1 (raw file):

Previously, annxingyuan (Ann Yuan) wrote…

avoid chained API

Same as above

@annxingyuan
Copy link
Collaborator

@tafsiri oops - sorry I missed your note in the PR description :)

@tafsiri
Copy link
Contributor Author

tafsiri commented Jun 9, 2020

@annxingyuan i did have to rename Mul.[ts|cc] to Multiply.[ts|cc], however this is a good thing as we did set the convention that the filename should match the kernel name.

@tafsiri tafsiri merged commit 05fb9c1 into master Jun 9, 2020
@tafsiri tafsiri deleted the mod-binary-ops branch June 9, 2020 16:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants