diff --git a/tfjs-core/src/engine.ts b/tfjs-core/src/engine.ts index b2d26a1be8b..17e5f58c81c 100644 --- a/tfjs-core/src/engine.ts +++ b/tfjs-core/src/engine.ts @@ -635,7 +635,8 @@ export class Engine implements TensorTracker, DataMover { totalBytesSnapshot: this.state.numBytes, tensorsAdded: this.state.numTensors - startingNumTensors, totalTensorsSnapshot: this.state.numTensors, - inputShapes: Object.keys(inputs).map(key => inputs[key].shape), + inputShapes: Object.keys(inputs).map( + key => inputs[key] != null ? inputs[key].shape : null), outputShapes: outputs.map(item => item.shape) }); } diff --git a/tfjs-core/src/profiler_test.ts b/tfjs-core/src/profiler_test.ts index 7d49cfedcaa..e75105c32fb 100644 --- a/tfjs-core/src/profiler_test.ts +++ b/tfjs-core/src/profiler_test.ts @@ -82,10 +82,11 @@ describeWithFlags('profiler.Profiler', SYNC_BACKEND_ENVS, () => { }, delayMs * 2); }); - it('profiles nested kernel', doneFn => { + it('profiles nested kernel with optional inputs', doneFn => { const delayMs = 5; const queryTimeMs = 10; - const inputs = {'x': tf.tensor1d([1])}; + const inputs: {'x': tf.Tensor, + 'bias': null} = {'x': tf.tensor1d([1]), 'bias': null}; const extraInfo = ''; const timer = new TestBackendTimer(delayMs, queryTimeMs, extraInfo); const logger = new TestLogger();