Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
dsmilkov committed Dec 18, 2019
1 parent 1117fe0 commit 4950ae5
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 57 deletions.
45 changes: 28 additions & 17 deletions tfjs-core/src/engine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@

import {BackendTimingInfo, DataMover, KernelBackend} from './backends/backend';
import {Environment, setEnvironmentGlobal} from './environment';
import {getKernel, getKernelsForBackend, NamedAttrMap, TensorInfo} from './kernel_registry';
import {getGradient, getKernel, getKernelsForBackend, NamedAttrMap, TensorInfo} from './kernel_registry';
import {Profiler} from './profiler';
import {backpropagateGradients, getFilteredNodesXToY, NamedGradientMap, TapeNode} from './tape';
import {DataId, setTensorTracker, Tensor, TensorTracker, Variable} from './tensor';
import {GradSaveFunc, NamedTensorMap, NamedVariableMap, TensorContainer} from './tensor_types';
import {getTensorsInContainer} from './tensor_util';
import {BackendValues, DataType, DataValues} from './types';
import * as util from './util';
import {bytesFromStringArray, makeOnesTypedArray, makeZerosTypedArray, now, sizeFromShape} from './util';
import {bytesFromStringArray, makeOnesTypedArray, now, sizeFromShape} from './util';

/**
* A function that computes an output. The save function is for saving tensors
Expand Down Expand Up @@ -798,16 +798,32 @@ export class Engine implements TensorTracker, DataMover {

private addTapeNode(
kernelName: string, inputs: NamedTensorMap, outputs: Tensor[],
gradient: (dy: Tensor|Tensor[], saved: Tensor[]) => NamedGradientMap,
gradientsFunc: (dy: Tensor|Tensor[], saved: Tensor[]) => NamedGradientMap,
saved: Tensor[]): void {
const tapeNode: TapeNode = {
id: this.state.nextTapeNodeId++,
kernelName,
inputs,
outputs,
saved,
gradient
};
const tapeNode: TapeNode =
{id: this.state.nextTapeNodeId++, kernelName, inputs, outputs, saved};

const gradConfig = getGradient(kernelName);
if (gradConfig) {
gradientsFunc = gradConfig.gradFunc;
}
if (gradientsFunc != null) {
tapeNode.gradient = (dys: Tensor[]) => {
// TODO(smilkov): To optimize back-prop, pass dys that are not used in
// the backprop graph to the user as null instead of zeros
dys = dys.map((dy, i) => {
if (dy == null) {
const output = outputs[i];
const vals = util.makeZerosTypedArray(output.size, output.dtype);
return this.makeTensor(vals, output.shape, output.dtype);
}
return dy;
});
// Grad functions of ops with single outputs expect a dy, while ops
// with multiple outputs expect dys (array of dy).
return gradientsFunc(dys.length > 1 ? dys : dys[0], saved);
};
}
this.state.activeTape.push(tapeNode);
}

Expand Down Expand Up @@ -915,7 +931,7 @@ export class Engine implements TensorTracker, DataMover {
backpropagateGradients(
accumulatedGradientMap, filteredTape,
// Pass the tidy function to avoid circular dep with `tape.ts`.
f => this.tidy(f as ScopeFn<Tensor>), zeros);
f => this.tidy(f as ScopeFn<Tensor>));
const grads = xs.map(x => accumulatedGradientMap[x.id]);

if (this.state.gradientDepth === 0) {
Expand Down Expand Up @@ -1053,11 +1069,6 @@ function ones(shape: number[]): Tensor {
return ENGINE.makeTensor(values, shape, 'float32');
}

function zeros(shape: number[]): Tensor {
const values = makeZerosTypedArray(sizeFromShape(shape), 'float32');
return ENGINE.makeTensor(values, shape, 'float32');
}

let GLOBAL: {_tfengine: Engine};
function getGlobalNamespace(): {_tfengine: Engine} {
if (GLOBAL == null) {
Expand Down
21 changes: 7 additions & 14 deletions tfjs-core/src/tape.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
* =============================================================================
*/

import {getGradient} from './kernel_registry';
import {Tensor} from './tensor';
import {NamedTensorMap} from './tensor_types';
import * as util from './util';
Expand All @@ -26,7 +25,7 @@ export interface TapeNode {
outputs: Tensor[];
inputs: NamedTensorMap;
// Optional params, defined only for ops with gradient impl.
gradient?: (dy: Tensor|Tensor[], saved: Tensor[]) => NamedGradientMap;
gradient?: (dys: Tensor[]) => NamedGradientMap;
saved?: Tensor[];
}

Expand Down Expand Up @@ -131,8 +130,7 @@ export function getFilteredNodesXToY(
*/
export function backpropagateGradients(
tensorAccumulatedGradientMap: {[tensorId: number]: Tensor},
filteredTape: TapeNode[], tidy: (f: Function) => Tensor,
zeros: (shape: number[]) => Tensor) {
filteredTape: TapeNode[], tidy: (f: Function) => Tensor) {
// Walk the tape backward and keep a map of Tensor to its gradient.
for (let i = filteredTape.length - 1; i >= 0; i--) {
const node = filteredTape[i];
Expand All @@ -144,24 +142,19 @@ export function backpropagateGradients(
dys.push(gradTensor);
} else {
// This particular output is not in the back-propagation subgraph, so it
// does not affect the final output.
// TODO(smilkov): To optimize back-prop, pass dys that are not used in
// the backprop graph to the user as null instead of zeros.
dys.push(zeros(o.shape));
// does not affect the final output, thus we put null for its dy.
dys.push(null);
}
});
const gradConfig = getGradient(node.kernelName);
const gradFunc = gradConfig ? gradConfig.gradFunc : node.gradient;
if (gradFunc == null) {

if (node.gradient == null) {
throw new Error(
`Cannot compute gradient: gradient function not found ` +
`for ${node.kernelName}.`);
}

// Backprop dy through this node and accumulate gradients over the inputs.
// Grad functions of ops with single outputs expect a dy, while ops
// with multiple outputs expect dys (array of dy).
const inputGradients = gradFunc(dys.length > 1 ? dys : dys[0], node.saved);
const inputGradients = node.gradient(dys);

for (const inputName in node.inputs) {
if (!(inputName in inputGradients)) {
Expand Down
45 changes: 19 additions & 26 deletions tfjs-core/src/tape_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@ import {zerosLike} from './ops/ops';
import {backpropagateGradients, getFilteredNodesXToY, TapeNode} from './tape';
import {expectArraysClose} from './test_util';

function zeros(shape: number[]): tf.Tensor {
return tf.zeros(shape, 'float32');
}

describeWithFlags('getFilteredNodesXToY', ALL_ENVS, () => {
it('no paths from x to y', () => {
const x = tf.scalar(1);
Expand Down Expand Up @@ -270,7 +266,7 @@ describeWithFlags('backpropagateGradients', ALL_ENVS, () => {
expect(
() => backpropagateGradients(
accumulatedGradientsMap, tape,
f => tf.tidy(f as ScopeFn<tf.Tensor>), zeros))
f => tf.tidy(f as ScopeFn<tf.Tensor>)))
.toThrowError();
});

Expand All @@ -288,14 +284,13 @@ describeWithFlags('backpropagateGradients', ALL_ENVS, () => {
kernelName: 'node0',
inputs: {x},
outputs: [y],
gradient: (dy: tf.Tensor) => {
return {x: () => dy.add(tf.scalar(1))};
gradient: (dys: tf.Tensor[]) => {
return {x: () => dys[0].add(tf.scalar(1))};
}
}];

backpropagateGradients(
accumulatedGradientsMap, tape, f => tf.tidy(f as ScopeFn<tf.Tensor>),
zeros);
accumulatedGradientsMap, tape, f => tf.tidy(f as ScopeFn<tf.Tensor>));

expectArraysClose(await accumulatedGradientsMap[x.id].data(), [2]);
});
Expand All @@ -316,24 +311,23 @@ describeWithFlags('backpropagateGradients', ALL_ENVS, () => {
kernelName: 'node0',
inputs: {x},
outputs: [intermediate],
gradient: (dy: tf.Tensor) => {
return {x: () => dy.add(tf.scalar(1))};
gradient: (dys: tf.Tensor[]) => {
return {x: () => dys[0].add(tf.scalar(1))};
}
},
{
id: 1,
kernelName: 'node1',
inputs: {intermediate},
outputs: [y],
gradient: (dy: tf.Tensor) => {
return {intermediate: () => dy.add(tf.scalar(1))};
gradient: (dys: tf.Tensor[]) => {
return {intermediate: () => dys[0].add(tf.scalar(1))};
}
}
];

backpropagateGradients(
accumulatedGradientsMap, tape, f => tf.tidy(f as ScopeFn<tf.Tensor>),
zeros);
accumulatedGradientsMap, tape, f => tf.tidy(f as ScopeFn<tf.Tensor>));

// dx = dy + 1 + 1
expectArraysClose(await accumulatedGradientsMap[x.id].data(), [3]);
Expand All @@ -356,36 +350,35 @@ describeWithFlags('backpropagateGradients', ALL_ENVS, () => {
kernelName: 'node0',
inputs: {x},
outputs: [intermediate1],
gradient: (dy: tf.Tensor) => {
return {x: () => dy.add(tf.scalar(1))};
gradient: (dys: tf.Tensor[]) => {
return {x: () => dys[0].add(tf.scalar(1))};
}
},
{
id: 1,
kernelName: 'node1',
inputs: {x},
outputs: [intermediate2],
gradient: (dy: tf.Tensor) => {
return {x: () => dy.add(tf.scalar(1))};
gradient: (dys: tf.Tensor[]) => {
return {x: () => dys[0].add(tf.scalar(1))};
}
},
{
id: 2,
kernelName: 'node2',
inputs: {intermediate1, intermediate2},
outputs: [y],
gradient: (dy: tf.Tensor) => {
gradient: (dys: tf.Tensor[]) => {
return {
intermediate1: () => dy.add(tf.scalar(1)),
intermediate2: () => dy.add(tf.scalar(1))
intermediate1: () => dys[0].add(tf.scalar(1)),
intermediate2: () => dys[0].add(tf.scalar(1))
};
}
}
];

backpropagateGradients(
accumulatedGradientsMap, tape, f => tf.tidy(f as ScopeFn<tf.Tensor>),
zeros);
accumulatedGradientsMap, tape, f => tf.tidy(f as ScopeFn<tf.Tensor>));

// dx = dy + 1 + 1 + 1 + 1 + 1
expectArraysClose(
Expand Down Expand Up @@ -417,8 +410,8 @@ describeWithFlags('backpropagateGradients', ALL_ENVS, () => {
}];

backpropagateGradients(
accumulatedGradientsMap, tape, f => tf.tidy(f as ScopeFn<tf.Tensor>),
zeros);
accumulatedGradientsMap, tape,
f => tf.tidy(f as ScopeFn<tf.Tensor>));
expectArraysClose(await accumulatedGradientsMap[x.id].data(), [0, 5, 0]);
expectArraysClose(await dys[0].data(), [0]);
expectArraysClose(await dys[1].data(), [5]);
Expand Down

0 comments on commit 4950ae5

Please sign in to comment.