From c8f0b088f4e2d109546f206b4c94523ab64751f0 Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Sun, 8 Dec 2019 09:45:27 -0500 Subject: [PATCH 1/3] save --- tfjs-backend-wasm/src/kernels/Slice.ts | 2 +- tfjs-backend-wasm/src/kernels/Unpack.ts | 58 +++++++++++++++++++ tfjs-backend-wasm/src/setup_test.ts | 1 + tfjs-core/src/ops/array_ops.ts | 74 +++++++++++-------------- 4 files changed, 93 insertions(+), 42 deletions(-) create mode 100644 tfjs-backend-wasm/src/kernels/Unpack.ts diff --git a/tfjs-backend-wasm/src/kernels/Slice.ts b/tfjs-backend-wasm/src/kernels/Slice.ts index 91ed047adf5..df91fdc65bb 100644 --- a/tfjs-backend-wasm/src/kernels/Slice.ts +++ b/tfjs-backend-wasm/src/kernels/Slice.ts @@ -29,7 +29,7 @@ interface SliceAttrs extends NamedAttrMap { size: number[]; } -function slice( +export function slice( args: {inputs: SliceInputs, attrs: SliceAttrs, backend: BackendWasm}) { const {inputs: {x}, attrs: {begin, size}, backend} = args; const isContinous = slice_util.isSliceContinous(x.shape, begin, size); diff --git a/tfjs-backend-wasm/src/kernels/Unpack.ts b/tfjs-backend-wasm/src/kernels/Unpack.ts new file mode 100644 index 00000000000..25d6659e01c --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/Unpack.ts @@ -0,0 +1,58 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {NamedAttrMap, NamedTensorInfoMap, registerKernel, TensorInfo} from '@tensorflow/tfjs-core'; +import {BackendWasm} from '../backend_wasm'; +import {slice} from './Slice'; + +interface UnpackInputs extends NamedTensorInfoMap { + x: TensorInfo; +} + +interface UnpackAttrs extends NamedAttrMap { + axis: number; +} + +function unpack( + args: {inputs: UnpackInputs, backend: BackendWasm, attrs: UnpackAttrs}): + TensorInfo[] { + const {inputs: {x}, backend, attrs: {axis}} = args; + const numOutputs = x.shape[axis]; + const rank = x.shape.length; + const outShape: number[] = new Array(rank - 1); + let outIndex = 0; + for (let i = 0; i < rank; i++) { + if (i !== axis) { + outShape[outIndex++] = x.shape[i]; + } + } + const outs: TensorInfo[] = new Array(numOutputs); + const begin = new Array(rank).fill(0); + const size = x.shape.slice(); + size[axis] = 1; + for (let i = 0; i < outs.length; i++) { + begin[axis] = i; + outs[i] = slice({inputs: {x}, attrs: {begin, size}, backend}); + } + return outs.map(({dataId, dtype}) => ({dataId, dtype, shape: outShape})); +} + +registerKernel({ + kernelName: 'Unpack', + backendName: 'wasm', + kernelFunc: unpack, +}); diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index 86c9f34faf5..19b28c8cec6 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -193,6 +193,7 @@ const TEST_FILTERS: TestFilter[] = [ {include: 'addN'}, {include: 'nonMaxSuppression'}, {include: 'argmax', excludes: ['gradient']}, + {include: 'unstack'}, ]; const customInclude = (testName: string) => { diff --git a/tfjs-core/src/ops/array_ops.ts b/tfjs-core/src/ops/array_ops.ts index 58a21f9cd4c..1b976b726d1 100644 --- a/tfjs-core/src/ops/array_ops.ts +++ b/tfjs-core/src/ops/array_ops.ts @@ -26,70 +26,60 @@ import {op} from './operation'; import {MPRandGauss, RandGamma, UniformRandom} from './rand'; import {zeros, zerosLike} from './tensor_ops'; -/** Broadcast an array to a compatible shape NumPy-style. - * - * The tensor's shape is compared to the broadcast shape from end to beginning. - * Ones are prepended to the tensor's shape until is has the same length as - * the broadcast shape. If input.shape[i]==shape[i], the (i+1)-th axis is - * already broadcast-compatible. If input.shape[i]==1 and shape[i]==N, then - * the input tensor is tiled N times along that axis (using tf.tile). - * - * @param input The tensor that is to be broadcasted. - * @param shape The input is to be broadcast to this shape. +/** + * Broadcast an array to a compatible shape NumPy-style. + * + * The tensor's shape is compared to the broadcast shape from end to beginning. + * Ones are prepended to the tensor's shape until is has the same length as + * the broadcast shape. If input.shape[i]==shape[i], the (i+1)-th axis is + * already broadcast-compatible. If input.shape[i]==1 and shape[i]==N, then + * the input tensor is tiled N times along that axis (using tf.tile). + * + * @param input The tensor that is to be broadcasted. + * @param shape The input is to be broadcast to this shape. */ /** @doc {heading: 'Tensors', subheading: 'Transformations'} */ function broadcastTo_( - x: Tensor|TensorLike, shape: ShapeMap[R] -): Tensor -{ + x: Tensor|TensorLike, shape: ShapeMap[R]): Tensor { let input = convertToTensor(x, 'broadcastTo', 'x'); const xShape = input.shape; - if( shape.some(d => !(d > 0) || d%1 !== 0) ) { + if (shape.some(d => !(d > 0) || d % 1 !== 0)) { throw new Error(`broadcastTo(): Invalid broadcast shape [${shape}].`); } - if( shape.length < input.rank ) { - throw new Error( - `broadcastTo(): shape.length=${shape.length} < input.rank=${input.rank}.` - ); + if (shape.length < input.rank) { + throw new Error(`broadcastTo(): shape.length=${shape.length} < input.rank=${ + input.rank}.`); } - if( shape.length > input.rank ) - { - const newShape = input.shape.slice(); - while(newShape.length < shape.length ) { - newShape.unshift(1); + if (shape.length > input.rank) { + const newShape = input.shape.slice(); + while (newShape.length < shape.length) { + newShape.unshift(1); } input = input.reshape(newShape); } const reps: number[] = Array.from(shape); - for( let i=shape.length-1; i >= 0; i-- ) - { - if( input.shape[i] === shape[i] ) { + for (let i = shape.length - 1; i >= 0; i--) { + if (input.shape[i] === shape[i]) { reps[i] = 1; - } - else if( input.shape[i] !== 1 ) { + } else if (input.shape[i] !== 1) { throw new Error( - `broadcastTo(): [${xShape}] cannot be broadcast to [${shape}].` - ); + `broadcastTo(): [${xShape}] cannot be broadcast to [${shape}].`); } } + const axes = reps.map((n, i) => n > 1 ? i : -1).filter(i => i >= 0); - const axes = reps.map( ( n,i) => n > 1 ? i : -1 ).filter( i => i >= 0 ); - - if( axes.length === 0 ) { + if (axes.length === 0) { return input.clone() as Tensor; } return ENGINE.runKernelFunc( - backend => backend.tile(input,reps), - {input}, - (dy: Tensor) => ({ - input: () => dy.sum(axes,/*keepDims=*/true) - }) - ) as Tensor; + backend => backend.tile(input, reps), {input}, + (dy: Tensor) => + ({input: () => dy.sum(axes, /*keepDims=*/true)})) as Tensor; } /** @@ -917,9 +907,11 @@ function unstack_(x: Tensor|TensorLike, axis = 0): Tensor[] { axis += $x.shape.length; } const grad = (dy: Tensor[]) => { - return {$x: () => stack(dy, axis)}; + return {x: () => stack(dy, axis)}; }; - return ENGINE.runKernelFunc(backend => backend.unstack($x, axis), {$x}, grad); + const attrs = {axis}; + return ENGINE.runKernelFunc( + backend => backend.unstack($x, axis), {x: $x}, grad, 'Unpack', attrs); } /** From d206f4c9573136543287c4cfea2a55c36723d92b Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Mon, 9 Dec 2019 11:35:56 -0500 Subject: [PATCH 2/3] save --- tfjs-backend-wasm/src/kernels/all_kernels.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/tfjs-backend-wasm/src/kernels/all_kernels.ts b/tfjs-backend-wasm/src/kernels/all_kernels.ts index 4d6dea9a2a3..097ae3f77af 100644 --- a/tfjs-backend-wasm/src/kernels/all_kernels.ts +++ b/tfjs-backend-wasm/src/kernels/all_kernels.ts @@ -51,3 +51,4 @@ import './Slice'; import './Square'; import './Sub'; import './Transpose'; +import './Unpack'; From 7c10b8372bed0411c8a5f90a84a35d73738b4578 Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Mon, 9 Dec 2019 13:26:37 -0500 Subject: [PATCH 3/3] save --- tfjs-backend-wasm/scripts/test-bundle-size.js | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tfjs-backend-wasm/scripts/test-bundle-size.js b/tfjs-backend-wasm/scripts/test-bundle-size.js index da511c2174a..0d9f2727a42 100755 --- a/tfjs-backend-wasm/scripts/test-bundle-size.js +++ b/tfjs-backend-wasm/scripts/test-bundle-size.js @@ -36,7 +36,9 @@ exec( shell.cd(dirName); shell.cd(wasmDirName); -exec(`yarn && ./scripts/build-ci.sh && yarn rollup -c`, {silent: false}); +exec( + `yarn && yarn build-core && ./scripts/build-ci.sh && yarn rollup -c`, + {silent: false}); const masterMinBundleSize = getFileSizeBytes(bundleFilename); const masterWasmSize = getFileSizeBytes(wasmFileName);