From efbcb091d90cbff408a242689b462a9fb5b91b25 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 22 Jun 2020 11:42:38 -0400 Subject: [PATCH 1/5] works --- tfjs-core/src/ops/slice_util.ts | 80 +++++++++++++++++++------ tfjs-core/src/ops/strided_slice.ts | 14 +++-- tfjs-core/src/ops/strided_slice_test.ts | 10 +++- 3 files changed, 80 insertions(+), 24 deletions(-) diff --git a/tfjs-core/src/ops/slice_util.ts b/tfjs-core/src/ops/slice_util.ts index 7545f02a122..094c635094a 100644 --- a/tfjs-core/src/ops/slice_util.ts +++ b/tfjs-core/src/ops/slice_util.ts @@ -81,20 +81,51 @@ export function stridesWithElidedDims( return newStrides; } +function unnormalizeAxis( + ellipsisInsertionIndex: number, numElidedAxes: number, + normalizedAxis: number): number { + if (normalizedAxis <= ellipsisInsertionIndex) { + return normalizedAxis; + } + + return normalizedAxis - (numElidedAxes - 1); +} + +function getElidedAxes(numElidedAxes: number, ellipsisInsertionIndex: number) { + const elidedAxes = []; + for (let i = 0; i < numElidedAxes; i++) { + elidedAxes.push(ellipsisInsertionIndex + i); + } + return elidedAxes; +} + // Creates full selection at the elided dimensions. If the dimension matches // the ellipsis mask, override the current start value. Otherwise, insert. export function startIndicesWithElidedDims( - startIndices: number[], ellipsisInsertionIndex: number, - numElidedAxes: number): number[] { + startIndices: number[], beginMask: number, ellipsisInsertionIndex: number, + numElidedAxes: number, originalBegin: number[]): number[] { const newIndices = [...startIndices]; - for (let i = 0; i < numElidedAxes; i++) { - if (i === 0) { - newIndices[ellipsisInsertionIndex] = 0; + const elidedAxes = getElidedAxes(numElidedAxes, ellipsisInsertionIndex); + + for (let axis = 0; axis < newIndices.length; axis++) { + if (elidedAxes.indexOf(axis) > -1) { + if (axis === ellipsisInsertionIndex) { + newIndices[ellipsisInsertionIndex] = 0; + } else { + newIndices.splice( + ellipsisInsertionIndex, 0 /* num elements to delete */, + 0 /* element to add */); + newIndices.pop(); + } } else { - newIndices.splice( - ellipsisInsertionIndex, 0 /* num elements to delete */, - 0 /* element to add */); - newIndices.pop(); + const originalAxis = + unnormalizeAxis(ellipsisInsertionIndex, numElidedAxes, axis); + let originalValue = originalBegin[originalAxis]; + if (beginMask & 1 << originalAxis) { + originalValue = 0; + } + + newIndices[axis] = originalValue; } } return newIndices; @@ -103,17 +134,30 @@ export function startIndicesWithElidedDims( // Creates full selection at the elided dimensions. If the dimension matches // the ellipsis mask, override the current stop value. Otherwise, insert. export function stopIndicesWithElidedDims( - stopIndices: number[], ellipsisInsertionIndex: number, - numElidedAxes: number, inputShape: number[]): number[] { + stopIndices: number[], endMask: number, ellipsisInsertionIndex: number, + numElidedAxes: number, inputShape: number[], + originalEnd: number[]): number[] { const newIndices = [...stopIndices]; - for (let i = 0; i < numElidedAxes; i++) { - if (i === 0) { - newIndices[ellipsisInsertionIndex] = Number.MAX_SAFE_INTEGER; + const elidedAxes = getElidedAxes(numElidedAxes, ellipsisInsertionIndex); + + for (let axis = 0; axis < newIndices.length; axis++) { + if (elidedAxes.indexOf(axis) > -1) { + if (axis === ellipsisInsertionIndex) { + newIndices[ellipsisInsertionIndex] = Number.MAX_SAFE_INTEGER; + } else { + newIndices.splice( + ellipsisInsertionIndex, 0 /* num elements to delete */, + Number.MAX_SAFE_INTEGER /* element to add */); + newIndices.pop(); + } } else { - newIndices.splice( - ellipsisInsertionIndex, 0 /* num elements to delete */, - Number.MAX_SAFE_INTEGER /* element to add */); - newIndices.pop(); + const originalAxis = + unnormalizeAxis(ellipsisInsertionIndex, numElidedAxes, axis); + let originalValue = originalEnd[originalAxis]; + if (endMask & 1 << originalAxis) { + originalValue = Number.MAX_SAFE_INTEGER; + } + newIndices[axis] = originalValue; } } diff --git a/tfjs-core/src/ops/strided_slice.ts b/tfjs-core/src/ops/strided_slice.ts index e93977143b5..20e53415c90 100644 --- a/tfjs-core/src/ops/strided_slice.ts +++ b/tfjs-core/src/ops/strided_slice.ts @@ -64,6 +64,9 @@ function stridedSlice_( strides = new Array(begin.length); } + const originalBegin = begin.slice(0); + const originalEnd = end.slice(0); + const ellipsisAxes = maskToAxes(ellipsisMask); if (ellipsisAxes.length > 1) { throw new Error('Multiple ellipses in slice is not allowed.'); @@ -104,12 +107,13 @@ function stridedSlice_( if (ellipsisAxes.length && numInterpolatedAxes > 0) { const fullIndex = ellipsisAxes[0]; - // The ellipsis applies to the masked index as well as any dimensions that - // were interpolated as full selection. + // The ellipsis applies to the masked index as well as any dimensions + // that were interpolated as full selection. const numElidedAxes = numInterpolatedAxes + 1; - - begin = startIndicesWithElidedDims(begin, fullIndex, numElidedAxes); - end = stopIndicesWithElidedDims(end, fullIndex, numElidedAxes, $x.shape); + begin = startIndicesWithElidedDims( + begin, beginMask, fullIndex, numElidedAxes, originalBegin); + end = stopIndicesWithElidedDims( + end, endMask, fullIndex, numElidedAxes, $x.shape, originalEnd); strides = stridesWithElidedDims(strides, fullIndex, numElidedAxes); } diff --git a/tfjs-core/src/ops/strided_slice_test.ts b/tfjs-core/src/ops/strided_slice_test.ts index 32e2db12378..e0f5e66f047 100644 --- a/tfjs-core/src/ops/strided_slice_test.ts +++ b/tfjs-core/src/ops/strided_slice_test.ts @@ -45,10 +45,18 @@ describeWithFlags('stridedSlice', ALL_ENVS, () => { expectArraysClose(await output.data(), [5, 6, 7, 8, 9, 10, 11, 11, 11, 11]); }); + it('with ellipsisMask=1, begin / end masks and start / end normalization', + async () => { + const t = tf.randomNormal([1, 6, 2006, 4]); + const output = + tf.stridedSlice(t, [0, 0, 0], [0, 2004, 0], [1, 1, 1], 6, 4, 1); + expect(output.shape).toEqual([1, 6, 2004, 4]); + }); + it('with ellipsisMask=1 and start / end normalization', async () => { const t = tf.tensor3d([ [[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 6, 6]] - ]); + ]); // [3, 2, 3] const begin = [1, 0]; const end = [2, 1]; const strides = [1, 1]; From e1d0ed6d9bbfca5904baf870a49b6d933cb5883c Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 22 Jun 2020 11:56:51 -0400 Subject: [PATCH 2/5] simplify --- tfjs-core/src/ops/slice_util.ts | 13 ++++++------- tfjs-core/src/ops/strided_slice.ts | 31 +++++++++++++++--------------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/tfjs-core/src/ops/slice_util.ts b/tfjs-core/src/ops/slice_util.ts index 094c635094a..457d41b2ab4 100644 --- a/tfjs-core/src/ops/slice_util.ts +++ b/tfjs-core/src/ops/slice_util.ts @@ -102,9 +102,9 @@ function getElidedAxes(numElidedAxes: number, ellipsisInsertionIndex: number) { // Creates full selection at the elided dimensions. If the dimension matches // the ellipsis mask, override the current start value. Otherwise, insert. export function startIndicesWithElidedDims( - startIndices: number[], beginMask: number, ellipsisInsertionIndex: number, - numElidedAxes: number, originalBegin: number[]): number[] { - const newIndices = [...startIndices]; + beginMask: number, ellipsisInsertionIndex: number, numElidedAxes: number, + originalBegin: number[], inputShape: number[]): number[] { + const newIndices = [...inputShape]; const elidedAxes = getElidedAxes(numElidedAxes, ellipsisInsertionIndex); for (let axis = 0; axis < newIndices.length; axis++) { @@ -134,10 +134,9 @@ export function startIndicesWithElidedDims( // Creates full selection at the elided dimensions. If the dimension matches // the ellipsis mask, override the current stop value. Otherwise, insert. export function stopIndicesWithElidedDims( - stopIndices: number[], endMask: number, ellipsisInsertionIndex: number, - numElidedAxes: number, inputShape: number[], - originalEnd: number[]): number[] { - const newIndices = [...stopIndices]; + endMask: number, ellipsisInsertionIndex: number, numElidedAxes: number, + originalEnd: number[], inputShape: number[]): number[] { + const newIndices = [...inputShape]; const elidedAxes = getElidedAxes(numElidedAxes, ellipsisInsertionIndex); for (let axis = 0; axis < newIndices.length; axis++) { diff --git a/tfjs-core/src/ops/strided_slice.ts b/tfjs-core/src/ops/strided_slice.ts index 20e53415c90..ea1bc39f308 100644 --- a/tfjs-core/src/ops/strided_slice.ts +++ b/tfjs-core/src/ops/strided_slice.ts @@ -64,9 +64,6 @@ function stridedSlice_( strides = new Array(begin.length); } - const originalBegin = begin.slice(0); - const originalEnd = end.slice(0); - const ellipsisAxes = maskToAxes(ellipsisMask); if (ellipsisAxes.length > 1) { throw new Error('Multiple ellipses in slice is not allowed.'); @@ -95,26 +92,30 @@ function stridedSlice_( }); $x = $x.reshape(newShape); - // Normalize the start, end and strides. - for (let axis = 0; axis < $x.rank; axis++) { - begin[axis] = - startForAxis(beginMask, begin, strides, $x.shape, axis, ellipsisMask); - end[axis] = - stopForAxis(endMask, end, strides, $x.shape, axis, ellipsisMask); - strides[axis] = stridesForAxis(strides, axis, ellipsisMask); - } - if (ellipsisAxes.length && numInterpolatedAxes > 0) { const fullIndex = ellipsisAxes[0]; // The ellipsis applies to the masked index as well as any dimensions - // that were interpolated as full selection. + // that are interpolated. const numElidedAxes = numInterpolatedAxes + 1; begin = startIndicesWithElidedDims( - begin, beginMask, fullIndex, numElidedAxes, originalBegin); + beginMask, fullIndex, numElidedAxes, begin, $x.shape); end = stopIndicesWithElidedDims( - end, endMask, fullIndex, numElidedAxes, $x.shape, originalEnd); + endMask, fullIndex, numElidedAxes, end, $x.shape); + + for (let axis = 0; axis < $x.rank; axis++) { + strides[axis] = stridesForAxis(strides, axis, ellipsisMask); + } strides = stridesWithElidedDims(strides, fullIndex, numElidedAxes); + } else { + // Normalize the start, end and strides. + for (let axis = 0; axis < $x.rank; axis++) { + begin[axis] = + startForAxis(beginMask, begin, strides, $x.shape, axis, ellipsisMask); + end[axis] = + stopForAxis(endMask, end, strides, $x.shape, axis, ellipsisMask); + strides[axis] = stridesForAxis(strides, axis, ellipsisMask); + } } const shrinkAxes = maskToAxes(shrinkAxisMask); From 3d80d70b9ed028a9ae6a8204266de3e507b5697e Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 22 Jun 2020 12:02:28 -0400 Subject: [PATCH 3/5] clean --- tfjs-core/src/ops/slice_util.ts | 7 +++++-- tfjs-core/src/ops/strided_slice.ts | 9 +++------ tfjs-core/src/ops/strided_slice_test.ts | 8 +++++--- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/tfjs-core/src/ops/slice_util.ts b/tfjs-core/src/ops/slice_util.ts index 457d41b2ab4..bc536ff9cd6 100644 --- a/tfjs-core/src/ops/slice_util.ts +++ b/tfjs-core/src/ops/slice_util.ts @@ -65,9 +65,12 @@ export function computeOutShape( // Creates full selection at the elided dimensions. If the dimension matches // the ellipsis mask, override the current stride value. Otherwise, insert. export function stridesWithElidedDims( - strides: number[], ellipsisInsertionIndex: number, - numElidedAxes: number): number[] { + strides: number[], ellipsisInsertionIndex: number, numElidedAxes: number, + inputShape: number[]): number[] { const newStrides = [...strides]; + for (let i = newStrides.length; i < inputShape.length; i++) { + newStrides.push(1); + } for (let i = 0; i < numElidedAxes; i++) { if (i === 0) { newStrides[ellipsisInsertionIndex] = 1; diff --git a/tfjs-core/src/ops/strided_slice.ts b/tfjs-core/src/ops/strided_slice.ts index ea1bc39f308..4d1a5a3911b 100644 --- a/tfjs-core/src/ops/strided_slice.ts +++ b/tfjs-core/src/ops/strided_slice.ts @@ -92,6 +92,7 @@ function stridedSlice_( }); $x = $x.reshape(newShape); + // Normalize the start, end and strides. if (ellipsisAxes.length && numInterpolatedAxes > 0) { const fullIndex = ellipsisAxes[0]; @@ -102,13 +103,9 @@ function stridedSlice_( beginMask, fullIndex, numElidedAxes, begin, $x.shape); end = stopIndicesWithElidedDims( endMask, fullIndex, numElidedAxes, end, $x.shape); - - for (let axis = 0; axis < $x.rank; axis++) { - strides[axis] = stridesForAxis(strides, axis, ellipsisMask); - } - strides = stridesWithElidedDims(strides, fullIndex, numElidedAxes); + strides = + stridesWithElidedDims(strides, fullIndex, numElidedAxes, $x.shape); } else { - // Normalize the start, end and strides. for (let axis = 0; axis < $x.rank; axis++) { begin[axis] = startForAxis(beginMask, begin, strides, $x.shape, axis, ellipsisMask); diff --git a/tfjs-core/src/ops/strided_slice_test.ts b/tfjs-core/src/ops/strided_slice_test.ts index e0f5e66f047..3b128ccfb2f 100644 --- a/tfjs-core/src/ops/strided_slice_test.ts +++ b/tfjs-core/src/ops/strided_slice_test.ts @@ -54,9 +54,11 @@ describeWithFlags('stridedSlice', ALL_ENVS, () => { }); it('with ellipsisMask=1 and start / end normalization', async () => { - const t = tf.tensor3d([ - [[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 6, 6]] - ]); // [3, 2, 3] + const t = tf.tensor3d( + [ + [[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 6, 6]] + ], + [3, 2, 3]); const begin = [1, 0]; const end = [2, 1]; const strides = [1, 1]; From 84fc723f63081f6dddd0eafb08793a02d6ea9acf Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 22 Jun 2020 12:04:18 -0400 Subject: [PATCH 4/5] clean --- tfjs-core/src/ops/strided_slice_test.ts | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tfjs-core/src/ops/strided_slice_test.ts b/tfjs-core/src/ops/strided_slice_test.ts index 3b128ccfb2f..6a1d42e217c 100644 --- a/tfjs-core/src/ops/strided_slice_test.ts +++ b/tfjs-core/src/ops/strided_slice_test.ts @@ -54,11 +54,9 @@ describeWithFlags('stridedSlice', ALL_ENVS, () => { }); it('with ellipsisMask=1 and start / end normalization', async () => { - const t = tf.tensor3d( - [ - [[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 6, 6]] - ], - [3, 2, 3]); + const t = tf.tensor3d([ + [[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 6, 6]] + ]); const begin = [1, 0]; const end = [2, 1]; const strides = [1, 1]; From fcc27d3fa69f5e6e328157f2d1f95844cee4914f Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Tue, 23 Jun 2020 08:41:44 -0400 Subject: [PATCH 5/5] clean --- tfjs-core/src/ops/slice_util.ts | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/tfjs-core/src/ops/slice_util.ts b/tfjs-core/src/ops/slice_util.ts index bc536ff9cd6..999a81eecb5 100644 --- a/tfjs-core/src/ops/slice_util.ts +++ b/tfjs-core/src/ops/slice_util.ts @@ -112,14 +112,7 @@ export function startIndicesWithElidedDims( for (let axis = 0; axis < newIndices.length; axis++) { if (elidedAxes.indexOf(axis) > -1) { - if (axis === ellipsisInsertionIndex) { - newIndices[ellipsisInsertionIndex] = 0; - } else { - newIndices.splice( - ellipsisInsertionIndex, 0 /* num elements to delete */, - 0 /* element to add */); - newIndices.pop(); - } + newIndices[axis] = 0; } else { const originalAxis = unnormalizeAxis(ellipsisInsertionIndex, numElidedAxes, axis); @@ -144,14 +137,7 @@ export function stopIndicesWithElidedDims( for (let axis = 0; axis < newIndices.length; axis++) { if (elidedAxes.indexOf(axis) > -1) { - if (axis === ellipsisInsertionIndex) { - newIndices[ellipsisInsertionIndex] = Number.MAX_SAFE_INTEGER; - } else { - newIndices.splice( - ellipsisInsertionIndex, 0 /* num elements to delete */, - Number.MAX_SAFE_INTEGER /* element to add */); - newIndices.pop(); - } + newIndices[axis] = Number.MAX_SAFE_INTEGER; } else { const originalAxis = unnormalizeAxis(ellipsisInsertionIndex, numElidedAxes, axis);