diff --git a/tfjs-core/src/backends/backend.ts b/tfjs-core/src/backends/backend.ts index 8d01d9df27b..7f5a14c4328 100644 --- a/tfjs-core/src/backends/backend.ts +++ b/tfjs-core/src/backends/backend.ts @@ -141,9 +141,7 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer { throw new Error('Not yet implemented'); } stridedSlice( - x: T, begin: number[], end: number[], strides: number[], - beginMask: number, endMask: number, ellipsisMask: number, - newAxisMask: number, shrinkAxisMask: number): T { + x: T, begin: number[], end: number[], strides: number[]): T { throw new Error('Not yet implemented'); } unstack(x: Tensor, axis: number): Tensor[] { diff --git a/tfjs-core/src/backends/cpu/backend_cpu.ts b/tfjs-core/src/backends/cpu/backend_cpu.ts index 85a1a8135c6..f3da154b2db 100644 --- a/tfjs-core/src/backends/cpu/backend_cpu.ts +++ b/tfjs-core/src/backends/cpu/backend_cpu.ts @@ -33,7 +33,7 @@ import * as ops from '../../ops/ops'; import {buffer, scalar, tensor, tensor3d, tensor4d} from '../../ops/ops'; import * as scatter_nd_util from '../../ops/scatter_nd_util'; import * as selu_util from '../../ops/selu_util'; -import {computeFlatOffset, getStridedSlicedInfo, isSliceContinous} from '../../ops/slice_util'; +import {computeFlatOffset, computeOutShape, isSliceContinous} from '../../ops/slice_util'; import {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer} from '../../tensor'; import {BackendValues, DataType, DataValues, NumericDataType, PixelData, Rank, ShapeMap, TypedArray, upcastType} from '../../types'; import * as util from '../../util'; @@ -313,34 +313,28 @@ export class MathBackendCPU implements KernelBackend { } stridedSlice( - x: T, begin: number[], end: number[], strides: number[], - beginMask: number, endMask: number, ellipsisMask: number, - newAxisMask: number, shrinkAxisMask: number): T { + x: T, begin: number[], end: number[], strides: number[]): T { this.assertNotComplex(x, 'stridedSlice'); - const [beginIndex, size, shrinkAxis] = getStridedSlicedInfo( - x.shape, begin, end, strides, beginMask, endMask, ellipsisMask, - newAxisMask, shrinkAxisMask); + const outShape = computeOutShape(begin, end, strides); - const shape = size.filter((v, index) => shrinkAxis.indexOf(index) === -1); - - if (shape.some(axis => axis === 0)) { - return ops.tensor([], shape) as T; + if (outShape.some(axis => axis === 0)) { + return ops.tensor([], outShape) as T; } - const buffer = ops.buffer(size, x.dtype); + const buffer = ops.buffer(outShape, x.dtype); const xBuf = this.bufferSync(x); for (let i = 0; i < buffer.size; i++) { const loc = buffer.indexToLoc(i); const newLoc: number[] = new Array(loc.length); for (let j = 0; j < newLoc.length; j++) { - newLoc[j] = loc[j] * strides[j] + beginIndex[j]; + newLoc[j] = loc[j] * strides[j] + begin[j]; } buffer.set(xBuf.get(...newLoc), ...loc); } - return buffer.toTensor().reshape(shape) as T; + return buffer.toTensor() as T; } diag(x: Tensor): Tensor { diff --git a/tfjs-core/src/backends/webgl/backend_webgl.ts b/tfjs-core/src/backends/webgl/backend_webgl.ts index 77eb28c04df..bde0e5322d8 100644 --- a/tfjs-core/src/backends/webgl/backend_webgl.ts +++ b/tfjs-core/src/backends/webgl/backend_webgl.ts @@ -34,7 +34,7 @@ import * as gather_nd_util from '../../ops/gather_nd_util'; import * as reduce_util from '../../ops/reduce_util'; import * as scatter_nd_util from '../../ops/scatter_nd_util'; import * as segment_util from '../../ops/segment_util'; -import {computeFlatOffset, getStridedSlicedInfo, isSliceContinous} from '../../ops/slice_util'; +import * as slice_util from '../../ops/slice_util'; import {softmax} from '../../ops/softmax'; import {range, scalar, tensor} from '../../ops/tensor_ops'; import {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../../tensor'; @@ -729,7 +729,7 @@ export class MathBackendWebGL implements KernelBackend { return tensor([], size, x.dtype) as T; } const {isPacked} = this.texData.get(x.dataId); - const isContinous = isSliceContinous(x.shape, begin, size); + const isContinous = slice_util.isSliceContinous(x.shape, begin, size); if (isPacked || !isContinous) { const program = ENV.getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? new SlicePackedProgram(size) : @@ -749,7 +749,7 @@ export class MathBackendWebGL implements KernelBackend { Object.assign(newTexData, xTexData); newTexData.shape = size; newTexData.dtype = x.dtype; - let flatOffset = computeFlatOffset(begin, x.strides); + let flatOffset = slice_util.computeFlatOffset(begin, x.strides); if (xTexData.slice) { // We are slicing an already sliced tensor, so we have to accumulate // the offset. @@ -769,26 +769,18 @@ export class MathBackendWebGL implements KernelBackend { } stridedSlice( - x: T, begin: number[], end: number[], strides: number[], - beginMask: number, endMask: number, ellipsisMask: number, - newAxisMask: number, shrinkAxisMask: number): T { + x: T, begin: number[], end: number[], strides: number[]): T { if (this.shouldExecuteOnCPU([x])) { - return this.cpuBackend.stridedSlice( - x, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, - shrinkAxisMask); + return this.cpuBackend.stridedSlice(x, begin, end, strides); } - const [beginIndex, size, shrinkAxis] = getStridedSlicedInfo( - x.shape, begin, end, strides, beginMask, endMask, ellipsisMask, - newAxisMask, shrinkAxisMask); + const outShape = slice_util.computeOutShape(begin, end, strides); - const shape = size.filter((v, index) => shrinkAxis.indexOf(index) === -1); - if (shape.some(axis => axis === 0)) { - return tensor([], shape) as T; + if (outShape.some(axis => axis === 0)) { + return tensor([], outShape) as T; } - const program = - new StridedSliceProgram(beginIndex, strides, size, shrinkAxis); + const program = new StridedSliceProgram(begin, strides, outShape); return this.compileAndRun(program, [x]); } diff --git a/tfjs-core/src/backends/webgl/strided_slice_gpu.ts b/tfjs-core/src/backends/webgl/strided_slice_gpu.ts index a341afbb5a9..192a7072a22 100644 --- a/tfjs-core/src/backends/webgl/strided_slice_gpu.ts +++ b/tfjs-core/src/backends/webgl/strided_slice_gpu.ts @@ -24,13 +24,11 @@ export class StridedSliceProgram implements GPGPUProgram { userCode: string; constructor( - begin: number[], strides: number[], size: number[], - shrinkAxis: number[]) { - const shape = size.filter((v, index) => shrinkAxis.indexOf(index) === -1); - this.outputShape = shape; + begin: number[], strides: number[], size: number[]) { + this.outputShape = size; const rank = size.length; const inputDtype = getCoordsDataType(size.length); - const dtype = getCoordsDataType(shape.length); + const dtype = getCoordsDataType(size.length); let newCoords = ''; if (rank === 1) { @@ -39,14 +37,10 @@ export class StridedSliceProgram implements GPGPUProgram { let outputAxis = 0; newCoords = size.map((_, i) => { - if (shrinkAxis.indexOf(i) === -1) { - outputAxis++; - return shape.length === 1 ? - `coords * strides[${i}] + begin[${i}]` : - `coords[${outputAxis - 1}] * strides[${i}] + begin[${i}]`; - } else { - return `begin[${i}]`; - } + outputAxis++; + return size.length === 1 ? + `coords * strides[${i}] + begin[${i}]` : + `coords[${outputAxis - 1}] * strides[${i}] + begin[${i}]`; }) .join(','); } diff --git a/tfjs-core/src/ops/slice_util.ts b/tfjs-core/src/ops/slice_util.ts index b5e1a2adca1..bdc2bdf864b 100644 --- a/tfjs-core/src/ops/slice_util.ts +++ b/tfjs-core/src/ops/slice_util.ts @@ -38,49 +38,28 @@ export function assertParamsValid( } } -/** - * Calculate the start index and output tensor shape for strided slice op. - * @returns array of [startIndex, size, shrinkAxis] - */ -export function getStridedSlicedInfo( - shape: number[], begin: number[], end: number[], strides: number[], - beginMask = 0, endMask = 0, ellipsisMask = 0, newAxisMask = 0, - shrinkAxisMask = 0): [number[], number[], number[]] { - if (ellipsisMask !== 0) { - throw new Error('ellipsis mask is not yet supported'); - } - if (newAxisMask !== 0) { - throw new Error('new axis mask is not yet supported'); - } - // Note that the axis orders are reversed for runtime ops, so the indices, - // strides and masks must be as well too. - const startIndex: number[] = []; - const endIndex: number[] = []; - const shrinkAxis: number[] = []; - for (let i = 0; i < shape.length; i++) { - startIndex[i] = startForAxis(beginMask, begin, strides, shape, i); - endIndex[i] = stopForAxis(endMask, end, strides, shape, i); - // When shrinking an axis, use startIndex + 1 for endIndex. - // Check the axis bit from right of shrinkAxisMask - if (shrinkAxisMask & 1 << i) { - endIndex[i] = startIndex[i] + 1; - shrinkAxis.push(i); +/** Converts a binary mask to an array of axes. Used in stridedSlice(). */ +export function maskToAxes(mask: number): number[] { + const axes = []; + let axis = 0; + while (mask > 0) { + if (mask & 1) { + axes.push(axis); } + mask /= 2; + axis++; } + return axes; +} - let size = new Array(shape.length).fill(0); - size = size.map((d, i) => { - let count = 0; - const stride = strides[i] || 1; - for (let start = startIndex[i]; - !(stride > 0 ? start >= endIndex[i] : start <= endIndex[i]); - start += stride) { - count += 1; - } - return count; - }); - - return [startIndex, size, shrinkAxis]; +/** Computes the output shape given the strided slice params. */ +export function computeOutShape( + begin: number[], end: number[], strides: number[]): number[] { + const size = []; + for (let axis = 0; axis < begin.length; axis++) { + size[axis] = Math.ceil((end[axis] - begin[axis]) / strides[axis]); + } + return size; } export function startForAxis( diff --git a/tfjs-core/src/ops/strided_slice.ts b/tfjs-core/src/ops/strided_slice.ts index 4bb37c38a68..40bddb48194 100644 --- a/tfjs-core/src/ops/strided_slice.ts +++ b/tfjs-core/src/ops/strided_slice.ts @@ -19,9 +19,10 @@ import {ENGINE} from '../engine'; import {Tensor} from '../tensor'; import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; + import {op} from './operation'; import {slice} from './slice'; -import {getStridedSlicedInfo} from './slice_util'; +import {computeOutShape, maskToAxes, startForAxis, stopForAxis} from './slice_util'; /** * Extracts a strided slice of a tensor. @@ -56,30 +57,53 @@ import {getStridedSlicedInfo} from './slice_util'; */ /** @doc {heading: 'Operations', subheading: 'Slicing and Joining'} */ function stridedSlice_( - x: Tensor|TensorLike, begin: number[], end: number[], strides: number[], + x: Tensor|TensorLike, begin: number[], end: number[], strides?: number[], beginMask = 0, endMask = 0, ellipsisMask = 0, newAxisMask = 0, shrinkAxisMask = 0): Tensor { + if (strides == null) { + strides = new Array(begin.length); + } if (ellipsisMask !== 0) { throw new Error('ellipsis mask is not yet supported'); } - if (newAxisMask !== 0) { - throw new Error('new axis mask is not yet supported'); + let $x = convertToTensor(x, 'x', 'stridedSlice'); + + // Expand the dims of x based on the newAxisMask. + const expandAxes = maskToAxes(newAxisMask); + const newShape = $x.shape.slice(); + expandAxes.forEach(axis => { + begin[axis] = 0; + end[axis] = 1; + newShape.splice(axis, 0, 1); + }); + $x = $x.reshape(newShape); + + // Normalize the start, end and strides. + for (let axis = 0; axis < $x.rank; axis++) { + begin[axis] = startForAxis(beginMask, begin, strides, $x.shape, axis); + end[axis] = stopForAxis(endMask, end, strides, $x.shape, axis); + strides[axis] = strides[axis] || 1; } - const $x = convertToTensor(x, 'x', 'stridedSlice'); + + const shrinkAxes = maskToAxes(shrinkAxisMask); + // Adjust the ends based on the shrink mask. + shrinkAxes.forEach(axis => { + end[axis] = begin[axis] + 1; + strides[axis] = 1; + }); + + // Figure out the output shape. + const size = computeOutShape(begin, end, strides); + // Remove the axes based on shrinkMask. + const outShape = size.filter((_, axis) => shrinkAxes.indexOf(axis) === -1); + const nonStrided = strides.every(v => v === 1); if (nonStrided) { - const [beginIndex, size, shrinkAxis] = getStridedSlicedInfo( - $x.shape, begin, end, strides, beginMask, endMask, ellipsisMask, - newAxisMask, shrinkAxisMask); - const outShape = - size.filter((_, index) => shrinkAxis.indexOf(index) === -1); - return slice($x, beginIndex, size).reshape(outShape); + return slice($x, begin, size).reshape(outShape); } - return ENGINE.runKernel( - backend => backend.stridedSlice( - $x, begin, end, strides, beginMask, endMask, ellipsisMask, - newAxisMask, shrinkAxisMask), - {$x}); + const res = ENGINE.runKernel( + backend => backend.stridedSlice($x, begin, end, strides), {$x}); + return res.reshape(outShape); } export const stridedSlice = op({stridedSlice_}); diff --git a/tfjs-core/src/ops/strided_slice_test.ts b/tfjs-core/src/ops/strided_slice_test.ts index 7ef7ea96320..84e39598c0d 100644 --- a/tfjs-core/src/ops/strided_slice_test.ts +++ b/tfjs-core/src/ops/strided_slice_test.ts @@ -20,9 +20,53 @@ import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; import {expectArraysClose} from '../test_util'; describeWithFlags('stridedSlice', ALL_ENVS, () => { - it('stridedSlice should fail if new axis mask is set', () => { - const tensor = tf.tensor1d([0, 1, 2, 3]); - expect(() => tf.stridedSlice(tensor, [0], [3], [2], 0, 0, 0, 1)).toThrow(); + it('stridedSlice with first axis being new', async () => { + // Python slice code: t[tf.newaxis,0:3] + const t = tf.tensor1d([0, 1, 2, 3]); + const begin = [0, 0]; + const end = [1, 3]; + const strides = [1, 2]; + const beginMask = 0; + const endMask = 0; + const ellipsisMask = 0; + const newAxisMask = 1; + + const output = tf.stridedSlice( + t, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask); + expect(output.shape).toEqual([1, 2]); + expectArraysClose(await output.data(), [0, 2]); + }); + + it('strided slice with several new axes', () => { + // Python slice code: t[1:2,tf.newaxis,0:3,tf.newaxis,2:5] + const t = tf.zeros([2, 3, 4, 5]); + const begin = [1, 0, 0, 0, 2]; + const end = [2, 1, 3, 1, 5]; + const strides: number[] = null; + const beginMask = 0; + const endMask = 0; + const ellipsisMask = 0; + const newAxisMask = 0b1010; + const output = tf.stridedSlice( + t, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask); + expect(output.shape).toEqual([1, 1, 3, 1, 2, 5]); + }); + + it('strided slice with new axes and shrink axes', () => { + // Python slice code: t[1:2,tf.newaxis,1,tf.newaxis,2,2:5] + const t = tf.zeros([2, 3, 4, 5]); + const begin = [1, 0, 1, 0, 2, 2]; + const end = [2, 1, 2, 1, 3, 5]; + const strides: number[] = null; + const beginMask = 0; + const endMask = 0; + const ellipsisMask = 0; + const newAxisMask = 0b1010; + const shrinkAxisMask = 0b10100; + const output = tf.stridedSlice( + t, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, + shrinkAxisMask); + expect(output.shape).toEqual([1, 1, 1, 3]); }); it('stridedSlice should fail if ellipsis mask is set', () => { diff --git a/tfjs-core/src/tensor_test.ts b/tfjs-core/src/tensor_test.ts index 5f92b68ae7c..54dfe13cf4c 100644 --- a/tfjs-core/src/tensor_test.ts +++ b/tfjs-core/src/tensor_test.ts @@ -1461,6 +1461,14 @@ describeWithFlags('tensor', ALL_ENVS, () => { expect(res.shape).toEqual([0, 0]); }); + it('squeeze can take an empty list of axis', () => { + const a = tf.zeros([2, 1, 3, 1, 4]); + const axes: number[] = []; + // Empty axes list means all possible axes. + const res = tf.squeeze(a, axes); + expect(res.shape).toEqual([2, 3, 4]); + }); + it('squeeze a complex64 tensor', async () => { const a = tf.complex([[4], [1], [5]], [[2], [3], [6]]); const b = a.squeeze(); diff --git a/tfjs-core/src/util.ts b/tfjs-core/src/util.ts index e36cbd0e5b9..2553f49c32d 100644 --- a/tfjs-core/src/util.ts +++ b/tfjs-core/src/util.ts @@ -349,7 +349,10 @@ export function squeezeShape(shape: number[], axis?: number[]): {newShape: number[], keptDims: number[]} { const newShape: number[] = []; const keptDims: number[] = []; - const axes = axis == null ? null : parseAxisParam(axis, shape).sort(); + const isEmptyArray = axis != null && Array.isArray(axis) && axis.length === 0; + const axes = (axis == null || isEmptyArray) ? + null : + parseAxisParam(axis, shape).sort(); let j = 0; for (let i = 0; i < shape.length; ++i) { if (axes != null) {