Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 71 additions & 13 deletions tfjs-core/src/engine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -535,24 +535,15 @@ export class Engine implements TensorTracker, DataMover {
runKernelFunc<T extends Tensor|Tensor[], I extends NamedTensorMap>(
forwardFunc: ForwardFunc<T>, inputs: I,
backwardsFunc?: (dy: T, saved: Tensor[]) => {[P in keyof I]: () => I[P]},
kernelName?: string, attrs?: NamedAttrMap, inputsToSave: Tensor[] = [],
outputsToSave: boolean[] = []): T {
kernelName?: string, attrs?: NamedAttrMap, inputsToSave?: Tensor[],
outputsToSave?: boolean[]): T {
let outputs: Tensor[];
let saved: Tensor[] = [];
const isTapeOn = this.isTapeOn();
if (kernelName == null) {
kernelName =
this.state.activeScope != null ? this.state.activeScope.name : '';
}
const saveFunc: GradSaveFunc = (tensors) => {
// Do not save unless we are recording to the tape. Otherwise it would
// cause a mem leak since we would never run backprop, which disposes
// the kept tensors.
if (!isTapeOn) {
return;
}
saved = tensors.map(tensor => this.keep(this.clone(tensor)));
};

const startingBytecount = this.state.numBytes;
const startingNumTensors = this.state.numTensors;
Expand All @@ -575,12 +566,40 @@ export class Engine implements TensorTracker, DataMover {
const outTensors = outInfos.map(
({dataId, shape, dtype}) =>
this.makeTensorFromDataId(dataId, shape, dtype));
const outsToSave = outTensors.filter((_, i) => outputsToSave[i]);

// Save the inputs and outputs.
saveFunc((inputsToSave || []).slice().concat(outsToSave));
// Do not save unless we are recording to the tape. Otherwise it would
// cause a mem leak since we would never run backprop, which disposes
// the kept tensors.
if (isTapeOn) {
let tensorsToSave =
this.getTensorsForGradient(kernelName, inputs, outTensors);
if (tensorsToSave == null) {
// Fallback for ops that call runKernelFunc and pass in
// inputsToSave and outputsToSave. Currently this is the set of ops
// with kernel support in the WASM backend. Once those ops and
// respective gradients are modularised we can remove this path.
if (outputsToSave == null) {
outputsToSave = [];
}
const outsToSave = outTensors.filter((_, i) => outputsToSave[i]);
tensorsToSave = (inputsToSave || []).slice().concat(outsToSave);
}
saved = this.saveTensorsForBackwardMode(tensorsToSave);
}
return outTensors;
};
} else {
const saveFunc: GradSaveFunc = (tensors) => {
// Do not save unless we are recording to the tape. Otherwise it would
// cause a mem leak since we would never run backprop, which disposes
// the kept tensors.
if (!isTapeOn) {
return;
}
saved = tensors.map(tensor => this.keep(this.clone(tensor)));
};

kernelFunc = () => {
const numDataIdsBefore = this.backend.numDataIds();
out = this.tidy(() => forwardFunc(this.backend, saveFunc));
Expand Down Expand Up @@ -622,6 +641,45 @@ export class Engine implements TensorTracker, DataMover {
return (Array.isArray(out) ? outputs : outputs[0]) as T;
}

/**
* Saves tensors used in forward mode for use in backward mode.
*
* @param tensors the list of tensors to save.
*/
private saveTensorsForBackwardMode(tensors: Tensor[]): Tensor[] {
const saved = tensors.map(tensor => this.keep(this.clone(tensor)));
return saved;
}

/**
* Returns a list of tensors to save for a given gradient calculation.
*
* Returns undefined if their is no registered gradient for this kernel in the
* gradient registry.
*
* @param kernelName name of kernel to look up gradient for.
* @param inputs a map of input tensors.
* @param outputs an array of output tensors from forward mode of kernel.
*/
private getTensorsForGradient(
kernelName: string, inputs: NamedTensorMap,
outputs: Tensor[]): Tensor[]|null {
const gradConfig = getGradient(kernelName);
if (gradConfig != null) {
const inputsToSave: string[] = gradConfig.inputsToSave || [];
const outputsToSave: boolean[] = gradConfig.outputsToSave || [];

const inputTensorsToSave: Tensor[] =
inputsToSave.map((inputName) => inputs[inputName]);
const outputTensorsToSave: Tensor[] =
outputs.filter((_, i) => outputsToSave[i]);
return inputTensorsToSave.concat(outputTensorsToSave);
}
// TODO(yassogba) throw exception here once all runkernelFunc calls with
// inputsToSave/outputsToSave are removed
return null;
}

/**
* Internal method used by public APIs for tensor creation. Makes a new
* tensor with the provided shape, dtype and values. It always
Expand Down
1 change: 1 addition & 0 deletions tfjs-core/src/gradients/Square_grad.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import {Tensor} from '../tensor';

export const squareGradConfig: GradConfig = {
kernelName: Square,
inputsToSave: ['x'],
gradFunc: (dy: Tensor, saved: Tensor[]) => {
const [x] = saved;
return {x: () => dy.mul(x.toFloat().mul(2))};
Expand Down
1 change: 1 addition & 0 deletions tfjs-core/src/gradients/SquaredDifference_grad.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import {Tensor} from '../tensor';

export const squaredDifferenceGradConfig: GradConfig = {
kernelName: SquaredDifference,
inputsToSave: ['a', 'b'],
gradFunc: (dy: Tensor, saved: Tensor[]) => {
const [a, b] = saved;
const two = scalar(2);
Expand Down
2 changes: 2 additions & 0 deletions tfjs-core/src/kernel_registry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ export interface KernelConfig {
/** Config object for registering a gradient in the global registry. */
export interface GradConfig {
kernelName: string;
inputsToSave?: string[];
outputsToSave?: boolean[];
gradFunc: GradFunc;
}

Expand Down
48 changes: 48 additions & 0 deletions tfjs-core/src/kernel_registry_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ describeWithFlags('gradient registry', ALL_ENVS, () => {

tf.registerGradient({
kernelName,
inputsToSave: ['x'],
gradFunc: (dy: tf.Tensor, saved) => {
// Make sure saved input (x) was passed to the gradient function.
expect(saved[0].dataId).toEqual(x.dataId);
Expand All @@ -210,6 +211,53 @@ describeWithFlags('gradient registry', ALL_ENVS, () => {
tf.unregisterGradient(kernelName);
});

it('register a kernel with gradient that specifies outputsToSave and call it',
async () => {
let kernelWasCalled = false;
let gradientWasCalled = false;
const kernelName = 'MyKernel';

const forwardReturnDataId = {};
tf.registerKernel({
kernelName,
backendName: tf.getBackend(),
kernelFunc: () => {
kernelWasCalled = true;
return {
dtype: 'float32',
shape: [3, 3],
dataId: forwardReturnDataId
};
}
});

tf.registerGradient({
kernelName,
outputsToSave: [true],
gradFunc: (dy: tf.Tensor, saved) => {
// Make sure saved output was passed to the gradient function.
expect(saved[0].dataId).toEqual(forwardReturnDataId);
// Make sure dy matches the shape of the output.
expect(dy.shape).toEqual([3, 3]);
gradientWasCalled = true;
return {x: () => tf.fill([2, 2], 3)};
},
});

const gradFunc = tf.grad(
x => tf.engine().runKernel(
kernelName, {x}, {} /* attrs */
) as tf.Tensor);
const x = tf.zeros([2, 2]);
const dx = gradFunc(x);
expect(kernelWasCalled).toBe(true);
expect(gradientWasCalled).toBe(true);
expect(dx.dtype).toBe('float32');
expect(dx.shape).toEqual([2, 2]);
tf.unregisterKernel(kernelName, tf.getBackend());
tf.unregisterGradient(kernelName);
});

it('errors when running non-existent gradient', () => {
const kernelName = 'MyKernel';
const x = tf.zeros([2, 2]);
Expand Down