Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions tfjs-core/src/profiler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,14 @@ export class Logger {

for (const name in inputs) {
const input = inputs[name];
// The input might be a non-tensor (e.g HTMLImageElement), in which case
// we claim the output shape as input shape.
const inputShape = input.shape || result.shape;
const inputRank = inputShape.length;
inputShapesDescription +=
`${name}: ${inputRank}D ${inputRank > 0 ? inputShape : ''} `;
if (input != null) {
// The input might be a non-tensor (e.g HTMLImageElement), in which case
// we claim the output shape as input shape.
const inputShape = input.shape || result.shape;
const inputRank = inputShape.length;
inputShapesDescription +=
`${name}: ${inputRank}D ${inputRank > 0 ? inputShape : ''} `;
}
}

console.log(
Expand Down
27 changes: 26 additions & 1 deletion tfjs-core/src/profiler_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import {BackendTimer, BackendTimingInfo} from './backends/backend';
import * as tf from './index';
import {describeWithFlags, SYNC_BACKEND_ENVS} from './jasmine_util';
import {ALL_ENVS, describeWithFlags, SYNC_BACKEND_ENVS} from './jasmine_util';
import {checkComputationForErrors, KernelProfile, Logger, Profiler} from './profiler';
import {Tensor} from './tensor';
import {NamedTensorMap} from './tensor_types';
Expand Down Expand Up @@ -217,3 +217,28 @@ describe('profiler.checkComputationForErrors', () => {
.toBe(false);
});
});

describeWithFlags('profiler.Logger', ALL_ENVS, () => {
it('skips logging for undefined input node in input tensor map', () => {
const kernelName = 'FusedConv2D';
const vals = new Float32Array(1);
const outputs = tf.tensor1d([1]);
const timeMs = 10;
const inputs: NamedTensorMap = {
'x': tf.tensor1d([1]),
'filter': tf.tensor1d([1]),
'bias': tf.tensor1d([1]),
'preluActivationWeights': undefined
};
const extraInfo = '';
const logger = new Logger();
spyOn(console, 'log');
const consoleLogSpy = console.log as jasmine.Spy;

logger.logKernelProfile(
kernelName, outputs, vals, timeMs, inputs, extraInfo);

expect(consoleLogSpy.calls.first().args)
.not.toContain('preluActivationWeights');
});
});