From fb62e174f4b648575f1e25a938c02169e65cc4f6 Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Mon, 16 Mar 2020 22:59:47 -0400 Subject: [PATCH 1/6] modularise broadcastTo op --- tfjs-core/src/engine.ts | 17 ++-- tfjs-core/src/gradients/BroadcastTo_grad.ts | 32 +++++++ tfjs-core/src/kernel_names.ts | 6 ++ tfjs-core/src/kernel_registry.ts | 6 +- tfjs-core/src/ops/array_ops.ts | 57 ------------ tfjs-core/src/ops/array_ops_test.ts | 85 ----------------- tfjs-core/src/ops/broadcast_to.ts | 91 +++++++++++++++++++ tfjs-core/src/ops/broadcast_to_test.ts | 76 ++++++++++++++++ tfjs-core/src/ops/ops.ts | 1 + .../src/public/chained_ops/broadcast_to.ts | 31 +++++++ .../chained_ops/register_all_chained_ops.ts | 1 + .../register_all_chained_ops_test.ts | 40 ++++++++ tfjs-core/src/register_all_gradients.ts | 2 + tfjs-core/src/tests.ts | 1 + 14 files changed, 294 insertions(+), 152 deletions(-) create mode 100644 tfjs-core/src/gradients/BroadcastTo_grad.ts create mode 100644 tfjs-core/src/ops/broadcast_to.ts create mode 100644 tfjs-core/src/ops/broadcast_to_test.ts create mode 100644 tfjs-core/src/public/chained_ops/broadcast_to.ts create mode 100644 tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts diff --git a/tfjs-core/src/engine.ts b/tfjs-core/src/engine.ts index 7f491b89bc3..7a53a917826 100644 --- a/tfjs-core/src/engine.ts +++ b/tfjs-core/src/engine.ts @@ -17,9 +17,9 @@ import {BackendTimingInfo, DataMover, KernelBackend} from './backends/backend'; import {Environment, setEnvironmentGlobal} from './environment'; -import {getGradient, getKernel, getKernelsForBackend, NamedAttrMap, TensorInfo} from './kernel_registry'; +import {getGradient, getKernel, getKernelsForBackend, GradFunc, NamedAttrMap, TensorInfo} from './kernel_registry'; import {Profiler} from './profiler'; -import {backpropagateGradients, getFilteredNodesXToY, NamedGradientMap, TapeNode} from './tape'; +import {backpropagateGradients, getFilteredNodesXToY, TapeNode} from './tape'; import {DataId, setTensorTracker, Tensor, TensorTracker, Variable} from './tensor'; import {GradSaveFunc, NamedTensorMap, NamedVariableMap, TensorContainer} from './tensor_types'; import {getTensorsInContainer} from './tensor_util'; @@ -465,7 +465,7 @@ export class Engine implements TensorTracker, DataMover { const inputs = {x}; const grad = (dy: Tensor) => ({x: () => dy.toFloat()}); const saved: Tensor[] = []; - this.addTapeNode(this.state.activeScope.name, inputs, [y], grad, saved); + this.addTapeNode(this.state.activeScope.name, inputs, [y], grad, saved, {}); return y; } @@ -534,7 +534,8 @@ export class Engine implements TensorTracker, DataMover { */ runKernelFunc( forwardFunc: ForwardFunc, inputs: I, - backwardsFunc?: (dy: T, saved: Tensor[]) => {[P in keyof I]: () => I[P]}, + backwardsFunc?: + (dy: T, saved: Tensor[]|TensorInfo[]) => {[P in keyof I]: () => I[P]}, kernelName?: string, attrs?: NamedAttrMap, inputsToSave: Tensor[] = [], outputsToSave: boolean[] = []): T { let outputs: Tensor[]; @@ -604,7 +605,8 @@ export class Engine implements TensorTracker, DataMover { }); if (isTapeOn) { - this.addTapeNode(kernelName, inputs, outputs, backwardsFunc, saved); + this.addTapeNode( + kernelName, inputs, outputs, backwardsFunc, saved, attrs); } if (this.state.profiling) { @@ -798,8 +800,7 @@ export class Engine implements TensorTracker, DataMover { private addTapeNode( kernelName: string, inputs: NamedTensorMap, outputs: Tensor[], - gradientsFunc: (dy: Tensor|Tensor[], saved: Tensor[]) => NamedGradientMap, - saved: Tensor[]): void { + gradientsFunc: GradFunc, saved: Tensor[], attrs: NamedAttrMap): void { const tapeNode: TapeNode = {id: this.state.nextTapeNodeId++, kernelName, inputs, outputs, saved}; @@ -821,7 +822,7 @@ export class Engine implements TensorTracker, DataMover { }); // Grad functions of ops with single outputs expect a dy, while ops // with multiple outputs expect dys (array of dy). - return gradientsFunc(dys.length > 1 ? dys : dys[0], saved); + return gradientsFunc(dys.length > 1 ? dys : dys[0], saved, attrs); }; } this.state.activeTape.push(tapeNode); diff --git a/tfjs-core/src/gradients/BroadcastTo_grad.ts b/tfjs-core/src/gradients/BroadcastTo_grad.ts new file mode 100644 index 00000000000..2f4c9340bed --- /dev/null +++ b/tfjs-core/src/gradients/BroadcastTo_grad.ts @@ -0,0 +1,32 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {BroadcastTo, BroadCastToAttrs} from '../kernel_names'; +import {GradConfig, NamedAttrMap} from '../kernel_registry'; +import {Tensor} from '../tensor'; + +export const broadcastToGradConfig: GradConfig = { + kernelName: BroadcastTo, + gradFunc: (dy: Tensor, saved: Tensor[], attrs: NamedAttrMap) => { + const broadCastToAttrs: BroadCastToAttrs = + attrs as unknown as BroadCastToAttrs; + const axes = + broadCastToAttrs.reps.map((n, i) => n > 1 ? i : -1).filter(i => i >= 0); + const keepDims = true; + return {x: () => dy.sum(axes, keepDims)}; + } +}; diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index 2f027e57b1f..bc7e3bdce28 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -37,6 +37,12 @@ export interface NonMaxSuppressionV5Attrs { softNmsSigma: number; } +export const BroadcastTo = 'BroadcastTo'; +export type BroadcastToInputs = Pick; +export interface BroadCastToAttrs { + reps: number[]; +} + /** * TensorFlow.js-only kernels */ diff --git a/tfjs-core/src/kernel_registry.ts b/tfjs-core/src/kernel_registry.ts index 5bc6dbf6d0f..5cd8beb487f 100644 --- a/tfjs-core/src/kernel_registry.ts +++ b/tfjs-core/src/kernel_registry.ts @@ -15,6 +15,7 @@ * ============================================================================= */ +import {NamedGradientMap} from './tape'; import {Tensor} from './tensor'; import {DataType, RecursiveArray} from './types'; @@ -37,8 +38,9 @@ export type KernelFunc = (params: { }) => TensorInfo|TensorInfo[]; /** The function to run when computing a gradient during backprop. */ -export type GradFunc = (dy: Tensor|Tensor[], saved: Tensor[]) => - ({[inputName: string]: () => Tensor}); +export type GradFunc = + (dy: Tensor|Tensor[], saved: Tensor[], attrs: NamedAttrMap) => + NamedGradientMap; /** Function that gets called after the backend initializes. */ export type KernelSetupFunc = (backend: {}) => void; diff --git a/tfjs-core/src/ops/array_ops.ts b/tfjs-core/src/ops/array_ops.ts index 177c5670a97..b19bb6687a3 100644 --- a/tfjs-core/src/ops/array_ops.ts +++ b/tfjs-core/src/ops/array_ops.ts @@ -26,62 +26,6 @@ import {op} from './operation'; import {MPRandGauss, RandGamma, UniformRandom} from './rand'; import {zeros, zerosLike} from './tensor_ops'; -/** - * Broadcast an array to a compatible shape NumPy-style. - * - * The tensor's shape is compared to the broadcast shape from end to beginning. - * Ones are prepended to the tensor's shape until is has the same length as - * the broadcast shape. If input.shape[i]==shape[i], the (i+1)-th axis is - * already broadcast-compatible. If input.shape[i]==1 and shape[i]==N, then - * the input tensor is tiled N times along that axis (using tf.tile). - * - * @param input The tensor that is to be broadcasted. - * @param shape The input is to be broadcast to this shape. - */ -/** @doc {heading: 'Tensors', subheading: 'Transformations'} */ -function broadcastTo_( - x: Tensor|TensorLike, shape: ShapeMap[R]): Tensor { - let input = convertToTensor(x, 'broadcastTo', 'x'); - const xShape = input.shape; - - if (shape.some(d => !(d > 0) || d % 1 !== 0)) { - throw new Error(`broadcastTo(): Invalid broadcast shape [${shape}].`); - } - - if (shape.length < input.rank) { - throw new Error(`broadcastTo(): shape.length=${shape.length} < input.rank=${ - input.rank}.`); - } - - if (shape.length > input.rank) { - const newShape = input.shape.slice(); - while (newShape.length < shape.length) { - newShape.unshift(1); - } - input = input.reshape(newShape); - } - - const reps: number[] = Array.from(shape); - for (let i = shape.length - 1; i >= 0; i--) { - if (input.shape[i] === shape[i]) { - reps[i] = 1; - } else if (input.shape[i] !== 1) { - throw new Error( - `broadcastTo(): [${xShape}] cannot be broadcast to [${shape}].`); - } - } - const axes = reps.map((n, i) => n > 1 ? i : -1).filter(i => i >= 0); - - if (axes.length === 0) { - return input.clone() as Tensor; - } - - return ENGINE.runKernelFunc( - backend => backend.tile(input, reps), {input}, - (dy: Tensor) => - ({input: () => dy.sum(axes, /*keepDims=*/true)})) as Tensor; -} - /** * Creates a new tensor with the same values and shape as the specified * tensor. @@ -1186,7 +1130,6 @@ export { }; export const batchToSpaceND = op({batchToSpaceND_}); -export const broadcastTo = op({broadcastTo_}); export const cast = op({cast_}); export const clone = op({clone_}); export const cumsum = op({cumsum_}); diff --git a/tfjs-core/src/ops/array_ops_test.ts b/tfjs-core/src/ops/array_ops_test.ts index 33103efb735..fe990494af2 100644 --- a/tfjs-core/src/ops/array_ops_test.ts +++ b/tfjs-core/src/ops/array_ops_test.ts @@ -19,95 +19,10 @@ import * as tf from '../index'; import {ALL_ENVS, BROWSER_ENVS, describeWithFlags, NODE_ENVS} from '../jasmine_util'; import {expectArraysClose, expectArraysEqual, expectPromiseToFail, expectValuesInRange} from '../test_util'; import {TypedArray} from '../types'; -import {Tensor} from '../tensor'; import * as util from '../util'; import {expectArrayInMeanStdRange, jarqueBeraNormalityTest} from './rand_util'; -describeWithFlags('broadcastTo', ALL_ENVS, () => { - it('[] -> [3,2]', async () => { - const a = tf.scalar(4.2); - const A = tf.tensor2d([[4.2, 4.2], - [4.2, 4.2], - [4.2, 4.2]]); - - expectArraysClose( - await A.array(), - await tf.broadcastTo(a,A.shape).array() - ); - - // test gradients - const w = tf.tensor2d([[ 4.7, 4.5], - [-6.1,-6.6], - [-8.1,-3.4]]), - f = (a: Tensor) => tf.broadcastTo(a,A.shape).mul(w).mean().asScalar(), - h = (a: Tensor) => a.mul(w).mean().asScalar(); - - const df = tf.grad(f), - dh = tf.grad(h); - - expectArraysClose( - await df(a).array(), - await dh(a).array() - ); - }); - - it('[2] -> [3,2]', async () => { - const a = tf.tensor1d( [1,2] ); - const A = tf.tensor2d([[1,2], - [1,2], - [1,2]]); - expectArraysClose( - await A.array(), - await tf.broadcastTo(a,A.shape).array() - ); - - // test gradients - const w = tf.tensor2d([[ 4.7, 4.5], - [-6.1,-6.6], - [-8.1,-3.4]]), - f = (a: Tensor) => tf.broadcastTo(a,A.shape).mul(w).mean().asScalar(), - h = (a: Tensor) => a.mul(w).mean().asScalar(); - - const df = tf.grad(f), - dh = tf.grad(h); - - expectArraysClose( - await df(a).array(), - await dh(a).array() - ); - }); - - it('[3,1] -> [3,2]', async () => { - const a = tf.tensor2d([[1], - [2], - [3]]); - const A = tf.tensor2d([[1,1], - [2,2], - [3,3]]); - - expectArraysClose( - await A.array(), - await tf.broadcastTo(a,A.shape).array() - ); - - // test gradients - const w = tf.tensor2d([[ 4.7, 4.5], - [-6.1,-6.6], - [-8.1,-3.4]]), - f = (a: Tensor) => tf.broadcastTo(a,A.shape).mul(w).mean().asScalar(), - h = (a: Tensor) => a.mul(w).mean().asScalar(); - - const df = tf.grad(f), - dh = tf.grad(h); - - expectArraysClose( - await df(a).array(), - await dh(a).array() - ); - }); -}); - describeWithFlags('zeros', ALL_ENVS, () => { it('1D default dtype', async () => { const a: tf.Tensor1D = tf.zeros([3]); diff --git a/tfjs-core/src/ops/broadcast_to.ts b/tfjs-core/src/ops/broadcast_to.ts new file mode 100644 index 00000000000..b78afec2d4d --- /dev/null +++ b/tfjs-core/src/ops/broadcast_to.ts @@ -0,0 +1,91 @@ +/** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelBackend} from '../backends/backend'; +import {ENGINE} from '../engine'; +import {BroadcastTo, BroadCastToAttrs, BroadcastToInputs} from '../kernel_names'; +import {NamedAttrMap} from '../kernel_registry'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {Rank, ShapeMap, TensorLike} from '../types'; + +import {op} from './operation'; + +/** + * Broadcast an array to a compatible shape NumPy-style. + * + * The tensor's shape is compared to the broadcast shape from end to beginning. + * Ones are prepended to the tensor's shape until is has the same length as + * the broadcast shape. If input.shape[i]==shape[i], the (i+1)-th axis is + * already broadcast-compatible. If input.shape[i]==1 and shape[i]==N, then + * the input tensor is tiled N times along that axis (using tf.tile). + * + * @param input The tensor that is to be broadcasted. + * @param shape The input is to be broadcast to this shape. + */ +/** @doc {heading: 'Tensors', subheading: 'Transformations'} */ +function broadcastTo_( + x: Tensor|TensorLike, shape: ShapeMap[R]): Tensor { + let input = convertToTensor(x, 'broadcastTo', 'x'); + const xShape = input.shape; + + if (shape.some(d => !(d > 0) || d % 1 !== 0)) { + throw new Error(`broadcastTo(): Invalid broadcast shape [${shape}].`); + } + + if (shape.length < input.rank) { + throw new Error(`broadcastTo(): shape.length=${shape.length} < input.rank=${ + input.rank}.`); + } + + if (shape.length > input.rank) { + const newShape = input.shape.slice(); + while (newShape.length < shape.length) { + newShape.unshift(1); + } + input = input.reshape(newShape); + } + + const reps: number[] = Array.from(shape); + for (let i = shape.length - 1; i >= 0; i--) { + if (input.shape[i] === shape[i]) { + reps[i] = 1; + } else if (input.shape[i] !== 1) { + throw new Error( + `broadcastTo(): [${xShape}] cannot be broadcast to [${shape}].`); + } + } + const axes = reps.map((n, i) => n > 1 ? i : -1).filter(i => i >= 0); + + if (axes.length === 0) { + return input.clone() as Tensor; + } + + const forward = (backend: KernelBackend) => backend.tile(input, reps); + const keepDims = true; + const backward = (dy: Tensor) => ({x: () => dy.sum(axes, keepDims)}); + + const inputs: BroadcastToInputs = {x: input}; + const attrs: BroadCastToAttrs = {reps}; + + return ENGINE.runKernelFunc( + forward, inputs as unknown as NamedTensorMap, backward, + BroadcastTo, attrs as unknown as NamedAttrMap) as Tensor; +} + +export const broadcastTo = op({broadcastTo_}); diff --git a/tfjs-core/src/ops/broadcast_to_test.ts b/tfjs-core/src/ops/broadcast_to_test.ts new file mode 100644 index 00000000000..91d9926b1cf --- /dev/null +++ b/tfjs-core/src/ops/broadcast_to_test.ts @@ -0,0 +1,76 @@ +/** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {Tensor} from '../tensor'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('broadcastTo', ALL_ENVS, () => { + it('[] -> [3,2]', async () => { + const a = tf.scalar(4.2); + const A = tf.tensor2d([[4.2, 4.2], [4.2, 4.2], [4.2, 4.2]]); + + expectArraysClose( + await A.array(), await tf.broadcastTo(a, A.shape).array()); + + // test gradients + const w = tf.tensor2d([[4.7, 4.5], [-6.1, -6.6], [-8.1, -3.4]]), + f = (a: Tensor) => + tf.broadcastTo(a, A.shape).mul(w).mean().asScalar(), + h = (a: Tensor) => a.mul(w).mean().asScalar(); + + const df = tf.grad(f), dh = tf.grad(h); + + expectArraysClose(await df(a).array(), await dh(a).array()); + }); + + it('[2] -> [3,2]', async () => { + const a = tf.tensor1d([1, 2]); + const A = tf.tensor2d([[1, 2], [1, 2], [1, 2]]); + expectArraysClose( + await A.array(), await tf.broadcastTo(a, A.shape).array()); + + // test gradients + const w = tf.tensor2d([[4.7, 4.5], [-6.1, -6.6], [-8.1, -3.4]]), + f = (a: Tensor) => + tf.broadcastTo(a, A.shape).mul(w).mean().asScalar(), + h = (a: Tensor) => a.mul(w).mean().asScalar(); + + const df = tf.grad(f), dh = tf.grad(h); + + expectArraysClose(await df(a).array(), await dh(a).array()); + }); + + it('[3,1] -> [3,2]', async () => { + const a = tf.tensor2d([[1], [2], [3]]); + const A = tf.tensor2d([[1, 1], [2, 2], [3, 3]]); + + expectArraysClose( + await A.array(), await tf.broadcastTo(a, A.shape).array()); + + // test gradients + const w = tf.tensor2d([[4.7, 4.5], [-6.1, -6.6], [-8.1, -3.4]]), + f = (a: Tensor) => + tf.broadcastTo(a, A.shape).mul(w).mean().asScalar(), + h = (a: Tensor) => a.mul(w).mean().asScalar(); + + const df = tf.grad(f), dh = tf.grad(h); + + expectArraysClose(await df(a).array(), await dh(a).array()); + }); +}); diff --git a/tfjs-core/src/ops/ops.ts b/tfjs-core/src/ops/ops.ts index 035fe7ba0e0..720d02bc6a0 100644 --- a/tfjs-core/src/ops/ops.ts +++ b/tfjs-core/src/ops/ops.ts @@ -16,6 +16,7 @@ */ // Modularized ops. +export {broadcastTo} from './broadcast_to'; export {square} from './square'; export {squaredDifference} from './squared_difference'; diff --git a/tfjs-core/src/public/chained_ops/broadcast_to.ts b/tfjs-core/src/public/chained_ops/broadcast_to.ts new file mode 100644 index 00000000000..49075b7df8a --- /dev/null +++ b/tfjs-core/src/public/chained_ops/broadcast_to.ts @@ -0,0 +1,31 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {broadcastTo} from '../../ops/broadcast_to'; +import {Tensor} from '../../tensor'; +import {Rank, ShapeMap} from '../../types'; + +declare module '../../tensor' { + interface Tensor { + broadcastTo(shape: ShapeMap[R]): Tensor; + } +} + +Tensor.prototype.broadcastTo = function(shape: ShapeMap[R]): + Tensor { + return broadcastTo(this, shape); +}; diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts index 654c9631f29..3a79a8dd6d9 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts @@ -16,3 +16,4 @@ */ import './squared_difference'; +import './broadcast_to'; diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts new file mode 100644 index 00000000000..0ce556c9aa0 --- /dev/null +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts @@ -0,0 +1,40 @@ +/** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../../index'; +import {ALL_ENVS, describeWithFlags} from '../../jasmine_util'; + +// Testing for presence of chained op in this file will allow us to more easily +// customize when we want this test to run. Currently it will run be default +// (And kerma will always load the chain augmentor files). But this gives us +// flexibility to change in future. + +const CHAINED_OPS = [ + 'square', + 'broadcastTo', +]; + +describeWithFlags('chained ops', ALL_ENVS, () => { + it('all chained ops should exist on tensor ', async () => { + const tensor = tf.tensor([1, 2, 3]); + for (const opName of CHAINED_OPS) { + //@ts-ignore + expect(typeof tensor[opName]) + .toBe('function', `${opName} chained op not found`); + } + }); +}); diff --git a/tfjs-core/src/register_all_gradients.ts b/tfjs-core/src/register_all_gradients.ts index 6fad5f8d46a..5438e897224 100644 --- a/tfjs-core/src/register_all_gradients.ts +++ b/tfjs-core/src/register_all_gradients.ts @@ -14,6 +14,7 @@ * limitations under the License. * ============================================================================= */ +import {broadcastToGradConfig} from './gradients/BroadcastTo_grad'; import {squareGradConfig} from './gradients/Square_grad'; import {squaredDifferenceGradConfig} from './gradients/SquaredDifference_grad'; import {GradConfig} from './kernel_registry'; @@ -23,6 +24,7 @@ import {registerGradient} from './kernel_registry'; const gradConfigs: GradConfig[] = [ squareGradConfig, squaredDifferenceGradConfig, + broadcastToGradConfig, ]; for (const gradientConfig of gradConfigs) { diff --git a/tfjs-core/src/tests.ts b/tfjs-core/src/tests.ts index c1dde0528d8..f08d4ded292 100644 --- a/tfjs-core/src/tests.ts +++ b/tfjs-core/src/tests.ts @@ -43,6 +43,7 @@ import './ops/axis_util_test'; import './ops/batchnorm_test'; import './ops/binary_ops_test'; import './ops/boolean_mask_test'; +import './ops/broadcast_to_test'; import './ops/broadcast_util_test'; import './ops/clone_test'; import './ops/compare_ops_test'; From 75d788355150be09ae8ecc187a8d980978954957 Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Tue, 17 Mar 2020 10:13:24 -0400 Subject: [PATCH 2/6] save --- tfjs-core/src/tests.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/tfjs-core/src/tests.ts b/tfjs-core/src/tests.ts index f08d4ded292..dddd4a3eaa1 100644 --- a/tfjs-core/src/tests.ts +++ b/tfjs-core/src/tests.ts @@ -105,6 +105,7 @@ import './optimizers/sgd_optimizer_test'; import './platforms/platform_browser_test'; import './platforms/platform_node_test'; import './profiler_test'; +import './public/chained_ops/register_all_chained_ops_test'; import './serialization_test'; import './tape_test'; import './tensor_test'; From f8a6e7e3c5771ce5b4c10e52d9724c91d5801660 Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Tue, 17 Mar 2020 10:32:33 -0400 Subject: [PATCH 3/6] save --- tfjs-core/src/engine.ts | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tfjs-core/src/engine.ts b/tfjs-core/src/engine.ts index 7a53a917826..16beb1a82d6 100644 --- a/tfjs-core/src/engine.ts +++ b/tfjs-core/src/engine.ts @@ -534,8 +534,7 @@ export class Engine implements TensorTracker, DataMover { */ runKernelFunc( forwardFunc: ForwardFunc, inputs: I, - backwardsFunc?: - (dy: T, saved: Tensor[]|TensorInfo[]) => {[P in keyof I]: () => I[P]}, + backwardsFunc?: (dy: T, saved: Tensor[]) => {[P in keyof I]: () => I[P]}, kernelName?: string, attrs?: NamedAttrMap, inputsToSave: Tensor[] = [], outputsToSave: boolean[] = []): T { let outputs: Tensor[]; From 6a797fa5244ee227ce5449e06b528a6aac403db5 Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Wed, 18 Mar 2020 13:35:08 -0400 Subject: [PATCH 4/6] change attrs for broadcastTo --- tfjs-core/src/gradients/BroadcastTo_grad.ts | 21 +++++++++++++++++++-- tfjs-core/src/kernel_names.ts | 3 ++- tfjs-core/src/ops/broadcast_to.ts | 5 +++-- 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/tfjs-core/src/gradients/BroadcastTo_grad.ts b/tfjs-core/src/gradients/BroadcastTo_grad.ts index 2f4c9340bed..6cf488fe240 100644 --- a/tfjs-core/src/gradients/BroadcastTo_grad.ts +++ b/tfjs-core/src/gradients/BroadcastTo_grad.ts @@ -24,8 +24,25 @@ export const broadcastToGradConfig: GradConfig = { gradFunc: (dy: Tensor, saved: Tensor[], attrs: NamedAttrMap) => { const broadCastToAttrs: BroadCastToAttrs = attrs as unknown as BroadCastToAttrs; - const axes = - broadCastToAttrs.reps.map((n, i) => n > 1 ? i : -1).filter(i => i >= 0); + + const inputShape = broadCastToAttrs.inputShape; + const outputShape = broadCastToAttrs.shape; + + const reps: number[] = Array.from(outputShape); + for (let i = inputShape.length - 1; i >= 0; i--) { + if (inputShape[i] === outputShape[i]) { + reps[i] = 1; + } else if (inputShape[i] !== 1) { + throw new Error(`broadcastTo(): [${ + inputShape}] cannot be broadcast to [${outputShape}].`); + } + } + const axes: number[] = []; + for (let i = 0; i < reps.length; i++) { + if (reps[i] > 1) { + axes.push(i); + } + } const keepDims = true; return {x: () => dy.sum(axes, keepDims)}; } diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index bc7e3bdce28..4b2ce190ef9 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -40,7 +40,8 @@ export interface NonMaxSuppressionV5Attrs { export const BroadcastTo = 'BroadcastTo'; export type BroadcastToInputs = Pick; export interface BroadCastToAttrs { - reps: number[]; + shape: number[]; + inputShape: number[]; // for gradient } /** diff --git a/tfjs-core/src/ops/broadcast_to.ts b/tfjs-core/src/ops/broadcast_to.ts index b78afec2d4d..cdb231c3489 100644 --- a/tfjs-core/src/ops/broadcast_to.ts +++ b/tfjs-core/src/ops/broadcast_to.ts @@ -61,9 +61,10 @@ function broadcastTo_( input = input.reshape(newShape); } + const inputShape = input.shape; const reps: number[] = Array.from(shape); for (let i = shape.length - 1; i >= 0; i--) { - if (input.shape[i] === shape[i]) { + if (inputShape[i] === shape[i]) { reps[i] = 1; } else if (input.shape[i] !== 1) { throw new Error( @@ -81,7 +82,7 @@ function broadcastTo_( const backward = (dy: Tensor) => ({x: () => dy.sum(axes, keepDims)}); const inputs: BroadcastToInputs = {x: input}; - const attrs: BroadCastToAttrs = {reps}; + const attrs: BroadCastToAttrs = {shape, inputShape}; return ENGINE.runKernelFunc( forward, inputs as unknown as NamedTensorMap, backward, From a9fa1d00da07d26aaecb73ee723b4e3ef1535411 Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Wed, 18 Mar 2020 13:39:10 -0400 Subject: [PATCH 5/6] save --- tfjs-core/src/gradients/BroadcastTo_grad.ts | 2 +- tfjs-core/src/ops/broadcast_to.ts | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tfjs-core/src/gradients/BroadcastTo_grad.ts b/tfjs-core/src/gradients/BroadcastTo_grad.ts index 6cf488fe240..0098ad5f524 100644 --- a/tfjs-core/src/gradients/BroadcastTo_grad.ts +++ b/tfjs-core/src/gradients/BroadcastTo_grad.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2019 Google Inc. All Rights Reserved. + * Copyright 2020 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at diff --git a/tfjs-core/src/ops/broadcast_to.ts b/tfjs-core/src/ops/broadcast_to.ts index cdb231c3489..c1904468912 100644 --- a/tfjs-core/src/ops/broadcast_to.ts +++ b/tfjs-core/src/ops/broadcast_to.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2018 Google Inc. All Rights Reserved. + * Copyright 2020 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at From 6cb5970c81cf6aea20a2581383b5898f0b47e3d2 Mon Sep 17 00:00:00 2001 From: Yannick Assogba Date: Wed, 18 Mar 2020 14:33:31 -0400 Subject: [PATCH 6/6] kick build