From e0566396282da86d586a37ff502bd89f54071910 Mon Sep 17 00:00:00 2001 From: Jiajie Hu Date: Fri, 23 Dec 2022 10:16:40 +0800 Subject: [PATCH] [webgpu] Let vectorized binary op return an overridable value on NaN This is like 2c95a67 [webgpu] Further tweak vectorized NaN handling in binary ops but for `valueForNaN`. --- tfjs-backend-webgpu/src/binary_op_util.ts | 6 +----- tfjs-backend-webgpu/src/binary_op_webgpu.ts | 1 + 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/tfjs-backend-webgpu/src/binary_op_util.ts b/tfjs-backend-webgpu/src/binary_op_util.ts index 915acd7d51d..bfdde48a331 100644 --- a/tfjs-backend-webgpu/src/binary_op_util.ts +++ b/tfjs-backend-webgpu/src/binary_op_util.ts @@ -119,7 +119,6 @@ const MOD = ` `; const MOD_VEC4 = ` let isNaN = !vec4(b); - let valueForNaN = uniforms.NAN; var resultTemp = vec4(a % b); ${CHECK_NAN_SNIPPET_VEC4} @@ -186,7 +185,6 @@ const POW_VEC4 = ` resultTemp.a = 1.0; } let isNaN = (a < vec4(0.0)) & (floor(b) < b); - let valueForNaN = uniforms.NAN; ${CHECK_NAN_SNIPPET_VEC4} return resultTemp; `; @@ -199,11 +197,9 @@ const PRELU_VEC4 = ` const SQUARED_DIFFERENCE = 'return (a - b) * (a - b);'; const SUB = 'return a - b;'; -function getBinaryWithNanString( - op: string, useVec4: boolean, valueForNaN = 'uniforms.NAN') { +function getBinaryWithNanString(op: string, useVec4: boolean) { const checkNanSnippet = useVec4 ? CHECK_NAN_SNIPPET_VEC4 : CHECK_NAN_SNIPPET; return useVec4 ? ` - let valueForNaN = ${valueForNaN}; var resultTemp = vec4(${op}(a, b)); ` + checkNanSnippet + ` diff --git a/tfjs-backend-webgpu/src/binary_op_webgpu.ts b/tfjs-backend-webgpu/src/binary_op_webgpu.ts index 96aac7966af..a3db4577148 100644 --- a/tfjs-backend-webgpu/src/binary_op_webgpu.ts +++ b/tfjs-backend-webgpu/src/binary_op_webgpu.ts @@ -88,6 +88,7 @@ export class BinaryOpProgram implements WebGPUProgram { const opFnStr = ` fn binaryOperation(a : ${dType}, b : ${dType}) -> ${dType} { let isNaN = false; + let valueForNaN = uniforms.NAN; { ${getBinaryOpString(this.op, this.isVec4)} }