Skip to content
76 changes: 54 additions & 22 deletions tfjs-core/src/ops/slice_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,12 @@ export function computeOutShape(
// Creates full selection at the elided dimensions. If the dimension matches
// the ellipsis mask, override the current stride value. Otherwise, insert.
export function stridesWithElidedDims(
strides: number[], ellipsisInsertionIndex: number,
numElidedAxes: number): number[] {
strides: number[], ellipsisInsertionIndex: number, numElidedAxes: number,
inputShape: number[]): number[] {
const newStrides = [...strides];
for (let i = newStrides.length; i < inputShape.length; i++) {
newStrides.push(1);
}
for (let i = 0; i < numElidedAxes; i++) {
if (i === 0) {
newStrides[ellipsisInsertionIndex] = 1;
Expand All @@ -81,20 +84,44 @@ export function stridesWithElidedDims(
return newStrides;
}

function unnormalizeAxis(
ellipsisInsertionIndex: number, numElidedAxes: number,
normalizedAxis: number): number {
if (normalizedAxis <= ellipsisInsertionIndex) {
return normalizedAxis;
}

return normalizedAxis - (numElidedAxes - 1);
}

function getElidedAxes(numElidedAxes: number, ellipsisInsertionIndex: number) {
const elidedAxes = [];
for (let i = 0; i < numElidedAxes; i++) {
elidedAxes.push(ellipsisInsertionIndex + i);
}
return elidedAxes;
}

// Creates full selection at the elided dimensions. If the dimension matches
// the ellipsis mask, override the current start value. Otherwise, insert.
export function startIndicesWithElidedDims(
startIndices: number[], ellipsisInsertionIndex: number,
numElidedAxes: number): number[] {
const newIndices = [...startIndices];
for (let i = 0; i < numElidedAxes; i++) {
if (i === 0) {
newIndices[ellipsisInsertionIndex] = 0;
beginMask: number, ellipsisInsertionIndex: number, numElidedAxes: number,
originalBegin: number[], inputShape: number[]): number[] {
const newIndices = [...inputShape];
const elidedAxes = getElidedAxes(numElidedAxes, ellipsisInsertionIndex);

for (let axis = 0; axis < newIndices.length; axis++) {
if (elidedAxes.indexOf(axis) > -1) {
newIndices[axis] = 0;
} else {
newIndices.splice(
ellipsisInsertionIndex, 0 /* num elements to delete */,
0 /* element to add */);
newIndices.pop();
const originalAxis =
unnormalizeAxis(ellipsisInsertionIndex, numElidedAxes, axis);
let originalValue = originalBegin[originalAxis];
if (beginMask & 1 << originalAxis) {
originalValue = 0;
}

newIndices[axis] = originalValue;
}
}
return newIndices;
Expand All @@ -103,17 +130,22 @@ export function startIndicesWithElidedDims(
// Creates full selection at the elided dimensions. If the dimension matches
// the ellipsis mask, override the current stop value. Otherwise, insert.
export function stopIndicesWithElidedDims(
stopIndices: number[], ellipsisInsertionIndex: number,
numElidedAxes: number, inputShape: number[]): number[] {
const newIndices = [...stopIndices];
for (let i = 0; i < numElidedAxes; i++) {
if (i === 0) {
newIndices[ellipsisInsertionIndex] = Number.MAX_SAFE_INTEGER;
endMask: number, ellipsisInsertionIndex: number, numElidedAxes: number,
originalEnd: number[], inputShape: number[]): number[] {
const newIndices = [...inputShape];
const elidedAxes = getElidedAxes(numElidedAxes, ellipsisInsertionIndex);

for (let axis = 0; axis < newIndices.length; axis++) {
if (elidedAxes.indexOf(axis) > -1) {
newIndices[axis] = Number.MAX_SAFE_INTEGER;
} else {
newIndices.splice(
ellipsisInsertionIndex, 0 /* num elements to delete */,
Number.MAX_SAFE_INTEGER /* element to add */);
newIndices.pop();
const originalAxis =
unnormalizeAxis(ellipsisInsertionIndex, numElidedAxes, axis);
let originalValue = originalEnd[originalAxis];
if (endMask & 1 << originalAxis) {
originalValue = Number.MAX_SAFE_INTEGER;
}
newIndices[axis] = originalValue;
}
}

Expand Down
30 changes: 16 additions & 14 deletions tfjs-core/src/ops/strided_slice.ts
Original file line number Diff line number Diff line change
Expand Up @@ -93,24 +93,26 @@ function stridedSlice_(
$x = $x.reshape(newShape);

// Normalize the start, end and strides.
for (let axis = 0; axis < $x.rank; axis++) {
begin[axis] =
startForAxis(beginMask, begin, strides, $x.shape, axis, ellipsisMask);
end[axis] =
stopForAxis(endMask, end, strides, $x.shape, axis, ellipsisMask);
strides[axis] = stridesForAxis(strides, axis, ellipsisMask);
}

if (ellipsisAxes.length && numInterpolatedAxes > 0) {
const fullIndex = ellipsisAxes[0];

// The ellipsis applies to the masked index as well as any dimensions that
// were interpolated as full selection.
// The ellipsis applies to the masked index as well as any dimensions
// that are interpolated.
const numElidedAxes = numInterpolatedAxes + 1;

begin = startIndicesWithElidedDims(begin, fullIndex, numElidedAxes);
end = stopIndicesWithElidedDims(end, fullIndex, numElidedAxes, $x.shape);
strides = stridesWithElidedDims(strides, fullIndex, numElidedAxes);
begin = startIndicesWithElidedDims(
beginMask, fullIndex, numElidedAxes, begin, $x.shape);
end = stopIndicesWithElidedDims(
endMask, fullIndex, numElidedAxes, end, $x.shape);
strides =
stridesWithElidedDims(strides, fullIndex, numElidedAxes, $x.shape);
} else {
for (let axis = 0; axis < $x.rank; axis++) {
begin[axis] =
startForAxis(beginMask, begin, strides, $x.shape, axis, ellipsisMask);
end[axis] =
stopForAxis(endMask, end, strides, $x.shape, axis, ellipsisMask);
strides[axis] = stridesForAxis(strides, axis, ellipsisMask);
}
}

const shrinkAxes = maskToAxes(shrinkAxisMask);
Expand Down
8 changes: 8 additions & 0 deletions tfjs-core/src/ops/strided_slice_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ describeWithFlags('stridedSlice', ALL_ENVS, () => {
expectArraysClose(await output.data(), [5, 6, 7, 8, 9, 10, 11, 11, 11, 11]);
});

it('with ellipsisMask=1, begin / end masks and start / end normalization',
async () => {
const t = tf.randomNormal([1, 6, 2006, 4]);
const output =
tf.stridedSlice(t, [0, 0, 0], [0, 2004, 0], [1, 1, 1], 6, 4, 1);
expect(output.shape).toEqual([1, 6, 2004, 4]);
});

it('with ellipsisMask=1 and start / end normalization', async () => {
const t = tf.tensor3d([
[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 6, 6]]
Expand Down