From f320ef61d06b26cb31df76542629dbd9999ec021 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Mon, 15 Jun 2020 13:02:47 -0400 Subject: [PATCH 01/14] init --- tfjs-core/src/ops/array_ops.ts | 176 +--------- tfjs-core/src/ops/array_ops_test.ts | 479 -------------------------- tfjs-core/src/ops/expand_dims.ts | 60 ++++ tfjs-core/src/ops/expand_dims_test.ts | 151 ++++++++ tfjs-core/src/ops/ops.ts | 5 + tfjs-core/src/ops/reshape.ts | 66 ++++ tfjs-core/src/ops/squeeze.ts | 45 +++ tfjs-core/src/ops/stack.ts | 74 ++++ tfjs-core/src/ops/stack_test.ts | 123 +++++++ tfjs-core/src/ops/unstack.ts | 58 ++++ tfjs-core/src/ops/unstack_test.ts | 265 ++++++++++++++ 11 files changed, 848 insertions(+), 654 deletions(-) create mode 100644 tfjs-core/src/ops/expand_dims.ts create mode 100644 tfjs-core/src/ops/expand_dims_test.ts create mode 100644 tfjs-core/src/ops/reshape.ts create mode 100644 tfjs-core/src/ops/squeeze.ts create mode 100644 tfjs-core/src/ops/stack.ts create mode 100644 tfjs-core/src/ops/stack_test.ts create mode 100644 tfjs-core/src/ops/unstack.ts create mode 100644 tfjs-core/src/ops/unstack_test.ts diff --git a/tfjs-core/src/ops/array_ops.ts b/tfjs-core/src/ops/array_ops.ts index 5c6dc23e328..6f7f396ebf8 100644 --- a/tfjs-core/src/ops/array_ops.ts +++ b/tfjs-core/src/ops/array_ops.ts @@ -17,73 +17,12 @@ import {ENGINE} from '../engine'; import {Tensor, TensorBuffer} from '../tensor'; -import {convertToTensor, convertToTensorArray} from '../tensor_util_env'; +import {convertToTensor} from '../tensor_util_env'; import {DataType, DataTypeMap, Rank, ShapeMap, TensorLike} from '../types'; import * as util from '../util'; -import {concat} from './concat'; import {op} from './operation'; -/** - * Reshapes a `tf.Tensor` to a given shape. - * - * Given an input tensor, returns a new tensor with the same values as the - * input tensor with shape `shape`. - * - * If one component of shape is the special value -1, the size of that - * dimension is computed so that the total size remains constant. In - * particular, a shape of [-1] flattens into 1-D. At most one component of - * shape can be -1. - * - * If shape is 1-D or higher, then the operation returns a tensor with shape - * shape filled with the values of tensor. In this case, the number of - * elements implied by shape must be the same as the number of elements in - * tensor. - * - * ```js - * const x = tf.tensor1d([1, 2, 3, 4]); - * x.reshape([2, 2]).print(); - * ``` - * - * @param x The input tensor to be reshaped. - * @param shape An array of integers defining the output tensor shape. - */ -/** @doc {heading: 'Tensors', subheading: 'Transformations'} */ -function reshape_( - x: Tensor|TensorLike, shape: ShapeMap[R2]): Tensor { - const $x = convertToTensor(x, 'x', 'reshape', null); - shape = util.inferFromImplicitShape(shape, $x.size) as ShapeMap[R2]; - util.assert( - $x.size === util.sizeFromShape(shape), - () => 'new shape and old shape must have the same number of elements.'); - - const grad = (dy: Tensor) => { - return {x: () => dy.reshape($x.shape)}; - }; - const attrs = {shape}; - return ENGINE.runKernelFunc( - backend => backend.reshape($x, shape), {x: $x}, grad, 'Reshape', attrs); -} - -/** - * Removes dimensions of size 1 from the shape of a `tf.Tensor`. - * - * ```js - * const x = tf.tensor([1, 2, 3, 4], [1, 1, 4]); - * x.squeeze().print(); - * ``` - * - * @param x The input tensor to be squeezed. - * @param axis An optional list of numbers. If specified, only - * squeezes the dimensions listed. The dimension index starts at 0. It - * is an error to squeeze a dimension that is not 1. - */ -/** @doc {heading: 'Tensors', subheading: 'Transformations'} */ -function squeeze_(x: Tensor|TensorLike, axis?: number[]): T { - const $x = convertToTensor(x, 'x', 'squeeze'); - return reshape($x, util.squeezeShape($x.shape, axis).newShape) as T; -} - /** * Casts a `tf.Tensor` to a new dtype. * @@ -115,114 +54,6 @@ function cast_(x: T|TensorLike, dtype: DataType): T { backend => backend.cast($x, dtype), {x: $x}, grad, 'Cast', attrs); } -/** - * Stacks a list of rank-`R` `tf.Tensor`s into one rank-`(R+1)` `tf.Tensor`. - * - * ```js - * const a = tf.tensor1d([1, 2]); - * const b = tf.tensor1d([3, 4]); - * const c = tf.tensor1d([5, 6]); - * tf.stack([a, b, c]).print(); - * ``` - * - * @param tensors A list of tensor objects with the same shape and dtype. - * @param axis The axis to stack along. Defaults to 0 (the first dim). - */ -/** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ -function stack_( - tensors: Array, axis = 0): Tensor { - const $tensors = convertToTensorArray(tensors, 'tensors', 'stack'); - - util.assert( - $tensors.length >= 1, () => 'Pass at least one tensor to tf.stack'); - if ($tensors.length === 1) { - return $tensors[0].expandDims(axis); - } - const rank = $tensors[0].rank; - const shape = $tensors[0].shape; - const dtype = $tensors[0].dtype; - - util.assert(axis <= rank, () => 'Axis must be <= rank of the tensor'); - - $tensors.forEach(t => { - util.assertShapesMatch( - shape, t.shape, - 'All tensors passed to stack must have matching shapes'); - }); - - $tensors.forEach(t => { - util.assert( - dtype === t.dtype, - () => 'All tensors passed to stack must have matching dtypes'); - }); - const expandedTensors = $tensors.map(t => t.expandDims(axis)); - return concat(expandedTensors, axis); -} - -/** - * Unstacks a `tf.Tensor` of rank-`R` into a list of rank-`(R-1)` `tf.Tensor`s. - * - * ```js - * const a = tf.tensor2d([1, 2, 3, 4], [2, 2]); - * - * tf.unstack(a).forEach(tensor => tensor.print()); - * ``` - * - * @param x A tensor object. - * @param axis The axis to unstack along. Defaults to 0 (the first dim). - */ -/** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ -function unstack_(x: Tensor|TensorLike, axis = 0): Tensor[] { - axis = axis || 0; - const $x = convertToTensor(x, 'x', 'unstack'); - util.assert( - axis >= -$x.shape.length && axis < $x.shape.length, - () => - `Axis = ${axis} is not in [-${$x.shape.length}, ${$x.shape.length})`); - if (axis < 0) { - axis += $x.shape.length; - } - const grad = (dy: Tensor[]) => { - return {x: () => stack(dy, axis)}; - }; - const attrs = {axis}; - return ENGINE.runKernelFunc( - backend => backend.unstack($x, axis), {x: $x}, grad, 'Unpack', attrs); -} - -/** - * Returns a `tf.Tensor` that has expanded rank, by inserting a dimension - * into the tensor's shape. - * - * ```js - * const x = tf.tensor1d([1, 2, 3, 4]); - * const axis = 1; - * x.expandDims(axis).print(); - * ``` - * - * @param x The input tensor whose dimensions to be expanded. - * @param axis The dimension index at which to insert shape of `1`. Defaults - * to 0 (the first dimension). - */ -/** @doc {heading: 'Tensors', subheading: 'Transformations'} */ -function expandDims_( - x: Tensor|TensorLike, axis = 0): Tensor { - const parseAs: DataType = null; - const $x = convertToTensor(x, 'x', 'expandDims', parseAs); - - util.assert(axis <= $x.rank, () => 'Axis must be <= rank of the tensor'); - const newShape = $x.shape.slice(); - if (axis < 0) { - // Negative value is counted from the tail of rank. - util.assert( - -($x.rank + 1) <= axis, - () => `Axis must be in the interval [${- ($x.rank + 1)}, ${$x.rank}]`); - axis = $x.rank + axis + 1; - } - newShape.splice(axis, 0, 1); - return reshape($x, newShape as ShapeMap[R2]); -} - /** * Computes the difference between two lists of numbers. * @@ -344,9 +175,4 @@ export { }; export const cast = op({cast_}); -export const expandDims = op({expandDims_}); -export const reshape = op({reshape_}); -export const squeeze = op({squeeze_}); -export const stack = op({stack_}); -export const unstack = op({unstack_}); export const setdiff1dAsync = setdiff1dAsync_; diff --git a/tfjs-core/src/ops/array_ops_test.ts b/tfjs-core/src/ops/array_ops_test.ts index be263e49b44..65dcf71467e 100644 --- a/tfjs-core/src/ops/array_ops_test.ts +++ b/tfjs-core/src/ops/array_ops_test.ts @@ -2161,354 +2161,6 @@ describeWithFlags('fill', ALL_ENVS, () => { }); }); -describeWithFlags('stack', ALL_ENVS, () => { - it('scalars 3, 5 and 7', async () => { - const a = tf.scalar(3); - const b = tf.scalar(5); - const c = tf.scalar(7); - const res = tf.stack([a, b, c]); - expect(res.shape).toEqual([3]); - expectArraysClose(await res.data(), [3, 5, 7]); - }); - - it('scalars 3, 5 and 7 along axis=1 throws error', () => { - const a = tf.scalar(3); - const b = tf.scalar(5); - const c = tf.scalar(7); - const f = () => tf.stack([a, b, c], 1); - expect(f).toThrowError(); - }); - - it('non matching shapes throws error', () => { - const a = tf.scalar(3); - const b = tf.tensor1d([5]); - const f = () => tf.stack([a, b]); - expect(f).toThrowError(); - }); - - it('non matching dtypes throws error', () => { - const a = tf.scalar(3); - const b = tf.scalar(5, 'bool'); - const f = () => tf.stack([a, b]); - expect(f).toThrowError(); - }); - - it('2d but axis=3 throws error', () => { - const a = tf.zeros([2, 2]); - const b = tf.zeros([2, 2]); - const f = () => tf.stack([a, b], 3 /* axis */); - expect(f).toThrowError(); - }); - - it('[1,2], [3,4] and [5,6], axis=0', async () => { - const a = tf.tensor1d([1, 2]); - const b = tf.tensor1d([3, 4]); - const c = tf.tensor1d([5, 6]); - const res = tf.stack([a, b, c], 0 /* axis */); - expect(res.shape).toEqual([3, 2]); - expectArraysClose(await res.data(), [1, 2, 3, 4, 5, 6]); - }); - - it('[1,2], [3,4] and [5,6], axis=1', async () => { - const a = tf.tensor1d([1, 2]); - const b = tf.tensor1d([3, 4]); - const c = tf.tensor1d([5, 6]); - const res = tf.stack([a, b, c], 1 /* axis */); - expect(res.shape).toEqual([2, 3]); - expectArraysClose(await res.data(), [1, 3, 5, 2, 4, 6]); - }); - - it('[[1,2],[3,4]] and [[5, 6], [7, 8]], axis=0', async () => { - const a = tf.tensor2d([[1, 2], [3, 4]]); - const b = tf.tensor2d([[5, 6], [7, 8]]); - const res = tf.stack([a, b], 0 /* axis */); - expect(res.shape).toEqual([2, 2, 2]); - expectArraysClose(await res.data(), [1, 2, 3, 4, 5, 6, 7, 8]); - }); - - it('[[1,2],[3,4]] and [[5, 6], [7, 8]], axis=2', async () => { - const a = tf.tensor2d([[1, 2], [3, 4]]); - const b = tf.tensor2d([[5, 6], [7, 8]]); - const c = tf.tensor2d([[9, 10], [11, 12]]); - const res = tf.stack([a, b, c], 2 /* axis */); - expect(res.shape).toEqual([2, 2, 3]); - expectArraysClose( - await res.data(), [1, 5, 9, 2, 6, 10, 3, 7, 11, 4, 8, 12]); - }); - - it('single tensor', async () => { - const a = tf.tensor2d([[1, 2], [3, 4]]); - const res = tf.stack([a], 2 /* axis */); - expect(res.shape).toEqual([2, 2, 1]); - expectArraysClose(await res.data(), [1, 2, 3, 4]); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.stack([{} as tf.Tensor])) - .toThrowError( - /Argument 'tensors\[0\]' passed to 'stack' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const a = [[1, 2], [3, 4]]; - const res = tf.stack([a], 2 /* axis */); - expect(res.shape).toEqual([2, 2, 1]); - expectArraysClose(await res.data(), [1, 2, 3, 4]); - }); - - it('chain api', async () => { - const a = tf.tensor([1, 2]); - const res = a.stack(tf.tensor([3, 4])); - expect(res.shape).toEqual([2, 2]); - expectArraysClose(await res.data(), [1, 2, 3, 4]); - }); -}); - -describeWithFlags('unstack', ALL_ENVS, () => { - it('unstack by default', async () => { - const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]); - const res = tf.unstack(x); - expect(res.length).toEqual(2); - expect(res[0].rank).toEqual(1); - expect(res[0].shape).toEqual([4]); - expectArraysClose(await res[0].data(), [1, 2, 3, 4]); - expect(res[1].rank).toEqual(1); - expect(res[1].shape).toEqual([4]); - expectArraysClose(await res[1].data(), [5, 6, 7, 8]); - }); - - it('chain api', async () => { - const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]); - const res = x.unstack(); - expect(res.length).toEqual(2); - expect(res[0].rank).toEqual(1); - expect(res[0].shape).toEqual([4]); - expectArraysClose(await res[0].data(), [1, 2, 3, 4]); - expect(res[1].rank).toEqual(1); - expect(res[1].shape).toEqual([4]); - expectArraysClose(await res[1].data(), [5, 6, 7, 8]); - }); - - it('unstack with negative integer axis', async () => { - const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]); - - let res = tf.unstack(x, -1); - expect(res.length).toEqual(4); - expect(res[0].rank).toEqual(1); - expect(res[0].shape).toEqual([2]); - expectArraysClose(await res[0].data(), [1, 5]); - expect(res[1].rank).toEqual(1); - expect(res[1].shape).toEqual([2]); - expectArraysClose(await res[1].data(), [2, 6]); - expect(res[2].rank).toEqual(1); - expect(res[2].shape).toEqual([2]); - expectArraysClose(await res[2].data(), [3, 7]); - expect(res[3].rank).toEqual(1); - expect(res[3].shape).toEqual([2]); - expectArraysClose(await res[3].data(), [4, 8]); - - res = tf.unstack(x, -2); - expect(res.length).toEqual(2); - expect(res[0].rank).toEqual(1); - expect(res[0].shape).toEqual([4]); - expectArraysClose(await res[0].data(), [1, 2, 3, 4]); - expect(res[1].rank).toEqual(1); - expect(res[1].shape).toEqual([4]); - expectArraysClose(await res[1].data(), [5, 6, 7, 8]); - }); - - it('unstack into 3 tensors', async () => { - const x = tf.tensor2d([1, 2, 3, 4, 5, 6], [3, 2]); - const res = tf.unstack(x, 0); - expect(res.length).toEqual(3); - expect(res[0].rank).toEqual(1); - expect(res[0].shape).toEqual([2]); - expectArraysClose(await res[0].data(), [1, 2]); - expect(res[1].rank).toEqual(1); - expect(res[1].shape).toEqual([2]); - expectArraysClose(await res[1].data(), [3, 4]); - expect(res[2].rank).toEqual(1); - expect(res[2].shape).toEqual([2]); - expectArraysClose(await res[2].data(), [5, 6]); - }); - - it('unstack by axis=1', async () => { - const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]); - const res = tf.unstack(x, 1); - expect(res.length).toEqual(4); - expect(res[0].rank).toEqual(1); - expect(res[0].shape).toEqual([2]); - expectArraysClose(await res[0].data(), [1, 5]); - expect(res[1].rank).toEqual(1); - expect(res[1].shape).toEqual([2]); - expectArraysClose(await res[1].data(), [2, 6]); - expect(res[2].rank).toEqual(1); - expect(res[2].shape).toEqual([2]); - expectArraysClose(await res[2].data(), [3, 7]); - expect(res[3].rank).toEqual(1); - expect(res[3].shape).toEqual([2]); - expectArraysClose(await res[3].data(), [4, 8]); - }); - - it('unstack rank 3 tensor', async () => { - const x = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2]); - const res = tf.unstack(x); - expect(res.length).toEqual(2); - expect(res[0].rank).toEqual(2); - expect(res[0].shape).toEqual([2, 2]); - expectArraysClose(await res[0].data(), [1, 2, 3, 4]); - expect(res[1].rank).toEqual(2); - expect(res[1].shape).toEqual([2, 2]); - expectArraysClose(await res[1].data(), [5, 6, 7, 8]); - }); - - it('unstack rank 3 tensor with axis=1', async () => { - const x = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2]); - const res = tf.unstack(x, 1); - expect(res.length).toEqual(2); - expect(res[0].rank).toEqual(2); - expect(res[0].shape).toEqual([2, 2]); - expectArraysClose(await res[0].data(), [1, 2, 5, 6]); - expect(res[1].rank).toEqual(2); - expect(res[1].shape).toEqual([2, 2]); - expectArraysClose(await res[1].data(), [3, 4, 7, 8]); - }); - - it('unstack rank 3 tensor with axis=2', async () => { - const x = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2]); - const res = tf.unstack(x, 2); - expect(res.length).toEqual(2); - expect(res[0].rank).toEqual(2); - expect(res[0].shape).toEqual([2, 2]); - expectArraysClose(await res[0].data(), [1, 3, 5, 7]); - expect(res[1].rank).toEqual(2); - expect(res[1].shape).toEqual([2, 2]); - expectArraysClose(await res[1].data(), [2, 4, 6, 8]); - }); - - it('unstack rank 4 tensor', async () => { - const x = tf.tensor4d([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2, 1]); - const res = tf.unstack(x); - expect(res.length).toEqual(2); - expect(res[0].rank).toEqual(3); - expect(res[0].shape).toEqual([2, 2, 1]); - expectArraysClose(await res[0].data(), [1, 2, 3, 4]); - expect(res[1].rank).toEqual(3); - expect(res[1].shape).toEqual([2, 2, 1]); - expectArraysClose(await res[1].data(), [5, 6, 7, 8]); - }); - - it('unstack rank 4 tensor with axis=1', async () => { - const x = tf.tensor4d([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2, 1]); - const res = tf.unstack(x, 1); - expect(res.length).toEqual(2); - expect(res[0].rank).toEqual(3); - expect(res[0].shape).toEqual([2, 2, 1]); - expectArraysClose(await res[0].data(), [1, 2, 5, 6]); - expect(res[1].rank).toEqual(3); - expect(res[1].shape).toEqual([2, 2, 1]); - expectArraysClose(await res[1].data(), [3, 4, 7, 8]); - }); - - it('unstack rank 4 tensor with axis=2', async () => { - const x = tf.tensor4d([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2, 1]); - const res = tf.unstack(x, 2); - expect(res.length).toEqual(2); - expect(res[0].rank).toEqual(3); - expect(res[0].shape).toEqual([2, 2, 1]); - expectArraysClose(await res[0].data(), [1, 3, 5, 7]); - expect(res[1].rank).toEqual(3); - expect(res[1].shape).toEqual([2, 2, 1]); - expectArraysClose(await res[1].data(), [2, 4, 6, 8]); - }); - - it('unstack rank 4 tensor with axis=3', async () => { - const x = tf.tensor4d([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2, 1]); - const res = tf.unstack(x, 3); - expect(res.length).toEqual(1); - expect(res[0].rank).toEqual(3); - expect(res[0].shape).toEqual([2, 2, 2]); - expectArraysClose(await res[0].data(), [1, 2, 3, 4, 5, 6, 7, 8]); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.unstack({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'unstack' must be a Tensor/); - }); - - it('throws when passed an invalid axis', () => { - expect(() => { - const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]); - tf.unstack(x, 3); - }).toThrowError('Axis = 3 is not in [-2, 2)'); - expect(() => { - const x = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2]); - tf.unstack(x, 3); - }).toThrowError('Axis = 3 is not in [-3, 3)'); - expect(() => { - const x = tf.tensor4d([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2, 1]); - tf.unstack(x, 5); - }).toThrowError('Axis = 5 is not in [-4, 4)'); - }); - - it('accepts a tensor-like object', async () => { - const x = [[1, 2, 3, 4], [5, 6, 7, 8]]; - const res = tf.unstack(x); - expect(res.length).toEqual(2); - expect(res[0].rank).toEqual(1); - expect(res[0].shape).toEqual([4]); - expectArraysClose(await res[0].data(), [1, 2, 3, 4]); - expect(res[1].rank).toEqual(1); - expect(res[1].shape).toEqual([4]); - expectArraysClose(await res[1].data(), [5, 6, 7, 8]); - }); - - it('grad of unstack axis=0', async () => { - const x = tf.tensor([[1, 2, 3], [4, 5, 6]]); - const dx1 = tf.grad(x => tf.unstack(x)[0])(x); - expect(dx1.shape).toEqual([2, 3]); - expect(dx1.dtype).toBe('float32'); - expectArraysClose(await dx1.data(), [1, 1, 1, 0, 0, 0]); - - const dx2 = tf.grad(x => tf.unstack(x)[1])(x); - expect(dx2.shape).toEqual([2, 3]); - expect(dx2.dtype).toBe('float32'); - expectArraysClose(await dx2.data(), [0, 0, 0, 1, 1, 1]); - }); - - it('gradient with clones', async () => { - const x = tf.tensor([[1, 2, 3], [4, 5, 6]]); - const dx1 = tf.grad(x => tf.unstack(x.clone())[0].clone())(x); - expect(dx1.shape).toEqual([2, 3]); - expect(dx1.dtype).toBe('float32'); - expectArraysClose(await dx1.data(), [1, 1, 1, 0, 0, 0]); - - const dx2 = tf.grad(x => tf.unstack(x.clone())[1].clone())(x); - expect(dx2.shape).toEqual([2, 3]); - expect(dx2.dtype).toBe('float32'); - expectArraysClose(await dx2.data(), [0, 0, 0, 1, 1, 1]); - }); - - it('grad of unstack axis=1', async () => { - const x = tf.tensor([[1, 2, 3], [4, 5, 6]]); - const axis = 1; - const dx1 = tf.grad(x => tf.unstack(x, axis)[0])(x); - expect(dx1.shape).toEqual([2, 3]); - expect(dx1.dtype).toBe('float32'); - expectArraysClose(await dx1.data(), [1, 0, 0, 1, 0, 0]); - - const dx2 = tf.grad(x => tf.unstack(x, axis)[1])(x); - expect(dx2.shape).toEqual([2, 3]); - expect(dx2.dtype).toBe('float32'); - expectArraysClose(await dx2.data(), [0, 1, 0, 0, 1, 0]); - - const dx3 = tf.grad(x => tf.unstack(x, axis)[2])(x); - expect(dx3.shape).toEqual([2, 3]); - expect(dx3.dtype).toBe('float32'); - expectArraysClose(await dx3.data(), [0, 0, 1, 0, 0, 1]); - }); -}); - describeWithFlags('split', ALL_ENVS, () => { it('split by number', async () => { const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]); @@ -2629,137 +2281,6 @@ describeWithFlags('split', ALL_ENVS, () => { }); }); -describeWithFlags('expandDims', ALL_ENVS, () => { - it('scalar, default axis is 0', async () => { - const res = tf.scalar(1).expandDims(); - expect(res.shape).toEqual([1]); - expectArraysClose(await res.data(), [1]); - }); - - it('scalar, axis is out of bounds throws error', () => { - const f = () => tf.scalar(1).expandDims(1); - expect(f).toThrowError(); - }); - - it('1d, axis=-3', () => { - expect(() => { - tf.tensor1d([1, 2, 3]).expandDims(-3); - }).toThrowError('Axis must be in the interval [-2, 1]'); - }); - - it('1d, axis=-2', async () => { - const res = tf.tensor1d([1, 2, 3]).expandDims(-2 /* axis */); - expect(res.shape).toEqual([1, 3]); - expectArraysClose(await res.data(), [1, 2, 3]); - }); - - it('1d, axis=-1', async () => { - const res = tf.tensor1d([1, 2, 3]).expandDims(-1 /* axis */); - expect(res.shape).toEqual([3, 1]); - expectArraysClose(await res.data(), [1, 2, 3]); - }); - - it('1d, axis=0', async () => { - const res = tf.tensor1d([1, 2, 3]).expandDims(0 /* axis */); - expect(res.shape).toEqual([1, 3]); - expectArraysClose(await res.data(), [1, 2, 3]); - }); - - it('1d, axis=1', async () => { - const res = tf.tensor1d([1, 2, 3]).expandDims(1 /* axis */); - expect(res.shape).toEqual([3, 1]); - expectArraysClose(await res.data(), [1, 2, 3]); - }); - - it('2d, axis=-4', () => { - expect(() => { - tf.tensor2d([[1, 2], [3, 4], [5, 6]]).expandDims(-4 /* axis */); - }).toThrowError('Axis must be in the interval [-3, 2]'); - }); - - it('2d, axis=-3', async () => { - const res = tf.tensor2d([[1, 2], [3, 4], [5, 6]]).expandDims(-3 /* axis */); - expect(res.shape).toEqual([1, 3, 2]); - expectArraysClose(await res.data(), [1, 2, 3, 4, 5, 6]); - }); - - it('2d, axis=-2', async () => { - const res = tf.tensor2d([[1, 2], [3, 4], [5, 6]]).expandDims(-2 /* axis */); - expect(res.shape).toEqual([3, 1, 2]); - expectArraysClose(await res.data(), [1, 2, 3, 4, 5, 6]); - }); - - it('2d, axis=-1', async () => { - const res = tf.tensor2d([[1, 2], [3, 4], [5, 6]]).expandDims(-1 /* axis */); - expect(res.shape).toEqual([3, 2, 1]); - expectArraysClose(await res.data(), [1, 2, 3, 4, 5, 6]); - }); - - it('2d, axis=0', async () => { - const res = tf.tensor2d([[1, 2], [3, 4], [5, 6]]).expandDims(0 /* axis */); - expect(res.shape).toEqual([1, 3, 2]); - expectArraysClose(await res.data(), [1, 2, 3, 4, 5, 6]); - }); - - it('2d, axis=1', async () => { - const res = tf.tensor2d([[1, 2], [3, 4], [5, 6]]).expandDims(1 /* axis */); - expect(res.shape).toEqual([3, 1, 2]); - expectArraysClose(await res.data(), [1, 2, 3, 4, 5, 6]); - }); - - it('2d, axis=2', async () => { - const res = tf.tensor2d([[1, 2], [3, 4], [5, 6]]).expandDims(2 /* axis */); - expect(res.shape).toEqual([3, 2, 1]); - expectArraysClose(await res.data(), [1, 2, 3, 4, 5, 6]); - }); - - it('4d, axis=0', async () => { - const res = tf.tensor4d([[[[4]]]]).expandDims(); - expect(res.shape).toEqual([1, 1, 1, 1, 1]); - expectArraysClose(await res.data(), [4]); - }); - - it('1d string tensor', async () => { - const t = tf.tensor(['hello', 'world']); - const res = t.expandDims(); - expect(res.shape).toEqual([1, 2]); - expectArraysClose(await res.data(), ['hello', 'world']); - }); - - it('2d string tensor, axis=1', async () => { - const t = tf.tensor([['a', 'b'], ['c', 'd']]); - const res = t.expandDims(1); - expect(res.shape).toEqual([2, 1, 2]); - expectArraysClose(await res.data(), ['a', 'b', 'c', 'd']); - }); - - it('throws when passed a non-tensor', () => { - expect(() => tf.expandDims({} as tf.Tensor)) - .toThrowError(/Argument 'x' passed to 'expandDims' must be a Tensor/); - }); - - it('accepts a tensor-like object', async () => { - const res = tf.expandDims(7); - expect(res.shape).toEqual([1]); - expectArraysClose(await res.data(), [7]); - }); - - it('works with 0 in shape', async () => { - const a = tf.tensor2d([], [0, 3]); - const res = a.expandDims(); - expect(res.shape).toEqual([1, 0, 3]); - expectArraysClose(await res.data(), []); - - const res2 = a.expandDims(1); - expect(res2.shape).toEqual([0, 1, 3]); - expectArraysClose(await res2.data(), []); - - const res3 = a.expandDims(2); - expect(res3.shape).toEqual([0, 3, 1]); - expectArraysClose(await res3.data(), []); - }); -}); - describeWithFlags('setdiff1dAsync', ALL_ENVS, () => { it('1d int32 tensor', async () => { const x = tf.tensor1d([1, 2, 3, 4], 'int32'); diff --git a/tfjs-core/src/ops/expand_dims.ts b/tfjs-core/src/ops/expand_dims.ts new file mode 100644 index 00000000000..a78e40a70e4 --- /dev/null +++ b/tfjs-core/src/ops/expand_dims.ts @@ -0,0 +1,60 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {Tensor, TensorBuffer} from '../tensor'; +import {convertToTensor} from '../tensor_util_env'; +import {DataType, DataTypeMap, Rank, ShapeMap, TensorLike} from '../types'; +import * as util from '../util'; + +import {op} from './operation'; +import {reshape} from './reshape'; + +/** + * Returns a `tf.Tensor` that has expanded rank, by inserting a dimension + * into the tensor's shape. + * + * ```js + * const x = tf.tensor1d([1, 2, 3, 4]); + * const axis = 1; + * x.expandDims(axis).print(); + * ``` + * + * @param x The input tensor whose dimensions to be expanded. + * @param axis The dimension index at which to insert shape of `1`. Defaults + * to 0 (the first dimension). + */ +/** @doc {heading: 'Tensors', subheading: 'Transformations'} */ +function expandDims_( + x: Tensor|TensorLike, axis = 0): Tensor { + const parseAs: DataType = null; + const $x = convertToTensor(x, 'x', 'expandDims', parseAs); + + util.assert(axis <= $x.rank, () => 'Axis must be <= rank of the tensor'); + const newShape = $x.shape.slice(); + if (axis < 0) { + // Negative value is counted from the tail of rank. + util.assert( + -($x.rank + 1) <= axis, + () => `Axis must be in the interval [${- ($x.rank + 1)}, ${$x.rank}]`); + axis = $x.rank + axis + 1; + } + newShape.splice(axis, 0, 1); + return reshape($x, newShape as ShapeMap[R2]); +} + +export const expandDims = op({expandDims_}); diff --git a/tfjs-core/src/ops/expand_dims_test.ts b/tfjs-core/src/ops/expand_dims_test.ts new file mode 100644 index 00000000000..8b405e98f06 --- /dev/null +++ b/tfjs-core/src/ops/expand_dims_test.ts @@ -0,0 +1,151 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('expandDims', ALL_ENVS, () => { + it('scalar, default axis is 0', async () => { + const res = tf.scalar(1).expandDims(); + expect(res.shape).toEqual([1]); + expectArraysClose(await res.data(), [1]); + }); + + it('scalar, axis is out of bounds throws error', () => { + const f = () => tf.scalar(1).expandDims(1); + expect(f).toThrowError(); + }); + + it('1d, axis=-3', () => { + expect(() => { + tf.tensor1d([1, 2, 3]).expandDims(-3); + }).toThrowError('Axis must be in the interval [-2, 1]'); + }); + + it('1d, axis=-2', async () => { + const res = tf.tensor1d([1, 2, 3]).expandDims(-2 /* axis */); + expect(res.shape).toEqual([1, 3]); + expectArraysClose(await res.data(), [1, 2, 3]); + }); + + it('1d, axis=-1', async () => { + const res = tf.tensor1d([1, 2, 3]).expandDims(-1 /* axis */); + expect(res.shape).toEqual([3, 1]); + expectArraysClose(await res.data(), [1, 2, 3]); + }); + + it('1d, axis=0', async () => { + const res = tf.tensor1d([1, 2, 3]).expandDims(0 /* axis */); + expect(res.shape).toEqual([1, 3]); + expectArraysClose(await res.data(), [1, 2, 3]); + }); + + it('1d, axis=1', async () => { + const res = tf.tensor1d([1, 2, 3]).expandDims(1 /* axis */); + expect(res.shape).toEqual([3, 1]); + expectArraysClose(await res.data(), [1, 2, 3]); + }); + + it('2d, axis=-4', () => { + expect(() => { + tf.tensor2d([[1, 2], [3, 4], [5, 6]]).expandDims(-4 /* axis */); + }).toThrowError('Axis must be in the interval [-3, 2]'); + }); + + it('2d, axis=-3', async () => { + const res = tf.tensor2d([[1, 2], [3, 4], [5, 6]]).expandDims(-3 /* axis */); + expect(res.shape).toEqual([1, 3, 2]); + expectArraysClose(await res.data(), [1, 2, 3, 4, 5, 6]); + }); + + it('2d, axis=-2', async () => { + const res = tf.tensor2d([[1, 2], [3, 4], [5, 6]]).expandDims(-2 /* axis */); + expect(res.shape).toEqual([3, 1, 2]); + expectArraysClose(await res.data(), [1, 2, 3, 4, 5, 6]); + }); + + it('2d, axis=-1', async () => { + const res = tf.tensor2d([[1, 2], [3, 4], [5, 6]]).expandDims(-1 /* axis */); + expect(res.shape).toEqual([3, 2, 1]); + expectArraysClose(await res.data(), [1, 2, 3, 4, 5, 6]); + }); + + it('2d, axis=0', async () => { + const res = tf.tensor2d([[1, 2], [3, 4], [5, 6]]).expandDims(0 /* axis */); + expect(res.shape).toEqual([1, 3, 2]); + expectArraysClose(await res.data(), [1, 2, 3, 4, 5, 6]); + }); + + it('2d, axis=1', async () => { + const res = tf.tensor2d([[1, 2], [3, 4], [5, 6]]).expandDims(1 /* axis */); + expect(res.shape).toEqual([3, 1, 2]); + expectArraysClose(await res.data(), [1, 2, 3, 4, 5, 6]); + }); + + it('2d, axis=2', async () => { + const res = tf.tensor2d([[1, 2], [3, 4], [5, 6]]).expandDims(2 /* axis */); + expect(res.shape).toEqual([3, 2, 1]); + expectArraysClose(await res.data(), [1, 2, 3, 4, 5, 6]); + }); + + it('4d, axis=0', async () => { + const res = tf.tensor4d([[[[4]]]]).expandDims(); + expect(res.shape).toEqual([1, 1, 1, 1, 1]); + expectArraysClose(await res.data(), [4]); + }); + + it('1d string tensor', async () => { + const t = tf.tensor(['hello', 'world']); + const res = t.expandDims(); + expect(res.shape).toEqual([1, 2]); + expectArraysClose(await res.data(), ['hello', 'world']); + }); + + it('2d string tensor, axis=1', async () => { + const t = tf.tensor([['a', 'b'], ['c', 'd']]); + const res = t.expandDims(1); + expect(res.shape).toEqual([2, 1, 2]); + expectArraysClose(await res.data(), ['a', 'b', 'c', 'd']); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.expandDims({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'expandDims' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const res = tf.expandDims(7); + expect(res.shape).toEqual([1]); + expectArraysClose(await res.data(), [7]); + }); + + it('works with 0 in shape', async () => { + const a = tf.tensor2d([], [0, 3]); + const res = a.expandDims(); + expect(res.shape).toEqual([1, 0, 3]); + expectArraysClose(await res.data(), []); + + const res2 = a.expandDims(1); + expect(res2.shape).toEqual([0, 1, 3]); + expectArraysClose(await res2.data(), []); + + const res3 = a.expandDims(2); + expect(res3.shape).toEqual([0, 3, 1]); + expectArraysClose(await res3.data(), []); + }); +}); diff --git a/tfjs-core/src/ops/ops.ts b/tfjs-core/src/ops/ops.ts index 9797cd5e548..4e6ef94dd4b 100644 --- a/tfjs-core/src/ops/ops.ts +++ b/tfjs-core/src/ops/ops.ts @@ -48,6 +48,7 @@ export {divNoNan} from './div_no_nan'; export {dot} from './dot'; export {elu} from './elu'; export {equal} from './equal'; +export {expandDims} from './expand_dims'; export {eye} from './eye'; export {fill} from './fill'; export {floorDiv} from './floorDiv'; @@ -86,15 +87,19 @@ export {randomUniform} from './random_uniform'; export {real} from './real'; export {relu} from './relu'; export {relu6} from './relu6'; +export {reshape} from './reshape'; export {selu} from './selu'; export {separableConv2d} from './separable_conv2d'; export {spaceToBatchND} from './space_to_batch_nd'; export {split} from './split'; export {square} from './square'; export {squaredDifference} from './squared_difference'; +export {squeeze} from './squeeze'; +export {stack} from './stack'; export {sub} from './sub'; export {tile} from './tile'; export {truncatedNormal} from './truncated_normal'; +export {unstack} from './unstack'; export * from './boolean_mask'; export * from './reverse'; diff --git a/tfjs-core/src/ops/reshape.ts b/tfjs-core/src/ops/reshape.ts new file mode 100644 index 00000000000..2245ee62f4e --- /dev/null +++ b/tfjs-core/src/ops/reshape.ts @@ -0,0 +1,66 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {Tensor, TensorBuffer} from '../tensor'; +import {convertToTensor} from '../tensor_util_env'; +import {DataType, DataTypeMap, Rank, ShapeMap, TensorLike} from '../types'; +import * as util from '../util'; + +import {op} from './operation'; + +/** + * Reshapes a `tf.Tensor` to a given shape. + * + * Given an input tensor, returns a new tensor with the same values as the + * input tensor with shape `shape`. + * + * If one component of shape is the special value -1, the size of that + * dimension is computed so that the total size remains constant. In + * particular, a shape of [-1] flattens into 1-D. At most one component of + * shape can be -1. + * + * If shape is 1-D or higher, then the operation returns a tensor with shape + * shape filled with the values of tensor. In this case, the number of + * elements implied by shape must be the same as the number of elements in + * tensor. + * + * ```js + * const x = tf.tensor1d([1, 2, 3, 4]); + * x.reshape([2, 2]).print(); + * ``` + * + * @param x The input tensor to be reshaped. + * @param shape An array of integers defining the output tensor shape. + */ +/** @doc {heading: 'Tensors', subheading: 'Transformations'} */ +function reshape_( + x: Tensor|TensorLike, shape: ShapeMap[R2]): Tensor { + const $x = convertToTensor(x, 'x', 'reshape', null); + shape = util.inferFromImplicitShape(shape, $x.size) as ShapeMap[R2]; + util.assert( + $x.size === util.sizeFromShape(shape), + () => 'new shape and old shape must have the same number of elements.'); + + const grad = (dy: Tensor) => { + return {x: () => reshape(dy, $x.shape)}; + }; + const attrs = {shape}; + return ENGINE.runKernelFunc( + backend => backend.reshape($x, shape), {x: $x}, grad, 'Reshape', attrs); +} +export const reshape = op({reshape_}); diff --git a/tfjs-core/src/ops/squeeze.ts b/tfjs-core/src/ops/squeeze.ts new file mode 100644 index 00000000000..88f993f2633 --- /dev/null +++ b/tfjs-core/src/ops/squeeze.ts @@ -0,0 +1,45 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Tensor} from '../tensor'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; +import * as util from '../util'; + +import {op} from './operation'; +import {reshape} from './reshape'; + +/** + * Removes dimensions of size 1 from the shape of a `tf.Tensor`. + * + * ```js + * const x = tf.tensor([1, 2, 3, 4], [1, 1, 4]); + * x.squeeze().print(); + * ``` + * + * @param x The input tensor to be squeezed. + * @param axis An optional list of numbers. If specified, only + * squeezes the dimensions listed. The dimension index starts at 0. It + * is an error to squeeze a dimension that is not 1. + */ +/** @doc {heading: 'Tensors', subheading: 'Transformations'} */ +function squeeze_(x: Tensor|TensorLike, axis?: number[]): T { + const $x = convertToTensor(x, 'x', 'squeeze'); + return reshape($x, util.squeezeShape($x.shape, axis).newShape) as T; +} + +export const squeeze = op({squeeze_}); diff --git a/tfjs-core/src/ops/stack.ts b/tfjs-core/src/ops/stack.ts new file mode 100644 index 00000000000..84eee459000 --- /dev/null +++ b/tfjs-core/src/ops/stack.ts @@ -0,0 +1,74 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {Tensor, TensorBuffer} from '../tensor'; +import {convertToTensor, convertToTensorArray} from '../tensor_util_env'; +import {DataType, DataTypeMap, Rank, ShapeMap, TensorLike} from '../types'; +import * as util from '../util'; + +import {concat} from './concat'; +import {expandDims} from './expand_dims'; +import {op} from './operation'; + +/** + * Stacks a list of rank-`R` `tf.Tensor`s into one rank-`(R+1)` `tf.Tensor`. + * + * ```js + * const a = tf.tensor1d([1, 2]); + * const b = tf.tensor1d([3, 4]); + * const c = tf.tensor1d([5, 6]); + * tf.stack([a, b, c]).print(); + * ``` + * + * @param tensors A list of tensor objects with the same shape and dtype. + * @param axis The axis to stack along. Defaults to 0 (the first dim). + */ +/** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ +function stack_( + tensors: Array, axis = 0): Tensor { + const $tensors = convertToTensorArray(tensors, 'tensors', 'stack'); + + util.assert( + $tensors.length >= 1, () => 'Pass at least one tensor to tf.stack'); + + if ($tensors.length === 1) { + return expandDims($tensors[0], axis); + } + + const rank = $tensors[0].rank; + const shape = $tensors[0].shape; + const dtype = $tensors[0].dtype; + + util.assert(axis <= rank, () => 'Axis must be <= rank of the tensor'); + + $tensors.forEach(t => { + util.assertShapesMatch( + shape, t.shape, + 'All tensors passed to stack must have matching shapes'); + }); + + $tensors.forEach(t => { + util.assert( + dtype === t.dtype, + () => 'All tensors passed to stack must have matching dtypes'); + }); + const expandedTensors = $tensors.map(t => expandDims(t, axis)); + return concat(expandedTensors, axis); +} + +export const stack = op({stack_}); diff --git a/tfjs-core/src/ops/stack_test.ts b/tfjs-core/src/ops/stack_test.ts new file mode 100644 index 00000000000..f8c4041e2fe --- /dev/null +++ b/tfjs-core/src/ops/stack_test.ts @@ -0,0 +1,123 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('stack', ALL_ENVS, () => { + it('scalars 3, 5 and 7', async () => { + const a = tf.scalar(3); + const b = tf.scalar(5); + const c = tf.scalar(7); + const res = tf.stack([a, b, c]); + expect(res.shape).toEqual([3]); + expectArraysClose(await res.data(), [3, 5, 7]); + }); + + it('scalars 3, 5 and 7 along axis=1 throws error', () => { + const a = tf.scalar(3); + const b = tf.scalar(5); + const c = tf.scalar(7); + const f = () => tf.stack([a, b, c], 1); + expect(f).toThrowError(); + }); + + it('non matching shapes throws error', () => { + const a = tf.scalar(3); + const b = tf.tensor1d([5]); + const f = () => tf.stack([a, b]); + expect(f).toThrowError(); + }); + + it('non matching dtypes throws error', () => { + const a = tf.scalar(3); + const b = tf.scalar(5, 'bool'); + const f = () => tf.stack([a, b]); + expect(f).toThrowError(); + }); + + it('2d but axis=3 throws error', () => { + const a = tf.zeros([2, 2]); + const b = tf.zeros([2, 2]); + const f = () => tf.stack([a, b], 3 /* axis */); + expect(f).toThrowError(); + }); + + it('[1,2], [3,4] and [5,6], axis=0', async () => { + const a = tf.tensor1d([1, 2]); + const b = tf.tensor1d([3, 4]); + const c = tf.tensor1d([5, 6]); + const res = tf.stack([a, b, c], 0 /* axis */); + expect(res.shape).toEqual([3, 2]); + expectArraysClose(await res.data(), [1, 2, 3, 4, 5, 6]); + }); + + it('[1,2], [3,4] and [5,6], axis=1', async () => { + const a = tf.tensor1d([1, 2]); + const b = tf.tensor1d([3, 4]); + const c = tf.tensor1d([5, 6]); + const res = tf.stack([a, b, c], 1 /* axis */); + expect(res.shape).toEqual([2, 3]); + expectArraysClose(await res.data(), [1, 3, 5, 2, 4, 6]); + }); + + it('[[1,2],[3,4]] and [[5, 6], [7, 8]], axis=0', async () => { + const a = tf.tensor2d([[1, 2], [3, 4]]); + const b = tf.tensor2d([[5, 6], [7, 8]]); + const res = tf.stack([a, b], 0 /* axis */); + expect(res.shape).toEqual([2, 2, 2]); + expectArraysClose(await res.data(), [1, 2, 3, 4, 5, 6, 7, 8]); + }); + + it('[[1,2],[3,4]] and [[5, 6], [7, 8]], axis=2', async () => { + const a = tf.tensor2d([[1, 2], [3, 4]]); + const b = tf.tensor2d([[5, 6], [7, 8]]); + const c = tf.tensor2d([[9, 10], [11, 12]]); + const res = tf.stack([a, b, c], 2 /* axis */); + expect(res.shape).toEqual([2, 2, 3]); + expectArraysClose( + await res.data(), [1, 5, 9, 2, 6, 10, 3, 7, 11, 4, 8, 12]); + }); + + it('single tensor', async () => { + const a = tf.tensor2d([[1, 2], [3, 4]]); + const res = tf.stack([a], 2 /* axis */); + expect(res.shape).toEqual([2, 2, 1]); + expectArraysClose(await res.data(), [1, 2, 3, 4]); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.stack([{} as tf.Tensor])) + .toThrowError( + /Argument 'tensors\[0\]' passed to 'stack' must be a Tensor/); + }); + + it('accepts a tensor-like object', async () => { + const a = [[1, 2], [3, 4]]; + const res = tf.stack([a], 2 /* axis */); + expect(res.shape).toEqual([2, 2, 1]); + expectArraysClose(await res.data(), [1, 2, 3, 4]); + }); + + it('chain api', async () => { + const a = tf.tensor([1, 2]); + const res = a.stack(tf.tensor([3, 4])); + expect(res.shape).toEqual([2, 2]); + expectArraysClose(await res.data(), [1, 2, 3, 4]); + }); +}); diff --git a/tfjs-core/src/ops/unstack.ts b/tfjs-core/src/ops/unstack.ts new file mode 100644 index 00000000000..307663df4e2 --- /dev/null +++ b/tfjs-core/src/ops/unstack.ts @@ -0,0 +1,58 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {Tensor, TensorBuffer} from '../tensor'; +import {convertToTensor} from '../tensor_util_env'; +import {DataType, DataTypeMap, Rank, ShapeMap, TensorLike} from '../types'; +import * as util from '../util'; + +import {op} from './operation'; +import {stack} from './stack'; + +/** + * Unstacks a `tf.Tensor` of rank-`R` into a list of rank-`(R-1)` `tf.Tensor`s. + * + * ```js + * const a = tf.tensor2d([1, 2, 3, 4], [2, 2]); + * + * tf.unstack(a).forEach(tensor => tensor.print()); + * ``` + * + * @param x A tensor object. + * @param axis The axis to unstack along. Defaults to 0 (the first dim). + */ +/** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ +function unstack_(x: Tensor|TensorLike, axis = 0): Tensor[] { + axis = axis || 0; + const $x = convertToTensor(x, 'x', 'unstack'); + util.assert( + axis >= -$x.shape.length && axis < $x.shape.length, + () => + `Axis = ${axis} is not in [-${$x.shape.length}, ${$x.shape.length})`); + if (axis < 0) { + axis += $x.shape.length; + } + const grad = (dy: Tensor[]) => { + return {x: () => stack(dy, axis)}; + }; + const attrs = {axis}; + return ENGINE.runKernelFunc( + backend => backend.unstack($x, axis), {x: $x}, grad, 'Unpack', attrs); +} + +export const unstack = op({unstack_}); diff --git a/tfjs-core/src/ops/unstack_test.ts b/tfjs-core/src/ops/unstack_test.ts new file mode 100644 index 00000000000..9f50aefbd5a --- /dev/null +++ b/tfjs-core/src/ops/unstack_test.ts @@ -0,0 +1,265 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('unstack', ALL_ENVS, () => { + it('unstack by default', async () => { + const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]); + const res = tf.unstack(x); + expect(res.length).toEqual(2); + expect(res[0].rank).toEqual(1); + expect(res[0].shape).toEqual([4]); + expectArraysClose(await res[0].data(), [1, 2, 3, 4]); + expect(res[1].rank).toEqual(1); + expect(res[1].shape).toEqual([4]); + expectArraysClose(await res[1].data(), [5, 6, 7, 8]); + }); + + it('chain api', async () => { + const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]); + const res = x.unstack(); + expect(res.length).toEqual(2); + expect(res[0].rank).toEqual(1); + expect(res[0].shape).toEqual([4]); + expectArraysClose(await res[0].data(), [1, 2, 3, 4]); + expect(res[1].rank).toEqual(1); + expect(res[1].shape).toEqual([4]); + expectArraysClose(await res[1].data(), [5, 6, 7, 8]); + }); + + it('unstack with negative integer axis', async () => { + const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]); + + let res = tf.unstack(x, -1); + expect(res.length).toEqual(4); + expect(res[0].rank).toEqual(1); + expect(res[0].shape).toEqual([2]); + expectArraysClose(await res[0].data(), [1, 5]); + expect(res[1].rank).toEqual(1); + expect(res[1].shape).toEqual([2]); + expectArraysClose(await res[1].data(), [2, 6]); + expect(res[2].rank).toEqual(1); + expect(res[2].shape).toEqual([2]); + expectArraysClose(await res[2].data(), [3, 7]); + expect(res[3].rank).toEqual(1); + expect(res[3].shape).toEqual([2]); + expectArraysClose(await res[3].data(), [4, 8]); + + res = tf.unstack(x, -2); + expect(res.length).toEqual(2); + expect(res[0].rank).toEqual(1); + expect(res[0].shape).toEqual([4]); + expectArraysClose(await res[0].data(), [1, 2, 3, 4]); + expect(res[1].rank).toEqual(1); + expect(res[1].shape).toEqual([4]); + expectArraysClose(await res[1].data(), [5, 6, 7, 8]); + }); + + it('unstack into 3 tensors', async () => { + const x = tf.tensor2d([1, 2, 3, 4, 5, 6], [3, 2]); + const res = tf.unstack(x, 0); + expect(res.length).toEqual(3); + expect(res[0].rank).toEqual(1); + expect(res[0].shape).toEqual([2]); + expectArraysClose(await res[0].data(), [1, 2]); + expect(res[1].rank).toEqual(1); + expect(res[1].shape).toEqual([2]); + expectArraysClose(await res[1].data(), [3, 4]); + expect(res[2].rank).toEqual(1); + expect(res[2].shape).toEqual([2]); + expectArraysClose(await res[2].data(), [5, 6]); + }); + + it('unstack by axis=1', async () => { + const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]); + const res = tf.unstack(x, 1); + expect(res.length).toEqual(4); + expect(res[0].rank).toEqual(1); + expect(res[0].shape).toEqual([2]); + expectArraysClose(await res[0].data(), [1, 5]); + expect(res[1].rank).toEqual(1); + expect(res[1].shape).toEqual([2]); + expectArraysClose(await res[1].data(), [2, 6]); + expect(res[2].rank).toEqual(1); + expect(res[2].shape).toEqual([2]); + expectArraysClose(await res[2].data(), [3, 7]); + expect(res[3].rank).toEqual(1); + expect(res[3].shape).toEqual([2]); + expectArraysClose(await res[3].data(), [4, 8]); + }); + + it('unstack rank 3 tensor', async () => { + const x = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2]); + const res = tf.unstack(x); + expect(res.length).toEqual(2); + expect(res[0].rank).toEqual(2); + expect(res[0].shape).toEqual([2, 2]); + expectArraysClose(await res[0].data(), [1, 2, 3, 4]); + expect(res[1].rank).toEqual(2); + expect(res[1].shape).toEqual([2, 2]); + expectArraysClose(await res[1].data(), [5, 6, 7, 8]); + }); + + it('unstack rank 3 tensor with axis=1', async () => { + const x = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2]); + const res = tf.unstack(x, 1); + expect(res.length).toEqual(2); + expect(res[0].rank).toEqual(2); + expect(res[0].shape).toEqual([2, 2]); + expectArraysClose(await res[0].data(), [1, 2, 5, 6]); + expect(res[1].rank).toEqual(2); + expect(res[1].shape).toEqual([2, 2]); + expectArraysClose(await res[1].data(), [3, 4, 7, 8]); + }); + + it('unstack rank 3 tensor with axis=2', async () => { + const x = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2]); + const res = tf.unstack(x, 2); + expect(res.length).toEqual(2); + expect(res[0].rank).toEqual(2); + expect(res[0].shape).toEqual([2, 2]); + expectArraysClose(await res[0].data(), [1, 3, 5, 7]); + expect(res[1].rank).toEqual(2); + expect(res[1].shape).toEqual([2, 2]); + expectArraysClose(await res[1].data(), [2, 4, 6, 8]); + }); + + it('unstack rank 4 tensor', async () => { + const x = tf.tensor4d([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2, 1]); + const res = tf.unstack(x); + expect(res.length).toEqual(2); + expect(res[0].rank).toEqual(3); + expect(res[0].shape).toEqual([2, 2, 1]); + expectArraysClose(await res[0].data(), [1, 2, 3, 4]); + expect(res[1].rank).toEqual(3); + expect(res[1].shape).toEqual([2, 2, 1]); + expectArraysClose(await res[1].data(), [5, 6, 7, 8]); + }); + + it('unstack rank 4 tensor with axis=1', async () => { + const x = tf.tensor4d([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2, 1]); + const res = tf.unstack(x, 1); + expect(res.length).toEqual(2); + expect(res[0].rank).toEqual(3); + expect(res[0].shape).toEqual([2, 2, 1]); + expectArraysClose(await res[0].data(), [1, 2, 5, 6]); + expect(res[1].rank).toEqual(3); + expect(res[1].shape).toEqual([2, 2, 1]); + expectArraysClose(await res[1].data(), [3, 4, 7, 8]); + }); + + it('unstack rank 4 tensor with axis=2', async () => { + const x = tf.tensor4d([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2, 1]); + const res = tf.unstack(x, 2); + expect(res.length).toEqual(2); + expect(res[0].rank).toEqual(3); + expect(res[0].shape).toEqual([2, 2, 1]); + expectArraysClose(await res[0].data(), [1, 3, 5, 7]); + expect(res[1].rank).toEqual(3); + expect(res[1].shape).toEqual([2, 2, 1]); + expectArraysClose(await res[1].data(), [2, 4, 6, 8]); + }); + + it('unstack rank 4 tensor with axis=3', async () => { + const x = tf.tensor4d([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2, 1]); + const res = tf.unstack(x, 3); + expect(res.length).toEqual(1); + expect(res[0].rank).toEqual(3); + expect(res[0].shape).toEqual([2, 2, 2]); + expectArraysClose(await res[0].data(), [1, 2, 3, 4, 5, 6, 7, 8]); + }); + + it('throws when passed a non-tensor', () => { + expect(() => tf.unstack({} as tf.Tensor)) + .toThrowError(/Argument 'x' passed to 'unstack' must be a Tensor/); + }); + + it('throws when passed an invalid axis', () => { + expect(() => { + const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]); + tf.unstack(x, 3); + }).toThrowError('Axis = 3 is not in [-2, 2)'); + expect(() => { + const x = tf.tensor3d([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2]); + tf.unstack(x, 3); + }).toThrowError('Axis = 3 is not in [-3, 3)'); + expect(() => { + const x = tf.tensor4d([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2, 1]); + tf.unstack(x, 5); + }).toThrowError('Axis = 5 is not in [-4, 4)'); + }); + + it('accepts a tensor-like object', async () => { + const x = [[1, 2, 3, 4], [5, 6, 7, 8]]; + const res = tf.unstack(x); + expect(res.length).toEqual(2); + expect(res[0].rank).toEqual(1); + expect(res[0].shape).toEqual([4]); + expectArraysClose(await res[0].data(), [1, 2, 3, 4]); + expect(res[1].rank).toEqual(1); + expect(res[1].shape).toEqual([4]); + expectArraysClose(await res[1].data(), [5, 6, 7, 8]); + }); + + it('grad of unstack axis=0', async () => { + const x = tf.tensor([[1, 2, 3], [4, 5, 6]]); + const dx1 = tf.grad(x => tf.unstack(x)[0])(x); + expect(dx1.shape).toEqual([2, 3]); + expect(dx1.dtype).toBe('float32'); + expectArraysClose(await dx1.data(), [1, 1, 1, 0, 0, 0]); + + const dx2 = tf.grad(x => tf.unstack(x)[1])(x); + expect(dx2.shape).toEqual([2, 3]); + expect(dx2.dtype).toBe('float32'); + expectArraysClose(await dx2.data(), [0, 0, 0, 1, 1, 1]); + }); + + it('gradient with clones', async () => { + const x = tf.tensor([[1, 2, 3], [4, 5, 6]]); + const dx1 = tf.grad(x => tf.unstack(x.clone())[0].clone())(x); + expect(dx1.shape).toEqual([2, 3]); + expect(dx1.dtype).toBe('float32'); + expectArraysClose(await dx1.data(), [1, 1, 1, 0, 0, 0]); + + const dx2 = tf.grad(x => tf.unstack(x.clone())[1].clone())(x); + expect(dx2.shape).toEqual([2, 3]); + expect(dx2.dtype).toBe('float32'); + expectArraysClose(await dx2.data(), [0, 0, 0, 1, 1, 1]); + }); + + it('grad of unstack axis=1', async () => { + const x = tf.tensor([[1, 2, 3], [4, 5, 6]]); + const axis = 1; + const dx1 = tf.grad(x => tf.unstack(x, axis)[0])(x); + expect(dx1.shape).toEqual([2, 3]); + expect(dx1.dtype).toBe('float32'); + expectArraysClose(await dx1.data(), [1, 0, 0, 1, 0, 0]); + + const dx2 = tf.grad(x => tf.unstack(x, axis)[1])(x); + expect(dx2.shape).toEqual([2, 3]); + expect(dx2.dtype).toBe('float32'); + expectArraysClose(await dx2.data(), [0, 1, 0, 0, 1, 0]); + + const dx3 = tf.grad(x => tf.unstack(x, axis)[2])(x); + expect(dx3.shape).toEqual([2, 3]); + expect(dx3.dtype).toBe('float32'); + expectArraysClose(await dx3.data(), [0, 0, 1, 0, 0, 1]); + }); +}); From 0aa9fe6c5a940489ba5c44272ba5d6ddfaa9a56a Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Tue, 16 Jun 2020 09:58:06 -0400 Subject: [PATCH 02/14] update imports --- tfjs-core/src/gradients/Atan2_grad.ts | 2 +- tfjs-core/src/gradients/Div_grad.ts | 2 +- tfjs-core/src/gradients/FusedBatchNorm_grad.ts | 2 +- tfjs-core/src/gradients/Mod_grad.ts | 2 +- tfjs-core/src/gradients/Multiply_grad.ts | 3 ++- tfjs-core/src/gradients/Pow_grad.ts | 3 ++- tfjs-core/src/gradients/Prelu_grad.ts | 2 +- tfjs-core/src/gradients/Sub_grad.ts | 2 +- tfjs-core/src/ops/avg_pool.ts | 3 ++- tfjs-core/src/ops/avg_pool_3d.ts | 4 +++- tfjs-core/src/ops/avg_pool_3d_backprop.ts | 2 +- tfjs-core/src/ops/avg_pool_backprop.ts | 2 +- tfjs-core/src/ops/batchnorm.ts | 3 ++- tfjs-core/src/ops/broadcast_to.ts | 3 ++- tfjs-core/src/ops/conv1d.ts | 3 ++- tfjs-core/src/ops/conv2d.ts | 3 ++- tfjs-core/src/ops/conv2d_backprop_filter.ts | 2 +- tfjs-core/src/ops/conv2d_backprop_input.ts | 2 +- tfjs-core/src/ops/conv3d.ts | 2 +- tfjs-core/src/ops/conv3d_backprop_filter.ts | 2 +- tfjs-core/src/ops/conv3d_backprop_input.ts | 2 +- tfjs-core/src/ops/depthwise_conv2d.ts | 3 ++- .../ops/depthwise_conv2d_native_backprop_filter.ts | 2 +- .../ops/depthwise_conv2d_native_backprop_input.ts | 2 +- tfjs-core/src/ops/diag.ts | 2 +- tfjs-core/src/ops/dot.ts | 3 ++- tfjs-core/src/ops/expand_dims.ts | 5 ++--- tfjs-core/src/ops/eye.ts | 3 ++- tfjs-core/src/ops/linalg_ops.ts | 5 ++++- tfjs-core/src/ops/local_response_normalization.ts | 3 ++- tfjs-core/src/ops/mat_mul.ts | 3 ++- tfjs-core/src/ops/max.ts | 4 +++- tfjs-core/src/ops/max_pool.ts | 2 +- tfjs-core/src/ops/max_pool_3d.ts | 3 ++- tfjs-core/src/ops/max_pool_3d_backprop.ts | 2 +- tfjs-core/src/ops/one_hot.ts | 3 ++- tfjs-core/src/ops/outer_product.ts | 3 ++- tfjs-core/src/ops/pool.ts | 3 ++- tfjs-core/src/ops/reshape.ts | 4 ++-- tfjs-core/src/ops/segment_ops.ts | 3 ++- tfjs-core/src/ops/stack.ts | 13 +++++++++---- tfjs-core/src/ops/unstack.ts | 4 ++-- tfjs-core/src/tests.ts | 3 +++ 43 files changed, 80 insertions(+), 49 deletions(-) diff --git a/tfjs-core/src/gradients/Atan2_grad.ts b/tfjs-core/src/gradients/Atan2_grad.ts index 272ecad8c6a..8319811b3cf 100644 --- a/tfjs-core/src/gradients/Atan2_grad.ts +++ b/tfjs-core/src/gradients/Atan2_grad.ts @@ -18,11 +18,11 @@ import {Atan2} from '../kernel_names'; import {GradConfig} from '../kernel_registry'; import {add} from '../ops/add'; -import {reshape} from '../ops/array_ops'; import {assertAndGetBroadcastShape, getReductionAxes} from '../ops/broadcast_util'; import {div} from '../ops/div'; import {mul} from '../ops/mul'; import {sum} from '../ops/reduction_ops'; +import {reshape} from '../ops/reshape'; import {square} from '../ops/square'; import {neg} from '../ops/unary_ops'; import {Tensor} from '../tensor'; diff --git a/tfjs-core/src/gradients/Div_grad.ts b/tfjs-core/src/gradients/Div_grad.ts index 4497a434a28..a4a247cfbbd 100644 --- a/tfjs-core/src/gradients/Div_grad.ts +++ b/tfjs-core/src/gradients/Div_grad.ts @@ -17,11 +17,11 @@ import {Div} from '../kernel_names'; import {GradConfig} from '../kernel_registry'; -import {reshape} from '../ops/array_ops'; import * as broadcast_util from '../ops/broadcast_util'; import {div} from '../ops/div'; import {mul} from '../ops/mul'; import {sum} from '../ops/reduction_ops'; +import {reshape} from '../ops/reshape'; import {square} from '../ops/square'; import {neg} from '../ops/unary_ops'; import {Tensor} from '../tensor'; diff --git a/tfjs-core/src/gradients/FusedBatchNorm_grad.ts b/tfjs-core/src/gradients/FusedBatchNorm_grad.ts index 10bb1ed7f13..d71720a510e 100644 --- a/tfjs-core/src/gradients/FusedBatchNorm_grad.ts +++ b/tfjs-core/src/gradients/FusedBatchNorm_grad.ts @@ -17,10 +17,10 @@ import {FusedBatchNorm, FusedBatchNormAttrs} from '../kernel_names'; import {GradConfig, NamedAttrMap} from '../kernel_registry'; import {add} from '../ops/add'; -import {reshape} from '../ops/array_ops'; import {getReductionAxes} from '../ops/broadcast_util'; import {mul} from '../ops/mul'; import {sum} from '../ops/reduction_ops'; +import {reshape} from '../ops/reshape'; import {sub} from '../ops/sub'; import {scalar} from '../ops/tensor_ops'; import {tile} from '../ops/tile'; diff --git a/tfjs-core/src/gradients/Mod_grad.ts b/tfjs-core/src/gradients/Mod_grad.ts index a09de1869f0..c8a0400e54a 100644 --- a/tfjs-core/src/gradients/Mod_grad.ts +++ b/tfjs-core/src/gradients/Mod_grad.ts @@ -17,11 +17,11 @@ import {Mod} from '../kernel_names'; import {GradConfig} from '../kernel_registry'; -import {reshape} from '../ops/array_ops'; import {assertAndGetBroadcastShape, getReductionAxes} from '../ops/broadcast_util'; import {div} from '../ops/div'; import {mul} from '../ops/mul'; import {sum} from '../ops/reduction_ops'; +import {reshape} from '../ops/reshape'; import {floor, neg} from '../ops/unary_ops'; import {Tensor} from '../tensor'; diff --git a/tfjs-core/src/gradients/Multiply_grad.ts b/tfjs-core/src/gradients/Multiply_grad.ts index c84d39547ce..d24c60a6ab2 100644 --- a/tfjs-core/src/gradients/Multiply_grad.ts +++ b/tfjs-core/src/gradients/Multiply_grad.ts @@ -17,10 +17,11 @@ import {Multiply} from '../kernel_names'; import {GradConfig} from '../kernel_registry'; -import {cast, reshape} from '../ops/array_ops'; +import {cast} from '../ops/array_ops'; import {assertAndGetBroadcastShape, getReductionAxes} from '../ops/broadcast_util'; import {mul} from '../ops/mul'; import {sum} from '../ops/reduction_ops'; +import {reshape} from '../ops/reshape'; import {Tensor} from '../tensor'; export const multiplyGradConfig: GradConfig = { diff --git a/tfjs-core/src/gradients/Pow_grad.ts b/tfjs-core/src/gradients/Pow_grad.ts index 806f00cf8ff..15394a248c4 100644 --- a/tfjs-core/src/gradients/Pow_grad.ts +++ b/tfjs-core/src/gradients/Pow_grad.ts @@ -16,13 +16,14 @@ */ import {Pow} from '../kernel_names'; import {GradConfig} from '../kernel_registry'; -import {cast, reshape} from '../ops/array_ops'; +import {cast} from '../ops/array_ops'; import * as broadcast_util from '../ops/broadcast_util'; import {greater} from '../ops/greater'; import {where} from '../ops/logical_ops'; import {mul} from '../ops/mul'; import {pow} from '../ops/pow'; import {sum} from '../ops/reduction_ops'; +import {reshape} from '../ops/reshape'; import {sub} from '../ops/sub'; import {scalar, zerosLike} from '../ops/tensor_ops'; import {log} from '../ops/unary_ops'; diff --git a/tfjs-core/src/gradients/Prelu_grad.ts b/tfjs-core/src/gradients/Prelu_grad.ts index 452ee27150f..f98c1138d6b 100644 --- a/tfjs-core/src/gradients/Prelu_grad.ts +++ b/tfjs-core/src/gradients/Prelu_grad.ts @@ -16,12 +16,12 @@ */ import {Prelu} from '../kernel_names'; import {GradConfig} from '../kernel_registry'; -import {reshape} from '../ops/array_ops'; import {getReductionAxes} from '../ops/broadcast_util'; import {greater} from '../ops/greater'; import {where} from '../ops/logical_ops'; import {mul} from '../ops/mul'; import {sum} from '../ops/reduction_ops'; +import {reshape} from '../ops/reshape'; import {zerosLike} from '../ops/tensor_ops'; import {Tensor} from '../tensor'; diff --git a/tfjs-core/src/gradients/Sub_grad.ts b/tfjs-core/src/gradients/Sub_grad.ts index 38b4ba5f335..3f56535eeec 100644 --- a/tfjs-core/src/gradients/Sub_grad.ts +++ b/tfjs-core/src/gradients/Sub_grad.ts @@ -16,9 +16,9 @@ */ import {Sub} from '../kernel_names'; import {GradConfig} from '../kernel_registry'; -import {reshape} from '../ops/array_ops'; import * as broadcast_util from '../ops/broadcast_util'; import {sum} from '../ops/reduction_ops'; +import {reshape} from '../ops/reshape'; import {neg} from '../ops/unary_ops'; import {Tensor} from '../tensor'; diff --git a/tfjs-core/src/ops/avg_pool.ts b/tfjs-core/src/ops/avg_pool.ts index e0690d58b2f..ebd664de989 100644 --- a/tfjs-core/src/ops/avg_pool.ts +++ b/tfjs-core/src/ops/avg_pool.ts @@ -24,9 +24,10 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; -import {cast, reshape} from './array_ops'; +import {cast} from './array_ops'; import * as conv_util from './conv_util'; import {op} from './operation'; +import {reshape} from './reshape'; /** * Computes the 2D average pooling of an image. diff --git a/tfjs-core/src/ops/avg_pool_3d.ts b/tfjs-core/src/ops/avg_pool_3d.ts index ec4d2fdf39c..2ee62a9232e 100644 --- a/tfjs-core/src/ops/avg_pool_3d.ts +++ b/tfjs-core/src/ops/avg_pool_3d.ts @@ -25,9 +25,11 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; -import {cast, reshape} from './array_ops'; +import {cast} from './array_ops'; import * as conv_util from './conv_util'; import {op} from './operation'; +import {reshape} from './reshape'; + /** * Computes the 3D average pooling. diff --git a/tfjs-core/src/ops/avg_pool_3d_backprop.ts b/tfjs-core/src/ops/avg_pool_3d_backprop.ts index 59bddb81bb4..b3aa1fac388 100644 --- a/tfjs-core/src/ops/avg_pool_3d_backprop.ts +++ b/tfjs-core/src/ops/avg_pool_3d_backprop.ts @@ -25,9 +25,9 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; -import {reshape} from './array_ops'; import * as conv_util from './conv_util'; import {op} from './operation'; +import {reshape} from './reshape'; /** * Computes the backprop of a 3d avg pool. diff --git a/tfjs-core/src/ops/avg_pool_backprop.ts b/tfjs-core/src/ops/avg_pool_backprop.ts index 37433d9fc4a..f82a7bcf717 100644 --- a/tfjs-core/src/ops/avg_pool_backprop.ts +++ b/tfjs-core/src/ops/avg_pool_backprop.ts @@ -24,9 +24,9 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; -import {reshape} from './array_ops'; import * as conv_util from './conv_util'; import {op} from './operation'; +import {reshape} from './reshape'; /** * Computes the backprop of an 2D avg pool. diff --git a/tfjs-core/src/ops/batchnorm.ts b/tfjs-core/src/ops/batchnorm.ts index 51cb376c2e9..c7bb1fe2216 100644 --- a/tfjs-core/src/ops/batchnorm.ts +++ b/tfjs-core/src/ops/batchnorm.ts @@ -24,9 +24,10 @@ import {convertToTensor} from '../tensor_util_env'; import {Rank, TensorLike} from '../types'; import * as util from '../util'; -import {reshape} from './array_ops'; import {xAs4D} from './batchnorm_util'; import {op} from './operation'; +import {reshape} from './reshape'; + /** * Batch normalization. diff --git a/tfjs-core/src/ops/broadcast_to.ts b/tfjs-core/src/ops/broadcast_to.ts index 195ec00ac4c..508dc66e77f 100644 --- a/tfjs-core/src/ops/broadcast_to.ts +++ b/tfjs-core/src/ops/broadcast_to.ts @@ -24,9 +24,10 @@ import {NamedTensorMap} from '../tensor_types'; import {convertToTensor} from '../tensor_util_env'; import {Rank, ShapeMap, TensorLike} from '../types'; -import {reshape} from './array_ops'; import {clone} from './clone'; import {op} from './operation'; +import {reshape} from './reshape'; + /** * Broadcast an array to a compatible shape NumPy-style. diff --git a/tfjs-core/src/ops/conv1d.ts b/tfjs-core/src/ops/conv1d.ts index eb6746eaede..175f53b4761 100644 --- a/tfjs-core/src/ops/conv1d.ts +++ b/tfjs-core/src/ops/conv1d.ts @@ -19,10 +19,11 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; -import {reshape} from './array_ops'; import {conv2d} from './conv2d'; import * as conv_util from './conv_util'; import {op} from './operation'; +import {reshape} from './reshape'; + /** * Computes a 1D convolution over the input x. diff --git a/tfjs-core/src/ops/conv2d.ts b/tfjs-core/src/ops/conv2d.ts index 569bd84a6ad..2299705e7ce 100644 --- a/tfjs-core/src/ops/conv2d.ts +++ b/tfjs-core/src/ops/conv2d.ts @@ -23,9 +23,10 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; -import {reshape} from './array_ops'; import * as conv_util from './conv_util'; import {op} from './operation'; +import {reshape} from './reshape'; + /** * Computes a 2D convolution over the input x. diff --git a/tfjs-core/src/ops/conv2d_backprop_filter.ts b/tfjs-core/src/ops/conv2d_backprop_filter.ts index d4cf4ad931b..f314263e6b2 100644 --- a/tfjs-core/src/ops/conv2d_backprop_filter.ts +++ b/tfjs-core/src/ops/conv2d_backprop_filter.ts @@ -21,9 +21,9 @@ import {Tensor, Tensor3D, Tensor4D} from '../tensor'; import {NamedTensorMap} from '../tensor_types'; import * as util from '../util'; -import {reshape} from './array_ops'; import * as conv_util from './conv_util'; import {op} from './operation'; +import {reshape} from './reshape'; /** * Computes the derivative of the filter of a 2D convolution. diff --git a/tfjs-core/src/ops/conv2d_backprop_input.ts b/tfjs-core/src/ops/conv2d_backprop_input.ts index 1e50ccba95c..fe1e59c21ae 100644 --- a/tfjs-core/src/ops/conv2d_backprop_input.ts +++ b/tfjs-core/src/ops/conv2d_backprop_input.ts @@ -21,9 +21,9 @@ import {Tensor, Tensor3D, Tensor4D} from '../tensor'; import {NamedTensorMap} from '../tensor_types'; import * as util from '../util'; -import {reshape} from './array_ops'; import * as conv_util from './conv_util'; import {op} from './operation'; +import {reshape} from './reshape'; /** * Computes the derivative of the input of a 2D convolution. diff --git a/tfjs-core/src/ops/conv3d.ts b/tfjs-core/src/ops/conv3d.ts index 7aaf863fdbf..4671076bb2a 100644 --- a/tfjs-core/src/ops/conv3d.ts +++ b/tfjs-core/src/ops/conv3d.ts @@ -23,10 +23,10 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; -import {reshape} from './array_ops'; import * as conv_util from './conv_util'; import {eitherStridesOrDilationsAreOne} from './conv_util'; import {op} from './operation'; +import {reshape} from './reshape'; /** * Computes a 3D convolution over the input x. diff --git a/tfjs-core/src/ops/conv3d_backprop_filter.ts b/tfjs-core/src/ops/conv3d_backprop_filter.ts index 53dd8e1bfa5..502e2455b7a 100644 --- a/tfjs-core/src/ops/conv3d_backprop_filter.ts +++ b/tfjs-core/src/ops/conv3d_backprop_filter.ts @@ -21,9 +21,9 @@ import {Tensor, Tensor4D, Tensor5D} from '../tensor'; import {NamedTensorMap} from '../tensor_types'; import * as util from '../util'; -import {reshape} from './array_ops'; import * as conv_util from './conv_util'; import {op} from './operation'; +import {reshape} from './reshape'; /** * Computes the derivative of the filter of a 3D convolution. diff --git a/tfjs-core/src/ops/conv3d_backprop_input.ts b/tfjs-core/src/ops/conv3d_backprop_input.ts index aadaa97301f..345106c25cc 100644 --- a/tfjs-core/src/ops/conv3d_backprop_input.ts +++ b/tfjs-core/src/ops/conv3d_backprop_input.ts @@ -21,9 +21,9 @@ import {Tensor, Tensor4D, Tensor5D} from '../tensor'; import {NamedTensorMap} from '../tensor_types'; import * as util from '../util'; -import {reshape} from './array_ops'; import * as conv_util from './conv_util'; import {op} from './operation'; +import {reshape} from './reshape'; /** * Computes the derivative of the input of a 3D convolution. diff --git a/tfjs-core/src/ops/depthwise_conv2d.ts b/tfjs-core/src/ops/depthwise_conv2d.ts index 164f97e203d..1164d42cef4 100644 --- a/tfjs-core/src/ops/depthwise_conv2d.ts +++ b/tfjs-core/src/ops/depthwise_conv2d.ts @@ -23,9 +23,10 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; -import {reshape} from './array_ops'; import * as conv_util from './conv_util'; import {op} from './operation'; +import {reshape} from './reshape'; + /** * Depthwise 2D convolution. diff --git a/tfjs-core/src/ops/depthwise_conv2d_native_backprop_filter.ts b/tfjs-core/src/ops/depthwise_conv2d_native_backprop_filter.ts index 1b238568c19..ba817aea4ff 100644 --- a/tfjs-core/src/ops/depthwise_conv2d_native_backprop_filter.ts +++ b/tfjs-core/src/ops/depthwise_conv2d_native_backprop_filter.ts @@ -19,9 +19,9 @@ import {DepthwiseConv2dNativeBackpropFilter, DepthwiseConv2dNativeBackpropFilter import {Tensor, Tensor3D, Tensor4D} from '../tensor'; import {NamedTensorMap} from '../tensor_types'; -import {reshape} from './array_ops'; import * as conv_util from './conv_util'; import {op} from './operation'; +import {reshape} from './reshape'; function depthwiseConv2dNativeBackpropFilter_( x: T, dy: T, filterShape: [number, number, number, number], diff --git a/tfjs-core/src/ops/depthwise_conv2d_native_backprop_input.ts b/tfjs-core/src/ops/depthwise_conv2d_native_backprop_input.ts index c88eb7564be..3f227f4b670 100644 --- a/tfjs-core/src/ops/depthwise_conv2d_native_backprop_input.ts +++ b/tfjs-core/src/ops/depthwise_conv2d_native_backprop_input.ts @@ -19,9 +19,9 @@ import {DepthwiseConv2dNativeBackpropInput, DepthwiseConv2dNativeBackpropInputIn import {Tensor, Tensor3D, Tensor4D} from '../tensor'; import {NamedTensorMap} from '../tensor_types'; -import {reshape} from './array_ops'; import * as conv_util from './conv_util'; import {op} from './operation'; +import {reshape} from './reshape'; function depthwiseConv2dNativeBackpropInput_( xShape: [number, number, number, number]|[number, number, number], dy: T, diff --git a/tfjs-core/src/ops/diag.ts b/tfjs-core/src/ops/diag.ts index c7d442510a5..27657535036 100644 --- a/tfjs-core/src/ops/diag.ts +++ b/tfjs-core/src/ops/diag.ts @@ -21,8 +21,8 @@ import {Tensor} from '../tensor'; import {NamedTensorMap} from '../tensor_types'; import {convertToTensor} from '../tensor_util_env'; -import {reshape} from './array_ops'; import {op} from './operation'; +import {reshape} from './reshape'; /** * Returns a diagonal tensor with a given diagonal values. diff --git a/tfjs-core/src/ops/dot.ts b/tfjs-core/src/ops/dot.ts index 162b98d2aab..da29ae675b9 100644 --- a/tfjs-core/src/ops/dot.ts +++ b/tfjs-core/src/ops/dot.ts @@ -20,9 +20,10 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; -import {reshape} from './array_ops'; import {matMul} from './mat_mul'; import {op} from './operation'; +import {reshape} from './reshape'; + /** * Computes the dot product of two matrices and/or vectors, `t1` and `t2`. diff --git a/tfjs-core/src/ops/expand_dims.ts b/tfjs-core/src/ops/expand_dims.ts index a78e40a70e4..06ccaa12920 100644 --- a/tfjs-core/src/ops/expand_dims.ts +++ b/tfjs-core/src/ops/expand_dims.ts @@ -15,10 +15,9 @@ * ============================================================================= */ -import {ENGINE} from '../engine'; -import {Tensor, TensorBuffer} from '../tensor'; +import {Tensor} from '../tensor'; import {convertToTensor} from '../tensor_util_env'; -import {DataType, DataTypeMap, Rank, ShapeMap, TensorLike} from '../types'; +import {DataType, Rank, ShapeMap, TensorLike} from '../types'; import * as util from '../util'; import {op} from './operation'; diff --git a/tfjs-core/src/ops/eye.ts b/tfjs-core/src/ops/eye.ts index bdfb7047d7a..1021df589dd 100644 --- a/tfjs-core/src/ops/eye.ts +++ b/tfjs-core/src/ops/eye.ts @@ -18,7 +18,8 @@ import {Tensor2D} from '../tensor'; import {DataType} from '../types'; -import {buffer, expandDims} from './array_ops'; +import {buffer} from './array_ops'; +import {expandDims} from './expand_dims'; import {op} from './operation'; import {tile} from './tile'; diff --git a/tfjs-core/src/ops/linalg_ops.ts b/tfjs-core/src/ops/linalg_ops.ts index 4f10e56fb33..9574333aed1 100644 --- a/tfjs-core/src/ops/linalg_ops.ts +++ b/tfjs-core/src/ops/linalg_ops.ts @@ -26,15 +26,18 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import {assert} from '../util'; -import {squeeze, stack, unstack} from './array_ops'; import {eye} from './eye'; import {logicalAnd, where} from './logical_ops'; import {norm} from './norm'; import {op} from './operation'; import {sum} from './reduction_ops'; import {split} from './split'; +import {squeeze} from './squeeze'; +import {stack} from './stack'; import {sub} from './sub'; import {range, scalar, tensor2d, zeros} from './tensor_ops'; +import {unstack} from './unstack'; + /** * Copy a tensor setting everything outside a central band in each innermost diff --git a/tfjs-core/src/ops/local_response_normalization.ts b/tfjs-core/src/ops/local_response_normalization.ts index 048a160dd48..14e2212ad36 100644 --- a/tfjs-core/src/ops/local_response_normalization.ts +++ b/tfjs-core/src/ops/local_response_normalization.ts @@ -24,8 +24,9 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; -import {reshape} from './array_ops'; import {op} from './operation'; +import {reshape} from './reshape'; + /** * Normalizes the activation of a local neighborhood across or within diff --git a/tfjs-core/src/ops/mat_mul.ts b/tfjs-core/src/ops/mat_mul.ts index b0d834345db..3867f196015 100644 --- a/tfjs-core/src/ops/mat_mul.ts +++ b/tfjs-core/src/ops/mat_mul.ts @@ -24,8 +24,9 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; -import {reshape} from './array_ops'; import {op} from './operation'; +import {reshape} from './reshape'; + /** * Computes the dot product of two matrices, A * B. These must be matrices. diff --git a/tfjs-core/src/ops/max.ts b/tfjs-core/src/ops/max.ts index 8a26f957ef0..630de7b8276 100644 --- a/tfjs-core/src/ops/max.ts +++ b/tfjs-core/src/ops/max.ts @@ -24,11 +24,13 @@ import {GradSaveFunc, NamedTensorMap} from '../tensor_types'; import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; -import {reshape} from './array_ops'; + import * as axis_util from './axis_util'; import {op} from './operation'; +import {reshape} from './reshape'; import {transpose} from './transpose'; + /** * Computes the maximum of elements across dimensions of a `tf.Tensor`. * diff --git a/tfjs-core/src/ops/max_pool.ts b/tfjs-core/src/ops/max_pool.ts index b3f82f60c76..0416d36f175 100644 --- a/tfjs-core/src/ops/max_pool.ts +++ b/tfjs-core/src/ops/max_pool.ts @@ -24,9 +24,9 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; -import {reshape} from './array_ops'; import * as conv_util from './conv_util'; import {op} from './operation'; +import {reshape} from './reshape'; /** * Computes the 2D max pooling of an image. diff --git a/tfjs-core/src/ops/max_pool_3d.ts b/tfjs-core/src/ops/max_pool_3d.ts index 5df4ecf1b1b..4c4cd5aa0dc 100644 --- a/tfjs-core/src/ops/max_pool_3d.ts +++ b/tfjs-core/src/ops/max_pool_3d.ts @@ -25,9 +25,10 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; -import {reshape} from './array_ops'; import * as conv_util from './conv_util'; import {op} from './operation'; +import {reshape} from './reshape'; + /** * Computes the 3D max pooling. diff --git a/tfjs-core/src/ops/max_pool_3d_backprop.ts b/tfjs-core/src/ops/max_pool_3d_backprop.ts index 5574d7ca8e0..ec0fe210fce 100644 --- a/tfjs-core/src/ops/max_pool_3d_backprop.ts +++ b/tfjs-core/src/ops/max_pool_3d_backprop.ts @@ -24,9 +24,9 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; -import {reshape} from './array_ops'; import * as conv_util from './conv_util'; import {op} from './operation'; +import {reshape} from './reshape'; /** * Computes the backprop of a 3d max pool. diff --git a/tfjs-core/src/ops/one_hot.ts b/tfjs-core/src/ops/one_hot.ts index 277bd3e1b03..991b420a32c 100644 --- a/tfjs-core/src/ops/one_hot.ts +++ b/tfjs-core/src/ops/one_hot.ts @@ -23,8 +23,9 @@ import {NamedTensorMap} from '../tensor_types'; import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; -import {reshape} from './array_ops'; import {op} from './operation'; +import {reshape} from './reshape'; + /** * Creates a one-hot `tf.Tensor`. The locations represented by `indices` take diff --git a/tfjs-core/src/ops/outer_product.ts b/tfjs-core/src/ops/outer_product.ts index b350d2820e5..32f08c0a8f4 100644 --- a/tfjs-core/src/ops/outer_product.ts +++ b/tfjs-core/src/ops/outer_product.ts @@ -19,9 +19,10 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; -import {reshape} from './array_ops'; import {matMul} from './mat_mul'; import {op} from './operation'; +import {reshape} from './reshape'; + /** * Computes the outer product of two vectors, `v1` and `v2`. diff --git a/tfjs-core/src/ops/pool.ts b/tfjs-core/src/ops/pool.ts index 45c2819c4fd..f8a16e85785 100644 --- a/tfjs-core/src/ops/pool.ts +++ b/tfjs-core/src/ops/pool.ts @@ -20,14 +20,15 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; -import {reshape} from './array_ops'; import {avgPool} from './avg_pool'; import {batchToSpaceND} from './batch_to_space_nd'; import * as conv_util from './conv_util'; import {maxPool} from './max_pool'; import {op} from './operation'; +import {reshape} from './reshape'; import {spaceToBatchND} from './space_to_batch_nd'; + /** * Performs an N-D pooling operation * diff --git a/tfjs-core/src/ops/reshape.ts b/tfjs-core/src/ops/reshape.ts index 2245ee62f4e..b958b817c80 100644 --- a/tfjs-core/src/ops/reshape.ts +++ b/tfjs-core/src/ops/reshape.ts @@ -16,9 +16,9 @@ */ import {ENGINE} from '../engine'; -import {Tensor, TensorBuffer} from '../tensor'; +import {Tensor} from '../tensor'; import {convertToTensor} from '../tensor_util_env'; -import {DataType, DataTypeMap, Rank, ShapeMap, TensorLike} from '../types'; +import {Rank, ShapeMap, TensorLike} from '../types'; import * as util from '../util'; import {op} from './operation'; diff --git a/tfjs-core/src/ops/segment_ops.ts b/tfjs-core/src/ops/segment_ops.ts index 33032a62882..44ca25d6d24 100644 --- a/tfjs-core/src/ops/segment_ops.ts +++ b/tfjs-core/src/ops/segment_ops.ts @@ -21,8 +21,8 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import {assert, isInt, parseAxisParam} from '../util'; -import {expandDims} from './array_ops'; import {getUndoAxesPermutation} from './axis_util'; +import {expandDims} from './expand_dims'; import {greaterEqual} from './greater_equal'; import {logicalAnd, where} from './logical_ops'; import {maximum} from './maximum'; @@ -30,6 +30,7 @@ import {op} from './operation'; import {collectGatherOpShapeInfo} from './segment_util'; import {ones, scalar, zerosLike} from './tensor_ops'; + /** * Computes the sum along segments of a `tf.Tensor`. * diff --git a/tfjs-core/src/ops/stack.ts b/tfjs-core/src/ops/stack.ts index 84eee459000..b54070cc1f8 100644 --- a/tfjs-core/src/ops/stack.ts +++ b/tfjs-core/src/ops/stack.ts @@ -15,10 +15,9 @@ * ============================================================================= */ -import {ENGINE} from '../engine'; -import {Tensor, TensorBuffer} from '../tensor'; -import {convertToTensor, convertToTensorArray} from '../tensor_util_env'; -import {DataType, DataTypeMap, Rank, ShapeMap, TensorLike} from '../types'; +import {Tensor} from '../tensor'; +import {convertToTensorArray} from '../tensor_util_env'; +import {TensorLike} from '../types'; import * as util from '../util'; import {concat} from './concat'; @@ -68,6 +67,12 @@ function stack_( () => 'All tensors passed to stack must have matching dtypes'); }); const expandedTensors = $tensors.map(t => expandDims(t, axis)); + // Stack exists in the TensorFlow C++ API + // (https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/stack) but not + // in + // https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/core/ops/ops.pbtxt. + // Therefore we are treating it like a high-level op rather than + // creating a dedicated stack kernel. return concat(expandedTensors, axis); } diff --git a/tfjs-core/src/ops/unstack.ts b/tfjs-core/src/ops/unstack.ts index 307663df4e2..9eddcdc8420 100644 --- a/tfjs-core/src/ops/unstack.ts +++ b/tfjs-core/src/ops/unstack.ts @@ -16,9 +16,9 @@ */ import {ENGINE} from '../engine'; -import {Tensor, TensorBuffer} from '../tensor'; +import {Tensor} from '../tensor'; import {convertToTensor} from '../tensor_util_env'; -import {DataType, DataTypeMap, Rank, ShapeMap, TensorLike} from '../types'; +import {TensorLike} from '../types'; import * as util from '../util'; import {op} from './operation'; diff --git a/tfjs-core/src/tests.ts b/tfjs-core/src/tests.ts index 47510b52c88..59bb3cc7cc6 100644 --- a/tfjs-core/src/tests.ts +++ b/tfjs-core/src/tests.ts @@ -72,6 +72,7 @@ import './ops/dropout_test'; import './ops/dropout_util_test'; import './ops/elu_test'; import './ops/equal_test'; +import './ops/expand_dims_test'; import './ops/eye_test'; import './ops/fused_test'; import './ops/gather_nd_test'; @@ -118,6 +119,7 @@ import './ops/softmax_test'; import './ops/space_to_batch_nd_test'; import './ops/sparse_to_dense_test'; import './ops/spectral_ops_test'; +import './ops/stack_test'; import './ops/strided_slice_test'; import './ops/sub_test'; import './ops/tile_test'; @@ -125,6 +127,7 @@ import './ops/topk_test'; import './ops/transpose_test'; import './ops/truncated_normal_test'; import './ops/unary_ops_test'; +import './ops/unstack_test'; import './optimizers/adadelta_optimizer_test'; import './optimizers/adagrad_optimizer_test'; import './optimizers/adam_optimizer_test'; From cc89c6354a4baff522fcb50b395b7d710239e5c4 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Tue, 16 Jun 2020 10:27:19 -0400 Subject: [PATCH 03/14] add unpack --- tfjs-backend-wasm/src/kernels/Unpack.ts | 36 ++++++++++++------------- tfjs-core/src/kernel_names.ts | 6 +++++ tfjs-core/src/ops/unstack.ts | 11 +++++--- 3 files changed, 31 insertions(+), 22 deletions(-) diff --git a/tfjs-backend-wasm/src/kernels/Unpack.ts b/tfjs-backend-wasm/src/kernels/Unpack.ts index 25d6659e01c..1645c76e227 100644 --- a/tfjs-backend-wasm/src/kernels/Unpack.ts +++ b/tfjs-backend-wasm/src/kernels/Unpack.ts @@ -15,44 +15,42 @@ * ============================================================================= */ -import {NamedAttrMap, NamedTensorInfoMap, registerKernel, TensorInfo} from '@tensorflow/tfjs-core'; -import {BackendWasm} from '../backend_wasm'; -import {slice} from './Slice'; +import {NamedAttrMap, NamedTensorInfoMap, registerKernel, TensorInfo, Unpack, UnpackAttrs, UnpackInputs} from '@tensorflow/tfjs-core'; -interface UnpackInputs extends NamedTensorInfoMap { - x: TensorInfo; -} +import {BackendWasm} from '../backend_wasm'; -interface UnpackAttrs extends NamedAttrMap { - axis: number; -} +import {slice} from './Slice'; -function unpack( - args: {inputs: UnpackInputs, backend: BackendWasm, attrs: UnpackAttrs}): - TensorInfo[] { - const {inputs: {x}, backend, attrs: {axis}} = args; - const numOutputs = x.shape[axis]; - const rank = x.shape.length; +function unpack(args: { + inputs: NamedTensorInfoMap, + backend: BackendWasm, + attrs: NamedAttrMap +}): TensorInfo[] { + const {inputs, backend, attrs} = args; + const {value} = inputs as {} as UnpackInputs; + const {axis} = attrs as {} as UnpackAttrs; + const numOutputs = value.shape[axis]; + const rank = value.shape.length; const outShape: number[] = new Array(rank - 1); let outIndex = 0; for (let i = 0; i < rank; i++) { if (i !== axis) { - outShape[outIndex++] = x.shape[i]; + outShape[outIndex++] = value.shape[i]; } } const outs: TensorInfo[] = new Array(numOutputs); const begin = new Array(rank).fill(0); - const size = x.shape.slice(); + const size = value.shape.slice(); size[axis] = 1; for (let i = 0; i < outs.length; i++) { begin[axis] = i; - outs[i] = slice({inputs: {x}, attrs: {begin, size}, backend}); + outs[i] = slice({inputs: {x: value}, attrs: {begin, size}, backend}); } return outs.map(({dataId, dtype}) => ({dataId, dtype, shape: outShape})); } registerKernel({ - kernelName: 'Unpack', + kernelName: Unpack, backendName: 'wasm', kernelFunc: unpack, }); diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index 5e95ecc3465..6cab056f1d6 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -446,6 +446,12 @@ export interface TransposeAttrs { perm: number[]; } +export const Unpack = 'Unpack'; +export type UnpackInputs = Pick; +export interface UnpackAttrs { + axis: number; +} + /** * TensorFlow.js-only kernels */ diff --git a/tfjs-core/src/ops/unstack.ts b/tfjs-core/src/ops/unstack.ts index 9eddcdc8420..2ae462f8349 100644 --- a/tfjs-core/src/ops/unstack.ts +++ b/tfjs-core/src/ops/unstack.ts @@ -16,7 +16,10 @@ */ import {ENGINE} from '../engine'; +import {Unpack, UnpackAttrs, UnpackInputs} from '../kernel_names'; +import {NamedAttrMap} from '../kernel_registry'; import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; @@ -48,11 +51,13 @@ function unstack_(x: Tensor|TensorLike, axis = 0): Tensor[] { axis += $x.shape.length; } const grad = (dy: Tensor[]) => { - return {x: () => stack(dy, axis)}; + return {value: () => stack(dy, axis)}; }; - const attrs = {axis}; + const inputs: UnpackInputs = {value: $x}; + const attrs: UnpackAttrs = {axis}; return ENGINE.runKernelFunc( - backend => backend.unstack($x, axis), {x: $x}, grad, 'Unpack', attrs); + backend => backend.unstack($x, axis), inputs as {} as NamedTensorMap, + grad, Unpack, attrs as {} as NamedAttrMap); } export const unstack = op({unstack_}); From 556361bc8afb07805c2871380d207cfa7dfadb54 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Tue, 16 Jun 2020 11:34:36 -0400 Subject: [PATCH 04/14] remove from chain --- .../src/public/chained_ops/expand_dims.ts | 30 +++++++++++++++++ .../chained_ops/register_all_chained_ops.ts | 3 ++ .../register_all_chained_ops_test.ts | 5 ++- tfjs-core/src/public/chained_ops/squeeze.ts | 30 +++++++++++++++++ tfjs-core/src/public/chained_ops/unstack.ts | 30 +++++++++++++++++ tfjs-core/src/tensor.ts | 32 ------------------- 6 files changed, 97 insertions(+), 33 deletions(-) create mode 100644 tfjs-core/src/public/chained_ops/expand_dims.ts create mode 100644 tfjs-core/src/public/chained_ops/squeeze.ts create mode 100644 tfjs-core/src/public/chained_ops/unstack.ts diff --git a/tfjs-core/src/public/chained_ops/expand_dims.ts b/tfjs-core/src/public/chained_ops/expand_dims.ts new file mode 100644 index 00000000000..7c062876e60 --- /dev/null +++ b/tfjs-core/src/public/chained_ops/expand_dims.ts @@ -0,0 +1,30 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {expandDims} from '../../ops/expand_dims'; +import {Tensor} from '../../tensor'; +import {Rank} from '../../types'; + +declare module '../../tensor' { + interface Tensor { + expandDims(axis?: number): T; + } +} + +Tensor.prototype.expandDims = function(axis?: number): T { + this.throwIfDisposed(); + return expandDims(this, axis) as T; +}; diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts index 5b1a5ee6b6a..e8af56bf3a3 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts @@ -33,6 +33,7 @@ import './div_no_nan'; import './dot'; import './elu'; import './equal'; +import './expand_dims'; import './floorDiv'; import './greater'; import './greater_equal'; @@ -61,7 +62,9 @@ import './selu'; import './separable_conv2d'; import './split'; import './squared_difference'; +import './squeeze'; import './space_to_batch_nd'; import './sub'; import './tile'; import './transpose'; +import './unstack'; diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts index df4e211c6b5..629a2f7d5a2 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts @@ -43,6 +43,7 @@ const CHAINED_OPS = [ 'dot', 'elu', 'equal', + 'expandDims', 'floorDiv', 'greater', 'greaterEqual', @@ -72,9 +73,11 @@ const CHAINED_OPS = [ 'spaceToBatchND', 'split', 'square', + 'squeeze', 'sub', 'tile', - 'transpose' + 'transpose', + 'unstack' ]; describeWithFlags('chained ops', ALL_ENVS, () => { diff --git a/tfjs-core/src/public/chained_ops/squeeze.ts b/tfjs-core/src/public/chained_ops/squeeze.ts new file mode 100644 index 00000000000..09f3a2ac70f --- /dev/null +++ b/tfjs-core/src/public/chained_ops/squeeze.ts @@ -0,0 +1,30 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {squeeze} from '../../ops/squeeze'; +import {Tensor} from '../../tensor'; +import {Rank} from '../../types'; + +declare module '../../tensor' { + interface Tensor { + squeeze(axis?: number[]): T; + } +} + +Tensor.prototype.squeeze = function(axis?: number[]): T { + this.throwIfDisposed(); + return squeeze(this, axis); +}; diff --git a/tfjs-core/src/public/chained_ops/unstack.ts b/tfjs-core/src/public/chained_ops/unstack.ts new file mode 100644 index 00000000000..c98285b2e22 --- /dev/null +++ b/tfjs-core/src/public/chained_ops/unstack.ts @@ -0,0 +1,30 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {unstack} from '../../ops/unstack'; +import {Tensor} from '../../tensor'; +import {Rank} from '../../types'; + +declare module '../../tensor' { + interface Tensor { + unstack(axis?: number): T[]; + } +} + +Tensor.prototype.unstack = function(axis?: number): T[] { + this.throwIfDisposed(); + return unstack(this, axis) as T[]; +}; diff --git a/tfjs-core/src/tensor.ts b/tfjs-core/src/tensor.ts index 91ec3c58cdd..20ff672c294 100644 --- a/tfjs-core/src/tensor.ts +++ b/tfjs-core/src/tensor.ts @@ -175,8 +175,6 @@ export interface OpHandler { values?: DataTypeMap[D]): TensorBuffer; print(x: T, verbose: boolean): void; reshape(x: Tensor, shape: ShapeMap[R2]): Tensor; - expandDims(x: Tensor, axis: number): Tensor; - squeeze(x: Tensor, axis?: number[]): T; clone(x: T): T; gather(x: T, indices: Tensor|TensorLike, axis: number): T; norm( @@ -186,7 +184,6 @@ export interface OpHandler { x: T, begin: number|number[], size?: number|number[]): T; reverse(x: T, axis?: number|number[]): T; stack(tensors: Array, axis: number): Tensor; - unstack(value: T, axis: number): Tensor[]; all(x: Tensor, axis: number|number[], keepDims: boolean): T; any(x: Tensor, axis: number|number[], keepDims: boolean): T; logSumExp( @@ -624,32 +621,6 @@ export class Tensor { return this.reshape(x.shape) as T; } - /** - * Returns a `tf.Tensor` that has expanded rank, by inserting a dimension - * into the tensor's shape. See `tf.expandDims` for details. - * - * @param axis The dimension index at which to insert shape of 1. Defaults to - * 0 (the first dimension). - */ - /** @doc {heading: 'Tensors', subheading: 'Classes'} */ - expandDims(axis = 0): Tensor { - return opHandler.expandDims(this, axis); - } - - /** - * Returns a `tf.Tensor` with dimensions of size 1 removed from the shape. - * See `tf.squeeze` for more details. - * - * @param axis A list of numbers. If specified, only squeezes the - * dimensions listed. The dimension index starts at 0. It is an error to - * squeeze a dimension that is not 1. - */ - /** @doc {heading: 'Tensors', subheading: 'Classes'} */ - squeeze(axis?: number[]): T { - this.throwIfDisposed(); - return opHandler.squeeze(this, axis); - } - /** Returns a copy of the tensor. See `tf.clone` for details. */ /** @doc {heading: 'Tensors', subheading: 'Classes'} */ clone(this: T): T { @@ -691,9 +662,6 @@ export class Tensor { stack(x: Tensor, axis = 0): Tensor { return opHandler.stack([this, x], axis); } - unstack(axis = 0): Tensor[] { - return opHandler.unstack(this, axis); - } // Reduction ops. all(axis: number|number[] = null, keepDims = false): T { this.throwIfDisposed(); From 6d2d67ff8840ccfab815da3343ae41ebbdbc87e9 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 24 Jun 2020 10:31:18 -0400 Subject: [PATCH 05/14] cealn --- tfjs-core/src/ops/band_part.ts | 5 ++++- tfjs-core/src/ops/dilation2d.ts | 3 ++- tfjs-core/src/ops/gram_schmidt.ts | 4 +++- tfjs-core/src/ops/qr.ts | 5 ++++- 4 files changed, 13 insertions(+), 4 deletions(-) diff --git a/tfjs-core/src/ops/band_part.ts b/tfjs-core/src/ops/band_part.ts index f52e9cdcd1c..b93c444de4c 100644 --- a/tfjs-core/src/ops/band_part.ts +++ b/tfjs-core/src/ops/band_part.ts @@ -20,15 +20,18 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import {assert} from '../util'; -import {reshape, stack, unstack} from './array_ops'; import {greaterEqual} from './greater_equal'; import {lessEqual} from './less_equal'; import {logicalAnd} from './logical_and'; import {op} from './operation'; +import {reshape} from './reshape'; +import {stack} from './stack'; import {sub} from './sub'; import {range, scalar, zeros} from './tensor_ops'; +import {unstack} from './unstack'; import {where} from './where'; + /** * Copy a tensor setting everything outside a central band in each innermost * matrix to zero. diff --git a/tfjs-core/src/ops/dilation2d.ts b/tfjs-core/src/ops/dilation2d.ts index 794962d6371..b15ab6e91e2 100644 --- a/tfjs-core/src/ops/dilation2d.ts +++ b/tfjs-core/src/ops/dilation2d.ts @@ -24,8 +24,9 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; -import {reshape} from './array_ops'; import {op} from './operation'; +import {reshape} from './reshape'; + /** * Computes the grayscale dilation over the input `x`. diff --git a/tfjs-core/src/ops/gram_schmidt.ts b/tfjs-core/src/ops/gram_schmidt.ts index 599a85c373c..62567e322b1 100644 --- a/tfjs-core/src/ops/gram_schmidt.ts +++ b/tfjs-core/src/ops/gram_schmidt.ts @@ -19,15 +19,17 @@ import {ENGINE} from '../engine'; import {Tensor1D, Tensor2D} from '../tensor'; import {assert} from '../util'; -import {squeeze, stack} from './array_ops'; import {div} from './div'; import {mul} from './mul'; import {norm} from './norm'; import {op} from './operation'; import {sum} from './reduction_ops'; import {split} from './split'; +import {squeeze} from './squeeze'; +import {stack} from './stack'; import {sub} from './sub'; + /** * Gram-Schmidt orthogonalization. * diff --git a/tfjs-core/src/ops/qr.ts b/tfjs-core/src/ops/qr.ts index 74658f3e037..92bf6656975 100644 --- a/tfjs-core/src/ops/qr.ts +++ b/tfjs-core/src/ops/qr.ts @@ -19,7 +19,6 @@ import {dispose} from '../globals'; import {Tensor, Tensor2D} from '../tensor'; import {assert} from '../util'; -import {reshape, stack, unstack} from './array_ops'; import {clone} from './clone'; import {concat} from './concat'; import {div} from './div'; @@ -29,13 +28,17 @@ import {matMul} from './mat_mul'; import {mul} from './mul'; import {norm} from './norm'; import {op} from './operation'; +import {reshape} from './reshape'; import {slice} from './slice'; +import {stack} from './stack'; import {sub} from './sub'; import {tensor2d} from './tensor_ops'; import {transpose} from './transpose'; import {neg} from './unary_ops'; +import {unstack} from './unstack'; import {where} from './where'; + /** * Compute QR decomposition of m-by-n matrix using Householder transformation. * From 1f7330da71c6d06d9866488cb7b8f1152db8b803 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 24 Jun 2020 10:51:20 -0400 Subject: [PATCH 06/14] disable warn --- tfjs-backend-webgl/src/backend_webgl.ts | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tfjs-backend-webgl/src/backend_webgl.ts b/tfjs-backend-webgl/src/backend_webgl.ts index fe848478d54..64e409b4a84 100644 --- a/tfjs-backend-webgl/src/backend_webgl.ts +++ b/tfjs-backend-webgl/src/backend_webgl.ts @@ -637,7 +637,8 @@ export class MathBackendWebGL extends KernelBackend { inputs: TensorInfo[], sizeThreshold = CPU_HANDOFF_SIZE_THRESHOLD): boolean { const cpuBackend = this.getCPUBackend(); - if (!this.warnedAboutCPUBackend && cpuBackend == null) { + if (!this.warnedAboutCPUBackend && cpuBackend == null && + !env().getBool('IS_TEST')) { console.warn( 'Your application contains ops that are small enough to be ' + 'executed on the CPU backend, however the CPU backend cannot ' + From 59b18e9e412df89940fe23ab954e8fbdcca3d666 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 24 Jun 2020 11:31:28 -0400 Subject: [PATCH 07/14] register --- tfjs-backend-wasm/src/kernels/Reshape.ts | 26 +++++++++------------ tfjs-core/src/gradients/Reshape_grad.ts | 29 ++++++++++++++++++++++++ tfjs-core/src/kernel_names.ts | 6 +++++ tfjs-core/src/ops/reshape.ts | 21 ++++++++++++----- tfjs-core/src/register_all_gradients.ts | 2 ++ 5 files changed, 63 insertions(+), 21 deletions(-) create mode 100644 tfjs-core/src/gradients/Reshape_grad.ts diff --git a/tfjs-backend-wasm/src/kernels/Reshape.ts b/tfjs-backend-wasm/src/kernels/Reshape.ts index 36c6fab9c54..96f2261bf00 100644 --- a/tfjs-backend-wasm/src/kernels/Reshape.ts +++ b/tfjs-backend-wasm/src/kernels/Reshape.ts @@ -15,27 +15,23 @@ * ============================================================================= */ -import {NamedAttrMap, NamedTensorInfoMap, registerKernel} from '@tensorflow/tfjs-core'; -import {TensorInfo} from '@tensorflow/tfjs-core'; +import {NamedAttrMap, NamedTensorInfoMap, registerKernel, Reshape, ReshapeAttrs, ReshapeInputs} from '@tensorflow/tfjs-core'; import {BackendWasm} from '../backend_wasm'; -interface ReshapeInputs extends NamedTensorInfoMap { - x: TensorInfo; -} - -interface ReshapeAttrs extends NamedAttrMap { - shape: number[]; -} - -export function reshape( - args: {inputs: ReshapeInputs, attrs: ReshapeAttrs, backend: BackendWasm}) { - const {inputs: {x}, attrs: {shape}} = args; - return {dataId: x.dataId, shape, dtype: x.dtype}; +export function reshape(args: { + inputs: NamedTensorInfoMap, + attrs: NamedAttrMap, + backend: BackendWasm +}) { + const {inputs, attrs} = args; + const {tensor} = inputs as {} as ReshapeInputs; + const {shape} = attrs as {} as ReshapeAttrs; + return {dataId: tensor.dataId, shape, dtype: tensor.dtype}; } registerKernel({ - kernelName: 'Reshape', + kernelName: Reshape, backendName: 'wasm', kernelFunc: reshape, }); diff --git a/tfjs-core/src/gradients/Reshape_grad.ts b/tfjs-core/src/gradients/Reshape_grad.ts new file mode 100644 index 00000000000..dc0ddd93637 --- /dev/null +++ b/tfjs-core/src/gradients/Reshape_grad.ts @@ -0,0 +1,29 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {Reshape} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {reshape} from '../ops/reshape'; +import {Tensor} from '../tensor'; + +export const reshapeGradConfig: GradConfig = { + kernelName: Reshape, + inputsToSave: ['tensor'], + gradFunc: (dy: Tensor, saved: Tensor[]) => { + const [tensor] = saved; + return {tensor: () => reshape(dy, tensor.shape)}; + } +}; diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index a3da0c044f8..9ec50b60ea5 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -416,6 +416,12 @@ export type RealInputs = Pick; export const Relu = 'Relu'; export type ReluInputs = Pick; +export const Reshape = 'Reshape'; +export type ReshapeInputs = Pick; +export interface ReshapeAttrs { + shape: number[]; +} + export const ResizeNearestNeighbor = 'ResizeNearestNeighbor'; export type ResizeNearestNeighborInputs = Pick; export interface ResizeNearestNeighborAttrs { diff --git a/tfjs-core/src/ops/reshape.ts b/tfjs-core/src/ops/reshape.ts index b958b817c80..d519e996247 100644 --- a/tfjs-core/src/ops/reshape.ts +++ b/tfjs-core/src/ops/reshape.ts @@ -15,14 +15,19 @@ * ============================================================================= */ -import {ENGINE} from '../engine'; +import {KernelBackend} from '../backends/backend'; +import {ENGINE, ForwardFunc} from '../engine'; +import {Reshape, ReshapeAttrs, ReshapeInputs} from '../kernel_names'; +import {NamedAttrMap} from '../kernel_registry'; import {Tensor} from '../tensor'; +import {GradSaveFunc, NamedTensorMap} from '../tensor_types'; import {convertToTensor} from '../tensor_util_env'; import {Rank, ShapeMap, TensorLike} from '../types'; import * as util from '../util'; import {op} from './operation'; + /** * Reshapes a `tf.Tensor` to a given shape. * @@ -56,11 +61,15 @@ function reshape_( $x.size === util.sizeFromShape(shape), () => 'new shape and old shape must have the same number of elements.'); - const grad = (dy: Tensor) => { - return {x: () => reshape(dy, $x.shape)}; - }; - const attrs = {shape}; + const inputs: ReshapeInputs = {tensor: $x}; + const attrs: ReshapeAttrs = {shape}; + const forward: ForwardFunc> = + (backend: KernelBackend, save: GradSaveFunc) => { + save([$x]); + return backend.reshape($x, shape); + }; return ENGINE.runKernelFunc( - backend => backend.reshape($x, shape), {x: $x}, grad, 'Reshape', attrs); + forward, inputs as {} as NamedTensorMap, null /* grad */, Reshape, + attrs as {} as NamedAttrMap); } export const reshape = op({reshape_}); diff --git a/tfjs-core/src/register_all_gradients.ts b/tfjs-core/src/register_all_gradients.ts index 563fa8d81d7..11794bb9c1a 100644 --- a/tfjs-core/src/register_all_gradients.ts +++ b/tfjs-core/src/register_all_gradients.ts @@ -49,6 +49,7 @@ import {powGradConfig} from './gradients/Pow_grad'; import {preluGradConfig} from './gradients/Prelu_grad'; import {relu6GradConfig} from './gradients/Relu6_grad'; import {reluGradConfig} from './gradients/Relu_grad'; +import {reshapeGradConfig} from './gradients/Reshape_grad'; import {resizeBilinearGradConfig} from './gradients/ResizeBilinear_grad'; import {resizeNearestNeighborGradConfig} from './gradients/ResizeNearestNeighbor_grad'; import {selectV2PoolGradConfig} from './gradients/SelectV2_grad'; @@ -104,6 +105,7 @@ const gradConfigs: GradConfig[] = [ powGradConfig, preluGradConfig, reluGradConfig, + reshapeGradConfig, resizeBilinearGradConfig, resizeNearestNeighborGradConfig, relu6GradConfig, From 3ba416b9fa858e86edc786e1f58479e7943bf53f Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 24 Jun 2020 11:35:40 -0400 Subject: [PATCH 08/14] lint --- tfjs-core/src/ops/avg_pool_3d.ts | 1 - tfjs-core/src/ops/band_part.ts | 1 - tfjs-core/src/ops/batchnorm.ts | 1 - tfjs-core/src/ops/broadcast_to.ts | 1 - tfjs-core/src/ops/conv1d.ts | 1 - tfjs-core/src/ops/conv2d.ts | 1 - tfjs-core/src/ops/depthwise_conv2d.ts | 1 - tfjs-core/src/ops/dilation2d.ts | 1 - tfjs-core/src/ops/dot.ts | 1 - tfjs-core/src/ops/gram_schmidt.ts | 1 - tfjs-core/src/ops/local_response_normalization.ts | 1 - tfjs-core/src/ops/mat_mul.ts | 1 - tfjs-core/src/ops/max.ts | 1 - tfjs-core/src/ops/max_pool_3d.ts | 1 - tfjs-core/src/ops/one_hot.ts | 1 - tfjs-core/src/ops/outer_product.ts | 1 - tfjs-core/src/ops/pool.ts | 1 - tfjs-core/src/ops/qr.ts | 1 - tfjs-core/src/ops/reshape.ts | 1 - tfjs-core/src/ops/segment_ops.ts | 1 - 20 files changed, 20 deletions(-) diff --git a/tfjs-core/src/ops/avg_pool_3d.ts b/tfjs-core/src/ops/avg_pool_3d.ts index 2ee62a9232e..24d592aef7f 100644 --- a/tfjs-core/src/ops/avg_pool_3d.ts +++ b/tfjs-core/src/ops/avg_pool_3d.ts @@ -30,7 +30,6 @@ import * as conv_util from './conv_util'; import {op} from './operation'; import {reshape} from './reshape'; - /** * Computes the 3D average pooling. * diff --git a/tfjs-core/src/ops/band_part.ts b/tfjs-core/src/ops/band_part.ts index b93c444de4c..87e48ce380a 100644 --- a/tfjs-core/src/ops/band_part.ts +++ b/tfjs-core/src/ops/band_part.ts @@ -31,7 +31,6 @@ import {range, scalar, zeros} from './tensor_ops'; import {unstack} from './unstack'; import {where} from './where'; - /** * Copy a tensor setting everything outside a central band in each innermost * matrix to zero. diff --git a/tfjs-core/src/ops/batchnorm.ts b/tfjs-core/src/ops/batchnorm.ts index c7bb1fe2216..642230f47b1 100644 --- a/tfjs-core/src/ops/batchnorm.ts +++ b/tfjs-core/src/ops/batchnorm.ts @@ -28,7 +28,6 @@ import {xAs4D} from './batchnorm_util'; import {op} from './operation'; import {reshape} from './reshape'; - /** * Batch normalization. * diff --git a/tfjs-core/src/ops/broadcast_to.ts b/tfjs-core/src/ops/broadcast_to.ts index 508dc66e77f..4fd1e13ecb8 100644 --- a/tfjs-core/src/ops/broadcast_to.ts +++ b/tfjs-core/src/ops/broadcast_to.ts @@ -28,7 +28,6 @@ import {clone} from './clone'; import {op} from './operation'; import {reshape} from './reshape'; - /** * Broadcast an array to a compatible shape NumPy-style. * diff --git a/tfjs-core/src/ops/conv1d.ts b/tfjs-core/src/ops/conv1d.ts index 175f53b4761..9cc8f39b543 100644 --- a/tfjs-core/src/ops/conv1d.ts +++ b/tfjs-core/src/ops/conv1d.ts @@ -24,7 +24,6 @@ import * as conv_util from './conv_util'; import {op} from './operation'; import {reshape} from './reshape'; - /** * Computes a 1D convolution over the input x. * diff --git a/tfjs-core/src/ops/conv2d.ts b/tfjs-core/src/ops/conv2d.ts index 2299705e7ce..b3b64ceaca1 100644 --- a/tfjs-core/src/ops/conv2d.ts +++ b/tfjs-core/src/ops/conv2d.ts @@ -27,7 +27,6 @@ import * as conv_util from './conv_util'; import {op} from './operation'; import {reshape} from './reshape'; - /** * Computes a 2D convolution over the input x. * diff --git a/tfjs-core/src/ops/depthwise_conv2d.ts b/tfjs-core/src/ops/depthwise_conv2d.ts index 1164d42cef4..71272b274ec 100644 --- a/tfjs-core/src/ops/depthwise_conv2d.ts +++ b/tfjs-core/src/ops/depthwise_conv2d.ts @@ -27,7 +27,6 @@ import * as conv_util from './conv_util'; import {op} from './operation'; import {reshape} from './reshape'; - /** * Depthwise 2D convolution. * diff --git a/tfjs-core/src/ops/dilation2d.ts b/tfjs-core/src/ops/dilation2d.ts index b15ab6e91e2..52782da7239 100644 --- a/tfjs-core/src/ops/dilation2d.ts +++ b/tfjs-core/src/ops/dilation2d.ts @@ -27,7 +27,6 @@ import * as util from '../util'; import {op} from './operation'; import {reshape} from './reshape'; - /** * Computes the grayscale dilation over the input `x`. * diff --git a/tfjs-core/src/ops/dot.ts b/tfjs-core/src/ops/dot.ts index da29ae675b9..e1ecc5ce566 100644 --- a/tfjs-core/src/ops/dot.ts +++ b/tfjs-core/src/ops/dot.ts @@ -24,7 +24,6 @@ import {matMul} from './mat_mul'; import {op} from './operation'; import {reshape} from './reshape'; - /** * Computes the dot product of two matrices and/or vectors, `t1` and `t2`. * diff --git a/tfjs-core/src/ops/gram_schmidt.ts b/tfjs-core/src/ops/gram_schmidt.ts index 62567e322b1..a9cb2aea873 100644 --- a/tfjs-core/src/ops/gram_schmidt.ts +++ b/tfjs-core/src/ops/gram_schmidt.ts @@ -29,7 +29,6 @@ import {squeeze} from './squeeze'; import {stack} from './stack'; import {sub} from './sub'; - /** * Gram-Schmidt orthogonalization. * diff --git a/tfjs-core/src/ops/local_response_normalization.ts b/tfjs-core/src/ops/local_response_normalization.ts index 14e2212ad36..8a697aa4385 100644 --- a/tfjs-core/src/ops/local_response_normalization.ts +++ b/tfjs-core/src/ops/local_response_normalization.ts @@ -27,7 +27,6 @@ import * as util from '../util'; import {op} from './operation'; import {reshape} from './reshape'; - /** * Normalizes the activation of a local neighborhood across or within * channels. diff --git a/tfjs-core/src/ops/mat_mul.ts b/tfjs-core/src/ops/mat_mul.ts index 3867f196015..5029c3273e8 100644 --- a/tfjs-core/src/ops/mat_mul.ts +++ b/tfjs-core/src/ops/mat_mul.ts @@ -27,7 +27,6 @@ import * as util from '../util'; import {op} from './operation'; import {reshape} from './reshape'; - /** * Computes the dot product of two matrices, A * B. These must be matrices. * diff --git a/tfjs-core/src/ops/max.ts b/tfjs-core/src/ops/max.ts index 630de7b8276..45a93a34415 100644 --- a/tfjs-core/src/ops/max.ts +++ b/tfjs-core/src/ops/max.ts @@ -30,7 +30,6 @@ import {op} from './operation'; import {reshape} from './reshape'; import {transpose} from './transpose'; - /** * Computes the maximum of elements across dimensions of a `tf.Tensor`. * diff --git a/tfjs-core/src/ops/max_pool_3d.ts b/tfjs-core/src/ops/max_pool_3d.ts index 4c4cd5aa0dc..ccae28bdebe 100644 --- a/tfjs-core/src/ops/max_pool_3d.ts +++ b/tfjs-core/src/ops/max_pool_3d.ts @@ -29,7 +29,6 @@ import * as conv_util from './conv_util'; import {op} from './operation'; import {reshape} from './reshape'; - /** * Computes the 3D max pooling. * diff --git a/tfjs-core/src/ops/one_hot.ts b/tfjs-core/src/ops/one_hot.ts index 782ad0981a3..1dde610e910 100644 --- a/tfjs-core/src/ops/one_hot.ts +++ b/tfjs-core/src/ops/one_hot.ts @@ -26,7 +26,6 @@ import {TensorLike} from '../types'; import {op} from './operation'; import {reshape} from './reshape'; - /** * Creates a one-hot `tf.Tensor`. The locations represented by `indices` take * value `onValue` (defaults to 1), while all other locations take value diff --git a/tfjs-core/src/ops/outer_product.ts b/tfjs-core/src/ops/outer_product.ts index 32f08c0a8f4..1966f4342af 100644 --- a/tfjs-core/src/ops/outer_product.ts +++ b/tfjs-core/src/ops/outer_product.ts @@ -23,7 +23,6 @@ import {matMul} from './mat_mul'; import {op} from './operation'; import {reshape} from './reshape'; - /** * Computes the outer product of two vectors, `v1` and `v2`. * diff --git a/tfjs-core/src/ops/pool.ts b/tfjs-core/src/ops/pool.ts index f8a16e85785..8932a2364fd 100644 --- a/tfjs-core/src/ops/pool.ts +++ b/tfjs-core/src/ops/pool.ts @@ -28,7 +28,6 @@ import {op} from './operation'; import {reshape} from './reshape'; import {spaceToBatchND} from './space_to_batch_nd'; - /** * Performs an N-D pooling operation * diff --git a/tfjs-core/src/ops/qr.ts b/tfjs-core/src/ops/qr.ts index 92bf6656975..5aeabfb4ef2 100644 --- a/tfjs-core/src/ops/qr.ts +++ b/tfjs-core/src/ops/qr.ts @@ -38,7 +38,6 @@ import {neg} from './unary_ops'; import {unstack} from './unstack'; import {where} from './where'; - /** * Compute QR decomposition of m-by-n matrix using Householder transformation. * diff --git a/tfjs-core/src/ops/reshape.ts b/tfjs-core/src/ops/reshape.ts index d519e996247..f6db7332b21 100644 --- a/tfjs-core/src/ops/reshape.ts +++ b/tfjs-core/src/ops/reshape.ts @@ -27,7 +27,6 @@ import * as util from '../util'; import {op} from './operation'; - /** * Reshapes a `tf.Tensor` to a given shape. * diff --git a/tfjs-core/src/ops/segment_ops.ts b/tfjs-core/src/ops/segment_ops.ts index 3d86baa7fe0..0ff2f94a1d1 100644 --- a/tfjs-core/src/ops/segment_ops.ts +++ b/tfjs-core/src/ops/segment_ops.ts @@ -31,7 +31,6 @@ import {collectGatherOpShapeInfo} from './segment_util'; import {ones, scalar, zerosLike} from './tensor_ops'; import {where} from './where'; - /** * Computes the sum along segments of a `tf.Tensor`. * From 74982c6c1c207910a3cb1af57df83783036d259e Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 24 Jun 2020 11:42:06 -0400 Subject: [PATCH 09/14] add grad --- tfjs-core/src/gradients/Unpack_grad.ts | 29 +++++++++++++++++++++++++ tfjs-core/src/ops/unstack.ts | 13 +++++------ tfjs-core/src/register_all_gradients.ts | 4 +++- 3 files changed, 38 insertions(+), 8 deletions(-) create mode 100644 tfjs-core/src/gradients/Unpack_grad.ts diff --git a/tfjs-core/src/gradients/Unpack_grad.ts b/tfjs-core/src/gradients/Unpack_grad.ts new file mode 100644 index 00000000000..44079fd6069 --- /dev/null +++ b/tfjs-core/src/gradients/Unpack_grad.ts @@ -0,0 +1,29 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {Unpack, UnpackAttrs} from '../kernel_names'; +import {GradConfig, NamedAttrMap} from '../kernel_registry'; +import {stack} from '../ops/stack'; +import {Tensor} from '../tensor'; + +export const unpackGradConfig: GradConfig = { + kernelName: Unpack, + gradFunc: (dy: Tensor[], saved: Tensor[], attrs: NamedAttrMap) => { + const unpackAttrs: UnpackAttrs = attrs as {} as UnpackAttrs; + const {axis} = unpackAttrs; + return {value: () => stack(dy, axis)}; + } +}; diff --git a/tfjs-core/src/ops/unstack.ts b/tfjs-core/src/ops/unstack.ts index 2ae462f8349..9697e476248 100644 --- a/tfjs-core/src/ops/unstack.ts +++ b/tfjs-core/src/ops/unstack.ts @@ -15,7 +15,8 @@ * ============================================================================= */ -import {ENGINE} from '../engine'; +import {KernelBackend} from '../backends/backend'; +import {ENGINE, ForwardFunc} from '../engine'; import {Unpack, UnpackAttrs, UnpackInputs} from '../kernel_names'; import {NamedAttrMap} from '../kernel_registry'; import {Tensor} from '../tensor'; @@ -25,7 +26,6 @@ import {TensorLike} from '../types'; import * as util from '../util'; import {op} from './operation'; -import {stack} from './stack'; /** * Unstacks a `tf.Tensor` of rank-`R` into a list of rank-`(R-1)` `tf.Tensor`s. @@ -50,14 +50,13 @@ function unstack_(x: Tensor|TensorLike, axis = 0): Tensor[] { if (axis < 0) { axis += $x.shape.length; } - const grad = (dy: Tensor[]) => { - return {value: () => stack(dy, axis)}; - }; const inputs: UnpackInputs = {value: $x}; const attrs: UnpackAttrs = {axis}; + const forward: ForwardFunc = (backend: KernelBackend) => + backend.unstack($x, axis); return ENGINE.runKernelFunc( - backend => backend.unstack($x, axis), inputs as {} as NamedTensorMap, - grad, Unpack, attrs as {} as NamedAttrMap); + forward, inputs as {} as NamedTensorMap, null /* grad */, Unpack, + attrs as {} as NamedAttrMap); } export const unstack = op({unstack_}); diff --git a/tfjs-core/src/register_all_gradients.ts b/tfjs-core/src/register_all_gradients.ts index 11794bb9c1a..0b0fd0898c3 100644 --- a/tfjs-core/src/register_all_gradients.ts +++ b/tfjs-core/src/register_all_gradients.ts @@ -61,6 +61,7 @@ import {squaredDifferenceGradConfig} from './gradients/SquaredDifference_grad'; import {subGradConfig} from './gradients/Sub_grad'; import {tileGradConfig} from './gradients/Tile_grad'; import {transposeGradConfig} from './gradients/Transpose_grad'; +import {unpackGradConfig} from './gradients/Unpack_grad'; import {GradConfig} from './kernel_registry'; import {registerGradient} from './kernel_registry'; @@ -115,9 +116,10 @@ const gradConfigs: GradConfig[] = [ splitVGradConfig, squareGradConfig, squaredDifferenceGradConfig, + subGradConfig, tileGradConfig, transposeGradConfig, - subGradConfig + unpackGradConfig ]; for (const gradientConfig of gradConfigs) { From febaacc63c63a502feb216fb879d40be16d83bba Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 24 Jun 2020 13:16:50 -0400 Subject: [PATCH 10/14] stack --- .../chained_ops/register_all_chained_ops.ts | 1 + .../register_all_chained_ops_test.ts | 1 + tfjs-core/src/public/chained_ops/stack.ts | 31 +++++++++++++++++++ tfjs-core/src/tensor.ts | 4 --- 4 files changed, 33 insertions(+), 4 deletions(-) create mode 100644 tfjs-core/src/public/chained_ops/stack.ts diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts index 96d71d5f070..1a5dbaf4e86 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts @@ -69,6 +69,7 @@ import './split'; import './squared_difference'; import './squeeze'; import './space_to_batch_nd'; +import './stack'; import './sub'; import './tile'; import './transpose'; diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts index 31150e4c45a..5b6b04648df 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts @@ -79,6 +79,7 @@ const CHAINED_OPS = [ 'split', 'square', 'squeeze', + 'stack', 'sub', 'tile', 'transpose', diff --git a/tfjs-core/src/public/chained_ops/stack.ts b/tfjs-core/src/public/chained_ops/stack.ts new file mode 100644 index 00000000000..8ed52b8c8c7 --- /dev/null +++ b/tfjs-core/src/public/chained_ops/stack.ts @@ -0,0 +1,31 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {stack} from '../../ops/stack'; +import {Tensor} from '../../tensor'; +import {Rank} from '../../types'; + +declare module '../../tensor' { + interface Tensor { + stack(x: Tensor, axis?: number): T; + } +} + +Tensor.prototype.stack = function( + x: Tensor, axis?: number): T { + this.throwIfDisposed(); + return stack([this, x], axis) as T; +}; diff --git a/tfjs-core/src/tensor.ts b/tfjs-core/src/tensor.ts index 303e8fa72be..ecf17924f64 100644 --- a/tfjs-core/src/tensor.ts +++ b/tfjs-core/src/tensor.ts @@ -183,7 +183,6 @@ export interface OpHandler { slice>( x: T, begin: number|number[], size?: number|number[]): T; reverse(x: T, axis?: number|number[]): T; - stack(tensors: Array, axis: number): Tensor; all(x: Tensor, axis: number|number[], keepDims: boolean): T; any(x: Tensor, axis: number|number[], keepDims: boolean): T; logSumExp( @@ -653,9 +652,6 @@ export class Tensor { this.throwIfDisposed(); return opHandler.reverse(this, axis); } - stack(x: Tensor, axis = 0): Tensor { - return opHandler.stack([this, x], axis); - } // Reduction ops. all(axis: number|number[] = null, keepDims = false): T { this.throwIfDisposed(); From 03e48dfac6d8ed8432ccdfc9dc10cea66b1d611d Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 24 Jun 2020 13:25:13 -0400 Subject: [PATCH 11/14] build --- tfjs-core/src/ops/log_sum_exp.ts | 2 +- tfjs-core/src/ops/moments.ts | 3 ++- tfjs-core/src/ops/norm.ts | 2 +- tfjs-core/src/ops/reverse.ts | 2 +- tfjs-core/src/ops/softmax_cross_entropy.ts | 3 ++- 5 files changed, 7 insertions(+), 5 deletions(-) diff --git a/tfjs-core/src/ops/log_sum_exp.ts b/tfjs-core/src/ops/log_sum_exp.ts index 360ece4a96a..940d86d6f69 100644 --- a/tfjs-core/src/ops/log_sum_exp.ts +++ b/tfjs-core/src/ops/log_sum_exp.ts @@ -21,11 +21,11 @@ import {TensorLike} from '../types'; import {parseAxisParam} from '../util'; import {add} from './add'; -import {reshape} from './array_ops'; import {expandShapeToKeepDim} from './axis_util'; import {max} from './max'; import {op} from './operation'; import {sum} from './reduction_ops'; +import {reshape} from './reshape'; import {sub} from './sub'; import {exp, log} from './unary_ops'; diff --git a/tfjs-core/src/ops/moments.ts b/tfjs-core/src/ops/moments.ts index 8fd0bef8138..fe567d8959e 100644 --- a/tfjs-core/src/ops/moments.ts +++ b/tfjs-core/src/ops/moments.ts @@ -20,10 +20,11 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import {parseAxisParam} from '../util'; -import {cast, reshape} from './array_ops'; +import {cast} from './array_ops'; import {expandShapeToKeepDim} from './axis_util'; import {op} from './operation'; import {mean} from './reduction_ops'; +import {reshape} from './reshape'; import {square} from './square'; import {sub} from './sub'; diff --git a/tfjs-core/src/ops/norm.ts b/tfjs-core/src/ops/norm.ts index 5e8e9b557b5..d49bdfad478 100644 --- a/tfjs-core/src/ops/norm.ts +++ b/tfjs-core/src/ops/norm.ts @@ -20,12 +20,12 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import {parseAxisParam} from '../util'; -import {reshape} from './array_ops'; import * as axis_util from './axis_util'; import {max} from './max'; import {op} from './operation'; import {pow} from './pow'; import {min, sum} from './reduction_ops'; +import {reshape} from './reshape'; import {square} from './square'; import {scalar} from './tensor_ops'; import {abs, sqrt} from './unary_ops'; diff --git a/tfjs-core/src/ops/reverse.ts b/tfjs-core/src/ops/reverse.ts index 875d8978e88..c6c54a6c526 100644 --- a/tfjs-core/src/ops/reverse.ts +++ b/tfjs-core/src/ops/reverse.ts @@ -24,9 +24,9 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import {parseAxisParam} from '../util'; -import {reshape} from './array_ops'; import {clone} from './clone'; import {op} from './operation'; +import {reshape} from './reshape'; /** * Reverses a `tf.Tensor` along a specified axis. diff --git a/tfjs-core/src/ops/softmax_cross_entropy.ts b/tfjs-core/src/ops/softmax_cross_entropy.ts index 51c4f9613aa..062e9d71e30 100644 --- a/tfjs-core/src/ops/softmax_cross_entropy.ts +++ b/tfjs-core/src/ops/softmax_cross_entropy.ts @@ -22,7 +22,7 @@ import {TensorLike} from '../types'; import {assertShapesMatch} from '../util'; import {add} from './add'; -import {cast, reshape} from './array_ops'; +import {cast} from './array_ops'; import {expandShapeToKeepDim} from './axis_util'; import {computeWeightedLoss} from './compute_weighted_loss'; import {div} from './div'; @@ -31,6 +31,7 @@ import {Reduction} from './loss_ops_utils'; import {mul} from './mul'; import {op} from './operation'; import {sum} from './reduction_ops'; +import {reshape} from './reshape'; import {sub} from './sub'; import {scalar} from './tensor_ops'; import {exp, neg} from './unary_ops'; From 7d6abc8a090c086b4b44fbecadf28fae982de3b9 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Thu, 25 Jun 2020 09:49:48 -0400 Subject: [PATCH 12/14] fix --- tfjs-core/src/ops/expand_dims.ts | 7 +++---- tfjs-core/src/ops/eye.ts | 12 ++++++------ tfjs-core/src/public/chained_ops/expand_dims.ts | 2 +- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/tfjs-core/src/ops/expand_dims.ts b/tfjs-core/src/ops/expand_dims.ts index 06ccaa12920..811f3c7100e 100644 --- a/tfjs-core/src/ops/expand_dims.ts +++ b/tfjs-core/src/ops/expand_dims.ts @@ -17,7 +17,7 @@ import {Tensor} from '../tensor'; import {convertToTensor} from '../tensor_util_env'; -import {DataType, Rank, ShapeMap, TensorLike} from '../types'; +import {DataType, TensorLike} from '../types'; import * as util from '../util'; import {op} from './operation'; @@ -38,8 +38,7 @@ import {reshape} from './reshape'; * to 0 (the first dimension). */ /** @doc {heading: 'Tensors', subheading: 'Transformations'} */ -function expandDims_( - x: Tensor|TensorLike, axis = 0): Tensor { +function expandDims_(x: Tensor|TensorLike, axis = 0): T { const parseAs: DataType = null; const $x = convertToTensor(x, 'x', 'expandDims', parseAs); @@ -53,7 +52,7 @@ function expandDims_( axis = $x.rank + axis + 1; } newShape.splice(axis, 0, 1); - return reshape($x, newShape as ShapeMap[R2]); + return reshape($x, newShape) as T; } export const expandDims = op({expandDims_}); diff --git a/tfjs-core/src/ops/eye.ts b/tfjs-core/src/ops/eye.ts index 1021df589dd..fb160a9f0a4 100644 --- a/tfjs-core/src/ops/eye.ts +++ b/tfjs-core/src/ops/eye.ts @@ -57,15 +57,15 @@ function eye_( return out; } else { if (batchShape.length === 1) { - return tile(expandDims(out, 0), [batchShape[0], 1, 1]); + return tile(expandDims(out, 0), [batchShape[0], 1, 1]) as Tensor2D; } else if (batchShape.length === 2) { return tile( - expandDims(expandDims(out, 0), 0), - [batchShape[0], batchShape[1], 1, 1]); + expandDims(expandDims(out, 0), 0), + [batchShape[0], batchShape[1], 1, 1]) as Tensor2D; } else if (batchShape.length === 3) { - return tile( - expandDims(expandDims(expandDims(out, 0), 0), 0), - [batchShape[0], batchShape[1], batchShape[2], 1, 1]); + return tile(expandDims(expandDims(expandDims(out, 0), 0), 0), [ + batchShape[0], batchShape[1], batchShape[2], 1, 1 + ]) as Tensor2D; } else { throw new Error( `eye() currently supports only 1D and 2D ` + diff --git a/tfjs-core/src/public/chained_ops/expand_dims.ts b/tfjs-core/src/public/chained_ops/expand_dims.ts index 7c062876e60..36c652d2322 100644 --- a/tfjs-core/src/public/chained_ops/expand_dims.ts +++ b/tfjs-core/src/public/chained_ops/expand_dims.ts @@ -26,5 +26,5 @@ declare module '../../tensor' { Tensor.prototype.expandDims = function(axis?: number): T { this.throwIfDisposed(); - return expandDims(this, axis) as T; + return expandDims(this, axis); }; From c90fb328181f81849b088491c2681837081a77d2 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Thu, 25 Jun 2020 10:05:25 -0400 Subject: [PATCH 13/14] fix --- tfjs-core/src/ops/squeeze.ts | 4 ++-- tfjs-core/src/ops/stack.ts | 4 +--- tfjs-core/src/ops/unstack.ts | 1 - tfjs-core/src/public/chained_ops/stack.ts | 7 ++++--- 4 files changed, 7 insertions(+), 9 deletions(-) diff --git a/tfjs-core/src/ops/squeeze.ts b/tfjs-core/src/ops/squeeze.ts index 88f993f2633..7f02a931c54 100644 --- a/tfjs-core/src/ops/squeeze.ts +++ b/tfjs-core/src/ops/squeeze.ts @@ -18,7 +18,7 @@ import {Tensor} from '../tensor'; import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; -import * as util from '../util'; +import {squeezeShape} from '../util'; import {op} from './operation'; import {reshape} from './reshape'; @@ -39,7 +39,7 @@ import {reshape} from './reshape'; /** @doc {heading: 'Tensors', subheading: 'Transformations'} */ function squeeze_(x: Tensor|TensorLike, axis?: number[]): T { const $x = convertToTensor(x, 'x', 'squeeze'); - return reshape($x, util.squeezeShape($x.shape, axis).newShape) as T; + return reshape($x, squeezeShape($x.shape, axis).newShape) as T; } export const squeeze = op({squeeze_}); diff --git a/tfjs-core/src/ops/stack.ts b/tfjs-core/src/ops/stack.ts index b54070cc1f8..812900e7e0b 100644 --- a/tfjs-core/src/ops/stack.ts +++ b/tfjs-core/src/ops/stack.ts @@ -59,13 +59,11 @@ function stack_( util.assertShapesMatch( shape, t.shape, 'All tensors passed to stack must have matching shapes'); - }); - - $tensors.forEach(t => { util.assert( dtype === t.dtype, () => 'All tensors passed to stack must have matching dtypes'); }); + const expandedTensors = $tensors.map(t => expandDims(t, axis)); // Stack exists in the TensorFlow C++ API // (https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/stack) but not diff --git a/tfjs-core/src/ops/unstack.ts b/tfjs-core/src/ops/unstack.ts index 9697e476248..7d5706fc977 100644 --- a/tfjs-core/src/ops/unstack.ts +++ b/tfjs-core/src/ops/unstack.ts @@ -41,7 +41,6 @@ import {op} from './operation'; */ /** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ function unstack_(x: Tensor|TensorLike, axis = 0): Tensor[] { - axis = axis || 0; const $x = convertToTensor(x, 'x', 'unstack'); util.assert( axis >= -$x.shape.length && axis < $x.shape.length, diff --git a/tfjs-core/src/public/chained_ops/stack.ts b/tfjs-core/src/public/chained_ops/stack.ts index 8ed52b8c8c7..84f22bad6e1 100644 --- a/tfjs-core/src/public/chained_ops/stack.ts +++ b/tfjs-core/src/public/chained_ops/stack.ts @@ -20,12 +20,13 @@ import {Rank} from '../../types'; declare module '../../tensor' { interface Tensor { - stack(x: Tensor, axis?: number): T; + stack(x: Tensor|Tensor[], axis?: number): T; } } Tensor.prototype.stack = function( - x: Tensor, axis?: number): T { + x: Tensor|Tensor[], axis?: number): T { this.throwIfDisposed(); - return stack([this, x], axis) as T; + const tensorsToBeStacked = x instanceof Tensor ? [this, x] : [this, ...x]; + return stack(tensorsToBeStacked, axis) as T; }; From 675853d377a8c87d9f8a93f04d02b4f72ebdc310 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Thu, 25 Jun 2020 12:47:37 -0400 Subject: [PATCH 14/14] save