-
Notifications
You must be signed in to change notification settings - Fork 955
added support for shrink axis mask for StridedSlice op #1201
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed 5 of 7 files at r1.
Reviewable status: 0 of 1 approvals obtained (waiting on @pyu10055 and @dsmilkov)
src/ops/slice_util.ts, line 57 at r1 (raw file):
startIndex[i] = startForAxis(beginMask, begin, strides, shape, i); endIndex[i] = stopForAxis(endMask, end, strides, shape, i); // When shrinking an axis, user startIndex + 1 for endIndex.
typo: user => use ?
src/ops/strided_slice.ts, line 58 at r1 (raw file):
function stridedSlice_<T extends Tensor>( x: T|TensorLike, begin: number[], end: number[], strides: number[], beginMask = 0, endMask = 0, shrinkAxisMask = 0): T {
To protect from breaking the API in the future when we add support for the other masks, add ellipsisMask
and newAxisMask
before shrinkAxisMark
and default them to 0. Then throw an error if their values are other than zero saying 'newAxisMark != 0 is not yet supported'. Likewise for elipsisMask
src/ops/strided_slice_test.ts, line 26 at r1 (raw file):
const tensor = tf.tensor1d([0, 1, 2, 3]); const output = tf.stridedSlice(tensor, [0], [3], [2]); expect(output.shape).toEqual([2]);
the merge is strange here, since you dropped 3 unit tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 tiny comment - add the elipsisMask and newAxisMask to the backend.ts as well so we can easily enable those for the nodejs backend later
…w/tfjs-core into strided_slice_shrink_axis_mask
…w/tfjs-core into strided_slice_shrink_axis_mask
Description
This is required for the SSD model, which has StridedSlice op with shrinkAxisMask set.
For repository owners only:
Please remember to apply all applicable tags to your pull request.
Tags: FEATURE, BREAKING, BUG, PERF, DEV, DOC, SECURITY
For more info see: https://github.com/tensorflow/tfjs/blob/master/DEVELOPMENT.md
This change is