Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tfjs-core/src/backends/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ export const EPSILON_FLOAT16 = 1e-4;

// Required information for all backends.
export interface BackendTimingInfo {
kernelMs: number;
kernelMs: number|{error: string};
getExtraProfileInfo?(): string; // a field for additional timing information
// e.g. packing / unpacking for WebGL backend
}
Expand Down
29 changes: 19 additions & 10 deletions tfjs-core/src/backends/webgl/backend_webgl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -520,18 +520,27 @@ export class MathBackendWebGL extends KernelBackend {
this.programTimersStack = null;
}

const kernelMs = await Promise.all(flattenedActiveTimerQueries);

const res: WebGLTimingInfo = {
uploadWaitMs: this.uploadWaitMs,
downloadWaitMs: this.downloadWaitMs,
kernelMs: util.sum(kernelMs),
getExtraProfileInfo: () =>
kernelMs.map((d, i) => ({name: flattenedActiveTimerNames[i], ms: d}))
.map(d => `${d.name}: ${d.ms}`)
.join(', '),
kernelMs: null,
wallMs: null // will be filled by the engine
};

if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) {
const kernelMs = await Promise.all(flattenedActiveTimerQueries);

res['kernelMs'] = util.sum(kernelMs);
res['getExtraProfileInfo'] = () =>
kernelMs.map((d, i) => ({name: flattenedActiveTimerNames[i], ms: d}))
.map(d => `${d.name}: ${d.ms}`)
.join(', ');
} else {
res['kernelMs'] = {
error: 'WebGL query timers are not supported in this environment.'
};
}

this.uploadWaitMs = 0;
this.downloadWaitMs = 0;
return res;
Expand All @@ -542,14 +551,14 @@ export class MathBackendWebGL extends KernelBackend {
}

private startTimer(): WebGLQuery|CPUTimerQuery {
if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) {
if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) {
return this.gpgpu.beginQuery();
}
return {startMs: util.now(), endMs: null};
}

private endTimer(query: WebGLQuery|CPUTimerQuery): WebGLQuery|CPUTimerQuery {
if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) {
if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) {
this.gpgpu.endQuery();
return query;
}
Expand All @@ -558,7 +567,7 @@ export class MathBackendWebGL extends KernelBackend {
}

private async getQueryTime(query: WebGLQuery|CPUTimerQuery): Promise<number> {
if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) {
if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) {
return this.gpgpu.waitForQueryAndGetTime(query as WebGLQuery);
}
const timerQuery = query as CPUTimerQuery;
Expand Down
4 changes: 3 additions & 1 deletion tfjs-core/src/globals.ts
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,9 @@ export function keep<T extends Tensor>(result: T): T {
* The result is an object with the following properties:
*
* - `wallMs`: Wall execution time.
* - `kernelMs`: Kernel execution time, ignoring data transfer.
* - `kernelMs`: Kernel execution time, ignoring data transfer. If using the
* WebGL backend and the query timer extension is not available, this will
* return an error object.
* - On `WebGL` The following additional properties exist:
* - `uploadWaitMs`: CPU blocking time on texture uploads.
* - `downloadWaitMs`: CPU blocking time on texture downloads (readPixels).
Expand Down
8 changes: 5 additions & 3 deletions tfjs-core/src/profiler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,11 @@ export function checkComputationForErrors<D extends DataType>(

export class Logger {
logKernelProfile(
name: string, result: Tensor, vals: TypedArray, timeMs: number,
inputs: NamedTensorMap, extraInfo?: string) {
const time = util.rightPad(`${timeMs}ms`, 9);
name: string, result: Tensor, vals: TypedArray,
timeMs: number|{error: string}, inputs: NamedTensorMap,
extraInfo?: string) {
const time = typeof timeMs === 'number' ? util.rightPad(`${timeMs}ms`, 9) :
timeMs['error'];
const paddedName = util.rightPad(name, 25);
const rank = result.rank;
const size = result.size;
Expand Down
29 changes: 29 additions & 0 deletions tfjs-core/src/profiler_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/

import {BackendTimer, BackendTimingInfo} from './backends/backend';
import {WEBGL_ENVS} from './backends/webgl/backend_webgl_test_registry';
import * as tf from './index';
import {describeWithFlags, SYNC_BACKEND_ENVS} from './jasmine_util';
import {checkComputationForErrors, Logger, Profiler} from './profiler';
Expand Down Expand Up @@ -130,6 +131,34 @@ describeWithFlags('profiler.Profiler', SYNC_BACKEND_ENVS, () => {
});
});

describeWithFlags('profiling WebGL', WEBGL_ENVS, () => {
it('If query timer extension is unavailable, profiler ' +
'reports an error for kernelMs.',
doneFn => {
const savedQueryReliableValue =
tf.env().get('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE');
tf.env().set('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE', false);

const logger = new TestLogger();
const profiler = new Profiler(tf.backend(), logger);

spyOn(logger, 'logKernelProfile').and.callThrough();
const logKernelProfileSpy = logger.logKernelProfile as jasmine.Spy;

profiler.profileKernel(
'MatMul', {'x': tf.tensor1d([1])}, () => [tf.scalar(1)]);

setTimeout(() => {
expect(logKernelProfileSpy.calls.first().args[3]['error'])
.toBeDefined();
tf.env().set(
'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE',
savedQueryReliableValue);
doneFn();
}, 0);
});
});

describe('profiler.checkComputationForErrors', () => {
beforeAll(() => {
// Silence warnings.
Expand Down