From 6329b2b2cd27cf1c0e4ca711cd9e4d3d1120d8de Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Thu, 30 Jul 2020 11:44:15 -0400 Subject: [PATCH 1/2] add --- tfjs-core/src/engine.ts | 3 ++- tfjs-core/src/profiler_test.ts | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/tfjs-core/src/engine.ts b/tfjs-core/src/engine.ts index b2d26a1be8b..9b7f974a2af 100644 --- a/tfjs-core/src/engine.ts +++ b/tfjs-core/src/engine.ts @@ -635,7 +635,8 @@ export class Engine implements TensorTracker, DataMover { totalBytesSnapshot: this.state.numBytes, tensorsAdded: this.state.numTensors - startingNumTensors, totalTensorsSnapshot: this.state.numTensors, - inputShapes: Object.keys(inputs).map(key => inputs[key].shape), + inputShapes: Object.keys(inputs).map( + key => inputs[key] != null ? inputs[key].shape : []), outputShapes: outputs.map(item => item.shape) }); } diff --git a/tfjs-core/src/profiler_test.ts b/tfjs-core/src/profiler_test.ts index 7d49cfedcaa..e75105c32fb 100644 --- a/tfjs-core/src/profiler_test.ts +++ b/tfjs-core/src/profiler_test.ts @@ -82,10 +82,11 @@ describeWithFlags('profiler.Profiler', SYNC_BACKEND_ENVS, () => { }, delayMs * 2); }); - it('profiles nested kernel', doneFn => { + it('profiles nested kernel with optional inputs', doneFn => { const delayMs = 5; const queryTimeMs = 10; - const inputs = {'x': tf.tensor1d([1])}; + const inputs: {'x': tf.Tensor, + 'bias': null} = {'x': tf.tensor1d([1]), 'bias': null}; const extraInfo = ''; const timer = new TestBackendTimer(delayMs, queryTimeMs, extraInfo); const logger = new TestLogger(); From 84bd010cfbd044b1e668cf62e18d0e31e3d85bdd Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Thu, 30 Jul 2020 13:19:27 -0400 Subject: [PATCH 2/2] fix --- tfjs-core/src/engine.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tfjs-core/src/engine.ts b/tfjs-core/src/engine.ts index 9b7f974a2af..17e5f58c81c 100644 --- a/tfjs-core/src/engine.ts +++ b/tfjs-core/src/engine.ts @@ -636,7 +636,7 @@ export class Engine implements TensorTracker, DataMover { tensorsAdded: this.state.numTensors - startingNumTensors, totalTensorsSnapshot: this.state.numTensors, inputShapes: Object.keys(inputs).map( - key => inputs[key] != null ? inputs[key].shape : []), + key => inputs[key] != null ? inputs[key].shape : null), outputShapes: outputs.map(item => item.shape) }); }