Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions tfjs-core/src/backends/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,7 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer {
throw new Error('Not yet implemented');
}
stridedSlice<T extends Tensor>(
x: T, begin: number[], end: number[], strides: number[],
beginMask: number, endMask: number, ellipsisMask: number,
newAxisMask: number, shrinkAxisMask: number): T {
x: T, begin: number[], end: number[], strides: number[]): T {
throw new Error('Not yet implemented');
}
unstack(x: Tensor, axis: number): Tensor[] {
Expand Down
22 changes: 8 additions & 14 deletions tfjs-core/src/backends/cpu/backend_cpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import * as ops from '../../ops/ops';
import {buffer, scalar, tensor, tensor3d, tensor4d} from '../../ops/ops';
import * as scatter_nd_util from '../../ops/scatter_nd_util';
import * as selu_util from '../../ops/selu_util';
import {computeFlatOffset, getStridedSlicedInfo, isSliceContinous} from '../../ops/slice_util';
import {computeFlatOffset, computeOutShape, isSliceContinous} from '../../ops/slice_util';
import {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer} from '../../tensor';
import {BackendValues, DataType, DataValues, NumericDataType, PixelData, Rank, ShapeMap, TypedArray, upcastType} from '../../types';
import * as util from '../../util';
Expand Down Expand Up @@ -313,34 +313,28 @@ export class MathBackendCPU implements KernelBackend {
}

stridedSlice<T extends Tensor>(
x: T, begin: number[], end: number[], strides: number[],
beginMask: number, endMask: number, ellipsisMask: number,
newAxisMask: number, shrinkAxisMask: number): T {
x: T, begin: number[], end: number[], strides: number[]): T {
this.assertNotComplex(x, 'stridedSlice');

const [beginIndex, size, shrinkAxis] = getStridedSlicedInfo(
x.shape, begin, end, strides, beginMask, endMask, ellipsisMask,
newAxisMask, shrinkAxisMask);
const outShape = computeOutShape(begin, end, strides);

const shape = size.filter((v, index) => shrinkAxis.indexOf(index) === -1);

if (shape.some(axis => axis === 0)) {
return ops.tensor([], shape) as T;
if (outShape.some(axis => axis === 0)) {
return ops.tensor([], outShape) as T;
}

const buffer = ops.buffer(size, x.dtype);
const buffer = ops.buffer(outShape, x.dtype);
const xBuf = this.bufferSync(x);
for (let i = 0; i < buffer.size; i++) {
const loc = buffer.indexToLoc(i);

const newLoc: number[] = new Array(loc.length);
for (let j = 0; j < newLoc.length; j++) {
newLoc[j] = loc[j] * strides[j] + beginIndex[j];
newLoc[j] = loc[j] * strides[j] + begin[j];
}
buffer.set(xBuf.get(...newLoc), ...loc);
}

return buffer.toTensor().reshape(shape) as T;
return buffer.toTensor() as T;
}

diag(x: Tensor): Tensor {
Expand Down
26 changes: 9 additions & 17 deletions tfjs-core/src/backends/webgl/backend_webgl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import * as gather_nd_util from '../../ops/gather_nd_util';
import * as reduce_util from '../../ops/reduce_util';
import * as scatter_nd_util from '../../ops/scatter_nd_util';
import * as segment_util from '../../ops/segment_util';
import {computeFlatOffset, getStridedSlicedInfo, isSliceContinous} from '../../ops/slice_util';
import * as slice_util from '../../ops/slice_util';
import {softmax} from '../../ops/softmax';
import {range, scalar, tensor} from '../../ops/tensor_ops';
import {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../../tensor';
Expand Down Expand Up @@ -729,7 +729,7 @@ export class MathBackendWebGL implements KernelBackend {
return tensor([], size, x.dtype) as T;
}
const {isPacked} = this.texData.get(x.dataId);
const isContinous = isSliceContinous(x.shape, begin, size);
const isContinous = slice_util.isSliceContinous(x.shape, begin, size);
if (isPacked || !isContinous) {
const program = ENV.getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?
new SlicePackedProgram(size) :
Expand All @@ -749,7 +749,7 @@ export class MathBackendWebGL implements KernelBackend {
Object.assign(newTexData, xTexData);
newTexData.shape = size;
newTexData.dtype = x.dtype;
let flatOffset = computeFlatOffset(begin, x.strides);
let flatOffset = slice_util.computeFlatOffset(begin, x.strides);
if (xTexData.slice) {
// We are slicing an already sliced tensor, so we have to accumulate
// the offset.
Expand All @@ -769,26 +769,18 @@ export class MathBackendWebGL implements KernelBackend {
}

stridedSlice<T extends Tensor>(
x: T, begin: number[], end: number[], strides: number[],
beginMask: number, endMask: number, ellipsisMask: number,
newAxisMask: number, shrinkAxisMask: number): T {
x: T, begin: number[], end: number[], strides: number[]): T {
if (this.shouldExecuteOnCPU([x])) {
return this.cpuBackend.stridedSlice(
x, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask,
shrinkAxisMask);
return this.cpuBackend.stridedSlice(x, begin, end, strides);
}

const [beginIndex, size, shrinkAxis] = getStridedSlicedInfo(
x.shape, begin, end, strides, beginMask, endMask, ellipsisMask,
newAxisMask, shrinkAxisMask);
const outShape = slice_util.computeOutShape(begin, end, strides);

const shape = size.filter((v, index) => shrinkAxis.indexOf(index) === -1);
if (shape.some(axis => axis === 0)) {
return tensor([], shape) as T;
if (outShape.some(axis => axis === 0)) {
return tensor([], outShape) as T;
}

const program =
new StridedSliceProgram(beginIndex, strides, size, shrinkAxis);
const program = new StridedSliceProgram(begin, strides, outShape);
return this.compileAndRun(program, [x]);
}

Expand Down
20 changes: 7 additions & 13 deletions tfjs-core/src/backends/webgl/strided_slice_gpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,11 @@ export class StridedSliceProgram implements GPGPUProgram {
userCode: string;

constructor(
begin: number[], strides: number[], size: number[],
shrinkAxis: number[]) {
const shape = size.filter((v, index) => shrinkAxis.indexOf(index) === -1);
this.outputShape = shape;
begin: number[], strides: number[], size: number[]) {
this.outputShape = size;
const rank = size.length;
const inputDtype = getCoordsDataType(size.length);
const dtype = getCoordsDataType(shape.length);
const dtype = getCoordsDataType(size.length);

let newCoords = '';
if (rank === 1) {
Expand All @@ -39,14 +37,10 @@ export class StridedSliceProgram implements GPGPUProgram {
let outputAxis = 0;
newCoords =
size.map((_, i) => {
if (shrinkAxis.indexOf(i) === -1) {
outputAxis++;
return shape.length === 1 ?
`coords * strides[${i}] + begin[${i}]` :
`coords[${outputAxis - 1}] * strides[${i}] + begin[${i}]`;
} else {
return `begin[${i}]`;
}
outputAxis++;
return size.length === 1 ?
`coords * strides[${i}] + begin[${i}]` :
`coords[${outputAxis - 1}] * strides[${i}] + begin[${i}]`;
})
.join(',');
}
Expand Down
59 changes: 19 additions & 40 deletions tfjs-core/src/ops/slice_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,49 +38,28 @@ export function assertParamsValid(
}
}

/**
* Calculate the start index and output tensor shape for strided slice op.
* @returns array of [startIndex, size, shrinkAxis]
*/
export function getStridedSlicedInfo(
shape: number[], begin: number[], end: number[], strides: number[],
beginMask = 0, endMask = 0, ellipsisMask = 0, newAxisMask = 0,
shrinkAxisMask = 0): [number[], number[], number[]] {
if (ellipsisMask !== 0) {
throw new Error('ellipsis mask is not yet supported');
}
if (newAxisMask !== 0) {
throw new Error('new axis mask is not yet supported');
}
// Note that the axis orders are reversed for runtime ops, so the indices,
// strides and masks must be as well too.
const startIndex: number[] = [];
const endIndex: number[] = [];
const shrinkAxis: number[] = [];
for (let i = 0; i < shape.length; i++) {
startIndex[i] = startForAxis(beginMask, begin, strides, shape, i);
endIndex[i] = stopForAxis(endMask, end, strides, shape, i);
// When shrinking an axis, use startIndex + 1 for endIndex.
// Check the axis bit from right of shrinkAxisMask
if (shrinkAxisMask & 1 << i) {
endIndex[i] = startIndex[i] + 1;
shrinkAxis.push(i);
/** Converts a binary mask to an array of axes. Used in stridedSlice(). */
export function maskToAxes(mask: number): number[] {
const axes = [];
let axis = 0;
while (mask > 0) {
if (mask & 1) {
axes.push(axis);
}
mask /= 2;
axis++;
}
return axes;
}

let size = new Array(shape.length).fill(0);
size = size.map((d, i) => {
let count = 0;
const stride = strides[i] || 1;
for (let start = startIndex[i];
!(stride > 0 ? start >= endIndex[i] : start <= endIndex[i]);
start += stride) {
count += 1;
}
return count;
});

return [startIndex, size, shrinkAxis];
/** Computes the output shape given the strided slice params. */
export function computeOutShape(
begin: number[], end: number[], strides: number[]): number[] {
const size = [];
for (let axis = 0; axis < begin.length; axis++) {
size[axis] = Math.ceil((end[axis] - begin[axis]) / strides[axis]);
}
return size;
}

export function startForAxis(
Expand Down
56 changes: 40 additions & 16 deletions tfjs-core/src/ops/strided_slice.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ import {ENGINE} from '../engine';
import {Tensor} from '../tensor';
import {convertToTensor} from '../tensor_util_env';
import {TensorLike} from '../types';

import {op} from './operation';
import {slice} from './slice';
import {getStridedSlicedInfo} from './slice_util';
import {computeOutShape, maskToAxes, startForAxis, stopForAxis} from './slice_util';

/**
* Extracts a strided slice of a tensor.
Expand Down Expand Up @@ -56,30 +57,53 @@ import {getStridedSlicedInfo} from './slice_util';
*/
/** @doc {heading: 'Operations', subheading: 'Slicing and Joining'} */
function stridedSlice_(
x: Tensor|TensorLike, begin: number[], end: number[], strides: number[],
x: Tensor|TensorLike, begin: number[], end: number[], strides?: number[],
beginMask = 0, endMask = 0, ellipsisMask = 0, newAxisMask = 0,
shrinkAxisMask = 0): Tensor {
if (strides == null) {
strides = new Array(begin.length);
}
if (ellipsisMask !== 0) {
throw new Error('ellipsis mask is not yet supported');
}
if (newAxisMask !== 0) {
throw new Error('new axis mask is not yet supported');
let $x = convertToTensor(x, 'x', 'stridedSlice');

// Expand the dims of x based on the newAxisMask.
const expandAxes = maskToAxes(newAxisMask);
const newShape = $x.shape.slice();
expandAxes.forEach(axis => {
begin[axis] = 0;
end[axis] = 1;
newShape.splice(axis, 0, 1);
});
$x = $x.reshape(newShape);

// Normalize the start, end and strides.
for (let axis = 0; axis < $x.rank; axis++) {
begin[axis] = startForAxis(beginMask, begin, strides, $x.shape, axis);
end[axis] = stopForAxis(endMask, end, strides, $x.shape, axis);
strides[axis] = strides[axis] || 1;
}
const $x = convertToTensor(x, 'x', 'stridedSlice');

const shrinkAxes = maskToAxes(shrinkAxisMask);
// Adjust the ends based on the shrink mask.
shrinkAxes.forEach(axis => {
end[axis] = begin[axis] + 1;
strides[axis] = 1;
});

// Figure out the output shape.
const size = computeOutShape(begin, end, strides);
// Remove the axes based on shrinkMask.
const outShape = size.filter((_, axis) => shrinkAxes.indexOf(axis) === -1);

const nonStrided = strides.every(v => v === 1);
if (nonStrided) {
const [beginIndex, size, shrinkAxis] = getStridedSlicedInfo(
$x.shape, begin, end, strides, beginMask, endMask, ellipsisMask,
newAxisMask, shrinkAxisMask);
const outShape =
size.filter((_, index) => shrinkAxis.indexOf(index) === -1);
return slice($x, beginIndex, size).reshape(outShape);
return slice($x, begin, size).reshape(outShape);
}
return ENGINE.runKernel(
backend => backend.stridedSlice(
$x, begin, end, strides, beginMask, endMask, ellipsisMask,
newAxisMask, shrinkAxisMask),
{$x});
const res = ENGINE.runKernel(
backend => backend.stridedSlice($x, begin, end, strides), {$x});
return res.reshape(outShape);
}

export const stridedSlice = op({stridedSlice_});
50 changes: 47 additions & 3 deletions tfjs-core/src/ops/strided_slice_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,53 @@ import {ALL_ENVS, describeWithFlags} from '../jasmine_util';
import {expectArraysClose} from '../test_util';

describeWithFlags('stridedSlice', ALL_ENVS, () => {
it('stridedSlice should fail if new axis mask is set', () => {
const tensor = tf.tensor1d([0, 1, 2, 3]);
expect(() => tf.stridedSlice(tensor, [0], [3], [2], 0, 0, 0, 1)).toThrow();
it('stridedSlice with first axis being new', async () => {
// Python slice code: t[tf.newaxis,0:3]
const t = tf.tensor1d([0, 1, 2, 3]);
const begin = [0, 0];
const end = [1, 3];
const strides = [1, 2];
const beginMask = 0;
const endMask = 0;
const ellipsisMask = 0;
const newAxisMask = 1;

const output = tf.stridedSlice(
t, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask);
expect(output.shape).toEqual([1, 2]);
expectArraysClose(await output.data(), [0, 2]);
});

it('strided slice with several new axes', () => {
// Python slice code: t[1:2,tf.newaxis,0:3,tf.newaxis,2:5]
const t = tf.zeros([2, 3, 4, 5]);
const begin = [1, 0, 0, 0, 2];
const end = [2, 1, 3, 1, 5];
const strides: number[] = null;
const beginMask = 0;
const endMask = 0;
const ellipsisMask = 0;
const newAxisMask = 0b1010;
const output = tf.stridedSlice(
t, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask);
expect(output.shape).toEqual([1, 1, 3, 1, 2, 5]);
});

it('strided slice with new axes and shrink axes', () => {
// Python slice code: t[1:2,tf.newaxis,1,tf.newaxis,2,2:5]
const t = tf.zeros([2, 3, 4, 5]);
const begin = [1, 0, 1, 0, 2, 2];
const end = [2, 1, 2, 1, 3, 5];
const strides: number[] = null;
const beginMask = 0;
const endMask = 0;
const ellipsisMask = 0;
const newAxisMask = 0b1010;
const shrinkAxisMask = 0b10100;
const output = tf.stridedSlice(
t, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask,
shrinkAxisMask);
expect(output.shape).toEqual([1, 1, 1, 3]);
});

it('stridedSlice should fail if ellipsis mask is set', () => {
Expand Down
8 changes: 8 additions & 0 deletions tfjs-core/src/tensor_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1461,6 +1461,14 @@ describeWithFlags('tensor', ALL_ENVS, () => {
expect(res.shape).toEqual([0, 0]);
});

it('squeeze can take an empty list of axis', () => {
const a = tf.zeros([2, 1, 3, 1, 4]);
const axes: number[] = [];
// Empty axes list means all possible axes.
const res = tf.squeeze(a, axes);
expect(res.shape).toEqual([2, 3, 4]);
});

it('squeeze a complex64 tensor', async () => {
const a = tf.complex([[4], [1], [5]], [[2], [3], [6]]);
const b = a.squeeze();
Expand Down
5 changes: 4 additions & 1 deletion tfjs-core/src/util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,10 @@ export function squeezeShape(shape: number[], axis?: number[]):
{newShape: number[], keptDims: number[]} {
const newShape: number[] = [];
const keptDims: number[] = [];
const axes = axis == null ? null : parseAxisParam(axis, shape).sort();
const isEmptyArray = axis != null && Array.isArray(axis) && axis.length === 0;
const axes = (axis == null || isEmptyArray) ?
null :
parseAxisParam(axis, shape).sort();
let j = 0;
for (let i = 0; i < shape.length; ++i) {
if (axes != null) {
Expand Down