Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions tfjs-core/src/engine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

import {BackendTimingInfo, DataMover, KernelBackend} from './backends/backend';
import {Environment, setEnvironmentGlobal} from './environment';
import {getGradient, getKernel, getKernelsForBackend, NamedAttrMap, TensorInfo} from './kernel_registry';
import {getGradient, getKernel, getKernelsForBackend, GradFunc, NamedAttrMap, TensorInfo} from './kernel_registry';
import {Profiler} from './profiler';
import {backpropagateGradients, getFilteredNodesXToY, NamedGradientMap, TapeNode} from './tape';
import {backpropagateGradients, getFilteredNodesXToY, TapeNode} from './tape';
import {DataId, setTensorTracker, Tensor, TensorTracker, Variable} from './tensor';
import {GradSaveFunc, NamedTensorMap, NamedVariableMap, TensorContainer} from './tensor_types';
import {getTensorsInContainer} from './tensor_util';
Expand Down Expand Up @@ -465,7 +465,7 @@ export class Engine implements TensorTracker, DataMover {
const inputs = {x};
const grad = (dy: Tensor) => ({x: () => dy.toFloat()});
const saved: Tensor[] = [];
this.addTapeNode(this.state.activeScope.name, inputs, [y], grad, saved);
this.addTapeNode(this.state.activeScope.name, inputs, [y], grad, saved, {});
return y;
}

Expand Down Expand Up @@ -604,7 +604,8 @@ export class Engine implements TensorTracker, DataMover {
});

if (isTapeOn) {
this.addTapeNode(kernelName, inputs, outputs, backwardsFunc, saved);
this.addTapeNode(
kernelName, inputs, outputs, backwardsFunc, saved, attrs);
}

if (this.state.profiling) {
Expand Down Expand Up @@ -798,8 +799,7 @@ export class Engine implements TensorTracker, DataMover {

private addTapeNode(
kernelName: string, inputs: NamedTensorMap, outputs: Tensor[],
gradientsFunc: (dy: Tensor|Tensor[], saved: Tensor[]) => NamedGradientMap,
saved: Tensor[]): void {
gradientsFunc: GradFunc, saved: Tensor[], attrs: NamedAttrMap): void {
const tapeNode: TapeNode =
{id: this.state.nextTapeNodeId++, kernelName, inputs, outputs, saved};

Expand All @@ -821,7 +821,7 @@ export class Engine implements TensorTracker, DataMover {
});
// Grad functions of ops with single outputs expect a dy, while ops
// with multiple outputs expect dys (array of dy).
return gradientsFunc(dys.length > 1 ? dys : dys[0], saved);
return gradientsFunc(dys.length > 1 ? dys : dys[0], saved, attrs);
};
}
this.state.activeTape.push(tapeNode);
Expand Down
49 changes: 49 additions & 0 deletions tfjs-core/src/gradients/BroadcastTo_grad.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {BroadcastTo, BroadCastToAttrs} from '../kernel_names';
import {GradConfig, NamedAttrMap} from '../kernel_registry';
import {Tensor} from '../tensor';

export const broadcastToGradConfig: GradConfig = {
kernelName: BroadcastTo,
gradFunc: (dy: Tensor, saved: Tensor[], attrs: NamedAttrMap) => {
const broadCastToAttrs: BroadCastToAttrs =
attrs as unknown as BroadCastToAttrs;

const inputShape = broadCastToAttrs.inputShape;
const outputShape = broadCastToAttrs.shape;

const reps: number[] = Array.from(outputShape);
for (let i = inputShape.length - 1; i >= 0; i--) {
if (inputShape[i] === outputShape[i]) {
reps[i] = 1;
} else if (inputShape[i] !== 1) {
throw new Error(`broadcastTo(): [${
inputShape}] cannot be broadcast to [${outputShape}].`);
}
}
const axes: number[] = [];
for (let i = 0; i < reps.length; i++) {
if (reps[i] > 1) {
axes.push(i);
}
}
const keepDims = true;
return {x: () => dy.sum(axes, keepDims)};
}
};
7 changes: 7 additions & 0 deletions tfjs-core/src/kernel_names.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ export interface NonMaxSuppressionV5Attrs {
softNmsSigma: number;
}

export const BroadcastTo = 'BroadcastTo';
export type BroadcastToInputs = Pick<NamedTensorInfoMap, 'x'>;
export interface BroadCastToAttrs {
shape: number[];
inputShape: number[]; // for gradient
}

/**
* TensorFlow.js-only kernels
*/
Expand Down
6 changes: 4 additions & 2 deletions tfjs-core/src/kernel_registry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
* =============================================================================
*/

import {NamedGradientMap} from './tape';
import {Tensor} from './tensor';
import {DataType, RecursiveArray} from './types';

Expand All @@ -37,8 +38,9 @@ export type KernelFunc = (params: {
}) => TensorInfo|TensorInfo[];

/** The function to run when computing a gradient during backprop. */
export type GradFunc = (dy: Tensor|Tensor[], saved: Tensor[]) =>
({[inputName: string]: () => Tensor});
export type GradFunc =
(dy: Tensor|Tensor[], saved: Tensor[], attrs: NamedAttrMap) =>
NamedGradientMap;

/** Function that gets called after the backend initializes. */
export type KernelSetupFunc = (backend: {}) => void;
Expand Down
57 changes: 0 additions & 57 deletions tfjs-core/src/ops/array_ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,62 +26,6 @@ import {op} from './operation';
import {MPRandGauss, RandGamma, UniformRandom} from './rand';
import {zeros, zerosLike} from './tensor_ops';

/**
* Broadcast an array to a compatible shape NumPy-style.
*
* The tensor's shape is compared to the broadcast shape from end to beginning.
* Ones are prepended to the tensor's shape until is has the same length as
* the broadcast shape. If input.shape[i]==shape[i], the (i+1)-th axis is
* already broadcast-compatible. If input.shape[i]==1 and shape[i]==N, then
* the input tensor is tiled N times along that axis (using tf.tile).
*
* @param input The tensor that is to be broadcasted.
* @param shape The input is to be broadcast to this shape.
*/
/** @doc {heading: 'Tensors', subheading: 'Transformations'} */
function broadcastTo_<R extends Rank>(
x: Tensor|TensorLike, shape: ShapeMap[R]): Tensor<R> {
let input = convertToTensor(x, 'broadcastTo', 'x');
const xShape = input.shape;

if (shape.some(d => !(d > 0) || d % 1 !== 0)) {
throw new Error(`broadcastTo(): Invalid broadcast shape [${shape}].`);
}

if (shape.length < input.rank) {
throw new Error(`broadcastTo(): shape.length=${shape.length} < input.rank=${
input.rank}.`);
}

if (shape.length > input.rank) {
const newShape = input.shape.slice();
while (newShape.length < shape.length) {
newShape.unshift(1);
}
input = input.reshape(newShape);
}

const reps: number[] = Array.from(shape);
for (let i = shape.length - 1; i >= 0; i--) {
if (input.shape[i] === shape[i]) {
reps[i] = 1;
} else if (input.shape[i] !== 1) {
throw new Error(
`broadcastTo(): [${xShape}] cannot be broadcast to [${shape}].`);
}
}
const axes = reps.map((n, i) => n > 1 ? i : -1).filter(i => i >= 0);

if (axes.length === 0) {
return input.clone() as Tensor<R>;
}

return ENGINE.runKernelFunc(
backend => backend.tile(input, reps), {input},
(dy: Tensor) =>
({input: () => dy.sum(axes, /*keepDims=*/true)})) as Tensor<R>;
}

/**
* Creates a new tensor with the same values and shape as the specified
* tensor.
Expand Down Expand Up @@ -1186,7 +1130,6 @@ export {
};

export const batchToSpaceND = op({batchToSpaceND_});
export const broadcastTo = op({broadcastTo_});
export const cast = op({cast_});
export const clone = op({clone_});
export const cumsum = op({cumsum_});
Expand Down
85 changes: 0 additions & 85 deletions tfjs-core/src/ops/array_ops_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,95 +19,10 @@ import * as tf from '../index';
import {ALL_ENVS, BROWSER_ENVS, describeWithFlags, NODE_ENVS} from '../jasmine_util';
import {expectArraysClose, expectArraysEqual, expectPromiseToFail, expectValuesInRange} from '../test_util';
import {TypedArray} from '../types';
import {Tensor} from '../tensor';
import * as util from '../util';

import {expectArrayInMeanStdRange, jarqueBeraNormalityTest} from './rand_util';

describeWithFlags('broadcastTo', ALL_ENVS, () => {
it('[] -> [3,2]', async () => {
const a = tf.scalar(4.2);
const A = tf.tensor2d([[4.2, 4.2],
[4.2, 4.2],
[4.2, 4.2]]);

expectArraysClose(
await A.array(),
await tf.broadcastTo(a,A.shape).array()
);

// test gradients
const w = tf.tensor2d([[ 4.7, 4.5],
[-6.1,-6.6],
[-8.1,-3.4]]),
f = (a: Tensor) => tf.broadcastTo(a,A.shape).mul(w).mean().asScalar(),
h = (a: Tensor) => a.mul(w).mean().asScalar();

const df = tf.grad(f),
dh = tf.grad(h);

expectArraysClose(
await df(a).array(),
await dh(a).array()
);
});

it('[2] -> [3,2]', async () => {
const a = tf.tensor1d( [1,2] );
const A = tf.tensor2d([[1,2],
[1,2],
[1,2]]);
expectArraysClose(
await A.array(),
await tf.broadcastTo(a,A.shape).array()
);

// test gradients
const w = tf.tensor2d([[ 4.7, 4.5],
[-6.1,-6.6],
[-8.1,-3.4]]),
f = (a: Tensor) => tf.broadcastTo(a,A.shape).mul(w).mean().asScalar(),
h = (a: Tensor) => a.mul(w).mean().asScalar();

const df = tf.grad(f),
dh = tf.grad(h);

expectArraysClose(
await df(a).array(),
await dh(a).array()
);
});

it('[3,1] -> [3,2]', async () => {
const a = tf.tensor2d([[1],
[2],
[3]]);
const A = tf.tensor2d([[1,1],
[2,2],
[3,3]]);

expectArraysClose(
await A.array(),
await tf.broadcastTo(a,A.shape).array()
);

// test gradients
const w = tf.tensor2d([[ 4.7, 4.5],
[-6.1,-6.6],
[-8.1,-3.4]]),
f = (a: Tensor) => tf.broadcastTo(a,A.shape).mul(w).mean().asScalar(),
h = (a: Tensor) => a.mul(w).mean().asScalar();

const df = tf.grad(f),
dh = tf.grad(h);

expectArraysClose(
await df(a).array(),
await dh(a).array()
);
});
});

describeWithFlags('zeros', ALL_ENVS, () => {
it('1D default dtype', async () => {
const a: tf.Tensor1D = tf.zeros([3]);
Expand Down
92 changes: 92 additions & 0 deletions tfjs-core/src/ops/broadcast_to.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/**
* @license
* Copyright 2020 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {KernelBackend} from '../backends/backend';
import {ENGINE} from '../engine';
import {BroadcastTo, BroadCastToAttrs, BroadcastToInputs} from '../kernel_names';
import {NamedAttrMap} from '../kernel_registry';
import {Tensor} from '../tensor';
import {NamedTensorMap} from '../tensor_types';
import {convertToTensor} from '../tensor_util_env';
import {Rank, ShapeMap, TensorLike} from '../types';

import {op} from './operation';

/**
* Broadcast an array to a compatible shape NumPy-style.
*
* The tensor's shape is compared to the broadcast shape from end to beginning.
* Ones are prepended to the tensor's shape until is has the same length as
* the broadcast shape. If input.shape[i]==shape[i], the (i+1)-th axis is
* already broadcast-compatible. If input.shape[i]==1 and shape[i]==N, then
* the input tensor is tiled N times along that axis (using tf.tile).
*
* @param input The tensor that is to be broadcasted.
* @param shape The input is to be broadcast to this shape.
*/
/** @doc {heading: 'Tensors', subheading: 'Transformations'} */
function broadcastTo_<R extends Rank>(
x: Tensor|TensorLike, shape: ShapeMap[R]): Tensor<R> {
let input = convertToTensor(x, 'broadcastTo', 'x');
const xShape = input.shape;

if (shape.some(d => !(d > 0) || d % 1 !== 0)) {
throw new Error(`broadcastTo(): Invalid broadcast shape [${shape}].`);
}

if (shape.length < input.rank) {
throw new Error(`broadcastTo(): shape.length=${shape.length} < input.rank=${
input.rank}.`);
}

if (shape.length > input.rank) {
const newShape = input.shape.slice();
while (newShape.length < shape.length) {
newShape.unshift(1);
}
input = input.reshape(newShape);
}

const inputShape = input.shape;
const reps: number[] = Array.from(shape);
for (let i = shape.length - 1; i >= 0; i--) {
if (inputShape[i] === shape[i]) {
reps[i] = 1;
} else if (input.shape[i] !== 1) {
throw new Error(
`broadcastTo(): [${xShape}] cannot be broadcast to [${shape}].`);
}
}
const axes = reps.map((n, i) => n > 1 ? i : -1).filter(i => i >= 0);

if (axes.length === 0) {
return input.clone() as Tensor<R>;
}

const forward = (backend: KernelBackend) => backend.tile(input, reps);
const keepDims = true;
const backward = (dy: Tensor) => ({x: () => dy.sum(axes, keepDims)});

const inputs: BroadcastToInputs = {x: input};
const attrs: BroadCastToAttrs = {shape, inputShape};

return ENGINE.runKernelFunc(
forward, inputs as unknown as NamedTensorMap, backward,
BroadcastTo, attrs as unknown as NamedAttrMap) as Tensor<R>;
}

export const broadcastTo = op({broadcastTo_});
Loading