diff --git a/tfjs-core/src/backends/backend.ts b/tfjs-core/src/backends/backend.ts index fd3dbbaab3d..aeb9ab3897a 100644 --- a/tfjs-core/src/backends/backend.ts +++ b/tfjs-core/src/backends/backend.ts @@ -25,7 +25,7 @@ export const EPSILON_FLOAT16 = 1e-4; // Required information for all backends. export interface BackendTimingInfo { - kernelMs: number; + kernelMs: number|{error: string}; getExtraProfileInfo?(): string; // a field for additional timing information // e.g. packing / unpacking for WebGL backend } diff --git a/tfjs-core/src/backends/webgl/backend_webgl.ts b/tfjs-core/src/backends/webgl/backend_webgl.ts index 2bce838b3aa..afa45013a53 100644 --- a/tfjs-core/src/backends/webgl/backend_webgl.ts +++ b/tfjs-core/src/backends/webgl/backend_webgl.ts @@ -520,18 +520,27 @@ export class MathBackendWebGL extends KernelBackend { this.programTimersStack = null; } - const kernelMs = await Promise.all(flattenedActiveTimerQueries); - const res: WebGLTimingInfo = { uploadWaitMs: this.uploadWaitMs, downloadWaitMs: this.downloadWaitMs, - kernelMs: util.sum(kernelMs), - getExtraProfileInfo: () => - kernelMs.map((d, i) => ({name: flattenedActiveTimerNames[i], ms: d})) - .map(d => `${d.name}: ${d.ms}`) - .join(', '), + kernelMs: null, wallMs: null // will be filled by the engine }; + + if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) { + const kernelMs = await Promise.all(flattenedActiveTimerQueries); + + res['kernelMs'] = util.sum(kernelMs); + res['getExtraProfileInfo'] = () => + kernelMs.map((d, i) => ({name: flattenedActiveTimerNames[i], ms: d})) + .map(d => `${d.name}: ${d.ms}`) + .join(', '); + } else { + res['kernelMs'] = { + error: 'WebGL query timers are not supported in this environment.' + }; + } + this.uploadWaitMs = 0; this.downloadWaitMs = 0; return res; @@ -542,14 +551,14 @@ export class MathBackendWebGL extends KernelBackend { } private startTimer(): WebGLQuery|CPUTimerQuery { - if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) { + if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) { return this.gpgpu.beginQuery(); } return {startMs: util.now(), endMs: null}; } private endTimer(query: WebGLQuery|CPUTimerQuery): WebGLQuery|CPUTimerQuery { - if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) { + if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) { this.gpgpu.endQuery(); return query; } @@ -558,7 +567,7 @@ export class MathBackendWebGL extends KernelBackend { } private async getQueryTime(query: WebGLQuery|CPUTimerQuery): Promise { - if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) { + if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) { return this.gpgpu.waitForQueryAndGetTime(query as WebGLQuery); } const timerQuery = query as CPUTimerQuery; diff --git a/tfjs-core/src/globals.ts b/tfjs-core/src/globals.ts index 1e6f7122c7d..dd44d50a336 100644 --- a/tfjs-core/src/globals.ts +++ b/tfjs-core/src/globals.ts @@ -239,7 +239,9 @@ export function keep(result: T): T { * The result is an object with the following properties: * * - `wallMs`: Wall execution time. - * - `kernelMs`: Kernel execution time, ignoring data transfer. + * - `kernelMs`: Kernel execution time, ignoring data transfer. If using the + * WebGL backend and the query timer extension is not available, this will + * return an error object. * - On `WebGL` The following additional properties exist: * - `uploadWaitMs`: CPU blocking time on texture uploads. * - `downloadWaitMs`: CPU blocking time on texture downloads (readPixels). diff --git a/tfjs-core/src/profiler.ts b/tfjs-core/src/profiler.ts index 156336e97ee..e837459d754 100644 --- a/tfjs-core/src/profiler.ts +++ b/tfjs-core/src/profiler.ts @@ -77,9 +77,11 @@ export function checkComputationForErrors( export class Logger { logKernelProfile( - name: string, result: Tensor, vals: TypedArray, timeMs: number, - inputs: NamedTensorMap, extraInfo?: string) { - const time = util.rightPad(`${timeMs}ms`, 9); + name: string, result: Tensor, vals: TypedArray, + timeMs: number|{error: string}, inputs: NamedTensorMap, + extraInfo?: string) { + const time = typeof timeMs === 'number' ? util.rightPad(`${timeMs}ms`, 9) : + timeMs['error']; const paddedName = util.rightPad(name, 25); const rank = result.rank; const size = result.size; diff --git a/tfjs-core/src/profiler_test.ts b/tfjs-core/src/profiler_test.ts index dddae0d0664..0f73753ff87 100644 --- a/tfjs-core/src/profiler_test.ts +++ b/tfjs-core/src/profiler_test.ts @@ -16,6 +16,7 @@ */ import {BackendTimer, BackendTimingInfo} from './backends/backend'; +import {WEBGL_ENVS} from './backends/webgl/backend_webgl_test_registry'; import * as tf from './index'; import {describeWithFlags, SYNC_BACKEND_ENVS} from './jasmine_util'; import {checkComputationForErrors, Logger, Profiler} from './profiler'; @@ -130,6 +131,34 @@ describeWithFlags('profiler.Profiler', SYNC_BACKEND_ENVS, () => { }); }); +describeWithFlags('profiling WebGL', WEBGL_ENVS, () => { + it('If query timer extension is unavailable, profiler ' + + 'reports an error for kernelMs.', + doneFn => { + const savedQueryReliableValue = + tf.env().get('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE'); + tf.env().set('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE', false); + + const logger = new TestLogger(); + const profiler = new Profiler(tf.backend(), logger); + + spyOn(logger, 'logKernelProfile').and.callThrough(); + const logKernelProfileSpy = logger.logKernelProfile as jasmine.Spy; + + profiler.profileKernel( + 'MatMul', {'x': tf.tensor1d([1])}, () => [tf.scalar(1)]); + + setTimeout(() => { + expect(logKernelProfileSpy.calls.first().args[3]['error']) + .toBeDefined(); + tf.env().set( + 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE', + savedQueryReliableValue); + doneFn(); + }, 0); + }); +}); + describe('profiler.checkComputationForErrors', () => { beforeAll(() => { // Silence warnings.