diff --git a/src/backends/webgl/backend_webgl.ts b/src/backends/webgl/backend_webgl.ts index 2ff6e00e27..284f8a5ff4 100644 --- a/src/backends/webgl/backend_webgl.ts +++ b/src/backends/webgl/backend_webgl.ts @@ -75,6 +75,8 @@ import {DepthwiseConv2DProgram} from './conv_gpu_depthwise'; import {DepthwiseConvPacked2DProgram} from './conv_packed_gpu_depthwise'; import {CropAndResizeProgram} from './crop_and_resize_gpu'; import {CumSumProgram} from './cumsum_gpu'; +import {DecodeMatrixProgram} from './decode_matrix_gpu'; +import {DecodeMatrixPackedProgram} from './decode_matrix_packed_gpu'; import {DepthToSpaceProgram} from './depth_to_space_gpu'; import {EncodeFloatProgram} from './encode_float_gpu'; import {EncodeMatrixProgram} from './encode_matrix_gpu'; @@ -406,16 +408,7 @@ export class MathBackendWebGL implements KernelBackend { return new Promise(resolve => subscribers.push(resolve)); } const texData = this.texData.get(dataId); - const { - texture, - values, - texShape, - isPacked, - shape, - slice, - dtype, - complexTensors - } = texData; + const {values, shape, slice, dtype, complexTensors} = texData; if (slice != null) { const program = new UnaryOpProgram(shape, unary_op.CLONE); @@ -429,8 +422,6 @@ export class MathBackendWebGL implements KernelBackend { return this.convertAndCacheOnCPU(dataId); } - this.pendingRead.set(dataId, []); - if (!ENV.getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED') && ENV.getNumber('WEBGL_VERSION') === 2) { throw new Error( @@ -439,17 +430,20 @@ export class MathBackendWebGL implements KernelBackend { } let buffer = null; - if (dtype !== 'complex64') { + if (dtype !== 'complex64' && ENV.get('WEBGL_BUFFER_SUPPORTED')) { // Possibly copy the texture into a buffer before inserting a fence. - let width = texShape[1]; - let height = texShape[0]; - if (isPacked) { - [width, height] = tex_util.getPackedMatrixTextureShapeWidthHeight( - texShape[0], texShape[1]); - } - if (ENV.get('WEBGL_BUFFER_SUPPORTED')) { - buffer = this.gpgpu.createBufferFromTexture(texture, height, width); - } + const tmpTarget = this.decode(dataId); + + dataId = tmpTarget.dataId; + const tmpData = this.texData.get(tmpTarget.dataId); + + buffer = this.gpgpu.createBufferFromTexture( + tmpData.texture, ...tex_util.getDenseTexShape(shape)); + } + + this.pendingRead.set(dataId, []); + + if (dtype !== 'complex64') { // Create a fence and wait for it to resolve. await this.gpgpu.createAndWaitForFence(); } @@ -466,22 +460,9 @@ export class MathBackendWebGL implements KernelBackend { vals = this.getValuesFromTexture(dataId); } else { const size = util.sizeFromShape(shape); - if (isPacked) { - const batch = webgl_util.getBatchDim(shape); - let rows = 1, cols = 1; - if (shape.length) { - [rows, cols] = webgl_util.getRowsCols(shape); - } - vals = this.gpgpu - .downloadPackedMatrixFromBuffer( - buffer, batch, rows, cols, texShape[0], texShape[1]) - .subarray(0, size); - } else { - vals = this.gpgpu - .downloadFloat32MatrixFromBuffer( - buffer, texShape[0], texShape[1]) - .subarray(0, size); - } + + vals = this.gpgpu.downloadFloat32MatrixFromBuffer(buffer, size); + this.disposeData(dataId); } const dTypeVals = this.convertAndCacheOnCPU(dataId, vals); @@ -498,25 +479,19 @@ export class MathBackendWebGL implements KernelBackend { } private getValuesFromTexture(dataId: DataId): Float32Array { - const {shape, dtype, texture, texShape} = this.texData.get(dataId); + const {shape, dtype} = this.texData.get(dataId); const size = util.sizeFromShape(shape); if (ENV.getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED')) { - if (this.texData.get(dataId).isPacked) { - const batch = webgl_util.getBatchDim(shape); - let rows = 1, cols = 1; - if (shape.length) { - [rows, cols] = webgl_util.getRowsCols(shape); - } - return this.gpgpu - .downloadMatrixFromPackedTexture( - texture, batch, rows, cols, texShape[0], texShape[1]) - .subarray(0, size); - } else { - return this.gpgpu - .downloadFloat32MatrixFromOutputTexture( - texture, texShape[0], texShape[1]) - .subarray(0, size); - } + const tmpTarget = this.decode(dataId); + const tmpData = this.texData.get(tmpTarget.dataId); + const vals = this.gpgpu + .downloadMatrixFromPackedTexture( + tmpData.texture, ...tex_util.getDenseTexShape(shape)) + .subarray(0, size); + + this.disposeData(tmpTarget.dataId); + + return vals; } const tmpTarget = this.makeTensorHandle(shape, 'float32') as TensorHandle & @@ -2317,7 +2292,8 @@ export class MathBackendWebGL implements KernelBackend { private packTensor(input: T|TensorHandle): T { const program = new PackProgram(input.shape); return this.compileAndRun( - program, [input], this.makePackedTensor(input.shape, input.dtype)); + program, [input], this.makePackedTensor(input.shape, input.dtype), null, + true); } private packedReshape(input: Tensor, afterShape: ShapeMap[R]): @@ -2336,11 +2312,39 @@ export class MathBackendWebGL implements KernelBackend { .reshape(afterShape); } + private decode(dataId: DataId): TensorHandle { + const texData = this.texData.get(dataId); + const {isPacked, shape, dtype} = texData; + const shapeAs3D = + webgl_util.getShapeAs3D(shape) as [number, number, number]; + const denseTexShape = tex_util.getDenseTexShape(shape); + + const tmpTarget = this.makeTensorHandle(shape, 'float32') as TensorHandle & + {size: number}; + this.texData.get(tmpTarget.dataId).isPacked = true; + this.texData.get(tmpTarget.dataId).dtype = dtype; + this.texData.get(tmpTarget.dataId).texShape = + denseTexShape.map( + d => d * 2) as [number, number]; // To undo the effect of isPacked + // being set to true. + + let program; + if (isPacked) { + program = new DecodeMatrixPackedProgram(shapeAs3D, denseTexShape); + } else { + program = new DecodeMatrixProgram(shapeAs3D, denseTexShape); + } + + this.compileAndRun( + program, [{shape: shapeAs3D, dtype, dataId}], tmpTarget, null, true); + return tmpTarget; + } + public compileAndRun< K extends {dtype: DataType, size: number, dataId: {}, shape: number[]}>( program: GPGPUProgram, inputs: TensorHandle[], output?: K, - customSetup?: (gpgpu: GPGPUContext, webGLProgram: WebGLProgram) => void): - K { + customSetup?: (gpgpu: GPGPUContext, webGLProgram: WebGLProgram) => void, + preventEagerUnpackingOfOutput = false): K { if (output == null) { if (program.usesPackedTextures) { output = this.makePackedTensor(program.outputShape, inputs[0].dtype) as @@ -2447,7 +2451,8 @@ export class MathBackendWebGL implements KernelBackend { } if (!ENV.getBool('WEBGL_LAZILY_UNPACK') && - this.texData.get(output.dataId).isPacked && !program.isPackShader) { + this.texData.get(output.dataId).isPacked && + preventEagerUnpackingOfOutput === false) { return this.unpackTensor(output as {} as Tensor) as {} as K; } return output; @@ -2525,19 +2530,14 @@ export class MathBackendWebGL implements KernelBackend { start = performance.now(); } - const texShape = - webgl_util.getTextureShapeFromLogicalShape(shape, isPacked); - texData.texShape = texShape; + let texShape = texData.texShape; + if (texShape == null) { + texShape = webgl_util.getTextureShapeFromLogicalShape(shape, isPacked); + texData.texShape = texShape; + } if (values != null) { - let shapeAs3D: [number, number, number] = [1, 1, 1]; - const isScalar = - shape.length === 0 || (shape.length === 1 && shape[0] === 1); - if (!isScalar) { - shapeAs3D = [ - webgl_util.getBatchDim(shape), ...webgl_util.getRowsCols(shape) - ] as [number, number, number]; - } + const shapeAs3D = webgl_util.getShapeAs3D(shape); let program; let width = texShape[1], height = texShape[0]; diff --git a/src/backends/webgl/decode_matrix_gpu.ts b/src/backends/webgl/decode_matrix_gpu.ts new file mode 100644 index 0000000000..0f8bf67f70 --- /dev/null +++ b/src/backends/webgl/decode_matrix_gpu.ts @@ -0,0 +1,58 @@ +/** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {getGlslDifferences} from './glsl_version'; +import {GPGPUProgram} from './gpgpu_math'; +import * as shader_util from './shader_compiler_util'; + +export class DecodeMatrixProgram implements GPGPUProgram { + variableNames = ['A']; + userCode: string; + outputShape: [number, number, number]; + + constructor(outputShape: [number, number, number], texShape: [ + number, number + ]) { + const glsl = getGlslDifferences(); + this.outputShape = outputShape; + + this.userCode = ` + ivec3 outCoordsFromFlatIndex(int index) { + ${ + shader_util.getLogicalCoordinatesFromFlatIndex( + ['r', 'c', 'd'], outputShape)} + return ivec3(r, c, d); + } + + void main() { + ivec2 resTexRC = ivec2(resultUV.yx * + vec2(${texShape[0]}, ${texShape[1]})); + int index = 4 * (resTexRC.x * ${texShape[1]} + resTexRC.y); + + vec4 result = vec4(0.); + + for (int i=0; i<4; i++) { + int flatIndex = index + i; + ivec3 rc = outCoordsFromFlatIndex(flatIndex); + result[i] = getA(rc.x, rc.y, rc.z); + } + + ${glsl.output} = result; + } + `; + } +} \ No newline at end of file diff --git a/src/backends/webgl/decode_matrix_packed_gpu.ts b/src/backends/webgl/decode_matrix_packed_gpu.ts new file mode 100644 index 0000000000..3bd3e0a014 --- /dev/null +++ b/src/backends/webgl/decode_matrix_packed_gpu.ts @@ -0,0 +1,59 @@ +/** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {getGlslDifferences} from './glsl_version'; +import {GPGPUProgram} from './gpgpu_math'; +import * as shader_util from './shader_compiler_util'; + +export class DecodeMatrixPackedProgram implements GPGPUProgram { + variableNames = ['A']; + userCode: string; + usesPackedTextures = true; + outputShape: [number, number, number]; + + constructor(outputShape: [number, number, number], texShape: [ + number, number + ]) { + const glsl = getGlslDifferences(); + this.outputShape = outputShape; + + this.userCode = ` + ivec3 outCoordsFromFlatIndex(int index) { + ${ + shader_util.getLogicalCoordinatesFromFlatIndex( + ['r', 'c', 'd'], outputShape)} + return ivec3(r, c, d); + } + + void main() { + ivec2 resTexRC = ivec2(resultUV.yx * + vec2(${texShape[0]}, ${texShape[1]})); + int index = 4 * (resTexRC.x * ${texShape[1]} + resTexRC.y); + + vec4 result = vec4(0.); + + for (int i=0; i<4; i++) { + int flatIndex = index + i; + ivec3 rc = outCoordsFromFlatIndex(flatIndex); + result[i] = getChannel(getA(rc.x, rc.y, rc.z), vec2(rc.y, rc.z)); + } + + ${glsl.output} = result; + } + `; + } +} \ No newline at end of file diff --git a/src/backends/webgl/gpgpu_context.ts b/src/backends/webgl/gpgpu_context.ts index 2466241c88..a86ad11d95 100644 --- a/src/backends/webgl/gpgpu_context.ts +++ b/src/backends/webgl/gpgpu_context.ts @@ -178,14 +178,6 @@ export class GPGPUContext { this.gl, this.debug, () => this.gl.deleteTexture(texture)); } - public downloadFloat32MatrixFromOutputTexture( - texture: WebGLTexture, rows: number, columns: number): Float32Array { - return this.downloadMatrixDriver( - texture, - () => gpgpu_util.downloadFloat32MatrixFromOutputTexture( - this.gl, this.debug, rows, columns, this.textureConfig)); - } - public downloadByteEncodedFloatMatrixFromOutputTexture( texture: WebGLTexture, rows: number, columns: number): Float32Array { return this.downloadMatrixDriver( @@ -202,10 +194,9 @@ export class GPGPUContext { this.textureConfig); } - public downloadFloat32MatrixFromBuffer( - buffer: WebGLBuffer, rows: number, columns: number): Float32Array { - return gpgpu_util.downloadFloat32MatrixFromBuffer( - this.gl, buffer, rows, columns, this.textureConfig); + public downloadFloat32MatrixFromBuffer(buffer: WebGLBuffer, size: number): + Float32Array { + return gpgpu_util.downloadFloat32MatrixFromBuffer(this.gl, buffer, size); } public createBufferFromTexture( @@ -258,13 +249,12 @@ export class GPGPUContext { } public downloadMatrixFromPackedTexture( - texture: WebGLTexture, batch: number, rows: number, columns: number, - physicalRows: number, physicalCols: number): Float32Array { + texture: WebGLTexture, physicalRows: number, + physicalCols: number): Float32Array { return this.downloadMatrixDriver( texture, () => gpgpu_util.downloadMatrixFromPackedOutputTexture( - this.gl, this.debug, batch, rows, columns, physicalRows, - physicalCols, this.textureConfig)); + this.gl, this.debug, physicalRows, physicalCols)); } private vertexAttrsAreBound = false; diff --git a/src/backends/webgl/gpgpu_context_test.ts b/src/backends/webgl/gpgpu_context_test.ts index b6f5ada37f..d03bae8f32 100644 --- a/src/backends/webgl/gpgpu_context_test.ts +++ b/src/backends/webgl/gpgpu_context_test.ts @@ -17,7 +17,6 @@ import {ENV} from '../../environment'; import {describeWithFlags} from '../../jasmine_util'; -import {expectArraysClose, expectNumbersClose} from '../../test_util'; import {WEBGL_ENVS} from './backend_webgl_test_registry'; import {getGlslDifferences} from './glsl_version'; @@ -29,86 +28,6 @@ const DOWNLOAD_FLOAT_ENVS = { predicate: WEBGL_ENVS.predicate }; -describeWithFlags( - 'GPGPUContext downloadMatrixFromTexture', DOWNLOAD_FLOAT_ENVS, () => { - let gpgpu: GPGPUContext; - let texture: WebGLTexture; - - beforeEach(() => { - gpgpu = new GPGPUContext(); - // Silences debug warnings. - spyOn(console, 'warn'); - ENV.set('DEBUG', true); - texture = gpgpu.createFloat32MatrixTexture(1, 1); - }); - - afterEach(() => { - gpgpu.deleteMatrixTexture(texture); - gpgpu.dispose(); - }); - - it('returns 1x1 matrix that was uploaded', () => { - gpgpu.uploadDenseMatrixToTexture( - texture, 1, 1, new Float32Array([1.234, 0, 0, 0])); - const result = - gpgpu.downloadFloat32MatrixFromOutputTexture(texture, 1, 1); - expectNumbersClose(result[0], 1.234); - }); - - it('returns 2x2 matrix that was uploaded', () => { - const texture2 = gpgpu.createFloat32MatrixTexture(2, 2); - gpgpu.uploadDenseMatrixToTexture( - texture2, 2, 2, - new Float32Array( - [1.234, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0])); - const result = - gpgpu.downloadFloat32MatrixFromOutputTexture(texture2, 2, 2); - expectArraysClose(result, new Float32Array([1.234, 2, 3, 4])); - gpgpu.deleteMatrixTexture(texture2); - }); - - it('uses texture parameter', () => { - const texture2: WebGLTexture = gpgpu.createFloat32MatrixTexture(1, 1); - gpgpu.uploadDenseMatrixToTexture( - texture, 1, 1, new Float32Array([1, 0, 0, 0])); - gpgpu.uploadDenseMatrixToTexture( - texture2, 1, 1, new Float32Array([2, 0, 0, 0])); - const read1 = - gpgpu.downloadFloat32MatrixFromOutputTexture(texture, 1, 1); - const read2 = - gpgpu.downloadFloat32MatrixFromOutputTexture(texture2, 1, 1); - - expectNumbersClose(read1[0], 1); - expectNumbersClose(read2[0], 2); - - gpgpu.deleteMatrixTexture(texture2); - }); - }); - -describeWithFlags( - 'GPGPUContext color texture with float', DOWNLOAD_FLOAT_ENVS, () => { - let gpgpu: GPGPUContext; - let texture: WebGLTexture; - - afterEach(() => { - gpgpu.deleteMatrixTexture(texture); - gpgpu.dispose(); - }); - - it('basic', () => { - gpgpu = new GPGPUContext(); - ENV.set('DEBUG', true); - texture = gpgpu.createFloat32MatrixTexture(1, 1); - - gpgpu.setOutputMatrixTexture(texture, 1, 1); - gpgpu.gl.clearColor(0.123, 0, 0, 0); - gpgpu.gl.clear(gpgpu.gl.COLOR_BUFFER_BIT); - const result = - gpgpu.downloadFloat32MatrixFromOutputTexture(texture, 1, 1); - expectNumbersClose(result[0], 0.123); - }); - }); - describeWithFlags( 'GPGPUContext setOutputMatrixTexture', DOWNLOAD_FLOAT_ENVS, () => { let gpgpu: GPGPUContext; @@ -132,28 +51,6 @@ describeWithFlags( expect(gpgpu.outputTexture).toBe(texture); }); - it('rebinds the output texture to the color buffer target', () => { - const output: WebGLTexture = gpgpu.createFloat32MatrixTexture(1, 1); - gpgpu.uploadDenseMatrixToTexture( - texture, 1, 1, new Float32Array([10, 0, 0, 0])); - gpgpu.setOutputMatrixTexture(output, 1, 1); - const tBeforeClear = - gpgpu.downloadFloat32MatrixFromOutputTexture(texture, 1, 1); - expectNumbersClose(tBeforeClear[0], 10); - gpgpu.gl.clearColor(1, 0, 0, 0); - const tAfterClear = - gpgpu.downloadFloat32MatrixFromOutputTexture(texture, 1, 1); - expectNumbersClose(tAfterClear[0], 10); - gpgpu.deleteMatrixTexture(output); - }); - - it('resets output texture to null if nothing was previously bound', - () => { - expect(gpgpu.outputTexture).toBeNull(); - gpgpu.downloadFloat32MatrixFromOutputTexture(texture, 1, 1); - expect(gpgpu.outputTexture).toBeNull(); - }); - it('sets the gl viewport to the output texture dimensions', () => { const columns = 456; const rows = 123; @@ -163,16 +60,6 @@ describeWithFlags( expect(gpgpu.gl.getParameter(gpgpu.gl.VIEWPORT)).toEqual(expected); gpgpu.deleteMatrixTexture(output); }); - - it('doesn\'t change gl viewport when downloading a non-output tex', - () => { - const output = gpgpu.createFloat32MatrixTexture(128, 128); - gpgpu.setOutputMatrixTexture(output, 128, 128); - gpgpu.downloadFloat32MatrixFromOutputTexture(texture, 1, 1); - const expected = new Int32Array([0, 0, 128, 128]); - expect(gpgpu.gl.getParameter(gpgpu.gl.VIEWPORT)).toEqual(expected); - gpgpu.deleteMatrixTexture(output); - }); }); describeWithFlags( @@ -244,15 +131,6 @@ describeWithFlags( gpgpu.dispose(); }); - it('writes to all pixels by default', () => { - gpgpu.executeProgram(); - const result = - gpgpu.downloadFloat32MatrixFromOutputTexture(output, 4, 4); - const expected = new Float32Array(4 * 4); - expected.fill(2); - expectArraysClose(result, expected); - }); - it('sets the scissor box to the requested parameters', () => { gpgpu.setOutputMatrixWriteRegion(0, 1, 2, 3); const scissorBox = gpgpu.gl.getParameter(gpgpu.gl.SCISSOR_BOX); @@ -261,42 +139,6 @@ describeWithFlags( expect(scissorBox[2]).toEqual(3); expect(scissorBox[3]).toEqual(1); }); - - it('writes only to center 2x2 region of 4x4 texture', () => { - gpgpu.setOutputMatrixWriteRegion(1, 2, 1, 2); - gpgpu.executeProgram(); - const result = - gpgpu.downloadFloat32MatrixFromOutputTexture(output, 4, 4); - const expected = - new Float32Array([0, 0, 0, 0, 0, 2, 2, 0, 0, 2, 2, 0, 0, 0, 0, 0]); - expectArraysClose(result, expected); - }); - - it('preserves data from previous writes outside of write region', () => { - gpgpu.setOutputMatrixWriteRegion(0, 1, 0, 4); // top row - gpgpu.executeProgram(); - gpgpu.setOutputMatrixWriteRegion(3, 1, 0, 4); // bottom row - gpgpu.executeProgram(); - const result = - gpgpu.downloadFloat32MatrixFromOutputTexture(output, 4, 4); - const expected = - new Float32Array([2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2]); - expectArraysClose(result, expected); - }); - - it('writes adjacent cells across multiple calls', () => { - for (let row = 0; row < 4; ++row) { - for (let col = 0; col < 4; ++col) { - gpgpu.setOutputMatrixWriteRegion(row, 1, col, 1); - gpgpu.executeProgram(); - } - } - const result = - gpgpu.downloadFloat32MatrixFromOutputTexture(output, 4, 4); - const expected = new Float32Array(4 * 4); - expected.fill(2); - expectArraysClose(result, expected); - }); }); describeWithFlags('GPGPUContext', DOWNLOAD_FLOAT_ENVS, () => { diff --git a/src/backends/webgl/gpgpu_math.ts b/src/backends/webgl/gpgpu_math.ts index 575c225243..a284da2550 100644 --- a/src/backends/webgl/gpgpu_math.ts +++ b/src/backends/webgl/gpgpu_math.ts @@ -30,9 +30,6 @@ export interface GPGPUProgram { outputShape: number[]; userCode: string; usesPackedTextures?: boolean; - isPackShader?: boolean; // This property is used to single out the packing - // shader so its output does not get eagerly unpacked - // by backend_webgl.compileAndRun. } export interface GPGPUBinary { diff --git a/src/backends/webgl/gpgpu_util.ts b/src/backends/webgl/gpgpu_util.ts index e58c6e9494..2bfb87846b 100644 --- a/src/backends/webgl/gpgpu_util.ts +++ b/src/backends/webgl/gpgpu_util.ts @@ -17,7 +17,6 @@ import {ENV} from '../../environment'; import {PixelData, TypedArray} from '../../types'; -import * as util from '../../util'; import {getGlslDifferences} from './glsl_version'; import * as tex_util from './tex_util'; @@ -281,9 +280,8 @@ export function createBufferFromOutputTexture( // Initialize the buffer to the size of the texture in bytes. const bytesPerFloat = 4; - const bufferSizeBytes = bytesPerFloat * - tex_util.getUnpackedArraySizeFromMatrixSize( - rows * columns, textureConfig.downloadUnpackNumChannels); + const valuesPerTexel = 4; + const bufferSizeBytes = bytesPerFloat * valuesPerTexel * rows * columns; webgl_util.callAndCheck( gl2, debug, @@ -303,47 +301,17 @@ export function createBufferFromOutputTexture( } export function downloadFloat32MatrixFromBuffer( - gl: WebGLRenderingContext, buffer: WebGLBuffer, rows: number, - columns: number, textureConfig: TextureConfig): Float32Array { + gl: WebGLRenderingContext, buffer: WebGLBuffer, + size: number): Float32Array { const gl2 = gl as WebGL2RenderingContext; - const downloadTarget = - new Float32Array(tex_util.getUnpackedArraySizeFromMatrixSize( - rows * columns, textureConfig.downloadUnpackNumChannels)); + const downloadTarget = new Float32Array(size); gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer); gl2.getBufferSubData(gl2.PIXEL_PACK_BUFFER, 0, downloadTarget); gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null); - const matrix = new Float32Array(rows * columns); - tex_util.decodeMatrixFromUnpackedArray( - downloadTarget as Float32Array, matrix, - textureConfig.downloadUnpackNumChannels); - - return matrix; -} - -export function downloadFloat32MatrixFromOutputTexture( - gl: WebGLRenderingContext, debug: boolean, rows: number, columns: number, - textureConfig: TextureConfig): Float32Array { - const [w, h] = - tex_util.getUnpackedMatrixTextureShapeWidthHeight(rows, columns); - - const downloadTarget = - new Float32Array(tex_util.getUnpackedArraySizeFromMatrixSize( - rows * columns, textureConfig.downloadUnpackNumChannels)); - - webgl_util.callAndCheck( - gl, debug, - () => gl.readPixels( - 0, 0, w, h, textureConfig.downloadTextureFormat, gl.FLOAT, - downloadTarget)); - - const matrix = new Float32Array(rows * columns); - tex_util.decodeMatrixFromUnpackedArray( - downloadTarget as Float32Array, matrix, - textureConfig.downloadUnpackNumChannels); - return matrix; + return downloadTarget; } export function downloadByteEncodedFloatMatrixFromOutputTexture( @@ -381,26 +349,17 @@ export function downloadPackedMatrixFromBuffer( gl2.getBufferSubData(gl2.PIXEL_PACK_BUFFER, 0, downloadTarget); gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null); - const matrix = new Float32Array(util.sizeFromShape([batch, rows, cols])); - tex_util.decodeMatrixFromPackedRGBA( - downloadTarget, batch, rows, cols, matrix); - return matrix; + return downloadTarget; } export function downloadMatrixFromPackedOutputTexture( - gl: WebGLRenderingContext, debug: boolean, batch: number, rows: number, - cols: number, physicalRows: number, physicalCols: number, - textureConfig: TextureConfig): Float32Array { - const [w, h] = tex_util.getPackedMatrixTextureShapeWidthHeight( - physicalRows, physicalCols); - - const packedRGBA = - new Float32Array(tex_util.getPackedRGBAArraySizeFromMatrixShape( - physicalRows, physicalCols)); + gl: WebGLRenderingContext, debug: boolean, physicalRows: number, + physicalCols: number): Float32Array { + const packedRGBA = new Float32Array(physicalRows * physicalCols * 4); webgl_util.callAndCheck( gl, debug, - () => gl.readPixels(0, 0, w, h, gl.RGBA, gl.FLOAT, packedRGBA)); - const matrix = new Float32Array(util.sizeFromShape([batch, rows, cols])); - return tex_util.decodeMatrixFromPackedRGBA( - packedRGBA, batch, rows, cols, matrix); + () => gl.readPixels( + 0, 0, physicalCols, physicalRows, gl.RGBA, gl.FLOAT, packedRGBA)); + + return packedRGBA; } diff --git a/src/backends/webgl/gpgpu_util_test.ts b/src/backends/webgl/gpgpu_util_test.ts index a95bdf09ad..553e542dcc 100644 --- a/src/backends/webgl/gpgpu_util_test.ts +++ b/src/backends/webgl/gpgpu_util_test.ts @@ -15,17 +15,11 @@ * ============================================================================= */ -import * as tf from '../../index'; import {describeWithFlags} from '../../jasmine_util'; -import {expectArraysClose} from '../../test_util'; import {WEBGL_ENVS} from './backend_webgl_test_registry'; import {GPGPUContext} from './gpgpu_context'; import * as gpgpu_util from './gpgpu_util'; -const DOWNLOAD_FLOAT_ENVS = { - flags: {'WEBGL_DOWNLOAD_FLOAT_ENABLED': true} -}; - describeWithFlags('gpgpu_util createWebGLContext', WEBGL_ENVS, () => { let gpgpu: GPGPUContext; @@ -137,132 +131,3 @@ describeWithFlags('gpgpu_util createPackedMatrixTexture', WEBGL_ENVS, () => { gpgpu.dispose(); }); }); - -describeWithFlags( - 'gpgpu_util downloadMatrixFromPackedOutputTexture', DOWNLOAD_FLOAT_ENVS, - () => { - it('should work when texture shape != logical shape', () => { - const gpgpu = new GPGPUContext(); - const textureConfig = gpgpu_util.getTextureConfig(gpgpu.gl); - const debug = false; - const tex = gpgpu_util.createPackedMatrixTexture( - gpgpu.gl, debug, 4, 6, textureConfig); - - const mat = - tf.tensor2d([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [1, 12]); - /* - This is how the tensor is arranged in a 2x3 packed texture - - 0|1 2|3 4|5 - ––– ––– ––– - x|x x|x x|x - - 6|7 8|9 10|11 - ––– ––– ––– - x|x x|x x|x - - Each group of four is one texel. x's represent empty channels. To - obtain the flattened representation in the call to gl.texSubImage2D - below, one moves through the texels from left to right, top to bottom, - reading off the 4 channels for each texel (x's become 0's). - */ - - gpgpu.gl.bindTexture(gpgpu.gl.TEXTURE_2D, tex); - gpgpu.gl.texSubImage2D( - gpgpu.gl.TEXTURE_2D, 0, 0, 0, 3, 2, gpgpu.gl.RGBA, gpgpu.gl.FLOAT, - new Float32Array([ - 0, 1, 0, 0, 2, 3, 0, 0, 4, 5, 0, 0, - 6, 7, 0, 0, 8, 9, 0, 0, 10, 11, 0, 0 - ])); - - const result = - gpgpu.downloadMatrixFromPackedTexture(tex, 1, 1, 12, 4, 6); - - expectArraysClose(result, mat.dataSync()); - }); - - it('should work when different batches occupy the same physical row', - () => { - const gpgpu = new GPGPUContext(); - const textureConfig = gpgpu_util.getTextureConfig(gpgpu.gl); - - // these dimensions will be halved to create the packed texture - const physicalRows = 10; - const physicalCols = 16; - const debug = false; - const tex = gpgpu_util.createPackedMatrixTexture( - gpgpu.gl, debug, physicalRows, physicalCols, textureConfig); - - const mat = tf.tensor3d( - [ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, - 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, - 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, - 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, - 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, - 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, - 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, - 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, - 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, - 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120 - ], - [2, 20, 3]); - /* - Here we pretend that gl.MAX_TEXTURE_SIZE is small enough that the - texture dimensions must be squarified. This way, values from - different batches are encoded in the same physical row of the - texture - - Physical row 1: - 1| 2 3| x 7| 8 9| x 13|14 15| x 19|20 21| x - ----- ----- ----- ----- ----- ----- ----- ----- - 4| 5 6| x 10|11 12| x 16|17 18| x 22|23 24| x - - Row 2: - 25|26 27| x 31|32 33| x 37|38 39| x 43|44 45| x - ----- ----- ----- ----- ----- ----- ----- ----- - 28|29 30| x 34|35 36| x 40|41 42| x 46|47 48| x - - Row 3: - 49|50 51| x 55|56 57| x 61|62 63| x 67|68 69| x - ----- ----- ----- ----- ----- ----- ----- ----- - 52|53 54| x 58|59 60| x 64|65 66| x 70|71 72| x - - Row 4: - 73|74 75| x 79|80 81| x 85|86 87| x 91|92 93| x - ----- ----- ----- ----- ----- ----- ----- ----- - 76|77 78| x 82|83 84| x 88|89 90| x 94|95 96| x - - Row 5: - 97|98 99| x 103|104105| x 109|110111| x 115|116117| x - ----- ----- ----- ----- ----- ----- ----- ----- - 100|101102| x 106|107108| x 112|113114| x 118|119120| x - - Note that physical row 3 is split between the two batches. - */ - - gpgpu.gl.bindTexture(gpgpu.gl.TEXTURE_2D, tex); - gpgpu.gl.texSubImage2D( - gpgpu.gl.TEXTURE_2D, 0, 0, 0, 8, 5, gpgpu.gl.RGBA, - gpgpu.gl.FLOAT, new Float32Array([ - 1, 2, 4, 5, 3, 0, 6, 0, 7, 8, 10, 11, - 9, 0, 12, 0, 13, 14, 16, 17, 15, 0, 18, 0, - 19, 20, 22, 23, 21, 0, 24, 0, 25, 26, 28, 29, - 27, 0, 30, 0, 31, 32, 34, 35, 33, 0, 36, 0, - 37, 38, 40, 41, 39, 0, 42, 0, 43, 44, 46, 47, - 45, 0, 48, 0, 49, 50, 52, 53, 51, 0, 54, 0, - 55, 56, 58, 59, 57, 0, 60, 0, 61, 62, 64, 65, - 63, 0, 66, 0, 67, 68, 70, 71, 69, 0, 72, 0, - 73, 74, 76, 77, 75, 0, 78, 0, 79, 80, 82, 83, - 81, 0, 84, 0, 85, 86, 88, 89, 87, 0, 90, 0, - 91, 92, 94, 95, 93, 0, 96, 0, 97, 98, 100, 101, - 99, 0, 102, 0, 103, 104, 106, 107, 105, 0, 108, 0, - 109, 110, 112, 113, 111, 0, 114, 0, 115, 116, 118, 119, - 117, 0, 120, 0 - ])); - - const result = gpgpu.downloadMatrixFromPackedTexture( - tex, 2, 20, 3, physicalRows, physicalCols); - expectArraysClose(result, mat.dataSync()); - }); - }); diff --git a/src/backends/webgl/pack_gpu.ts b/src/backends/webgl/pack_gpu.ts index 8c03453080..4feab34f52 100644 --- a/src/backends/webgl/pack_gpu.ts +++ b/src/backends/webgl/pack_gpu.ts @@ -22,7 +22,6 @@ import {getCoordsDataType} from './shader_compiler'; export class PackProgram implements GPGPUProgram { variableNames = ['A']; - isPackShader = true; outputShape: number[]; userCode: string; diff --git a/src/backends/webgl/tex_util.ts b/src/backends/webgl/tex_util.ts index a4eb7d21d6..6518190e0b 100644 --- a/src/backends/webgl/tex_util.ts +++ b/src/backends/webgl/tex_util.ts @@ -75,6 +75,15 @@ export function getColorMatrixTextureShapeWidthHeight( return [columns * 4, rows]; } +/** + * Get shape for densely packed RGBA texture. + */ +export function getDenseTexShape(shape: number[]): [number, number] { + const size = util.sizeFromShape(shape); + const texelsNeeded = Math.ceil(size / 4); + return util.sizeToSquarishShape(texelsNeeded); +} + export function getMatrixSizeFromUnpackedArraySize( unpackedSize: number, channelsPerTexture: number): number { if (unpackedSize % channelsPerTexture !== 0) { @@ -85,21 +94,6 @@ export function getMatrixSizeFromUnpackedArraySize( return unpackedSize / channelsPerTexture; } -export function decodeMatrixFromUnpackedArray( - unpackedArray: Float32Array, matrix: Float32Array, - channelsPerTexture: number) { - const requiredSize = getMatrixSizeFromUnpackedArraySize( - unpackedArray.length, channelsPerTexture); - if (matrix.length < requiredSize) { - throw new Error( - `matrix length (${matrix.length}) must be >= ${requiredSize}`); - } - let dst = 0; - for (let src = 0; src < unpackedArray.length; src += channelsPerTexture) { - matrix[dst++] = unpackedArray[src]; - } -} - export function decodeMatrixFromUnpackedColorRGBAArray( unpackedArray: Float32Array, matrix: Float32Array, channels: number) { const requiredSize = unpackedArray.length * channels / 4; @@ -127,81 +121,3 @@ export function getPackedRGBAArraySizeFromMatrixShape( const [w, h] = getPackedMatrixTextureShapeWidthHeight(rows, columns); return w * h * 4; } - -export function decodeMatrixFromPackedRGBA( - packedRGBA: Float32Array, batches: number, rows: number, columns: number, - matrix: Float32Array): Float32Array { - const requiredSize = rows * columns; - if (matrix.length < requiredSize) { - throw new Error( - `matrix length (${matrix.length}) must be >= ${requiredSize}`); - } - - const oddWidth = (columns % 2) === 1; - const oddHeight = (rows % 2) === 1; - const widthInFullBlocks = Math.floor(columns / 2); - const heightInFullBlocks = Math.floor(rows / 2); - - const texelsPerRow = Math.ceil(columns / 2); - const texelsPerBatch = texelsPerRow * Math.ceil(rows / 2); - - const flattenedMatrixSize = - util.nearestLargerEven(rows) * util.nearestLargerEven(columns); - - for (let batch = 0; batch < batches; batch++) { - const batchOffset = batch * rows * columns; - const sourceOffset = batch * flattenedMatrixSize; - - // loop over full 2x2 blocks - { - const srcStride = oddWidth ? 4 : 0; - const dstStride = columns + (oddWidth ? 1 : 0); - let src = sourceOffset; - let dstRow1 = batchOffset; - let dstRow2 = batchOffset + columns; - for (let blockY = 0; blockY < heightInFullBlocks; ++blockY) { - for (let blockX = 0; blockX < widthInFullBlocks; ++blockX) { - matrix[dstRow1++] = packedRGBA[src++]; - matrix[dstRow1++] = packedRGBA[src++]; - matrix[dstRow2++] = packedRGBA[src++]; - matrix[dstRow2++] = packedRGBA[src++]; - } - src += srcStride; - dstRow1 += dstStride; - dstRow2 += dstStride; - } - } - - // loop down final column - if (oddWidth) { - let src = sourceOffset + (texelsPerRow - 1) * 4; - let dst = batchOffset + columns - 1; - const srcStride = texelsPerRow * 4; - const dstStride = 2 * columns; - for (let blockY = 0; blockY < heightInFullBlocks; ++blockY) { - matrix[dst] = packedRGBA[src]; - matrix[dst + columns] = packedRGBA[src + 2]; - src += srcStride; - dst += dstStride; - } - } - - // loop across final row - if (oddHeight) { - let src = sourceOffset + (texelsPerBatch - texelsPerRow) * 4; - let dst = batchOffset + (rows - 1) * columns; - for (let blockX = 0; blockX < widthInFullBlocks; ++blockX) { - matrix[dst++] = packedRGBA[src++]; - matrix[dst++] = packedRGBA[src++]; - src += 2; - } - - // fill in bottom-right cell - if (oddWidth) { - matrix[batchOffset + (rows * columns) - 1] = packedRGBA[src]; - } - } - } - - return matrix; -} diff --git a/src/backends/webgl/tex_util_test.ts b/src/backends/webgl/tex_util_test.ts index b6fdf07bde..ceabcb8b91 100644 --- a/src/backends/webgl/tex_util_test.ts +++ b/src/backends/webgl/tex_util_test.ts @@ -87,152 +87,10 @@ describe('tex_util getPackedMatrixTextureShapeWidthHeight', () => { }); }); -describeWithFlags('tex_util decodeMatrixFromUnpackedArray', WEBGL_ENVS, () => { - it('1x1 writes the only matrix array value to the first element', () => { - const unpackedRGBA = new Float32Array([1, 0, 0, 0]); - const matrix = new Float32Array(1); - tex_util.decodeMatrixFromUnpackedArray(unpackedRGBA, matrix, 4); - expect(matrix.length).toEqual(1); - expect(matrix[0]).toEqual(1); +describeWithFlags('tex_util getDenseTexShape', WEBGL_ENVS, () => { + it('basic', () => { + const shape = [1, 3, 3, 4]; + const denseShape = tex_util.getDenseTexShape(shape); + expectArraysClose(denseShape, [3, 3]); }); - - it('1x2 writes the second texel R component to the second element', () => { - const unpackedRGBA = new Float32Array([1, 0, 0, 0, 2, 0, 0, 0]); - const matrix = new Float32Array(2); - tex_util.decodeMatrixFromUnpackedArray(unpackedRGBA, matrix, 4); - expect(matrix.length).toEqual(2); - expectArraysClose(matrix, new Float32Array([1, 2])); - }); -}); - -describeWithFlags('tex_util decodeMatrixFromPackedRGBA', WEBGL_ENVS, () => { - it('1x1 matrix only loads R component from only texel', () => { - const packedRGBA = new Float32Array([1, 0, 0, 0]); - const matrix = new Float32Array(1); - tex_util.decodeMatrixFromPackedRGBA(packedRGBA, 1, 1, 1, matrix); - expect(matrix[0]).toEqual(1); - }); - - it('1x2 matrix loads RG from only texel', () => { - const packedRGBA = new Float32Array([1, 2, 0, 0]); - const matrix = new Float32Array(2); - tex_util.decodeMatrixFromPackedRGBA(packedRGBA, 1, 1, 2, matrix); - expectArraysClose(matrix, new Float32Array([1, 2])); - }); - - it('2x1 matrix loads RB from only texel', () => { - const packedRGBA = new Float32Array([1, 0, 2, 0]); - const matrix = new Float32Array(2); - tex_util.decodeMatrixFromPackedRGBA(packedRGBA, 1, 2, 1, matrix); - expectArraysClose(matrix, new Float32Array([1, 2])); - }); - - it('2x2 matrix loads RGBA from only texel', () => { - const packedRGBA = new Float32Array([1, 2, 3, 4]); - const matrix = new Float32Array(4); - tex_util.decodeMatrixFromPackedRGBA(packedRGBA, 1, 2, 2, matrix); - expectArraysClose(matrix, new Float32Array([1, 2, 3, 4])); - }); - - it('4x3 final column only reads RB from edge texels', () => { - /* - 1 2 4 5 | 3 0 6 0 1 2 3 - -----------+----------- => 4 5 6 - 7 8 10 11 | 9 0 12 0 7 8 9 - 10 11 12 - */ - const packedRGBA = - new Float32Array([1, 2, 4, 5, 3, 0, 6, 0, 7, 8, 10, 11, 9, 0, 12, 0]); - const matrix = new Float32Array(12); - tex_util.decodeMatrixFromPackedRGBA(packedRGBA, 1, 4, 3, matrix); - expectArraysClose( - matrix, new Float32Array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])); - }); - - it('3x4 final row only reads RG from edge texels', () => { - /* - 1 2 5 6 | 3 4 7 8 1 2 3 4 - ---------+---------- => 5 6 7 8 - 9 10 0 0 | 11 12 0 0 9 10 11 12 - */ - const packedRGBA = - new Float32Array([1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 0, 0, 11, 12, 0, 0]); - const matrix = new Float32Array(12); - tex_util.decodeMatrixFromPackedRGBA(packedRGBA, 1, 3, 4, matrix); - expectArraysClose( - matrix, new Float32Array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])); - }); - - it('3x3 bottom-right only reads R from corner texel', () => { - /* - 1 2 4 5 | 3 0 6 0 1 2 3 - --------+-------- => 4 5 6 - 7 8 0 0 | 9 0 0 0 7 8 9 - */ - const packedRGBA = - new Float32Array([1, 2, 4, 5, 3, 0, 6, 0, 7, 8, 0, 0, 9, 0, 0, 0]); - const matrix = new Float32Array(9); - tex_util.decodeMatrixFromPackedRGBA(packedRGBA, 1, 3, 3, matrix); - expectArraysClose(matrix, new Float32Array([1, 2, 3, 4, 5, 6, 7, 8, 9])); - }); - - it('2x3x4 bottom row in each batch only reads RG', () => { - const packedRGBA = new Float32Array([ - 1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 0, 0, 11, 12, 0, 0, - 13, 14, 17, 18, 15, 16, 19, 20, 21, 22, 0, 0, 23, 24, 0, 0 - ]); - const matrix = new Float32Array(24); - tex_util.decodeMatrixFromPackedRGBA(packedRGBA, 2, 3, 4, matrix); - expectArraysClose(matrix, new Float32Array([ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, - 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 - ])); - }); - - it('2x4x3 final column in each batch only reads RB', () => { - const packedRGBA = new Float32Array([ - 1, 2, 4, 5, 3, 0, 6, 0, 7, 8, 10, 11, 9, 0, 12, 0, - 13, 14, 16, 17, 15, 0, 18, 0, 19, 20, 22, 23, 21, 0, 24, 0 - ]); - const matrix = new Float32Array(24); - tex_util.decodeMatrixFromPackedRGBA(packedRGBA, 2, 4, 3, matrix); - expectArraysClose(matrix, new Float32Array([ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, - 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 - ])); - }); - - it('2x3x3 bottom right texel in each batch only reads R', () => { - const packedRGBA = new Float32Array([ - 1, 2, 4, 5, 3, 0, 6, 0, 7, 8, 0, 0, 9, 0, 0, 0, - 10, 11, 13, 14, 12, 0, 15, 0, 16, 17, 0, 0, 18, 0, 0, 0 - ]); - const matrix = new Float32Array(18); - tex_util.decodeMatrixFromPackedRGBA(packedRGBA, 2, 3, 3, matrix); - expectArraysClose( - matrix, - new Float32Array( - [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18])); - }); - - it('4D (2x3x3x4) is properly decoded', () => { - const packedRGBA = new Float32Array([ - 1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 0, 0, 11, 12, 0, 0, - 13, 14, 17, 18, 15, 16, 19, 20, 21, 22, 0, 0, 23, 24, 0, 0, - 25, 26, 29, 30, 27, 28, 31, 32, 33, 34, 0, 0, 35, 36, 0, 0, - 37, 38, 41, 42, 39, 40, 43, 44, 45, 46, 0, 0, 47, 48, 0, 0, - 49, 50, 53, 54, 51, 52, 55, 56, 57, 58, 0, 0, 59, 60, 0, 0, - 61, 62, 65, 66, 63, 64, 67, 68, 69, 70, 0, 0, 71, 72, 0, 0 - ]); - const matrix = new Float32Array(72); - tex_util.decodeMatrixFromPackedRGBA(packedRGBA, 6, 3, 4, matrix); - expectArraysClose( - matrix, new Float32Array([ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, - 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, - 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, - 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72 - ])); - }); -}); +}); \ No newline at end of file diff --git a/src/backends/webgl/webgl_util.ts b/src/backends/webgl/webgl_util.ts index 3b90fcbab0..f35ce52082 100644 --- a/src/backends/webgl/webgl_util.ts +++ b/src/backends/webgl/webgl_util.ts @@ -370,6 +370,16 @@ export function getRowsCols(shape: number[]): [number, number] { ]; } +export function getShapeAs3D(shape: number[]): [number, number, number] { + let shapeAs3D: [number, number, number] = [1, 1, 1]; + const isScalar = shape.length === 0 || (shape.length === 1 && shape[0] === 1); + if (!isScalar) { + shapeAs3D = + [getBatchDim(shape), ...getRowsCols(shape)] as [number, number, number]; + } + return shapeAs3D; +} + export function getTextureShapeFromLogicalShape( logShape: number[], isPacked = false): [number, number] { let maxTexSize = ENV.getNumber('WEBGL_MAX_TEXTURE_SIZE');