diff --git a/tfjs-backend-webgpu/src/binary_op_util.ts b/tfjs-backend-webgpu/src/binary_op_util.ts index 915acd7d51d..bfdde48a331 100644 --- a/tfjs-backend-webgpu/src/binary_op_util.ts +++ b/tfjs-backend-webgpu/src/binary_op_util.ts @@ -119,7 +119,6 @@ const MOD = ` `; const MOD_VEC4 = ` let isNaN = !vec4(b); - let valueForNaN = uniforms.NAN; var resultTemp = vec4(a % b); ${CHECK_NAN_SNIPPET_VEC4} @@ -186,7 +185,6 @@ const POW_VEC4 = ` resultTemp.a = 1.0; } let isNaN = (a < vec4(0.0)) & (floor(b) < b); - let valueForNaN = uniforms.NAN; ${CHECK_NAN_SNIPPET_VEC4} return resultTemp; `; @@ -199,11 +197,9 @@ const PRELU_VEC4 = ` const SQUARED_DIFFERENCE = 'return (a - b) * (a - b);'; const SUB = 'return a - b;'; -function getBinaryWithNanString( - op: string, useVec4: boolean, valueForNaN = 'uniforms.NAN') { +function getBinaryWithNanString(op: string, useVec4: boolean) { const checkNanSnippet = useVec4 ? CHECK_NAN_SNIPPET_VEC4 : CHECK_NAN_SNIPPET; return useVec4 ? ` - let valueForNaN = ${valueForNaN}; var resultTemp = vec4(${op}(a, b)); ` + checkNanSnippet + ` diff --git a/tfjs-backend-webgpu/src/binary_op_webgpu.ts b/tfjs-backend-webgpu/src/binary_op_webgpu.ts index 96aac7966af..a3db4577148 100644 --- a/tfjs-backend-webgpu/src/binary_op_webgpu.ts +++ b/tfjs-backend-webgpu/src/binary_op_webgpu.ts @@ -88,6 +88,7 @@ export class BinaryOpProgram implements WebGPUProgram { const opFnStr = ` fn binaryOperation(a : ${dType}, b : ${dType}) -> ${dType} { let isNaN = false; + let valueForNaN = uniforms.NAN; { ${getBinaryOpString(this.op, this.isVec4)} }