diff --git a/tfjs-backend-webgpu/src/binary_op_util.ts b/tfjs-backend-webgpu/src/binary_op_util.ts index f6ca1048231..add30e2052e 100644 --- a/tfjs-backend-webgpu/src/binary_op_util.ts +++ b/tfjs-backend-webgpu/src/binary_op_util.ts @@ -42,9 +42,7 @@ export enum BinaryOpType { } const CHECK_NAN_SNIPPET = ` - if (isnan(a)) { return a; } - if (isnan(b)) { return b; } - `; + resultTemp = select(resultTemp, valueForNaN, isNaN | isnan(a) | isnan(b));`; const CHECK_NAN_SNIPPET_VEC4 = ` resultTemp = select( @@ -53,6 +51,7 @@ const CHECK_NAN_SNIPPET_VEC4 = ` `; const ADD = 'return a + b;'; +const ATAN2 = 'var resultTemp = atan2(a, b);'; // (Ar + Ai)(Br + Bi) = // ArBr + ArBi + AiBr + AiBi = ArBr - AB + ArBi + AiBr // Yr = ArBr - AB @@ -109,23 +108,17 @@ const LOGICAL_AND_VEC4 = `return (vec4(a >= vec4(1.0)) * const LOGICAL_OR = 'return f32(a >= 1.0 || b >= 1.0);'; const LOGICAL_OR_VEC4 = `return min(vec4(a >= vec4(1.0)) + vec4(b >= vec4(1.0)), vec4(1.0));`; +const MAX = 'var resultTemp = max(a, b);'; +const MIN = 'var resultTemp = min(a, b);'; const MOD = ` - ${CHECK_NAN_SNIPPET} - if (b == 0.) { - return uniforms.NAN; - } + let isNaN = b == 0.; var resultTemp = a % b; - if ((a < 0. && b < 0.) || (a >= 0. && b > 0.)) { - return resultTemp; - } else { - return (resultTemp + b) % b; - } + resultTemp = select((resultTemp + b) % b, resultTemp, + (a < 0. && b < 0.) || (a >= 0. && b > 0.)); `; const MOD_VEC4 = ` let isNaN = !vec4(b); var resultTemp = vec4(a % b); - ${CHECK_NAN_SNIPPET_VEC4} - if (!((a[0] < 0. && b[0] < 0.) || (a[0] >= 0. && b[0] > 0.))) { resultTemp[0] = (resultTemp[0] + b[0]) % b[0]; } @@ -138,35 +131,24 @@ const MOD_VEC4 = ` if (!((a[3] < 0. && b[3] < 0.) || (a[3] >= 0. && b[3] > 0.))) { resultTemp[3] = (resultTemp[3] + b[3]) % b[3]; } - - return resultTemp; `; const MUL = 'return a * b;'; const NOT_EQUAL = ` - if (isnan(a) || isnan(b)) { - return 1.0; - } - return f32(a != b); + var resultTemp = f32(a != b); + let valueForNaN = 1.0; `; const NOT_EQUAL_VEC4 = ` var resultTemp = vec4(a != b); let valueForNaN = 1.0; - ${CHECK_NAN_SNIPPET_VEC4} - - return resultTemp; `; const POW = ` - if(a < 0.0 && floor(b) < b) { - return uniforms.NAN; - } + let isNaN = a < 0.0 && floor(b) < b; if (b == 0.0) { return 1.0; } - if (round(abs(b) % 2.0) != 1.0) { - return pow(abs(a), b); - } - return sign(a) * pow(abs(a), b); + var resultTemp = select(sign(a) * pow(abs(a), b), pow(abs(a), b), + round(abs(b) % 2.0) != 1.0); `; const POW_VEC4 = ` let isModRound1Bool = vec4(round(abs(b) % vec4(2.0))) == vec4(1); @@ -189,8 +171,6 @@ const POW_VEC4 = ` resultTemp.a = 1.0; } let isNaN = (a < vec4(0.0)) & (floor(b) < b); - ${CHECK_NAN_SNIPPET_VEC4} - return resultTemp; `; const PRELU = `if (a < 0.0) { return b * a; } return a;`; @@ -201,26 +181,48 @@ const PRELU_VEC4 = ` const SQUARED_DIFFERENCE = 'return (a - b) * (a - b);'; const SUB = 'return a - b;'; -function getBinaryWithNanString(op: string, useVec4: boolean) { - const checkNanSnippet = useVec4 ? CHECK_NAN_SNIPPET_VEC4 : CHECK_NAN_SNIPPET; - return useVec4 ? ` - var resultTemp = vec4(${op}(a, b)); - ` + checkNanSnippet + - ` - return resultTemp; - ` : - checkNanSnippet + ` - return ${op}(a, b); - `; -} - export function getBinaryOpString( type: BinaryOpType, useVec4?: boolean): string { + // Ops with NaN check + do { + let doOpSnippet: string; + switch (type) { + case BinaryOpType.ATAN2: + doOpSnippet = ATAN2; + break; + case BinaryOpType.MAX: + doOpSnippet = MAX; + break; + case BinaryOpType.MIN: + doOpSnippet = MIN; + break; + case BinaryOpType.MOD: + doOpSnippet = useVec4 ? MOD_VEC4 : MOD; + break; + case BinaryOpType.NOT_EQUAL: + doOpSnippet = useVec4 ? NOT_EQUAL_VEC4 : NOT_EQUAL; + break; + case BinaryOpType.POW: + doOpSnippet = useVec4 ? POW_VEC4 : POW; + break; + default: + continue; + } + return ` + let isNaN = false; + let valueForNaN = uniforms.NAN; + { + ${doOpSnippet} + ${useVec4 ? CHECK_NAN_SNIPPET_VEC4 : CHECK_NAN_SNIPPET} + return resultTemp; + } + `; + } while (false); + + // Ops without NaN check switch (type) { case BinaryOpType.ADD: return ADD; - case BinaryOpType.ATAN2: - return getBinaryWithNanString('atan2', useVec4); case BinaryOpType.COMPLEX_MULTIPLY_IMAG: return COMPLEX_MULTIPLY_IMAG; case BinaryOpType.COMPLEX_MULTIPLY_REAL: @@ -245,18 +247,8 @@ export function getBinaryOpString( return useVec4 ? LOGICAL_AND_VEC4 : LOGICAL_AND; case BinaryOpType.LOGICAL_OR: return useVec4 ? LOGICAL_OR_VEC4 : LOGICAL_OR; - case BinaryOpType.MAX: - return getBinaryWithNanString('max', useVec4); - case BinaryOpType.MIN: - return getBinaryWithNanString('min', useVec4); - case BinaryOpType.MOD: - return useVec4 ? MOD_VEC4 : MOD; case BinaryOpType.MUL: return MUL; - case BinaryOpType.NOT_EQUAL: - return useVec4 ? NOT_EQUAL_VEC4 : NOT_EQUAL; - case BinaryOpType.POW: - return useVec4 ? POW_VEC4 : POW; case BinaryOpType.PRELU: return useVec4 ? PRELU_VEC4 : PRELU; case BinaryOpType.SQUARED_DIFFERENCE: diff --git a/tfjs-backend-webgpu/src/binary_op_webgpu.ts b/tfjs-backend-webgpu/src/binary_op_webgpu.ts index a3db4577148..891785d30d0 100644 --- a/tfjs-backend-webgpu/src/binary_op_webgpu.ts +++ b/tfjs-backend-webgpu/src/binary_op_webgpu.ts @@ -87,11 +87,7 @@ export class BinaryOpProgram implements WebGPUProgram { const dType = this.isVec4 ? 'vec4' : 'f32'; const opFnStr = ` fn binaryOperation(a : ${dType}, b : ${dType}) -> ${dType} { - let isNaN = false; - let valueForNaN = uniforms.NAN; - { - ${getBinaryOpString(this.op, this.isVec4)} - } + ${getBinaryOpString(this.op, this.isVec4)} }; `;