diff --git a/tfjs-core/src/engine.ts b/tfjs-core/src/engine.ts index 16beb1a82d6..3c518d9be37 100644 --- a/tfjs-core/src/engine.ts +++ b/tfjs-core/src/engine.ts @@ -535,8 +535,8 @@ export class Engine implements TensorTracker, DataMover { runKernelFunc( forwardFunc: ForwardFunc, inputs: I, backwardsFunc?: (dy: T, saved: Tensor[]) => {[P in keyof I]: () => I[P]}, - kernelName?: string, attrs?: NamedAttrMap, inputsToSave: Tensor[] = [], - outputsToSave: boolean[] = []): T { + kernelName?: string, attrs?: NamedAttrMap, inputsToSave?: Tensor[], + outputsToSave?: boolean[]): T { let outputs: Tensor[]; let saved: Tensor[] = []; const isTapeOn = this.isTapeOn(); @@ -544,15 +544,6 @@ export class Engine implements TensorTracker, DataMover { kernelName = this.state.activeScope != null ? this.state.activeScope.name : ''; } - const saveFunc: GradSaveFunc = (tensors) => { - // Do not save unless we are recording to the tape. Otherwise it would - // cause a mem leak since we would never run backprop, which disposes - // the kept tensors. - if (!isTapeOn) { - return; - } - saved = tensors.map(tensor => this.keep(this.clone(tensor))); - }; const startingBytecount = this.state.numBytes; const startingNumTensors = this.state.numTensors; @@ -575,12 +566,40 @@ export class Engine implements TensorTracker, DataMover { const outTensors = outInfos.map( ({dataId, shape, dtype}) => this.makeTensorFromDataId(dataId, shape, dtype)); - const outsToSave = outTensors.filter((_, i) => outputsToSave[i]); + // Save the inputs and outputs. - saveFunc((inputsToSave || []).slice().concat(outsToSave)); + // Do not save unless we are recording to the tape. Otherwise it would + // cause a mem leak since we would never run backprop, which disposes + // the kept tensors. + if (isTapeOn) { + let tensorsToSave = + this.getTensorsForGradient(kernelName, inputs, outTensors); + if (tensorsToSave == null) { + // Fallback for ops that call runKernelFunc and pass in + // inputsToSave and outputsToSave. Currently this is the set of ops + // with kernel support in the WASM backend. Once those ops and + // respective gradients are modularised we can remove this path. + if (outputsToSave == null) { + outputsToSave = []; + } + const outsToSave = outTensors.filter((_, i) => outputsToSave[i]); + tensorsToSave = (inputsToSave || []).slice().concat(outsToSave); + } + saved = this.saveTensorsForBackwardMode(tensorsToSave); + } return outTensors; }; } else { + const saveFunc: GradSaveFunc = (tensors) => { + // Do not save unless we are recording to the tape. Otherwise it would + // cause a mem leak since we would never run backprop, which disposes + // the kept tensors. + if (!isTapeOn) { + return; + } + saved = tensors.map(tensor => this.keep(this.clone(tensor))); + }; + kernelFunc = () => { const numDataIdsBefore = this.backend.numDataIds(); out = this.tidy(() => forwardFunc(this.backend, saveFunc)); @@ -622,6 +641,45 @@ export class Engine implements TensorTracker, DataMover { return (Array.isArray(out) ? outputs : outputs[0]) as T; } + /** + * Saves tensors used in forward mode for use in backward mode. + * + * @param tensors the list of tensors to save. + */ + private saveTensorsForBackwardMode(tensors: Tensor[]): Tensor[] { + const saved = tensors.map(tensor => this.keep(this.clone(tensor))); + return saved; + } + + /** + * Returns a list of tensors to save for a given gradient calculation. + * + * Returns undefined if their is no registered gradient for this kernel in the + * gradient registry. + * + * @param kernelName name of kernel to look up gradient for. + * @param inputs a map of input tensors. + * @param outputs an array of output tensors from forward mode of kernel. + */ + private getTensorsForGradient( + kernelName: string, inputs: NamedTensorMap, + outputs: Tensor[]): Tensor[]|null { + const gradConfig = getGradient(kernelName); + if (gradConfig != null) { + const inputsToSave: string[] = gradConfig.inputsToSave || []; + const outputsToSave: boolean[] = gradConfig.outputsToSave || []; + + const inputTensorsToSave: Tensor[] = + inputsToSave.map((inputName) => inputs[inputName]); + const outputTensorsToSave: Tensor[] = + outputs.filter((_, i) => outputsToSave[i]); + return inputTensorsToSave.concat(outputTensorsToSave); + } + // TODO(yassogba) throw exception here once all runkernelFunc calls with + // inputsToSave/outputsToSave are removed + return null; + } + /** * Internal method used by public APIs for tensor creation. Makes a new * tensor with the provided shape, dtype and values. It always diff --git a/tfjs-core/src/gradients/Square_grad.ts b/tfjs-core/src/gradients/Square_grad.ts index 21e6c071adc..6aad3076e84 100644 --- a/tfjs-core/src/gradients/Square_grad.ts +++ b/tfjs-core/src/gradients/Square_grad.ts @@ -21,6 +21,7 @@ import {Tensor} from '../tensor'; export const squareGradConfig: GradConfig = { kernelName: Square, + inputsToSave: ['x'], gradFunc: (dy: Tensor, saved: Tensor[]) => { const [x] = saved; return {x: () => dy.mul(x.toFloat().mul(2))}; diff --git a/tfjs-core/src/gradients/SquaredDifference_grad.ts b/tfjs-core/src/gradients/SquaredDifference_grad.ts index 0d0e668d1ce..f34ede8d2d3 100644 --- a/tfjs-core/src/gradients/SquaredDifference_grad.ts +++ b/tfjs-core/src/gradients/SquaredDifference_grad.ts @@ -23,6 +23,7 @@ import {Tensor} from '../tensor'; export const squaredDifferenceGradConfig: GradConfig = { kernelName: SquaredDifference, + inputsToSave: ['a', 'b'], gradFunc: (dy: Tensor, saved: Tensor[]) => { const [a, b] = saved; const two = scalar(2); diff --git a/tfjs-core/src/kernel_registry.ts b/tfjs-core/src/kernel_registry.ts index 5cd8beb487f..3997636ca3d 100644 --- a/tfjs-core/src/kernel_registry.ts +++ b/tfjs-core/src/kernel_registry.ts @@ -59,6 +59,8 @@ export interface KernelConfig { /** Config object for registering a gradient in the global registry. */ export interface GradConfig { kernelName: string; + inputsToSave?: string[]; + outputsToSave?: boolean[]; gradFunc: GradFunc; } diff --git a/tfjs-core/src/kernel_registry_test.ts b/tfjs-core/src/kernel_registry_test.ts index efdfc63604d..b7afde14e8e 100644 --- a/tfjs-core/src/kernel_registry_test.ts +++ b/tfjs-core/src/kernel_registry_test.ts @@ -186,6 +186,7 @@ describeWithFlags('gradient registry', ALL_ENVS, () => { tf.registerGradient({ kernelName, + inputsToSave: ['x'], gradFunc: (dy: tf.Tensor, saved) => { // Make sure saved input (x) was passed to the gradient function. expect(saved[0].dataId).toEqual(x.dataId); @@ -210,6 +211,53 @@ describeWithFlags('gradient registry', ALL_ENVS, () => { tf.unregisterGradient(kernelName); }); + it('register a kernel with gradient that specifies outputsToSave and call it', + async () => { + let kernelWasCalled = false; + let gradientWasCalled = false; + const kernelName = 'MyKernel'; + + const forwardReturnDataId = {}; + tf.registerKernel({ + kernelName, + backendName: tf.getBackend(), + kernelFunc: () => { + kernelWasCalled = true; + return { + dtype: 'float32', + shape: [3, 3], + dataId: forwardReturnDataId + }; + } + }); + + tf.registerGradient({ + kernelName, + outputsToSave: [true], + gradFunc: (dy: tf.Tensor, saved) => { + // Make sure saved output was passed to the gradient function. + expect(saved[0].dataId).toEqual(forwardReturnDataId); + // Make sure dy matches the shape of the output. + expect(dy.shape).toEqual([3, 3]); + gradientWasCalled = true; + return {x: () => tf.fill([2, 2], 3)}; + }, + }); + + const gradFunc = tf.grad( + x => tf.engine().runKernel( + kernelName, {x}, {} /* attrs */ + ) as tf.Tensor); + const x = tf.zeros([2, 2]); + const dx = gradFunc(x); + expect(kernelWasCalled).toBe(true); + expect(gradientWasCalled).toBe(true); + expect(dx.dtype).toBe('float32'); + expect(dx.shape).toEqual([2, 2]); + tf.unregisterKernel(kernelName, tf.getBackend()); + tf.unregisterGradient(kernelName); + }); + it('errors when running non-existent gradient', () => { const kernelName = 'MyKernel'; const x = tf.zeros([2, 2]);