-
Notifications
You must be signed in to change notification settings - Fork 2k
Modularize argMax, argMin, min, mean and sum #3517
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Modularizes the rest of the reduction ops and makes the max kernels handle the keepDims attr across backends. It also fixes a bug that was present in the max kernel in wasm where if an internal transpose was used the input tensor would be changed after the op completed. I added a test for that to max and min, but a general thing to watch out for is calling |
annxingyuan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewable status:
complete! 1 of 1 approvals obtained (waiting on @annxingyuan, @lina128, and @tafsiri)
tfjs-core/src/kernel_names.ts, line 386 at r4 (raw file):
export const Mean = 'Mean'; export type MeanInputs = Pick<NamedTensorInfoMap, 'x'>; export interface MeanAttrs {
Maybe would make sense to have ReductionInputs and ReductionAttrs (similar to BinaryInputs) given that they share the same interface.
tfjs-core/src/ops/arg_max.ts, line 62 at r4 (raw file):
const permutedAxes = axis_util.getAxesPermutation(axes, $x.rank); if (permutedAxes != null) { $x = $x.transpose(permutedAxes);
remove chained op
tfjs-core/src/ops/arg_max.ts, line 66 at r4 (raw file):
} const res = backend.argMax($x, axes[0]); return res;
nit: return backend.argMax
tfjs-core/src/ops/arg_min.ts, line 66 at r4 (raw file):
const permutedAxes = axis_util.getAxesPermutation(axes, $x.rank); if (permutedAxes != null) { $x = $x.transpose(permutedAxes);
remove chained op
tfjs-core/src/ops/mean.ts, line 80 at r4 (raw file):
const value = sum(res, axis, keepDims); const gradFunc = (dy: Tensor) => {
Would we want to eventually modularize custom grad functions as well?
tfjs-core/src/ops/sum.ts, line 62 at r4 (raw file):
let $x = convertToTensor(x, 'x', 'sum'); if ($x.dtype === 'bool') { $x = $x.toInt();
Sorry if we already discussed this - did we consider whether casting should happen at the kernel level?
tfjs-core/src/ops/sum.ts, line 73 at r4 (raw file):
let permutedX = $x; if (permutation != null) { permutedX = $x.transpose(permutation);
remove chained op
tfjs-core/src/ops/sum.ts, line 79 at r4 (raw file):
if (keepDims) { const newShape = expandShapeToKeepDim(value.shape, axes); value = value.reshape(newShape);
remove chained
tfjs-core/src/public/chained_ops/arg_max.ts, line 29 at r4 (raw file):
Tensor.prototype.argMax = function<T extends Tensor>(axis?: number): T { this.throwIfDisposed(); // tslint:disable-next-line: no-unnecessary-type-assertion
could you include a comment about why the assertion is nevertheless needed? curious why min seems to be the exception in that it doesn't need an assertion?
tafsiri
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed 52 of 52 files at r1, 3 of 3 files at r2, 17 of 17 files at r3.
Reviewable status:complete! 1 of 1 approvals obtained (waiting on @annxingyuan and @lina128)
tfjs-core/src/kernel_names.ts, line 386 at r4 (raw file):
Previously, annxingyuan (Ann Yuan) wrote…
Maybe would make sense to have
ReductionInputsandReductionAttrs(similar toBinaryInputs) given that they share the same interface.
Argmin and argmax are different but are also reduction ops. I don't personally think its worth it.
tfjs-core/src/ops/arg_max.ts, line 62 at r4 (raw file):
Previously, annxingyuan (Ann Yuan) wrote…
remove chained op
as per discussion in chat. i'm ignoring chained op in forwardFuncs as we plan to delete these
tfjs-core/src/ops/arg_max.ts, line 66 at r4 (raw file):
Previously, annxingyuan (Ann Yuan) wrote…
nit: return backend.argMax
done
tfjs-core/src/ops/arg_min.ts, line 66 at r4 (raw file):
Previously, annxingyuan (Ann Yuan) wrote…
remove chained op
same as above
tfjs-core/src/ops/mean.ts, line 80 at r4 (raw file):
Previously, annxingyuan (Ann Yuan) wrote…
Would we want to eventually modularize custom grad functions as well?
in this case there is no kernel for mean, we just customize the gradient from what would automatically be computed for efficiency. So there isn't anything to register a gradient on. However it opens the question for how someone would override mean, I think the solution to that would be for them to just copy this code and implement their custom gradient.
tfjs-core/src/ops/sum.ts, line 62 at r4 (raw file):
Previously, annxingyuan (Ann Yuan) wrote…
Sorry if we already discussed this - did we consider whether casting should happen at the kernel level?
in this case this kernel (as described in ops.pbtxt) does not support boolean tensors, since it's a tensorflow.js convenience I think we need to leave it in the op.
tfjs-core/src/ops/sum.ts, line 73 at r4 (raw file):
Previously, annxingyuan (Ann Yuan) wrote…
remove chained op
same as above
tfjs-core/src/ops/sum.ts, line 79 at r4 (raw file):
Previously, annxingyuan (Ann Yuan) wrote…
remove chained
same as above (leaving these in forwardFunc)
tfjs-core/src/public/chained_ops/arg_max.ts, line 29 at r4 (raw file):
Previously, annxingyuan (Ann Yuan) wrote…
could you include a comment about why the assertion is nevertheless needed? curious why
minseems to be the exception in that it doesn't need an assertion?
I removed the assertion (must have been left over from something else). Thanks!
Also move norm and max tests into their own files.
To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.
This change is