diff --git a/demos/benchmarks/conv_benchmarks.ts b/demos/benchmarks/conv_benchmarks.ts index e3beeff75e..870caaccd6 100644 --- a/demos/benchmarks/conv_benchmarks.ts +++ b/demos/benchmarks/conv_benchmarks.ts @@ -80,8 +80,8 @@ export class ConvGPUBenchmark extends ConvBenchmark { gpgpu.dispose(); }; - if (ENV.get('WEBGL_DISJOINT_QUERY_TIMER')) { - gpgpu.runBenchmark(benchmark).then((timeElapsed: number) => { + if (ENV.get('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE')) { + gpgpu.runQuery(benchmark).then((timeElapsed: number) => { delayedCleanup(); resolve(timeElapsed); }); diff --git a/demos/benchmarks/conv_transposed_benchmarks.ts b/demos/benchmarks/conv_transposed_benchmarks.ts index 9bdb4c34e1..bb5791d622 100644 --- a/demos/benchmarks/conv_transposed_benchmarks.ts +++ b/demos/benchmarks/conv_transposed_benchmarks.ts @@ -81,8 +81,8 @@ export class ConvTransposedGPUBenchmark extends ConvTransposedBenchmark { gpgpu.dispose(); }; - if (ENV.get('WEBGL_DISJOINT_QUERY_TIMER')) { - gpgpu.runBenchmark(benchmark).then((timeElapsed: number) => { + if (ENV.get('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE')) { + gpgpu.runQuery(benchmark).then((timeElapsed: number) => { delayedCleanup(); resolve(timeElapsed); }); diff --git a/demos/benchmarks/logsumexp_benchmarks.ts b/demos/benchmarks/logsumexp_benchmarks.ts index afed074153..a9add49e86 100644 --- a/demos/benchmarks/logsumexp_benchmarks.ts +++ b/demos/benchmarks/logsumexp_benchmarks.ts @@ -66,8 +66,8 @@ export class LogSumExpGPUBenchmark extends BenchmarkTest { gpgpu.dispose(); }; - if (ENV.get('WEBGL_DISJOINT_QUERY_TIMER')) { - gpgpu.runBenchmark(benchmark).then((timeElapsed: number) => { + if (ENV.get('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE')) { + gpgpu.runQuery(benchmark).then((timeElapsed: number) => { delayedCleanup(); resolve(timeElapsed); }); diff --git a/demos/benchmarks/matmul_benchmarks.ts b/demos/benchmarks/matmul_benchmarks.ts index 019223c845..60f61b61d2 100644 --- a/demos/benchmarks/matmul_benchmarks.ts +++ b/demos/benchmarks/matmul_benchmarks.ts @@ -80,8 +80,8 @@ export class MatmulGPUBenchmark extends BenchmarkTest { gpgpu.dispose(); }; - if (ENV.get('WEBGL_DISJOINT_QUERY_TIMER')) { - gpgpu.runBenchmark(benchmark).then((timeElapsed: number) => { + if (ENV.get('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE')) { + gpgpu.runQuery(benchmark).then((timeElapsed: number) => { delayedCleanup(); resolve(timeElapsed); }); diff --git a/demos/benchmarks/pool_benchmarks.ts b/demos/benchmarks/pool_benchmarks.ts index a1322ab2d7..2ddda60430 100644 --- a/demos/benchmarks/pool_benchmarks.ts +++ b/demos/benchmarks/pool_benchmarks.ts @@ -96,8 +96,8 @@ export class PoolGPUBenchmark extends PoolBenchmark { gpgpu.dispose(); }; - if (ENV.get('WEBGL_DISJOINT_QUERY_TIMER')) { - gpgpu.runBenchmark(benchmark).then((timeElapsed: number) => { + if (ENV.get('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE')) { + gpgpu.runQuery(benchmark).then((timeElapsed: number) => { delayedCleanup(); resolve(timeElapsed); }); diff --git a/demos/benchmarks/reduction_ops_benchmark.ts b/demos/benchmarks/reduction_ops_benchmark.ts index fce53d82a4..9f71c98101 100644 --- a/demos/benchmarks/reduction_ops_benchmark.ts +++ b/demos/benchmarks/reduction_ops_benchmark.ts @@ -75,8 +75,8 @@ export class ReductionOpsGPUBenchmark extends ReductionOpsBenchmark { math.dispose(); }; - if (ENV.get('WEBGL_DISJOINT_QUERY_TIMER')) { - math.getGPGPUContext().runBenchmark(benchmark).then( + if (ENV.get('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE')) { + math.getGPGPUContext().runQuery(benchmark).then( (timeElapsed: number) => { delayedCleanup(); resolve(timeElapsed); diff --git a/demos/benchmarks/unary_ops_benchmark.ts b/demos/benchmarks/unary_ops_benchmark.ts index 727a507564..ea0c6a7859 100644 --- a/demos/benchmarks/unary_ops_benchmark.ts +++ b/demos/benchmarks/unary_ops_benchmark.ts @@ -103,8 +103,8 @@ export class UnaryOpsGPUBenchmark extends UnaryOpsBenchmark { math.dispose(); }; - if (ENV.get('WEBGL_DISJOINT_QUERY_TIMER')) { - math.getGPGPUContext().runBenchmark(benchmark).then( + if (ENV.get('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE')) { + math.getGPGPUContext().runQuery(benchmark).then( (timeElapsed: number) => { delayedCleanup(); resolve(timeElapsed); diff --git a/src/environment.ts b/src/environment.ts index 3cdd505560..09332026fb 100644 --- a/src/environment.ts +++ b/src/environment.ts @@ -24,12 +24,18 @@ export enum Type { } export interface Features { - 'WEBGL_DISJOINT_QUERY_TIMER'?: boolean; + // Whether the disjoint_query_timer extension is an available extension. + 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_ENABLED'?: boolean; + // Whether the timer object from the disjoint_query_timer extension gives + // timing information that is reliable. + 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE'?: boolean; + // 0: No WebGL, 1: WebGL 1.0, 2: WebGL 2.0. 'WEBGL_VERSION'?: number; } export const URL_PROPERTIES: URLProperty[] = [ - {name: 'WEBGL_DISJOINT_QUERY_TIMER', type: Type.BOOLEAN}, + {name: 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_ENABLED', type: Type.BOOLEAN}, + {name: 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE', type: Type.BOOLEAN}, {name: 'WEBGL_VERSION', type: Type.NUMBER} ]; @@ -38,9 +44,21 @@ export interface URLProperty { type: Type; } -function isWebGL2Enabled() { +function getWebGLRenderingContext(webGLVersion: number): WebGLRenderingContext { + if (webGLVersion === 0) { + throw new Error('Cannot get WebGL rendering context, WebGL is disabled.'); + } + const tempCanvas = document.createElement('canvas'); - const gl = tempCanvas.getContext('webgl2') as WebGLRenderingContext; + if (webGLVersion === 1) { + return (tempCanvas.getContext('webgl') || + tempCanvas.getContext('experimental-webgl')) as + WebGLRenderingContext; + } + return tempCanvas.getContext('webgl2') as WebGLRenderingContext; +} + +function loseContext(gl: WebGLRenderingContext) { if (gl != null) { const loseContextExtension = gl.getExtension('WEBGL_lose_context'); if (loseContextExtension == null) { @@ -48,40 +66,29 @@ function isWebGL2Enabled() { 'Extension WEBGL_lose_context not supported on this browser.'); } loseContextExtension.loseContext(); - return true; } - return false; } -function isWebGL1Enabled() { - const tempCanvas = document.createElement('canvas'); - const gl = - (tempCanvas.getContext('webgl') || - tempCanvas.getContext('experimental-webgl')) as WebGLRenderingContext; +function isWebGLVersionEnabled(webGLVersion: 1|2) { + const gl = getWebGLRenderingContext(webGLVersion); if (gl != null) { - const loseContextExtension = gl.getExtension('WEBGL_lose_context'); - if (loseContextExtension == null) { - throw new Error( - 'Extension WEBGL_lose_context not supported on this browser.'); - } - loseContextExtension.loseContext(); + loseContext(gl); return true; } return false; } -function evaluateFeature(feature: K): Features[K] { - if (feature === 'WEBGL_DISJOINT_QUERY_TIMER') { - return !device_util.isMobile(); - } else if (feature === 'WEBGL_VERSION') { - if (isWebGL2Enabled()) { - return 2; - } else if (isWebGL1Enabled()) { - return 1; - } - return 0; +function isWebGLDisjointQueryTimerEnabled(webGLVersion: number) { + const gl = getWebGLRenderingContext(webGLVersion); + + const extensionName = webGLVersion === 1 ? 'EXT_disjoint_timer_query' : + 'EXT_disjoint_timer_query_webgl2'; + const ext = gl.getExtension(extensionName); + const isExtEnabled = ext != null; + if (gl != null) { + loseContext(gl); } - throw new Error(`Unknown feature ${feature}.`); + return isExtEnabled; } export class Environment { @@ -98,10 +105,33 @@ export class Environment { return this.features[feature]; } - this.features[feature] = evaluateFeature(feature); + this.features[feature] = this.evaluateFeature(feature); return this.features[feature]; } + + private evaluateFeature(feature: K): Features[K] { + if (feature === 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_ENABLED') { + const webGLVersion = this.get('WEBGL_VERSION'); + + if (webGLVersion === 0) { + return false; + } + + return isWebGLDisjointQueryTimerEnabled(webGLVersion); + } else if (feature === 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') { + return this.get('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_ENABLED') && + !device_util.isMobile(); + } else if (feature === 'WEBGL_VERSION') { + if (isWebGLVersionEnabled(2)) { + return 2; + } else if (isWebGLVersionEnabled(1)) { + return 1; + } + return 0; + } + throw new Error(`Unknown feature ${feature}.`); + } } // Expects flags from URL in the format ?dljsflags=FLAG1:1,FLAG2:true. diff --git a/src/environment_test.ts b/src/environment_test.ts index b66b00b368..8a53ce0212 100644 --- a/src/environment_test.ts +++ b/src/environment_test.ts @@ -15,23 +15,101 @@ * ============================================================================= */ import * as device_util from './device_util'; -import {Environment} from './environment'; +import {Environment, Features} from './environment'; -describe('disjoint query timer', () => { - it('mobile', () => { +describe('disjoint query timer enabled', () => { + it('no webgl', () => { + const features: Features = {'WEBGL_VERSION': 0}; + + const env = new Environment(features); + + expect(env.get('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_ENABLED')).toBe(false); + }); + + it('webgl 1', () => { + const features: Features = {'WEBGL_VERSION': 1}; + + spyOn(document, 'createElement').and.returnValue({ + getContext: (context: string) => { + if (context === 'webgl' || context === 'experimental-webgl') { + return { + getExtension: (extensionName: string) => { + if (extensionName === 'EXT_disjoint_timer_query') { + return {}; + } else if (extensionName === 'WEBGL_lose_context') { + return {loseContext: () => {}}; + } + return null; + } + }; + } + return null; + } + }); + + const env = new Environment(features); + + expect(env.get('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_ENABLED')).toBe(true); + }); + + it('webgl 2', () => { + const features: Features = {'WEBGL_VERSION': 2}; + + spyOn(document, 'createElement').and.returnValue({ + getContext: (context: string) => { + if (context === 'webgl2') { + return { + getExtension: (extensionName: string) => { + if (extensionName === 'EXT_disjoint_timer_query_webgl2') { + return {}; + } else if (extensionName === 'WEBGL_lose_context') { + return {loseContext: () => {}}; + } + return null; + } + }; + } + return null; + } + }); + + const env = new Environment(features); + + expect(env.get('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_ENABLED')).toBe(true); + }); + + +}); +describe('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE', () => { + it('disjoint query timer disabled', () => { + const features: + Features = {'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_ENABLED': false}; + + const env = new Environment(features); + + expect(env.get('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE')) + .toBe(false); + }); + + it('disjoint query timer enabled, mobile', () => { + const features: + Features = {'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_ENABLED': true}; spyOn(device_util, 'isMobile').and.returnValue(true); - const env = new Environment(); + const env = new Environment(features); - expect(env.get('WEBGL_DISJOINT_QUERY_TIMER')).toBe(false); + expect(env.get('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE')) + .toBe(false); }); - it('not mobile', () => { + it('disjoint query timer enabled, not mobile', () => { + const features: + Features = {'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_ENABLED': true}; spyOn(device_util, 'isMobile').and.returnValue(false); - const env = new Environment(); + const env = new Environment(features); - expect(env.get('WEBGL_DISJOINT_QUERY_TIMER')).toBe(true); + expect(env.get('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE')).toBe(true); }); }); diff --git a/src/math/ndarray.ts b/src/math/ndarray.ts index f9b140dc81..508b4f5c55 100644 --- a/src/math/ndarray.ts +++ b/src/math/ndarray.ts @@ -15,6 +15,7 @@ * ============================================================================= */ +import {ENV} from '../environment'; import * as util from '../util'; import {GPGPUContext} from './webgl/gpgpu_context'; @@ -243,6 +244,27 @@ export class NDArray { return this.data.values; } + getValuesAsync(): Promise { + return new Promise((resolve, reject) => { + if (this.data.values != null) { + resolve(this.data.values); + return; + } + + if (!ENV.get('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_ENABLED')) { + resolve(this.getValues()); + return; + } + + // Construct an empty query. We're just interested in getting a callback + // when the GPU command queue has executed until this point in time. + const queryFn = () => {}; + GPGPU.runQuery(queryFn).then(() => { + resolve(this.getValues()); + }); + }); + } + private uploadToGPU(preferredTexShape?: [number, number]) { throwIfGPUNotInitialized(); this.data.textureShapeRC = webgl_util.getTextureShapeFromLogicalShape( diff --git a/src/math/ndarray_test.ts b/src/math/ndarray_test.ts index 5ca7ce01e7..ff65602dca 100644 --- a/src/math/ndarray_test.ts +++ b/src/math/ndarray_test.ts @@ -127,7 +127,7 @@ describe('NDArray', () => { expect(a.get(1, 2)).toBe(10); }); - it('NDArray CPU --> GPU', () => { + it('NDArray getValues CPU --> GPU', () => { const a = Array2D.new([3, 2], [1, 2, 3, 4, 5, 6]); expect(a.inGPU()).toBe(false); @@ -143,7 +143,7 @@ describe('NDArray', () => { a.dispose(); }); - it('NDArray GPU --> CPU', () => { + it('NDArray getValues GPU --> CPU', () => { const texture = textureManager.acquireTexture([3, 2]); gpgpu.uploadMatrixToTexture( texture, 3, 2, new Float32Array([1, 2, 3, 4, 5, 6])); @@ -155,6 +155,40 @@ describe('NDArray', () => { expect(a.inGPU()).toBe(false); }); + it('NDArray getValuesAsync CPU --> GPU', (doneFn) => { + const a = Array2D.new([3, 2], [1, 2, 3, 4, 5, 6]); + + expect(a.inGPU()).toBe(false); + + a.getValuesAsync().then(values => { + expect(values).toEqual(new Float32Array([1, 2, 3, 4, 5, 6])); + + expect(a.inGPU()).toBe(false); + + // Upload to GPU. + expect(a.getTexture() != null).toBe(true); + + expect(a.inGPU()).toBe(true); + a.dispose(); + doneFn(); + }); + }); + + it('NDArray getValuesAsync GPU --> CPU', (doneFn) => { + const texture = textureManager.acquireTexture([3, 2]); + gpgpu.uploadMatrixToTexture( + texture, 3, 2, new Float32Array([1, 2, 3, 4, 5, 6])); + + const a = new Array2D([3, 2], {texture, textureShapeRC: [3, 2]}); + expect(a.inGPU()).toBe(true); + + a.getValuesAsync().then(values => { + expect(values).toEqual(new Float32Array([1, 2, 3, 4, 5, 6])); + expect(a.inGPU()).toBe(false); + doneFn(); + }); + }); + it('Scalar basic methods', () => { const a = Scalar.new(5); expect(a.get()).toBe(5); diff --git a/src/math/webgl/gpgpu_context.ts b/src/math/webgl/gpgpu_context.ts index cbffb9d098..66fa1a8601 100644 --- a/src/math/webgl/gpgpu_context.ts +++ b/src/math/webgl/gpgpu_context.ts @@ -259,14 +259,20 @@ export class GPGPUContext { webgl_util.callAndCheck(this.gl, () => this.gl.finish()); } - public runBenchmark(benchmark: () => void): Promise { + /** + * Executes a query function which contains GL commands and resolves when + * the command buffer has finished executing the query. + * @param queryFn The query function containing GL commands to execute. + * @return a promise that resolves with the ellapsed time in milliseconds. + */ + public runQuery(queryFn: () => void): Promise { if (ENV.get('WEBGL_VERSION') === 2) { - return this.runBenchmarkWebGL2(benchmark); + return this.runQueryWebGL2(queryFn); } - return this.runBenchmarkWebGL1(benchmark); + return this.runQueryWebGL1(queryFn); } - private runBenchmarkWebGL2(benchmark: () => void): Promise { + private runQueryWebGL2(benchmark: () => void): Promise { const ext = webgl_util.getExtensionOrThrow( this.gl, 'EXT_disjoint_timer_query_webgl2'); // tslint:disable-next-line:no-any @@ -296,13 +302,13 @@ export class GPGPUContext { }; const getTimeElapsed = () => { - const timeElapsed = + const timeElapsedNanos = // tslint:disable-next-line:no-any (this.gl as any) // tslint:disable-next-line:no-any .getQueryParameter(query, (this.gl as any).QUERY_RESULT); // Return milliseconds. - resolve(timeElapsed / 1000000); + resolve(timeElapsedNanos / 1000000); }; const resolveWithWarning = () => { @@ -310,14 +316,11 @@ export class GPGPUContext { resolve(-1); }; - const maxBackoffMs = 2048; - util.tryWithBackoff(queryGPU, maxBackoffMs) - .then(getTimeElapsed) - .catch(resolveWithWarning); + util.repeatedTry(queryGPU).then(getTimeElapsed).catch(resolveWithWarning); }); } - private runBenchmarkWebGL1(benchmark: () => void): Promise { + private runQueryWebGL1(benchmark: () => void): Promise { const ext = webgl_util.getExtensionOrThrow( // tslint:disable-next-line:no-any this.gl, 'EXT_disjoint_timer_query') as any; @@ -340,9 +343,10 @@ export class GPGPUContext { }; const getTimeElapsed = () => { - const timeElapsed = ext.getQueryObjectEXT(query, ext.QUERY_RESULT_EXT); + const timeElapsedNanos = + ext.getQueryObjectEXT(query, ext.QUERY_RESULT_EXT); // Return milliseconds. - resolve(timeElapsed / 1000000); + resolve(timeElapsedNanos / 1000000); }; const resolveWithWarning = () => { @@ -350,10 +354,7 @@ export class GPGPUContext { resolve(-1); }; - const maxBackoffMs = 2048; - util.tryWithBackoff(queryGPU, maxBackoffMs) - .then(getTimeElapsed) - .catch(resolveWithWarning); + util.repeatedTry(queryGPU).then(getTimeElapsed).catch(resolveWithWarning); }); } diff --git a/src/util.ts b/src/util.ts index b4ed92e01c..ea3bca537d 100644 --- a/src/util.ts +++ b/src/util.ts @@ -222,8 +222,9 @@ export function rightPad(a: string, size: number): string { return a + ' '.repeat(size - a.length); } -export function tryWithBackoff( - checkFn: () => boolean, maxBackoffMs: number): Promise { +export function repeatedTry( + checkFn: () => boolean, delayFn = (counter: number) => 0, + maxCounter?: number): Promise { return new Promise((resolve, reject) => { let tryCount = 0; @@ -235,9 +236,9 @@ export function tryWithBackoff( tryCount++; - const nextBackoff = Math.pow(2, tryCount); + const nextBackoff = delayFn(tryCount); - if (nextBackoff >= maxBackoffMs) { + if (maxCounter != null && tryCount >= maxCounter) { reject(); return; } diff --git a/src/util_test.ts b/src/util_test.ts index c37dc2c95a..c1b55e38bd 100644 --- a/src/util_test.ts +++ b/src/util_test.ts @@ -147,7 +147,7 @@ describe('util.getBroadcastedShape', () => { }); }); -describe('util.tryWithBackoff', () => { +describe('util.repeatedTry', () => { it('resolves', (doneFn) => { let counter = 0; const checkFn = () => { @@ -158,14 +158,14 @@ describe('util.tryWithBackoff', () => { return false; }; - util.tryWithBackoff(checkFn, 128).then(doneFn).catch(() => { + util.repeatedTry(checkFn).then(doneFn).catch(() => { throw new Error('Rejected backoff.'); }); }); it('rejects', (doneFn) => { const checkFn = () => false; - util.tryWithBackoff(checkFn, 32) + util.repeatedTry(checkFn, () => 0, 5) .then(() => { throw new Error('Backoff resolved'); })