diff --git a/tfjs-core/src/engine.ts b/tfjs-core/src/engine.ts index 07d8f56d5c6..6d6c309df1a 100644 --- a/tfjs-core/src/engine.ts +++ b/tfjs-core/src/engine.ts @@ -58,6 +58,8 @@ type KernelInfo = { totalTensorsSnapshot: number; inputShapes: number[][]; outputShapes: number[][]; + kernelTimeMs: number | {error: string} | Promise; + extraInfo: string | Promise; }; export type ProfileInfo = { @@ -625,15 +627,17 @@ export class Engine implements TensorTracker, DataMover { } // Stop recording to a tape when running a kernel. + let kernelProfile: KernelProfile; this.scopedRun( () => this.state.kernelDepth++, () => this.state.kernelDepth--, () => { - if (!this.ENV.getBool('DEBUG')) { + if (!this.ENV.getBool('DEBUG') && !this.state.profiling) { outputs = kernelFunc(); } else { - let kernelProfile: KernelProfile; kernelProfile = this.profiler.profileKernel( kernelName, inputs, () => kernelFunc()); - this.profiler.logKernelProfile(kernelProfile); + if (this.ENV.getBool('DEBUG')) { + this.profiler.logKernelProfile(kernelProfile); + } outputs = kernelProfile.outputs; } }); @@ -652,7 +656,9 @@ export class Engine implements TensorTracker, DataMover { totalTensorsSnapshot: this.state.numTensors, inputShapes: Object.keys(inputs).map( key => inputs[key] != null ? inputs[key].shape : null), - outputShapes: outputs.map(item => item.shape) + outputShapes: outputs.map(item => item.shape), + kernelTimeMs: kernelProfile.timeMs, + extraInfo: kernelProfile.extraInfo }); } return (Array.isArray(out) ? outputs : outputs[0]) as T; @@ -878,6 +884,10 @@ export class Engine implements TensorTracker, DataMover { this.state.activeProfile.newBytes = this.state.numBytes - startBytes; this.state.activeProfile.newTensors = this.state.numTensors - startNumTensors; + for (const kernel of this.state.activeProfile.kernels) { + kernel.kernelTimeMs = await kernel.kernelTimeMs; + kernel.extraInfo = await kernel.extraInfo; + } return this.state.activeProfile; } diff --git a/tfjs-core/src/engine_test.ts b/tfjs-core/src/engine_test.ts index 937c69485d6..c35369f28fa 100644 --- a/tfjs-core/src/engine_test.ts +++ b/tfjs-core/src/engine_test.ts @@ -387,26 +387,40 @@ describeWithFlags('profile', ALL_ENVS, () => { expect(profile.peakBytes).toBe(24); expect(profile.newTensors).toBe(1); expectArraysClose(await result.data(), [1, 2, 3]); - expect(profile.kernels).toEqual([ - { - 'name': 'Square', - 'bytesAdded': 12, - 'totalBytesSnapshot': 24, - 'tensorsAdded': 1, - 'totalTensorsSnapshot': 2, - 'inputShapes': [[3]], - 'outputShapes': [[3]] - }, - { - 'name': 'Square', - 'bytesAdded': 12, - 'totalBytesSnapshot': 24, - 'tensorsAdded': 1, - 'totalTensorsSnapshot': 2, - 'inputShapes': [[3]], - 'outputShapes': [[3]] - } - ]); + expect(profile.kernels.length).toBe(2); + + // Test the types for `kernelTimeMs` and `extraInfo` to confirm the promises + // are resolved. + expect(typeof profile.kernels[0].kernelTimeMs).toBe('number'); + expect(typeof profile.kernels[0].extraInfo).toBe('string'); + expect(typeof profile.kernels[1].kernelTimeMs).toBe('number'); + expect(typeof profile.kernels[1].extraInfo).toBe('string'); + + // The specific values of `kernelTimeMs` and `extraInfo` are tested in the + // tests of Profiler.profileKernel, so their values are not tested here. + expect(profile.kernels[0]).toEqual({ + 'name': 'Square', + 'bytesAdded': 12, + 'totalBytesSnapshot': 24, + 'tensorsAdded': 1, + 'totalTensorsSnapshot': 2, + 'inputShapes': [[3]], + 'outputShapes': [[3]], + 'kernelTimeMs': profile.kernels[0].kernelTimeMs, + 'extraInfo': profile.kernels[0].extraInfo + }); + + expect(profile.kernels[1]).toEqual({ + 'name': 'Square', + 'bytesAdded': 12, + 'totalBytesSnapshot': 24, + 'tensorsAdded': 1, + 'totalTensorsSnapshot': 2, + 'inputShapes': [[3]], + 'outputShapes': [[3]], + 'kernelTimeMs': profile.kernels[1].kernelTimeMs, + 'extraInfo': profile.kernels[1].extraInfo + }); }); it('squaring without disposing', async () => { @@ -422,15 +436,20 @@ describeWithFlags('profile', ALL_ENVS, () => { expect(profile.peakBytes).toBe(24); expect(profile.newTensors).toBe(2); expectArraysClose(await result.data(), [1, 4, 9]); - expect(profile.kernels).toEqual([{ + expect(profile.kernels.length).toBe(1); + expect(typeof profile.kernels[0].kernelTimeMs).toBe('number'); + expect(typeof profile.kernels[0].extraInfo).toBe('string'); + expect(profile.kernels[0]).toEqual({ 'name': 'Square', 'bytesAdded': 12, 'totalBytesSnapshot': 24, 'tensorsAdded': 1, 'totalTensorsSnapshot': 2, 'inputShapes': [[3]], - 'outputShapes': [[3]] - }]); + 'outputShapes': [[3]], + 'kernelTimeMs': profile.kernels[0].kernelTimeMs, + 'extraInfo': profile.kernels[0].extraInfo + }); }); it('squaring in async query', async () => { @@ -448,15 +467,20 @@ describeWithFlags('profile', ALL_ENVS, () => { expect(profile.peakBytes).toBe(24); expect(profile.newTensors).toBe(1); expectArraysClose(await result.data(), [1, 2, 3]); - expect(profile.kernels).toEqual([{ + expect(profile.kernels.length).toBe(1); + expect(typeof profile.kernels[0].kernelTimeMs).toBe('number'); + expect(typeof profile.kernels[0].extraInfo).toBe('string'); + expect(profile.kernels[0]).toEqual({ 'name': 'Square', 'bytesAdded': 12, 'totalBytesSnapshot': 24, 'tensorsAdded': 1, 'totalTensorsSnapshot': 2, 'inputShapes': [[3]], - 'outputShapes': [[3]] - }]); + 'outputShapes': [[3]], + 'kernelTimeMs': profile.kernels[0].kernelTimeMs, + 'extraInfo': profile.kernels[0].extraInfo + }); }); });