-
Notifications
You must be signed in to change notification settings - Fork 955
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed 2 of 2 files at r1.
Reviewable status: 0 of 1 approvals obtained (waiting on @caisq and @nsthorat)
src/ops/reduction_ops.ts, line 223 at r1 (raw file):
return { $x: () => { const dx = dy.mul(xOrig.equal(y).cast(dy.dtype));
dy.dtype is always float32
src/ops/reduction_ops_test.ts, line 200 at r1 (raw file):
[[[0, 20], [20, 20]], [[-10, -30], [-10, -30]]] ]); const dy = tf.tensor4d([[[[-1]], [[-1]]], [[[-1]], [[-1]]]]);
small nit, can you make the values of dy be -1, -2, -3, -4 etc here and above? this will ensure that the first value isn't being placed everywhere
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status: complete! 1 of 1 approvals obtained (waiting on @caisq)
src/ops/reduction_ops.ts, line 223 at r1 (raw file):
Previously, nsthorat (Nikhil Thorat) wrote…
dy.dtype is always float32
It seems to be less magical and potentially more maintainable to use dy.dtype
here. Also, I don't think that call is particularly expensive.
src/ops/reduction_ops_test.ts, line 200 at r1 (raw file):
Previously, nsthorat (Nikhil Thorat) wrote…
small nit, can you make the values of dy be -1, -2, -3, -4 etc here and above? this will ensure that the first value isn't being placed everywhere
Done.
Fixes tensorflow/tfjs#677
This change is