diff --git a/tfjs-backend-cpu/src/kernels/Dilation2D.ts b/tfjs-backend-cpu/src/kernels/Dilation2D.ts index f5d37064c20..14ab33423c5 100644 --- a/tfjs-backend-cpu/src/kernels/Dilation2D.ts +++ b/tfjs-backend-cpu/src/kernels/Dilation2D.ts @@ -92,9 +92,8 @@ export const dilation2dConfig: KernelConfig = { } } - const dataId = cpuBackend.write( - util.toTypedArray(output, x.dtype, false /* debug mode */), outShape, - x.dtype); + const dataId = + cpuBackend.write(util.toTypedArray(output, x.dtype), outShape, x.dtype); return {dataId, shape: outShape, dtype: x.dtype}; } diff --git a/tfjs-backend-cpu/src/kernels/Dilation2DBackpropFilter.ts b/tfjs-backend-cpu/src/kernels/Dilation2DBackpropFilter.ts index 2d39c057b77..8bba97488fc 100644 --- a/tfjs-backend-cpu/src/kernels/Dilation2DBackpropFilter.ts +++ b/tfjs-backend-cpu/src/kernels/Dilation2DBackpropFilter.ts @@ -113,8 +113,7 @@ export const dilation2dBackpropFilterConfig: KernelConfig = { } const dataId = cpuBackend.write( - util.toTypedArray(gradients, x.dtype, false /* debug mode */), - filter.shape, filter.dtype); + util.toTypedArray(gradients, x.dtype), filter.shape, filter.dtype); return {dataId, shape: filter.shape, dtype: filter.dtype}; } diff --git a/tfjs-backend-cpu/src/kernels/Dilation2DBackpropInput.ts b/tfjs-backend-cpu/src/kernels/Dilation2DBackpropInput.ts index 3e1405255b9..beb9307035f 100644 --- a/tfjs-backend-cpu/src/kernels/Dilation2DBackpropInput.ts +++ b/tfjs-backend-cpu/src/kernels/Dilation2DBackpropInput.ts @@ -113,8 +113,7 @@ export const dilation2dBackpropInputConfig: KernelConfig = { } const dataId = cpuBackend.write( - util.toTypedArray(gradients, x.dtype, false /* debug mode */), x.shape, - x.dtype); + util.toTypedArray(gradients, x.dtype), x.shape, x.dtype); return {dataId, shape: x.shape, dtype: x.dtype}; } diff --git a/tfjs-backend-webgl/src/gpgpu_context.ts b/tfjs-backend-webgl/src/gpgpu_context.ts index 4fc8aaf9477..7a167181b2f 100644 --- a/tfjs-backend-webgl/src/gpgpu_context.ts +++ b/tfjs-backend-webgl/src/gpgpu_context.ts @@ -62,10 +62,10 @@ export class GPGPUContext { const TEXTURE_HALF_FLOAT = 'OES_texture_half_float'; this.textureFloatExtension = - webgl_util.getExtensionOrThrow(this.gl, this.debug, TEXTURE_FLOAT); + webgl_util.getExtensionOrThrow(this.gl, TEXTURE_FLOAT); if (webgl_util.hasExtension(this.gl, TEXTURE_HALF_FLOAT)) { - this.textureHalfFloatExtension = webgl_util.getExtensionOrThrow( - this.gl, this.debug, TEXTURE_HALF_FLOAT); + this.textureHalfFloatExtension = + webgl_util.getExtensionOrThrow(this.gl, TEXTURE_HALF_FLOAT); } else if (env().get('WEBGL_FORCE_F16_TEXTURES')) { throw new Error( 'GL context does not support half float textures, yet the ' + @@ -74,8 +74,8 @@ export class GPGPUContext { this.colorBufferFloatExtension = this.gl.getExtension(COLOR_BUFFER_FLOAT); if (webgl_util.hasExtension(this.gl, COLOR_BUFFER_HALF_FLOAT)) { - this.colorBufferHalfFloatExtension = webgl_util.getExtensionOrThrow( - this.gl, this.debug, COLOR_BUFFER_HALF_FLOAT); + this.colorBufferHalfFloatExtension = + webgl_util.getExtensionOrThrow(this.gl, COLOR_BUFFER_HALF_FLOAT); } else if (env().get('WEBGL_FORCE_F16_TEXTURES')) { throw new Error( 'GL context does not support color renderable half floats, yet ' + @@ -94,9 +94,9 @@ export class GPGPUContext { } } - this.vertexBuffer = gpgpu_util.createVertexBuffer(this.gl, this.debug); - this.indexBuffer = gpgpu_util.createIndexBuffer(this.gl, this.debug); - this.framebuffer = webgl_util.createFramebuffer(this.gl, this.debug); + this.vertexBuffer = gpgpu_util.createVertexBuffer(this.gl); + this.indexBuffer = gpgpu_util.createIndexBuffer(this.gl); + this.framebuffer = webgl_util.createFramebuffer(this.gl); this.textureConfig = tex_util.getTextureConfig(this.gl, this.textureHalfFloatExtension); @@ -124,17 +124,13 @@ export class GPGPUContext { 'disposing.'); } const gl = this.gl; - webgl_util.callAndCheck(gl, this.debug, () => gl.finish()); - webgl_util.callAndCheck( - gl, this.debug, () => gl.bindFramebuffer(gl.FRAMEBUFFER, null)); - webgl_util.callAndCheck( - gl, this.debug, () => gl.deleteFramebuffer(this.framebuffer)); - webgl_util.callAndCheck( - gl, this.debug, () => gl.bindBuffer(gl.ARRAY_BUFFER, null)); + webgl_util.callAndCheck(gl, () => gl.finish()); + webgl_util.callAndCheck(gl, () => gl.bindFramebuffer(gl.FRAMEBUFFER, null)); + webgl_util.callAndCheck(gl, () => gl.deleteFramebuffer(this.framebuffer)); + webgl_util.callAndCheck(gl, () => gl.bindBuffer(gl.ARRAY_BUFFER, null)); webgl_util.callAndCheck( - gl, this.debug, () => gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, null)); - webgl_util.callAndCheck( - gl, this.debug, () => gl.deleteBuffer(this.indexBuffer)); + gl, () => gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, null)); + webgl_util.callAndCheck(gl, () => gl.deleteBuffer(this.indexBuffer)); this.disposed = true; } @@ -142,60 +138,58 @@ export class GPGPUContext { WebGLTexture { this.throwIfDisposed(); return gpgpu_util.createFloat32MatrixTexture( - this.gl, this.debug, rows, columns, this.textureConfig); + this.gl, rows, columns, this.textureConfig); } public createFloat16MatrixTexture(rows: number, columns: number): WebGLTexture { this.throwIfDisposed(); return gpgpu_util.createFloat16MatrixTexture( - this.gl, this.debug, rows, columns, this.textureConfig); + this.gl, rows, columns, this.textureConfig); } public createUnsignedBytesMatrixTexture(rows: number, columns: number): WebGLTexture { this.throwIfDisposed(); return gpgpu_util.createUnsignedBytesMatrixTexture( - this.gl, this.debug, rows, columns, this.textureConfig); + this.gl, rows, columns, this.textureConfig); } public uploadPixelDataToTexture( texture: WebGLTexture, pixels: PixelData|ImageData|HTMLImageElement|HTMLCanvasElement) { this.throwIfDisposed(); - gpgpu_util.uploadPixelDataToTexture(this.gl, this.debug, texture, pixels); + gpgpu_util.uploadPixelDataToTexture(this.gl, texture, pixels); } public uploadDenseMatrixToTexture( texture: WebGLTexture, width: number, height: number, data: TypedArray) { this.throwIfDisposed(); gpgpu_util.uploadDenseMatrixToTexture( - this.gl, this.debug, texture, width, height, data, this.textureConfig); + this.gl, texture, width, height, data, this.textureConfig); } public createFloat16PackedMatrixTexture(rows: number, columns: number): WebGLTexture { this.throwIfDisposed(); return gpgpu_util.createFloat16PackedMatrixTexture( - this.gl, this.debug, rows, columns, this.textureConfig); + this.gl, rows, columns, this.textureConfig); } public createPackedMatrixTexture(rows: number, columns: number): WebGLTexture { this.throwIfDisposed(); return gpgpu_util.createPackedMatrixTexture( - this.gl, this.debug, rows, columns, this.textureConfig); + this.gl, rows, columns, this.textureConfig); } public deleteMatrixTexture(texture: WebGLTexture) { this.throwIfDisposed(); if (this.outputTexture === texture) { - webgl_util.unbindColorTextureFromFramebuffer( - this.gl, this.debug, this.framebuffer); + webgl_util.unbindColorTextureFromFramebuffer(this.gl, this.framebuffer); this.outputTexture = null; } - webgl_util.callAndCheck( - this.gl, this.debug, () => this.gl.deleteTexture(texture)); + webgl_util.callAndCheck(this.gl, () => this.gl.deleteTexture(texture)); } public downloadByteEncodedFloatMatrixFromOutputTexture( @@ -203,7 +197,7 @@ export class GPGPUContext { return this.downloadMatrixDriver( texture, () => gpgpu_util.downloadByteEncodedFloatMatrixFromOutputTexture( - this.gl, this.debug, rows, columns, this.textureConfig)); + this.gl, rows, columns, this.textureConfig)); } public downloadPackedMatrixFromBuffer( @@ -223,8 +217,7 @@ export class GPGPUContext { texture: WebGLTexture, rows: number, columns: number): WebGLBuffer { this.bindTextureToFrameBuffer(texture); const result = gpgpu_util.createBufferFromOutputTexture( - this.gl as WebGL2RenderingContext, this.debug, rows, columns, - this.textureConfig); + this.gl as WebGL2RenderingContext, rows, columns, this.textureConfig); this.unbindTextureToFrameBuffer(); return result; } @@ -275,7 +268,7 @@ export class GPGPUContext { return this.downloadMatrixDriver( texture, () => gpgpu_util.downloadMatrixFromPackedOutputTexture( - this.gl, this.debug, physicalRows, physicalCols)); + this.gl, physicalRows, physicalCols)); } private vertexAttrsAreBound = false; @@ -284,25 +277,19 @@ export class GPGPUContext { this.throwIfDisposed(); const gl = this.gl; const fragmentShader: WebGLShader = - webgl_util.createFragmentShader(gl, this.debug, fragmentShaderSource); - const vertexShader: WebGLShader = - gpgpu_util.createVertexShader(gl, this.debug); - const program: WebGLProgram = webgl_util.createProgram( - gl, - this.debug, - ); - webgl_util.callAndCheck( - gl, this.debug, () => gl.attachShader(program, vertexShader)); - webgl_util.callAndCheck( - gl, this.debug, () => gl.attachShader(program, fragmentShader)); - webgl_util.linkProgram(gl, this.debug, program); + webgl_util.createFragmentShader(gl, fragmentShaderSource); + const vertexShader: WebGLShader = gpgpu_util.createVertexShader(gl); + const program: WebGLProgram = webgl_util.createProgram(gl); + webgl_util.callAndCheck(gl, () => gl.attachShader(program, vertexShader)); + webgl_util.callAndCheck(gl, () => gl.attachShader(program, fragmentShader)); + webgl_util.linkProgram(gl, program); if (this.debug) { - webgl_util.validateProgram(gl, this.debug, program); + webgl_util.validateProgram(gl, program); } if (!this.vertexAttrsAreBound) { this.setProgram(program); this.vertexAttrsAreBound = gpgpu_util.bindVertexProgramAttributeStreams( - gl, this.debug, this.program, this.vertexBuffer); + gl, this.program, this.vertexBuffer); } return program; } @@ -313,8 +300,7 @@ export class GPGPUContext { this.program = null; } if (program != null) { - webgl_util.callAndCheck( - this.gl, this.debug, () => this.gl.deleteProgram(program)); + webgl_util.callAndCheck(this.gl, () => this.gl.deleteProgram(program)); } } @@ -322,10 +308,9 @@ export class GPGPUContext { this.throwIfDisposed(); this.program = program; if ((this.program != null) && this.debug) { - webgl_util.validateProgram(this.gl, this.debug, this.program); + webgl_util.validateProgram(this.gl, this.program); } - webgl_util.callAndCheck( - this.gl, this.debug, () => this.gl.useProgram(program)); + webgl_util.callAndCheck(this.gl, () => this.gl.useProgram(program)); } public getUniformLocation( @@ -334,7 +319,7 @@ export class GPGPUContext { this.throwIfDisposed(); if (shouldThrow) { return webgl_util.getProgramUniformLocationOrThrow( - this.gl, this.debug, program, uniformName); + this.gl, program, uniformName); } else { return webgl_util.getProgramUniformLocation( this.gl, program, uniformName); @@ -345,8 +330,7 @@ export class GPGPUContext { number { this.throwIfDisposed(); return webgl_util.callAndCheck( - this.gl, this.debug, - () => this.gl.getAttribLocation(program, attribute)); + this.gl, () => this.gl.getAttribLocation(program, attribute)); } public getUniformLocationNoThrow(program: WebGLProgram, uniformName: string): @@ -361,8 +345,7 @@ export class GPGPUContext { this.throwIfDisposed(); this.throwIfNoProgram(); webgl_util.bindTextureToProgramUniformSampler( - this.gl, this.debug, this.program, inputMatrixTexture, uniformLocation, - textureUnit); + this.gl, inputMatrixTexture, uniformLocation, textureUnit); } public setOutputMatrixTexture( @@ -393,7 +376,7 @@ export class GPGPUContext { public debugValidate() { if (this.program != null) { - webgl_util.validateProgram(this.gl, this.debug, this.program); + webgl_util.validateProgram(this.gl, this.program); } webgl_util.validateFramebuffer(this.gl); } @@ -406,13 +389,12 @@ export class GPGPUContext { this.debugValidate(); } webgl_util.callAndCheck( - gl, this.debug, - () => gl.drawElements(gl.TRIANGLES, 6, gl.UNSIGNED_SHORT, 0)); + gl, () => gl.drawElements(gl.TRIANGLES, 6, gl.UNSIGNED_SHORT, 0)); } public blockUntilAllProgramsCompleted() { this.throwIfDisposed(); - webgl_util.callAndCheck(this.gl, this.debug, () => this.gl.finish()); + webgl_util.callAndCheck(this.gl, () => this.gl.finish()); } private getQueryTimerExtension(): WebGL1DisjointQueryTimerExtension @@ -420,7 +402,7 @@ export class GPGPUContext { if (this.disjointQueryTimerExtension == null) { this.disjointQueryTimerExtension = webgl_util.getExtensionOrThrow( - this.gl, this.debug, + this.gl, env().getNumber( 'WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2 ? 'EXT_disjoint_timer_query_webgl2' : @@ -564,7 +546,7 @@ export class GPGPUContext { private bindTextureToFrameBuffer(texture: WebGLTexture) { this.throwIfDisposed(); webgl_util.bindColorTextureToFramebuffer( - this.gl, this.debug, texture, this.framebuffer); + this.gl, texture, this.framebuffer); if (this.debug) { webgl_util.validateFramebuffer(this.gl); } @@ -573,13 +555,12 @@ export class GPGPUContext { private unbindTextureToFrameBuffer() { if (this.outputTexture != null) { webgl_util.bindColorTextureToFramebuffer( - this.gl, this.debug, this.outputTexture, this.framebuffer); + this.gl, this.outputTexture, this.framebuffer); if (this.debug) { webgl_util.validateFramebuffer(this.gl); } } else { - webgl_util.unbindColorTextureFromFramebuffer( - this.gl, this.debug, this.framebuffer); + webgl_util.unbindColorTextureFromFramebuffer(this.gl, this.framebuffer); } } @@ -599,22 +580,20 @@ export class GPGPUContext { this.throwIfDisposed(); const gl = this.gl; webgl_util.bindColorTextureToFramebuffer( - gl, this.debug, outputMatrixTextureMaybePacked, this.framebuffer); + gl, outputMatrixTextureMaybePacked, this.framebuffer); if (this.debug) { webgl_util.validateFramebuffer(gl); } this.outputTexture = outputMatrixTextureMaybePacked; - webgl_util.callAndCheck( - gl, this.debug, () => gl.viewport(0, 0, width, height)); - webgl_util.callAndCheck( - gl, this.debug, () => gl.scissor(0, 0, width, height)); + webgl_util.callAndCheck(gl, () => gl.viewport(0, 0, width, height)); + webgl_util.callAndCheck(gl, () => gl.scissor(0, 0, width, height)); } private setOutputMatrixWriteRegionDriver( x: number, y: number, width: number, height: number) { this.throwIfDisposed(); webgl_util.callAndCheck( - this.gl, this.debug, () => this.gl.scissor(x, y, width, height)); + this.gl, () => this.gl.scissor(x, y, width, height)); } private throwIfDisposed() { diff --git a/tfjs-backend-webgl/src/gpgpu_util.ts b/tfjs-backend-webgl/src/gpgpu_util.ts index 17bec6b2042..536fca90fd1 100644 --- a/tfjs-backend-webgl/src/gpgpu_util.ts +++ b/tfjs-backend-webgl/src/gpgpu_util.ts @@ -22,8 +22,7 @@ import * as tex_util from './tex_util'; import {TextureConfig} from './tex_util'; import * as webgl_util from './webgl_util'; -export function createVertexShader( - gl: WebGLRenderingContext, debug: boolean): WebGLShader { +export function createVertexShader(gl: WebGLRenderingContext): WebGLShader { const glsl = getGlslDifferences(); const vertexShaderSource = `${glsl.version} precision highp float; @@ -35,124 +34,116 @@ export function createVertexShader( gl_Position = vec4(clipSpacePos, 1); resultUV = uv; }`; - return webgl_util.createVertexShader(gl, debug, vertexShaderSource); + return webgl_util.createVertexShader(gl, vertexShaderSource); } -export function createVertexBuffer( - gl: WebGLRenderingContext, debug: boolean): WebGLBuffer { +export function createVertexBuffer(gl: WebGLRenderingContext): WebGLBuffer { // [x y z u v] * [upper-left, lower-left, upper-right, lower-right] const vertexArray = new Float32Array( [-1, 1, 0, 0, 1, -1, -1, 0, 0, 0, 1, 1, 0, 1, 1, 1, -1, 0, 1, 0]); - return webgl_util.createStaticVertexBuffer(gl, debug, vertexArray); + return webgl_util.createStaticVertexBuffer(gl, vertexArray); } -export function createIndexBuffer( - gl: WebGLRenderingContext, debug: boolean): WebGLBuffer { +export function createIndexBuffer(gl: WebGLRenderingContext): WebGLBuffer { // OpenGL (and WebGL) have "CCW == front" winding const triangleVertexIndices = new Uint16Array([0, 1, 2, 2, 1, 3]); - return webgl_util.createStaticIndexBuffer(gl, debug, triangleVertexIndices); + return webgl_util.createStaticIndexBuffer(gl, triangleVertexIndices); } function createAndConfigureTexture( - gl: WebGLRenderingContext, debug: boolean, width: number, height: number, + gl: WebGLRenderingContext, width: number, height: number, internalFormat: number, textureFormat: number, textureType: number): WebGLTexture { webgl_util.validateTextureSize(width, height); - const texture = webgl_util.createTexture(gl, debug); + const texture = webgl_util.createTexture(gl); const tex2d = gl.TEXTURE_2D; - webgl_util.callAndCheck(gl, debug, () => gl.bindTexture(tex2d, texture)); + webgl_util.callAndCheck(gl, () => gl.bindTexture(tex2d, texture)); webgl_util.callAndCheck( - gl, debug, - () => gl.texParameteri(tex2d, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE)); + gl, () => gl.texParameteri(tex2d, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE)); webgl_util.callAndCheck( - gl, debug, - () => gl.texParameteri(tex2d, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE)); + gl, () => gl.texParameteri(tex2d, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE)); webgl_util.callAndCheck( - gl, debug, - () => gl.texParameteri(tex2d, gl.TEXTURE_MIN_FILTER, gl.NEAREST)); + gl, () => gl.texParameteri(tex2d, gl.TEXTURE_MIN_FILTER, gl.NEAREST)); webgl_util.callAndCheck( - gl, debug, - () => gl.texParameteri(tex2d, gl.TEXTURE_MAG_FILTER, gl.NEAREST)); + gl, () => gl.texParameteri(tex2d, gl.TEXTURE_MAG_FILTER, gl.NEAREST)); webgl_util.callAndCheck( - gl, debug, + gl, () => gl.texImage2D( tex2d, 0, internalFormat, width, height, 0, textureFormat, textureType, null)); - webgl_util.callAndCheck(gl, debug, () => gl.bindTexture(gl.TEXTURE_2D, null)); + webgl_util.callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, null)); return texture; } export function createFloat32MatrixTexture( - gl: WebGLRenderingContext, debug: boolean, rows: number, columns: number, + gl: WebGLRenderingContext, rows: number, columns: number, textureConfig: TextureConfig): WebGLTexture { const [width, height] = tex_util.getUnpackedMatrixTextureShapeWidthHeight(rows, columns); return createAndConfigureTexture( - gl, debug, width, height, textureConfig.internalFormatFloat, + gl, width, height, textureConfig.internalFormatFloat, textureConfig.textureFormatFloat, gl.FLOAT); } export function createFloat16MatrixTexture( - gl: WebGLRenderingContext, debug: boolean, rows: number, columns: number, + gl: WebGLRenderingContext, rows: number, columns: number, textureConfig: TextureConfig): WebGLTexture { const [width, height] = tex_util.getUnpackedMatrixTextureShapeWidthHeight(rows, columns); return createAndConfigureTexture( - gl, debug, width, height, textureConfig.internalFormatHalfFloat, + gl, width, height, textureConfig.internalFormatHalfFloat, textureConfig.textureFormatFloat, textureConfig.textureTypeHalfFloat); } export function createUnsignedBytesMatrixTexture( - gl: WebGLRenderingContext, debug: boolean, rows: number, columns: number, + gl: WebGLRenderingContext, rows: number, columns: number, textureConfig: TextureConfig): WebGLTexture { const [width, height] = tex_util.getUnpackedMatrixTextureShapeWidthHeight(rows, columns); return createAndConfigureTexture( - gl, debug, width, height, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE); + gl, width, height, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE); } export function createPackedMatrixTexture( - gl: WebGLRenderingContext, debug: boolean, rows: number, columns: number, + gl: WebGLRenderingContext, rows: number, columns: number, textureConfig: TextureConfig): WebGLTexture { const [width, height] = tex_util.getPackedMatrixTextureShapeWidthHeight(rows, columns); return createAndConfigureTexture( - gl, debug, width, height, textureConfig.internalFormatPackedFloat, - gl.RGBA, gl.FLOAT); + gl, width, height, textureConfig.internalFormatPackedFloat, gl.RGBA, + gl.FLOAT); } export function createFloat16PackedMatrixTexture( - gl: WebGLRenderingContext, debug: boolean, rows: number, columns: number, + gl: WebGLRenderingContext, rows: number, columns: number, textureConfig: TextureConfig): WebGLTexture { const [width, height] = tex_util.getPackedMatrixTextureShapeWidthHeight(rows, columns); return createAndConfigureTexture( - gl, debug, width, height, textureConfig.internalFormatPackedHalfFloat, - gl.RGBA, textureConfig.textureTypeHalfFloat); + gl, width, height, textureConfig.internalFormatPackedHalfFloat, gl.RGBA, + textureConfig.textureTypeHalfFloat); } export function bindVertexProgramAttributeStreams( - gl: WebGLRenderingContext, debug: boolean, program: WebGLProgram, + gl: WebGLRenderingContext, program: WebGLProgram, vertexBuffer: WebGLBuffer): boolean { const posOffset = 0; // x is the first buffer element const uvOffset = 3 * 4; // uv comes after [x y z] const stride = (3 * 4) + (2 * 4); // xyz + uv, each entry is 4-byte float. webgl_util.callAndCheck( - gl, debug, () => gl.bindBuffer(gl.ARRAY_BUFFER, vertexBuffer)); + gl, () => gl.bindBuffer(gl.ARRAY_BUFFER, vertexBuffer)); const success = webgl_util.bindVertexBufferToProgramAttribute( - gl, debug, program, 'clipSpacePos', vertexBuffer, 3, stride, posOffset); + gl, program, 'clipSpacePos', vertexBuffer, 3, stride, posOffset); return success && webgl_util.bindVertexBufferToProgramAttribute( - gl, debug, program, 'uv', vertexBuffer, 2, stride, uvOffset); + gl, program, 'uv', vertexBuffer, 2, stride, uvOffset); } export function uploadDenseMatrixToTexture( - gl: WebGLRenderingContext, debug: boolean, texture: WebGLTexture, - width: number, height: number, data: TypedArray, - textureConfig: TextureConfig) { - webgl_util.callAndCheck( - gl, debug, () => gl.bindTexture(gl.TEXTURE_2D, texture)); + gl: WebGLRenderingContext, texture: WebGLTexture, width: number, + height: number, data: TypedArray, textureConfig: TextureConfig) { + webgl_util.callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, texture)); let dataForUpload: TypedArray, texelDataType: number, internalFormat: number; if (data instanceof Uint8Array) { @@ -168,45 +159,44 @@ export function uploadDenseMatrixToTexture( dataForUpload.set(data); webgl_util.callAndCheck( - gl, debug, + gl, () => gl.texImage2D( gl.TEXTURE_2D, 0, internalFormat, width, height, 0, gl.RGBA, texelDataType, dataForUpload)); - webgl_util.callAndCheck(gl, debug, () => gl.bindTexture(gl.TEXTURE_2D, null)); + webgl_util.callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, null)); } export function uploadPixelDataToTexture( - gl: WebGLRenderingContext, debug: boolean, texture: WebGLTexture, + gl: WebGLRenderingContext, texture: WebGLTexture, pixels: PixelData|ImageData|HTMLImageElement|HTMLCanvasElement| HTMLVideoElement) { - webgl_util.callAndCheck( - gl, debug, () => gl.bindTexture(gl.TEXTURE_2D, texture)); + webgl_util.callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, texture)); if ((pixels as PixelData).data instanceof Uint8Array) { webgl_util.callAndCheck( - gl, debug, + gl, () => gl.texImage2D( gl.TEXTURE_2D, 0, gl.RGBA, pixels.width, pixels.height, 0, gl.RGBA, gl.UNSIGNED_BYTE, (pixels as PixelData).data)); } else { webgl_util.callAndCheck( - gl, debug, + gl, () => gl.texImage2D( gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, pixels as ImageData | HTMLImageElement | HTMLCanvasElement | HTMLVideoElement)); } - webgl_util.callAndCheck(gl, debug, () => gl.bindTexture(gl.TEXTURE_2D, null)); + webgl_util.callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, null)); } export function createBufferFromOutputTexture( - gl2: WebGL2RenderingContext, debug: boolean, rows: number, columns: number, + gl2: WebGL2RenderingContext, rows: number, columns: number, textureConfig: TextureConfig): WebGLBuffer { // Create and bind the buffer. const buffer = gl2.createBuffer(); webgl_util.callAndCheck( - gl2, debug, () => gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer)); + gl2, () => gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer)); // Initialize the buffer to the size of the texture in bytes. const bytesPerFloat = 4; @@ -214,18 +204,17 @@ export function createBufferFromOutputTexture( const bufferSizeBytes = bytesPerFloat * valuesPerTexel * rows * columns; webgl_util.callAndCheck( - gl2, debug, + gl2, () => gl2.bufferData( gl2.PIXEL_PACK_BUFFER, bufferSizeBytes, gl2.STREAM_READ)); // Enqueue a command on the GPU command queue to copy of texture into the // buffer. webgl_util.callAndCheck( - gl2, debug, - () => gl2.readPixels(0, 0, columns, rows, gl2.RGBA, gl2.FLOAT, 0)); + gl2, () => gl2.readPixels(0, 0, columns, rows, gl2.RGBA, gl2.FLOAT, 0)); webgl_util.callAndCheck( - gl2, debug, () => gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null)); + gl2, () => gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null)); return buffer; } @@ -245,7 +234,7 @@ export function downloadFloat32MatrixFromBuffer( } export function downloadByteEncodedFloatMatrixFromOutputTexture( - gl: WebGLRenderingContext, debug: boolean, rows: number, columns: number, + gl: WebGLRenderingContext, rows: number, columns: number, textureConfig: TextureConfig) { const [w, h] = tex_util.getUnpackedMatrixTextureShapeWidthHeight(rows, columns); @@ -255,7 +244,7 @@ export function downloadByteEncodedFloatMatrixFromOutputTexture( tex_util.getUnpackedArraySizeFromMatrixSize(rows * columns, numChannels)); webgl_util.callAndCheck( - gl, debug, + gl, () => gl.readPixels( 0, 0, w, h, textureConfig.downloadTextureFormat, gl.UNSIGNED_BYTE, downloadTarget)); @@ -283,11 +272,11 @@ export function downloadPackedMatrixFromBuffer( } export function downloadMatrixFromPackedOutputTexture( - gl: WebGLRenderingContext, debug: boolean, physicalRows: number, + gl: WebGLRenderingContext, physicalRows: number, physicalCols: number): Float32Array { const packedRGBA = new Float32Array(physicalRows * physicalCols * 4); webgl_util.callAndCheck( - gl, debug, + gl, () => gl.readPixels( 0, 0, physicalCols, physicalRows, gl.RGBA, gl.FLOAT, packedRGBA)); diff --git a/tfjs-backend-webgl/src/gpgpu_util_test.ts b/tfjs-backend-webgl/src/gpgpu_util_test.ts index c53063405d4..40253e9595c 100644 --- a/tfjs-backend-webgl/src/gpgpu_util_test.ts +++ b/tfjs-backend-webgl/src/gpgpu_util_test.ts @@ -62,9 +62,8 @@ describeWithFlags('gpgpu_util createFloat32MatrixTexture', WEBGL_ENVS, () => { it('sets the TEXTURE_WRAP S+T parameters to CLAMP_TO_EDGE', () => { const gpgpu = new GPGPUContext(); const textureConfig = tex_util.getTextureConfig(gpgpu.gl); - const debug = false; - const tex = gpgpu_util.createFloat32MatrixTexture( - gpgpu.gl, debug, 32, 32, textureConfig); + const tex = + gpgpu_util.createFloat32MatrixTexture(gpgpu.gl, 32, 32, textureConfig); gpgpu.gl.bindTexture(gpgpu.gl.TEXTURE_2D, tex); expect( gpgpu.gl.getTexParameter(gpgpu.gl.TEXTURE_2D, gpgpu.gl.TEXTURE_WRAP_S)) @@ -80,9 +79,8 @@ describeWithFlags('gpgpu_util createFloat32MatrixTexture', WEBGL_ENVS, () => { it('sets the TEXTURE_[MIN|MAG]_FILTER parameters to NEAREST', () => { const gpgpu = new GPGPUContext(); const textureConfig = tex_util.getTextureConfig(gpgpu.gl); - const debug = false; - const tex = gpgpu_util.createFloat32MatrixTexture( - gpgpu.gl, debug, 32, 32, textureConfig); + const tex = + gpgpu_util.createFloat32MatrixTexture(gpgpu.gl, 32, 32, textureConfig); gpgpu.gl.bindTexture(gpgpu.gl.TEXTURE_2D, tex); expect(gpgpu.gl.getTexParameter( gpgpu.gl.TEXTURE_2D, gpgpu.gl.TEXTURE_MIN_FILTER)) @@ -100,9 +98,8 @@ describeWithFlags('gpgpu_util createPackedMatrixTexture', WEBGL_ENVS, () => { it('sets the TEXTURE_WRAP S+T parameters to CLAMP_TO_EDGE', () => { const gpgpu = new GPGPUContext(); const textureConfig = tex_util.getTextureConfig(gpgpu.gl); - const debug = false; - const tex = gpgpu_util.createPackedMatrixTexture( - gpgpu.gl, debug, 32, 32, textureConfig); + const tex = + gpgpu_util.createPackedMatrixTexture(gpgpu.gl, 32, 32, textureConfig); gpgpu.gl.bindTexture(gpgpu.gl.TEXTURE_2D, tex); expect( gpgpu.gl.getTexParameter(gpgpu.gl.TEXTURE_2D, gpgpu.gl.TEXTURE_WRAP_S)) @@ -118,9 +115,8 @@ describeWithFlags('gpgpu_util createPackedMatrixTexture', WEBGL_ENVS, () => { it('sets the TEXTURE_[MIN|MAG]_FILTER parameters to NEAREST', () => { const gpgpu = new GPGPUContext(); const textureConfig = tex_util.getTextureConfig(gpgpu.gl); - const debug = false; - const tex = gpgpu_util.createPackedMatrixTexture( - gpgpu.gl, debug, 32, 32, textureConfig); + const tex = + gpgpu_util.createPackedMatrixTexture(gpgpu.gl, 32, 32, textureConfig); gpgpu.gl.bindTexture(gpgpu.gl.TEXTURE_2D, tex); expect(gpgpu.gl.getTexParameter( gpgpu.gl.TEXTURE_2D, gpgpu.gl.TEXTURE_MIN_FILTER)) diff --git a/tfjs-backend-webgl/src/webgl_util.ts b/tfjs-backend-webgl/src/webgl_util.ts index 62369a7b177..90870bfff2b 100644 --- a/tfjs-backend-webgl/src/webgl_util.ts +++ b/tfjs-backend-webgl/src/webgl_util.ts @@ -20,10 +20,9 @@ import {env, util} from '@tensorflow/tfjs-core'; import {getWebGLContext} from './canvas_util'; import {getTextureConfig} from './tex_util'; -export function callAndCheck( - gl: WebGLRenderingContext, debugMode: boolean, func: () => T): T { +export function callAndCheck(gl: WebGLRenderingContext, func: () => T): T { const returnValue = func(); - if (debugMode) { + if (env().getBool('DEBUG')) { checkWebGLError(gl); } return returnValue; @@ -71,21 +70,19 @@ export function getWebGLErrorMessage( } export function getExtensionOrThrow( - gl: WebGLRenderingContext, debug: boolean, extensionName: string): {} { + gl: WebGLRenderingContext, extensionName: string): {} { return throwIfNull<{}>( - gl, debug, () => gl.getExtension(extensionName), + gl, () => gl.getExtension(extensionName), 'Extension "' + extensionName + '" not supported on this browser.'); } export function createVertexShader( - gl: WebGLRenderingContext, debug: boolean, - vertexShaderSource: string): WebGLShader { + gl: WebGLRenderingContext, vertexShaderSource: string): WebGLShader { const vertexShader: WebGLShader = throwIfNull( - gl, debug, () => gl.createShader(gl.VERTEX_SHADER), + gl, () => gl.createShader(gl.VERTEX_SHADER), 'Unable to create vertex WebGLShader.'); - callAndCheck( - gl, debug, () => gl.shaderSource(vertexShader, vertexShaderSource)); - callAndCheck(gl, debug, () => gl.compileShader(vertexShader)); + callAndCheck(gl, () => gl.shaderSource(vertexShader, vertexShaderSource)); + callAndCheck(gl, () => gl.compileShader(vertexShader)); if (gl.getShaderParameter(vertexShader, gl.COMPILE_STATUS) === false) { console.log(gl.getShaderInfoLog(vertexShader)); throw new Error('Failed to compile vertex shader.'); @@ -94,14 +91,12 @@ export function createVertexShader( } export function createFragmentShader( - gl: WebGLRenderingContext, debug: boolean, - fragmentShaderSource: string): WebGLShader { + gl: WebGLRenderingContext, fragmentShaderSource: string): WebGLShader { const fragmentShader: WebGLShader = throwIfNull( - gl, debug, () => gl.createShader(gl.FRAGMENT_SHADER), + gl, () => gl.createShader(gl.FRAGMENT_SHADER), 'Unable to create fragment WebGLShader.'); - callAndCheck( - gl, debug, () => gl.shaderSource(fragmentShader, fragmentShaderSource)); - callAndCheck(gl, debug, () => gl.compileShader(fragmentShader)); + callAndCheck(gl, () => gl.shaderSource(fragmentShader, fragmentShaderSource)); + callAndCheck(gl, () => gl.compileShader(fragmentShader)); if (gl.getShaderParameter(fragmentShader, gl.COMPILE_STATUS) === false) { logShaderSourceAndInfoLog( fragmentShaderSource, gl.getShaderInfoLog(fragmentShader)); @@ -144,15 +139,13 @@ function logShaderSourceAndInfoLog( console.log(afterErrorLines.join('\n')); } -export function createProgram( - gl: WebGLRenderingContext, debug: boolean): WebGLProgram { +export function createProgram(gl: WebGLRenderingContext): WebGLProgram { return throwIfNull( - gl, debug, () => gl.createProgram(), 'Unable to create WebGLProgram.'); + gl, () => gl.createProgram(), 'Unable to create WebGLProgram.'); } -export function linkProgram( - gl: WebGLRenderingContext, debug: boolean, program: WebGLProgram) { - callAndCheck(gl, debug, () => gl.linkProgram(program)); +export function linkProgram(gl: WebGLRenderingContext, program: WebGLProgram) { + callAndCheck(gl, () => gl.linkProgram(program)); if (gl.getProgramParameter(program, gl.LINK_STATUS) === false) { console.log(gl.getProgramInfoLog(program)); throw new Error('Failed to link vertex and fragment shaders.'); @@ -160,8 +153,8 @@ export function linkProgram( } export function validateProgram( - gl: WebGLRenderingContext, debug: boolean, program: WebGLProgram) { - callAndCheck(gl, debug, () => gl.validateProgram(program)); + gl: WebGLRenderingContext, program: WebGLProgram) { + callAndCheck(gl, () => gl.validateProgram(program)); if (gl.getProgramParameter(program, gl.VALIDATE_STATUS) === false) { console.log(gl.getProgramInfoLog(program)); throw new Error('Shader program validation failed.'); @@ -169,24 +162,21 @@ export function validateProgram( } export function createStaticVertexBuffer( - gl: WebGLRenderingContext, debug: boolean, - data: Float32Array): WebGLBuffer { + gl: WebGLRenderingContext, data: Float32Array): WebGLBuffer { const buffer: WebGLBuffer = throwIfNull( - gl, debug, () => gl.createBuffer(), 'Unable to create WebGLBuffer'); - callAndCheck(gl, debug, () => gl.bindBuffer(gl.ARRAY_BUFFER, buffer)); - callAndCheck( - gl, debug, () => gl.bufferData(gl.ARRAY_BUFFER, data, gl.STATIC_DRAW)); + gl, () => gl.createBuffer(), 'Unable to create WebGLBuffer'); + callAndCheck(gl, () => gl.bindBuffer(gl.ARRAY_BUFFER, buffer)); + callAndCheck(gl, () => gl.bufferData(gl.ARRAY_BUFFER, data, gl.STATIC_DRAW)); return buffer; } export function createStaticIndexBuffer( - gl: WebGLRenderingContext, debug: boolean, data: Uint16Array): WebGLBuffer { + gl: WebGLRenderingContext, data: Uint16Array): WebGLBuffer { const buffer: WebGLBuffer = throwIfNull( - gl, debug, () => gl.createBuffer(), 'Unable to create WebGLBuffer'); - callAndCheck(gl, debug, () => gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, buffer)); + gl, () => gl.createBuffer(), 'Unable to create WebGLBuffer'); + callAndCheck(gl, () => gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, buffer)); callAndCheck( - gl, debug, - () => gl.bufferData(gl.ELEMENT_ARRAY_BUFFER, data, gl.STATIC_DRAW)); + gl, () => gl.bufferData(gl.ELEMENT_ARRAY_BUFFER, data, gl.STATIC_DRAW)); return buffer; } @@ -197,10 +187,9 @@ export function getNumChannels(): number { return 4; } -export function createTexture( - gl: WebGLRenderingContext, debug: boolean): WebGLTexture { +export function createTexture(gl: WebGLRenderingContext): WebGLTexture { return throwIfNull( - gl, debug, () => gl.createTexture(), 'Unable to create WebGLTexture.'); + gl, () => gl.createTexture(), 'Unable to create WebGLTexture.'); } export function validateTextureSize(width: number, height: number) { @@ -218,53 +207,50 @@ export function validateTextureSize(width: number, height: number) { } } -export function createFramebuffer( - gl: WebGLRenderingContext, debug: boolean): WebGLFramebuffer { +export function createFramebuffer(gl: WebGLRenderingContext): WebGLFramebuffer { return throwIfNull( - gl, debug, () => gl.createFramebuffer(), - 'Unable to create WebGLFramebuffer.'); + gl, () => gl.createFramebuffer(), 'Unable to create WebGLFramebuffer.'); } export function bindVertexBufferToProgramAttribute( - gl: WebGLRenderingContext, debug: boolean, program: WebGLProgram, - attribute: string, buffer: WebGLBuffer, arrayEntriesPerItem: number, - itemStrideInBytes: number, itemOffsetInBytes: number): boolean { + gl: WebGLRenderingContext, program: WebGLProgram, attribute: string, + buffer: WebGLBuffer, arrayEntriesPerItem: number, itemStrideInBytes: number, + itemOffsetInBytes: number): boolean { const loc = gl.getAttribLocation(program, attribute); if (loc === -1) { // The GPU compiler decided to strip out this attribute because it's unused, // thus no need to bind. return false; } - callAndCheck(gl, debug, () => gl.bindBuffer(gl.ARRAY_BUFFER, buffer)); + callAndCheck(gl, () => gl.bindBuffer(gl.ARRAY_BUFFER, buffer)); callAndCheck( - gl, debug, + gl, () => gl.vertexAttribPointer( loc, arrayEntriesPerItem, gl.FLOAT, false, itemStrideInBytes, itemOffsetInBytes)); - callAndCheck(gl, debug, () => gl.enableVertexAttribArray(loc)); + callAndCheck(gl, () => gl.enableVertexAttribArray(loc)); return true; } export function bindTextureUnit( - gl: WebGLRenderingContext, debug: boolean, texture: WebGLTexture, - textureUnit: number) { + gl: WebGLRenderingContext, texture: WebGLTexture, textureUnit: number) { validateTextureUnit(gl, textureUnit); - callAndCheck(gl, debug, () => gl.activeTexture(gl.TEXTURE0 + textureUnit)); - callAndCheck(gl, debug, () => gl.bindTexture(gl.TEXTURE_2D, texture)); + callAndCheck(gl, () => gl.activeTexture(gl.TEXTURE0 + textureUnit)); + callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, texture)); } export function unbindTextureUnit( - gl: WebGLRenderingContext, debug: boolean, textureUnit: number) { + gl: WebGLRenderingContext, textureUnit: number) { validateTextureUnit(gl, textureUnit); - callAndCheck(gl, debug, () => gl.activeTexture(gl.TEXTURE0 + textureUnit)); - callAndCheck(gl, debug, () => gl.bindTexture(gl.TEXTURE_2D, null)); + callAndCheck(gl, () => gl.activeTexture(gl.TEXTURE0 + textureUnit)); + callAndCheck(gl, () => gl.bindTexture(gl.TEXTURE_2D, null)); } export function getProgramUniformLocationOrThrow( - gl: WebGLRenderingContext, debug: boolean, program: WebGLProgram, + gl: WebGLRenderingContext, program: WebGLProgram, uniformName: string): WebGLUniformLocation { return throwIfNull( - gl, debug, () => gl.getUniformLocation(program, uniformName), + gl, () => gl.getUniformLocation(program, uniformName), 'uniform "' + uniformName + '" not present in program.'); } @@ -275,41 +261,33 @@ export function getProgramUniformLocation( } export function bindTextureToProgramUniformSampler( - gl: WebGLRenderingContext, debug: boolean, program: WebGLProgram, - texture: WebGLTexture, uniformSamplerLocation: WebGLUniformLocation, - textureUnit: number) { - callAndCheck( - gl, debug, () => bindTextureUnit(gl, debug, texture, textureUnit)); - callAndCheck( - gl, debug, () => gl.uniform1i(uniformSamplerLocation, textureUnit)); + gl: WebGLRenderingContext, texture: WebGLTexture, + uniformSamplerLocation: WebGLUniformLocation, textureUnit: number) { + callAndCheck(gl, () => bindTextureUnit(gl, texture, textureUnit)); + callAndCheck(gl, () => gl.uniform1i(uniformSamplerLocation, textureUnit)); } -export function bindCanvasToFramebuffer( - gl: WebGLRenderingContext, debug: boolean) { - callAndCheck(gl, debug, () => gl.bindFramebuffer(gl.FRAMEBUFFER, null)); - callAndCheck( - gl, debug, () => gl.viewport(0, 0, gl.canvas.width, gl.canvas.height)); - callAndCheck( - gl, debug, () => gl.scissor(0, 0, gl.canvas.width, gl.canvas.height)); +export function bindCanvasToFramebuffer(gl: WebGLRenderingContext) { + callAndCheck(gl, () => gl.bindFramebuffer(gl.FRAMEBUFFER, null)); + callAndCheck(gl, () => gl.viewport(0, 0, gl.canvas.width, gl.canvas.height)); + callAndCheck(gl, () => gl.scissor(0, 0, gl.canvas.width, gl.canvas.height)); } export function bindColorTextureToFramebuffer( - gl: WebGLRenderingContext, debug: boolean, texture: WebGLTexture, + gl: WebGLRenderingContext, texture: WebGLTexture, framebuffer: WebGLFramebuffer) { + callAndCheck(gl, () => gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer)); callAndCheck( - gl, debug, () => gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer)); - callAndCheck( - gl, debug, + gl, () => gl.framebufferTexture2D( gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0)); } export function unbindColorTextureFromFramebuffer( - gl: WebGLRenderingContext, debug: boolean, framebuffer: WebGLFramebuffer) { - callAndCheck( - gl, debug, () => gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer)); + gl: WebGLRenderingContext, framebuffer: WebGLFramebuffer) { + callAndCheck(gl, () => gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer)); callAndCheck( - gl, debug, + gl, () => gl.framebufferTexture2D( gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, null, 0)); } @@ -339,9 +317,9 @@ export function getFramebufferErrorMessage( } function throwIfNull( - gl: WebGLRenderingContext, debug: boolean, returnTOrNull: () => T | null, + gl: WebGLRenderingContext, returnTOrNull: () => T | null, failureMessage: string): T { - const tOrNull: T|null = callAndCheck(gl, debug, () => returnTOrNull()); + const tOrNull: T|null = callAndCheck(gl, () => returnTOrNull()); if (tOrNull == null) { throw new Error(failureMessage); } diff --git a/tfjs-backend-webgpu/src/backend_webgpu.ts b/tfjs-backend-webgpu/src/backend_webgpu.ts index a61cb9c78b7..5ce901cf6e9 100644 --- a/tfjs-backend-webgpu/src/backend_webgpu.ts +++ b/tfjs-backend-webgpu/src/backend_webgpu.ts @@ -1171,8 +1171,8 @@ export class WebGPUBackend extends KernelBackend { paddedX.shape, blockShape, prod, false); return paddedX.reshape(reshapedPaddedShape) - .transpose(permutedReshapedPaddedPermutation) - .reshape(flattenShape); + .transpose(permutedReshapedPaddedPermutation) + .reshape(flattenShape); } batchMatMul( diff --git a/tfjs-core/src/ops/tensor_ops.ts b/tfjs-core/src/ops/tensor_ops.ts index 1255a1ac7e8..6248a96c7f0 100644 --- a/tfjs-core/src/ops/tensor_ops.ts +++ b/tfjs-core/src/ops/tensor_ops.ts @@ -16,7 +16,6 @@ */ import {ENGINE} from '../engine'; -import {env} from '../environment'; import {Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, Tensor6D, Variable} from '../tensor'; import {convertToTensor, inferShape} from '../tensor_util_env'; import {TensorLike, TensorLike1D, TensorLike2D, TensorLike3D, TensorLike4D, TensorLike5D, TensorLike6D, TypedArray} from '../types'; @@ -110,7 +109,7 @@ function makeTensor( shape = shape || inferredShape; values = dtype !== 'string' ? - toTypedArray(values, dtype, env().getBool('DEBUG')) : + toTypedArray(values, dtype) : flatten(values as string[], [], true) as string[]; return ENGINE.makeTensor(values as TypedArray, shape, dtype); } diff --git a/tfjs-core/src/tensor_util_env.ts b/tfjs-core/src/tensor_util_env.ts index 46bd8d96a02..381f0df878b 100644 --- a/tfjs-core/src/tensor_util_env.ts +++ b/tfjs-core/src/tensor_util_env.ts @@ -113,7 +113,7 @@ export function convertToTensor( } const skipTypedArray = true; const values = inferredDtype !== 'string' ? - toTypedArray(x, inferredDtype as DataType, env().getBool('DEBUG')) : + toTypedArray(x, inferredDtype as DataType) : flatten(x as string[], [], skipTypedArray) as string[]; return ENGINE.makeTensor(values, inferredShape, inferredDtype) as T; } diff --git a/tfjs-core/src/util.ts b/tfjs-core/src/util.ts index 4b3b5641132..b5ead6f9150 100644 --- a/tfjs-core/src/util.ts +++ b/tfjs-core/src/util.ts @@ -548,15 +548,15 @@ export function computeStrides(shape: number[]): number[] { return strides; } -export function toTypedArray( - a: TensorLike, dtype: DataType, debugMode: boolean): TypedArray { +export function toTypedArray(a: TensorLike, dtype: DataType): TypedArray { if (dtype === 'string') { throw new Error('Cannot convert a string[] to a TypedArray'); } if (Array.isArray(a)) { a = flatten(a); } - if (debugMode) { + + if (env().getBool('DEBUG')) { checkConversionForErrors(a as number[], dtype); } if (noConversionNeeded(a, dtype)) {