diff --git a/tfjs-core/src/backends/webgl/backend_webgl.ts b/tfjs-core/src/backends/webgl/backend_webgl.ts index 128a42feee3..492f3ac6d59 100644 --- a/tfjs-core/src/backends/webgl/backend_webgl.ts +++ b/tfjs-core/src/backends/webgl/backend_webgl.ts @@ -472,12 +472,12 @@ export class MathBackendWebGL implements KernelBackend { } let buffer = null; + let tmpDownloadTarget: TensorHandle; + if (dtype !== 'complex64' && ENV.get('WEBGL_BUFFER_SUPPORTED')) { // Possibly copy the texture into a buffer before inserting a fence. - const tmpTarget = this.decode(dataId); - - dataId = tmpTarget.dataId; - const tmpData = this.texData.get(tmpTarget.dataId); + tmpDownloadTarget = this.decode(dataId); + const tmpData = this.texData.get(tmpDownloadTarget.dataId); buffer = this.gpgpu.createBufferFromTexture( tmpData.texture, ...tex_util.getDenseTexShape(shape)); @@ -503,9 +503,10 @@ export class MathBackendWebGL implements KernelBackend { vals = this.getValuesFromTexture(dataId); } else { const size = util.sizeFromShape(shape); - vals = this.gpgpu.downloadFloat32MatrixFromBuffer(buffer, size); - this.disposeData(dataId); + } + if (tmpDownloadTarget != null) { + this.disposeData(tmpDownloadTarget.dataId); } const dTypeVals = this.convertAndCacheOnCPU(dataId, vals); @@ -690,6 +691,14 @@ export class MathBackendWebGL implements KernelBackend { return this.texData.get(dataId).texture; } + /** + * Returns internal information for the specific data bucket. Used in unit + * tests. + */ + getDataInfo(dataId: DataId): TextureData { + return this.texData.get(dataId); + } + private getCPUBackend(): KernelBackend|null { if (!ENV.getBool('WEBGL_CPU_FORWARD')) { return null; diff --git a/tfjs-core/src/backends/webgl/backend_webgl_test.ts b/tfjs-core/src/backends/webgl/backend_webgl_test.ts index 9033ce93bd7..f12d5a5c270 100644 --- a/tfjs-core/src/backends/webgl/backend_webgl_test.ts +++ b/tfjs-core/src/backends/webgl/backend_webgl_test.ts @@ -506,6 +506,50 @@ describeWithFlags('time webgl', WEBGL_ENVS, () => { }); }); +describeWithFlags('caching on cpu', WEBGL_ENVS, () => { + beforeAll(() => { + tf.ENV.set('WEBGL_CPU_FORWARD', false); + }); + + it('caches on cpu after async read', async () => { + const backend = new MathBackendWebGL(); + tf.registerBackend('cache-on-cpu', () => backend); + tf.setBackend('cache-on-cpu'); + + const t = tf.square(2); + const info = backend.getDataInfo(t.dataId); + + // Make sure the tensor is on the GPU. + expect(info.values == null).toBe(true); + + await t.data(); + + // Make sure the tensor is cached on CPU. + expect(info.values).not.toBe(null); + + tf.removeBackend('cache-on-cpu'); + }); + + it('caches on cpu after sync read', () => { + const backend = new MathBackendWebGL(); + tf.registerBackend('cache-on-cpu', () => backend); + tf.setBackend('cache-on-cpu'); + + const t = tf.square(2); + const info = backend.getDataInfo(t.dataId); + + // Make sure the tensor is on the GPU. + expect(info.values == null).toBe(true); + + t.dataSync(); + + // Make sure the tensor is cached on CPU. + expect(info.values).not.toBe(null); + + tf.removeBackend('cache-on-cpu'); + }); +}); + describe('WebGL backend has sync init', () => { it('can do matmul without waiting for ready', async () => { tf.registerBackend('my-webgl', () => { diff --git a/tfjs-core/src/ops/image_ops.ts b/tfjs-core/src/ops/image_ops.ts index ba4ee7dceba..27e5eacb8f4 100644 --- a/tfjs-core/src/ops/image_ops.ts +++ b/tfjs-core/src/ops/image_ops.ts @@ -191,8 +191,8 @@ async function nonMaxSuppressionAsync_( iouThreshold = inputs.iouThreshold; scoreThreshold = inputs.scoreThreshold; - const boxesVals = await $boxes.data(); - const scoresVals = await $scores.data(); + const [boxesVals, scoresVals] = + await Promise.all([$boxes.data(), $scores.data()]); const res = nonMaxSuppressionImpl( boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold); if ($boxes !== boxes) {