diff --git a/tfjs-backend-wasm/rollup.config.js b/tfjs-backend-wasm/rollup.config.js index 39a470dd579..6642d3a1caf 100644 --- a/tfjs-backend-wasm/rollup.config.js +++ b/tfjs-backend-wasm/rollup.config.js @@ -44,7 +44,7 @@ function config({plugins = [], output = {}}) { typescript({ tsconfigOverride: {compilerOptions: {module: 'ES2015'}}, }), - node(), + node({preferBuiltins: true}), // Polyfill require() from dependencies. commonjs({ ignore: ['crypto', 'node-fetch', 'util'], @@ -55,10 +55,10 @@ function config({plugins = [], output = {}}) { output: { banner: PREAMBLE, sourcemap: true, - globals: {'@tensorflow/tfjs-core': 'tf'}, + globals: {'@tensorflow/tfjs-core': 'tf', 'fs': 'fs', 'path': 'path'}, ...output, }, - external: ['crypto', '@tensorflow/tfjs-core'], + external: ['crypto', '@tensorflow/tfjs-core', 'fs', 'path'], onwarn: warning => { let {code} = warning; if (code === 'CIRCULAR_DEPENDENCY' || code === 'CIRCULAR' || diff --git a/tfjs-backend-wasm/src/kernels/Slice.ts b/tfjs-backend-wasm/src/kernels/Slice.ts index 91ed047adf5..df91fdc65bb 100644 --- a/tfjs-backend-wasm/src/kernels/Slice.ts +++ b/tfjs-backend-wasm/src/kernels/Slice.ts @@ -29,7 +29,7 @@ interface SliceAttrs extends NamedAttrMap { size: number[]; } -function slice( +export function slice( args: {inputs: SliceInputs, attrs: SliceAttrs, backend: BackendWasm}) { const {inputs: {x}, attrs: {begin, size}, backend} = args; const isContinous = slice_util.isSliceContinous(x.shape, begin, size); diff --git a/tfjs-backend-wasm/src/kernels/Unpack.ts b/tfjs-backend-wasm/src/kernels/Unpack.ts new file mode 100644 index 00000000000..25d6659e01c --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/Unpack.ts @@ -0,0 +1,58 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {NamedAttrMap, NamedTensorInfoMap, registerKernel, TensorInfo} from '@tensorflow/tfjs-core'; +import {BackendWasm} from '../backend_wasm'; +import {slice} from './Slice'; + +interface UnpackInputs extends NamedTensorInfoMap { + x: TensorInfo; +} + +interface UnpackAttrs extends NamedAttrMap { + axis: number; +} + +function unpack( + args: {inputs: UnpackInputs, backend: BackendWasm, attrs: UnpackAttrs}): + TensorInfo[] { + const {inputs: {x}, backend, attrs: {axis}} = args; + const numOutputs = x.shape[axis]; + const rank = x.shape.length; + const outShape: number[] = new Array(rank - 1); + let outIndex = 0; + for (let i = 0; i < rank; i++) { + if (i !== axis) { + outShape[outIndex++] = x.shape[i]; + } + } + const outs: TensorInfo[] = new Array(numOutputs); + const begin = new Array(rank).fill(0); + const size = x.shape.slice(); + size[axis] = 1; + for (let i = 0; i < outs.length; i++) { + begin[axis] = i; + outs[i] = slice({inputs: {x}, attrs: {begin, size}, backend}); + } + return outs.map(({dataId, dtype}) => ({dataId, dtype, shape: outShape})); +} + +registerKernel({ + kernelName: 'Unpack', + backendName: 'wasm', + kernelFunc: unpack, +}); diff --git a/tfjs-backend-wasm/src/kernels/all_kernels.ts b/tfjs-backend-wasm/src/kernels/all_kernels.ts index 51819c27127..3360c73762b 100644 --- a/tfjs-backend-wasm/src/kernels/all_kernels.ts +++ b/tfjs-backend-wasm/src/kernels/all_kernels.ts @@ -52,3 +52,4 @@ import './Slice'; import './Square'; import './Sub'; import './Transpose'; +import './Unpack'; diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index bcc0d59483e..8a4da9c7301 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -194,6 +194,7 @@ const TEST_FILTERS: TestFilter[] = [ {include: 'nonMaxSuppression'}, {include: 'argmax', excludes: ['gradient']}, {include: 'exp '}, + {include: 'unstack'}, ]; const customInclude = (testName: string) => { diff --git a/tfjs-core/src/ops/array_ops.ts b/tfjs-core/src/ops/array_ops.ts index 58a21f9cd4c..1b976b726d1 100644 --- a/tfjs-core/src/ops/array_ops.ts +++ b/tfjs-core/src/ops/array_ops.ts @@ -26,70 +26,60 @@ import {op} from './operation'; import {MPRandGauss, RandGamma, UniformRandom} from './rand'; import {zeros, zerosLike} from './tensor_ops'; -/** Broadcast an array to a compatible shape NumPy-style. - * - * The tensor's shape is compared to the broadcast shape from end to beginning. - * Ones are prepended to the tensor's shape until is has the same length as - * the broadcast shape. If input.shape[i]==shape[i], the (i+1)-th axis is - * already broadcast-compatible. If input.shape[i]==1 and shape[i]==N, then - * the input tensor is tiled N times along that axis (using tf.tile). - * - * @param input The tensor that is to be broadcasted. - * @param shape The input is to be broadcast to this shape. +/** + * Broadcast an array to a compatible shape NumPy-style. + * + * The tensor's shape is compared to the broadcast shape from end to beginning. + * Ones are prepended to the tensor's shape until is has the same length as + * the broadcast shape. If input.shape[i]==shape[i], the (i+1)-th axis is + * already broadcast-compatible. If input.shape[i]==1 and shape[i]==N, then + * the input tensor is tiled N times along that axis (using tf.tile). + * + * @param input The tensor that is to be broadcasted. + * @param shape The input is to be broadcast to this shape. */ /** @doc {heading: 'Tensors', subheading: 'Transformations'} */ function broadcastTo_( - x: Tensor|TensorLike, shape: ShapeMap[R] -): Tensor -{ + x: Tensor|TensorLike, shape: ShapeMap[R]): Tensor { let input = convertToTensor(x, 'broadcastTo', 'x'); const xShape = input.shape; - if( shape.some(d => !(d > 0) || d%1 !== 0) ) { + if (shape.some(d => !(d > 0) || d % 1 !== 0)) { throw new Error(`broadcastTo(): Invalid broadcast shape [${shape}].`); } - if( shape.length < input.rank ) { - throw new Error( - `broadcastTo(): shape.length=${shape.length} < input.rank=${input.rank}.` - ); + if (shape.length < input.rank) { + throw new Error(`broadcastTo(): shape.length=${shape.length} < input.rank=${ + input.rank}.`); } - if( shape.length > input.rank ) - { - const newShape = input.shape.slice(); - while(newShape.length < shape.length ) { - newShape.unshift(1); + if (shape.length > input.rank) { + const newShape = input.shape.slice(); + while (newShape.length < shape.length) { + newShape.unshift(1); } input = input.reshape(newShape); } const reps: number[] = Array.from(shape); - for( let i=shape.length-1; i >= 0; i-- ) - { - if( input.shape[i] === shape[i] ) { + for (let i = shape.length - 1; i >= 0; i--) { + if (input.shape[i] === shape[i]) { reps[i] = 1; - } - else if( input.shape[i] !== 1 ) { + } else if (input.shape[i] !== 1) { throw new Error( - `broadcastTo(): [${xShape}] cannot be broadcast to [${shape}].` - ); + `broadcastTo(): [${xShape}] cannot be broadcast to [${shape}].`); } } + const axes = reps.map((n, i) => n > 1 ? i : -1).filter(i => i >= 0); - const axes = reps.map( ( n,i) => n > 1 ? i : -1 ).filter( i => i >= 0 ); - - if( axes.length === 0 ) { + if (axes.length === 0) { return input.clone() as Tensor; } return ENGINE.runKernelFunc( - backend => backend.tile(input,reps), - {input}, - (dy: Tensor) => ({ - input: () => dy.sum(axes,/*keepDims=*/true) - }) - ) as Tensor; + backend => backend.tile(input, reps), {input}, + (dy: Tensor) => + ({input: () => dy.sum(axes, /*keepDims=*/true)})) as Tensor; } /** @@ -917,9 +907,11 @@ function unstack_(x: Tensor|TensorLike, axis = 0): Tensor[] { axis += $x.shape.length; } const grad = (dy: Tensor[]) => { - return {$x: () => stack(dy, axis)}; + return {x: () => stack(dy, axis)}; }; - return ENGINE.runKernelFunc(backend => backend.unstack($x, axis), {$x}, grad); + const attrs = {axis}; + return ENGINE.runKernelFunc( + backend => backend.unstack($x, axis), {x: $x}, grad, 'Unpack', attrs); } /**