Skip to content

Commit 06b5ec8

Browse files
authored
[core] Account for optional inputs in profiler.Logger (#3769)
BUG
1 parent 188718f commit 06b5ec8

File tree

2 files changed

+34
-7
lines changed

2 files changed

+34
-7
lines changed

tfjs-core/src/profiler.ts

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,14 @@ export class Logger {
110110

111111
for (const name in inputs) {
112112
const input = inputs[name];
113-
// The input might be a non-tensor (e.g HTMLImageElement), in which case
114-
// we claim the output shape as input shape.
115-
const inputShape = input.shape || result.shape;
116-
const inputRank = inputShape.length;
117-
inputShapesDescription +=
118-
`${name}: ${inputRank}D ${inputRank > 0 ? inputShape : ''} `;
113+
if (input != null) {
114+
// The input might be a non-tensor (e.g HTMLImageElement), in which case
115+
// we claim the output shape as input shape.
116+
const inputShape = input.shape || result.shape;
117+
const inputRank = inputShape.length;
118+
inputShapesDescription +=
119+
`${name}: ${inputRank}D ${inputRank > 0 ? inputShape : ''} `;
120+
}
119121
}
120122

121123
console.log(

tfjs-core/src/profiler_test.ts

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import {BackendTimer, BackendTimingInfo} from './backends/backend';
1919
import * as tf from './index';
20-
import {describeWithFlags, SYNC_BACKEND_ENVS} from './jasmine_util';
20+
import {ALL_ENVS, describeWithFlags, SYNC_BACKEND_ENVS} from './jasmine_util';
2121
import {checkComputationForErrors, KernelProfile, Logger, Profiler} from './profiler';
2222
import {Tensor} from './tensor';
2323
import {NamedTensorMap} from './tensor_types';
@@ -217,3 +217,28 @@ describe('profiler.checkComputationForErrors', () => {
217217
.toBe(false);
218218
});
219219
});
220+
221+
describeWithFlags('profiler.Logger', ALL_ENVS, () => {
222+
it('skips logging for undefined input node in input tensor map', () => {
223+
const kernelName = 'FusedConv2D';
224+
const vals = new Float32Array(1);
225+
const outputs = tf.tensor1d([1]);
226+
const timeMs = 10;
227+
const inputs: NamedTensorMap = {
228+
'x': tf.tensor1d([1]),
229+
'filter': tf.tensor1d([1]),
230+
'bias': tf.tensor1d([1]),
231+
'preluActivationWeights': undefined
232+
};
233+
const extraInfo = '';
234+
const logger = new Logger();
235+
spyOn(console, 'log');
236+
const consoleLogSpy = console.log as jasmine.Spy;
237+
238+
logger.logKernelProfile(
239+
kernelName, outputs, vals, timeMs, inputs, extraInfo);
240+
241+
expect(consoleLogSpy.calls.first().args)
242+
.not.toContain('preluActivationWeights');
243+
});
244+
});

0 commit comments

Comments
 (0)