Skip to content

Conversation

@tafsiri
Copy link
Contributor

@tafsiri tafsiri commented Jun 26, 2020

Also move norm and max tests into their own files.

To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.


This change is Reviewable

@tafsiri
Copy link
Contributor Author

tafsiri commented Jun 29, 2020

Modularizes the rest of the reduction ops and makes the max kernels handle the keepDims attr across backends.

It also fixes a bug that was present in the max kernel in wasm where if an internal transpose was used the input tensor would be changed after the op completed. I added a test for that to max and min, but a general thing to watch out for is calling .set on the typedarray for a models inputs.

@tafsiri tafsiri marked this pull request as ready for review June 29, 2020 15:46
@tafsiri tafsiri requested review from annxingyuan and removed request for annxingyuan June 29, 2020 15:46
@tafsiri tafsiri requested review from annxingyuan and lina128 June 29, 2020 15:46
Copy link
Contributor

@annxingyuan annxingyuan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: :shipit: complete! 1 of 1 approvals obtained (waiting on @annxingyuan, @lina128, and @tafsiri)


tfjs-core/src/kernel_names.ts, line 386 at r4 (raw file):

export const Mean = 'Mean';
export type MeanInputs = Pick<NamedTensorInfoMap, 'x'>;
export interface MeanAttrs {

Maybe would make sense to have ReductionInputs and ReductionAttrs (similar to BinaryInputs) given that they share the same interface.


tfjs-core/src/ops/arg_max.ts, line 62 at r4 (raw file):

    const permutedAxes = axis_util.getAxesPermutation(axes, $x.rank);
    if (permutedAxes != null) {
      $x = $x.transpose(permutedAxes);

remove chained op


tfjs-core/src/ops/arg_max.ts, line 66 at r4 (raw file):

    }
    const res = backend.argMax($x, axes[0]);
    return res;

nit: return backend.argMax


tfjs-core/src/ops/arg_min.ts, line 66 at r4 (raw file):

    const permutedAxes = axis_util.getAxesPermutation(axes, $x.rank);
    if (permutedAxes != null) {
      $x = $x.transpose(permutedAxes);

remove chained op


tfjs-core/src/ops/mean.ts, line 80 at r4 (raw file):

    const value = sum(res, axis, keepDims);

    const gradFunc = (dy: Tensor) => {

Would we want to eventually modularize custom grad functions as well?


tfjs-core/src/ops/sum.ts, line 62 at r4 (raw file):

  let $x = convertToTensor(x, 'x', 'sum');
  if ($x.dtype === 'bool') {
    $x = $x.toInt();

Sorry if we already discussed this - did we consider whether casting should happen at the kernel level?


tfjs-core/src/ops/sum.ts, line 73 at r4 (raw file):

    let permutedX = $x;
    if (permutation != null) {
      permutedX = $x.transpose(permutation);

remove chained op


tfjs-core/src/ops/sum.ts, line 79 at r4 (raw file):

    if (keepDims) {
      const newShape = expandShapeToKeepDim(value.shape, axes);
      value = value.reshape(newShape);

remove chained


tfjs-core/src/public/chained_ops/arg_max.ts, line 29 at r4 (raw file):

Tensor.prototype.argMax = function<T extends Tensor>(axis?: number): T {
  this.throwIfDisposed();
  // tslint:disable-next-line: no-unnecessary-type-assertion

could you include a comment about why the assertion is nevertheless needed? curious why min seems to be the exception in that it doesn't need an assertion?

Copy link
Contributor Author

@tafsiri tafsiri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed 52 of 52 files at r1, 3 of 3 files at r2, 17 of 17 files at r3.
Reviewable status: :shipit: complete! 1 of 1 approvals obtained (waiting on @annxingyuan and @lina128)


tfjs-core/src/kernel_names.ts, line 386 at r4 (raw file):

Previously, annxingyuan (Ann Yuan) wrote…

Maybe would make sense to have ReductionInputs and ReductionAttrs (similar to BinaryInputs) given that they share the same interface.

Argmin and argmax are different but are also reduction ops. I don't personally think its worth it.


tfjs-core/src/ops/arg_max.ts, line 62 at r4 (raw file):

Previously, annxingyuan (Ann Yuan) wrote…

remove chained op

as per discussion in chat. i'm ignoring chained op in forwardFuncs as we plan to delete these


tfjs-core/src/ops/arg_max.ts, line 66 at r4 (raw file):

Previously, annxingyuan (Ann Yuan) wrote…

nit: return backend.argMax

done


tfjs-core/src/ops/arg_min.ts, line 66 at r4 (raw file):

Previously, annxingyuan (Ann Yuan) wrote…

remove chained op

same as above


tfjs-core/src/ops/mean.ts, line 80 at r4 (raw file):

Previously, annxingyuan (Ann Yuan) wrote…

Would we want to eventually modularize custom grad functions as well?

in this case there is no kernel for mean, we just customize the gradient from what would automatically be computed for efficiency. So there isn't anything to register a gradient on. However it opens the question for how someone would override mean, I think the solution to that would be for them to just copy this code and implement their custom gradient.


tfjs-core/src/ops/sum.ts, line 62 at r4 (raw file):

Previously, annxingyuan (Ann Yuan) wrote…

Sorry if we already discussed this - did we consider whether casting should happen at the kernel level?

in this case this kernel (as described in ops.pbtxt) does not support boolean tensors, since it's a tensorflow.js convenience I think we need to leave it in the op.


tfjs-core/src/ops/sum.ts, line 73 at r4 (raw file):

Previously, annxingyuan (Ann Yuan) wrote…

remove chained op

same as above


tfjs-core/src/ops/sum.ts, line 79 at r4 (raw file):

Previously, annxingyuan (Ann Yuan) wrote…

remove chained

same as above (leaving these in forwardFunc)


tfjs-core/src/public/chained_ops/arg_max.ts, line 29 at r4 (raw file):

Previously, annxingyuan (Ann Yuan) wrote…

could you include a comment about why the assertion is nevertheless needed? curious why min seems to be the exception in that it doesn't need an assertion?

I removed the assertion (must have been left over from something else). Thanks!

@tafsiri tafsiri changed the title Mod argMax, argMin, min, mean and sum Modularize argMax, argMin, min, mean and sum Jul 1, 2020
@tafsiri tafsiri merged commit dcd8d6d into master Jul 1, 2020
@tafsiri tafsiri deleted the mod-more-reduction-ops branch July 1, 2020 01:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants