diff --git a/tfjs-core/src/ops/slice_util.ts b/tfjs-core/src/ops/slice_util.ts index 8d2fb86dfaa..096e2966356 100644 --- a/tfjs-core/src/ops/slice_util.ts +++ b/tfjs-core/src/ops/slice_util.ts @@ -65,9 +65,12 @@ export function computeOutShape( // Creates full selection at the elided dimensions. If the dimension matches // the ellipsis mask, override the current stride value. Otherwise, insert. export function stridesWithElidedDims( - strides: number[], ellipsisInsertionIndex: number, - numElidedAxes: number): number[] { + strides: number[], ellipsisInsertionIndex: number, numElidedAxes: number, + inputShape: number[]): number[] { const newStrides = [...strides]; + for (let i = newStrides.length; i < inputShape.length; i++) { + newStrides.push(1); + } for (let i = 0; i < numElidedAxes; i++) { if (i === 0) { newStrides[ellipsisInsertionIndex] = 1; @@ -81,20 +84,44 @@ export function stridesWithElidedDims( return newStrides; } +function unnormalizeAxis( + ellipsisInsertionIndex: number, numElidedAxes: number, + normalizedAxis: number): number { + if (normalizedAxis <= ellipsisInsertionIndex) { + return normalizedAxis; + } + + return normalizedAxis - (numElidedAxes - 1); +} + +function getElidedAxes(numElidedAxes: number, ellipsisInsertionIndex: number) { + const elidedAxes = []; + for (let i = 0; i < numElidedAxes; i++) { + elidedAxes.push(ellipsisInsertionIndex + i); + } + return elidedAxes; +} + // Creates full selection at the elided dimensions. If the dimension matches // the ellipsis mask, override the current start value. Otherwise, insert. export function startIndicesWithElidedDims( - startIndices: number[], ellipsisInsertionIndex: number, - numElidedAxes: number): number[] { - const newIndices = [...startIndices]; - for (let i = 0; i < numElidedAxes; i++) { - if (i === 0) { - newIndices[ellipsisInsertionIndex] = 0; + beginMask: number, ellipsisInsertionIndex: number, numElidedAxes: number, + originalBegin: number[], inputShape: number[]): number[] { + const newIndices = [...inputShape]; + const elidedAxes = getElidedAxes(numElidedAxes, ellipsisInsertionIndex); + + for (let axis = 0; axis < newIndices.length; axis++) { + if (elidedAxes.indexOf(axis) > -1) { + newIndices[axis] = 0; } else { - newIndices.splice( - ellipsisInsertionIndex, 0 /* num elements to delete */, - 0 /* element to add */); - newIndices.pop(); + const originalAxis = + unnormalizeAxis(ellipsisInsertionIndex, numElidedAxes, axis); + let originalValue = originalBegin[originalAxis]; + if (beginMask & 1 << originalAxis) { + originalValue = 0; + } + + newIndices[axis] = originalValue; } } return newIndices; @@ -103,17 +130,22 @@ export function startIndicesWithElidedDims( // Creates full selection at the elided dimensions. If the dimension matches // the ellipsis mask, override the current stop value. Otherwise, insert. export function stopIndicesWithElidedDims( - stopIndices: number[], ellipsisInsertionIndex: number, - numElidedAxes: number, inputShape: number[]): number[] { - const newIndices = [...stopIndices]; - for (let i = 0; i < numElidedAxes; i++) { - if (i === 0) { - newIndices[ellipsisInsertionIndex] = Number.MAX_SAFE_INTEGER; + endMask: number, ellipsisInsertionIndex: number, numElidedAxes: number, + originalEnd: number[], inputShape: number[]): number[] { + const newIndices = [...inputShape]; + const elidedAxes = getElidedAxes(numElidedAxes, ellipsisInsertionIndex); + + for (let axis = 0; axis < newIndices.length; axis++) { + if (elidedAxes.indexOf(axis) > -1) { + newIndices[axis] = Number.MAX_SAFE_INTEGER; } else { - newIndices.splice( - ellipsisInsertionIndex, 0 /* num elements to delete */, - Number.MAX_SAFE_INTEGER /* element to add */); - newIndices.pop(); + const originalAxis = + unnormalizeAxis(ellipsisInsertionIndex, numElidedAxes, axis); + let originalValue = originalEnd[originalAxis]; + if (endMask & 1 << originalAxis) { + originalValue = Number.MAX_SAFE_INTEGER; + } + newIndices[axis] = originalValue; } } diff --git a/tfjs-core/src/ops/strided_slice.ts b/tfjs-core/src/ops/strided_slice.ts index 69f59ef7835..9ad1b855e60 100644 --- a/tfjs-core/src/ops/strided_slice.ts +++ b/tfjs-core/src/ops/strided_slice.ts @@ -93,24 +93,26 @@ function stridedSlice_( $x = $x.reshape(newShape); // Normalize the start, end and strides. - for (let axis = 0; axis < $x.rank; axis++) { - begin[axis] = - startForAxis(beginMask, begin, strides, $x.shape, axis, ellipsisMask); - end[axis] = - stopForAxis(endMask, end, strides, $x.shape, axis, ellipsisMask); - strides[axis] = stridesForAxis(strides, axis, ellipsisMask); - } - if (ellipsisAxes.length && numInterpolatedAxes > 0) { const fullIndex = ellipsisAxes[0]; - // The ellipsis applies to the masked index as well as any dimensions that - // were interpolated as full selection. + // The ellipsis applies to the masked index as well as any dimensions + // that are interpolated. const numElidedAxes = numInterpolatedAxes + 1; - - begin = startIndicesWithElidedDims(begin, fullIndex, numElidedAxes); - end = stopIndicesWithElidedDims(end, fullIndex, numElidedAxes, $x.shape); - strides = stridesWithElidedDims(strides, fullIndex, numElidedAxes); + begin = startIndicesWithElidedDims( + beginMask, fullIndex, numElidedAxes, begin, $x.shape); + end = stopIndicesWithElidedDims( + endMask, fullIndex, numElidedAxes, end, $x.shape); + strides = + stridesWithElidedDims(strides, fullIndex, numElidedAxes, $x.shape); + } else { + for (let axis = 0; axis < $x.rank; axis++) { + begin[axis] = + startForAxis(beginMask, begin, strides, $x.shape, axis, ellipsisMask); + end[axis] = + stopForAxis(endMask, end, strides, $x.shape, axis, ellipsisMask); + strides[axis] = stridesForAxis(strides, axis, ellipsisMask); + } } const shrinkAxes = maskToAxes(shrinkAxisMask); diff --git a/tfjs-core/src/ops/strided_slice_test.ts b/tfjs-core/src/ops/strided_slice_test.ts index cc077df6efd..49a1dc2754c 100644 --- a/tfjs-core/src/ops/strided_slice_test.ts +++ b/tfjs-core/src/ops/strided_slice_test.ts @@ -45,6 +45,14 @@ describeWithFlags('stridedSlice', ALL_ENVS, () => { expectArraysClose(await output.data(), [5, 6, 7, 8, 9, 10, 11, 11, 11, 11]); }); + it('with ellipsisMask=1, begin / end masks and start / end normalization', + async () => { + const t = tf.randomNormal([1, 6, 2006, 4]); + const output = + tf.stridedSlice(t, [0, 0, 0], [0, 2004, 0], [1, 1, 1], 6, 4, 1); + expect(output.shape).toEqual([1, 6, 2004, 4]); + }); + it('with ellipsisMask=1 and start / end normalization', async () => { const t = tf.tensor3d([ [[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 6, 6]]