diff --git a/tfjs-core/src/profiler.ts b/tfjs-core/src/profiler.ts index eb04520f082..76ea7dfbd93 100644 --- a/tfjs-core/src/profiler.ts +++ b/tfjs-core/src/profiler.ts @@ -110,12 +110,14 @@ export class Logger { for (const name in inputs) { const input = inputs[name]; - // The input might be a non-tensor (e.g HTMLImageElement), in which case - // we claim the output shape as input shape. - const inputShape = input.shape || result.shape; - const inputRank = inputShape.length; - inputShapesDescription += - `${name}: ${inputRank}D ${inputRank > 0 ? inputShape : ''} `; + if (input != null) { + // The input might be a non-tensor (e.g HTMLImageElement), in which case + // we claim the output shape as input shape. + const inputShape = input.shape || result.shape; + const inputRank = inputShape.length; + inputShapesDescription += + `${name}: ${inputRank}D ${inputRank > 0 ? inputShape : ''} `; + } } console.log( diff --git a/tfjs-core/src/profiler_test.ts b/tfjs-core/src/profiler_test.ts index f1e180c5a3f..2b8648d8dd2 100644 --- a/tfjs-core/src/profiler_test.ts +++ b/tfjs-core/src/profiler_test.ts @@ -17,7 +17,7 @@ import {BackendTimer, BackendTimingInfo} from './backends/backend'; import * as tf from './index'; -import {describeWithFlags, SYNC_BACKEND_ENVS} from './jasmine_util'; +import {ALL_ENVS, describeWithFlags, SYNC_BACKEND_ENVS} from './jasmine_util'; import {checkComputationForErrors, KernelProfile, Logger, Profiler} from './profiler'; import {Tensor} from './tensor'; import {NamedTensorMap} from './tensor_types'; @@ -217,3 +217,28 @@ describe('profiler.checkComputationForErrors', () => { .toBe(false); }); }); + +describeWithFlags('profiler.Logger', ALL_ENVS, () => { + it('skips logging for undefined input node in input tensor map', () => { + const kernelName = 'FusedConv2D'; + const vals = new Float32Array(1); + const outputs = tf.tensor1d([1]); + const timeMs = 10; + const inputs: NamedTensorMap = { + 'x': tf.tensor1d([1]), + 'filter': tf.tensor1d([1]), + 'bias': tf.tensor1d([1]), + 'preluActivationWeights': undefined + }; + const extraInfo = ''; + const logger = new Logger(); + spyOn(console, 'log'); + const consoleLogSpy = console.log as jasmine.Spy; + + logger.logKernelProfile( + kernelName, outputs, vals, timeMs, inputs, extraInfo); + + expect(consoleLogSpy.calls.first().args) + .not.toContain('preluActivationWeights'); + }); +});