Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 49 additions & 57 deletions tfjs-backend-webgpu/src/binary_op_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@ export enum BinaryOpType {
}

const CHECK_NAN_SNIPPET = `
if (isnan(a)) { return a; }
if (isnan(b)) { return b; }
`;
resultTemp = select(resultTemp, valueForNaN, isNaN | isnan(a) | isnan(b));`;

const CHECK_NAN_SNIPPET_VEC4 = `
resultTemp = select(
Expand All @@ -53,6 +51,7 @@ const CHECK_NAN_SNIPPET_VEC4 = `
`;

const ADD = 'return a + b;';
const ATAN2 = 'var resultTemp = atan2(a, b);';
// (Ar + Ai)(Br + Bi) =
// ArBr + ArBi + AiBr + AiBi = ArBr - AB + ArBi + AiBr
// Yr = ArBr - AB
Expand Down Expand Up @@ -109,23 +108,17 @@ const LOGICAL_AND_VEC4 = `return (vec4<f32>(a >= vec4<f32>(1.0)) *
const LOGICAL_OR = 'return f32(a >= 1.0 || b >= 1.0);';
const LOGICAL_OR_VEC4 = `return min(vec4<f32>(a >= vec4<f32>(1.0)) +
vec4<f32>(b >= vec4<f32>(1.0)), vec4<f32>(1.0));`;
const MAX = 'var resultTemp = max(a, b);';
const MIN = 'var resultTemp = min(a, b);';
const MOD = `
${CHECK_NAN_SNIPPET}
if (b == 0.) {
return uniforms.NAN;
}
let isNaN = b == 0.;
var resultTemp = a % b;
if ((a < 0. && b < 0.) || (a >= 0. && b > 0.)) {
return resultTemp;
} else {
return (resultTemp + b) % b;
}
resultTemp = select((resultTemp + b) % b, resultTemp,
(a < 0. && b < 0.) || (a >= 0. && b > 0.));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the line breaking rule? Follow clang-format?

`;
const MOD_VEC4 = `
let isNaN = !vec4<bool>(b);
var resultTemp = vec4<f32>(a % b);
${CHECK_NAN_SNIPPET_VEC4}

if (!((a[0] < 0. && b[0] < 0.) || (a[0] >= 0. && b[0] > 0.))) {
resultTemp[0] = (resultTemp[0] + b[0]) % b[0];
}
Expand All @@ -138,35 +131,24 @@ const MOD_VEC4 = `
if (!((a[3] < 0. && b[3] < 0.) || (a[3] >= 0. && b[3] > 0.))) {
resultTemp[3] = (resultTemp[3] + b[3]) % b[3];
}

return resultTemp;
`;
const MUL = 'return a * b;';
const NOT_EQUAL = `
if (isnan(a) || isnan(b)) {
return 1.0;
}
return f32(a != b);
var resultTemp = f32(a != b);
let valueForNaN = 1.0;
`;
const NOT_EQUAL_VEC4 = `
var resultTemp = vec4<f32>(a != b);
let valueForNaN = 1.0;
${CHECK_NAN_SNIPPET_VEC4}

return resultTemp;
`;

const POW = `
if(a < 0.0 && floor(b) < b) {
return uniforms.NAN;
}
let isNaN = a < 0.0 && floor(b) < b;
if (b == 0.0) {
return 1.0;
}
if (round(abs(b) % 2.0) != 1.0) {
return pow(abs(a), b);
}
return sign(a) * pow(abs(a), b);
var resultTemp = select(sign(a) * pow(abs(a), b), pow(abs(a), b),
round(abs(b) % 2.0) != 1.0);
`;
const POW_VEC4 = `
let isModRound1Bool = vec4<i32>(round(abs(b) % vec4<f32>(2.0))) == vec4<i32>(1);
Expand All @@ -189,8 +171,6 @@ const POW_VEC4 = `
resultTemp.a = 1.0;
}
let isNaN = (a < vec4<f32>(0.0)) & (floor(b) < b);
${CHECK_NAN_SNIPPET_VEC4}
return resultTemp;
`;

const PRELU = `if (a < 0.0) { return b * a; } return a;`;
Expand All @@ -201,26 +181,48 @@ const PRELU_VEC4 = `
const SQUARED_DIFFERENCE = 'return (a - b) * (a - b);';
const SUB = 'return a - b;';

function getBinaryWithNanString(op: string, useVec4: boolean) {
const checkNanSnippet = useVec4 ? CHECK_NAN_SNIPPET_VEC4 : CHECK_NAN_SNIPPET;
return useVec4 ? `
var resultTemp = vec4<f32>(${op}(a, b));
` + checkNanSnippet +
`
return resultTemp;
` :
checkNanSnippet + `
return ${op}(a, b);
`;
}

export function getBinaryOpString(
type: BinaryOpType, useVec4?: boolean): string {
// Ops with NaN check
do {
let doOpSnippet: string;
switch (type) {
case BinaryOpType.ATAN2:
doOpSnippet = ATAN2;
break;
case BinaryOpType.MAX:
doOpSnippet = MAX;
break;
case BinaryOpType.MIN:
doOpSnippet = MIN;
break;
case BinaryOpType.MOD:
doOpSnippet = useVec4 ? MOD_VEC4 : MOD;
break;
case BinaryOpType.NOT_EQUAL:
doOpSnippet = useVec4 ? NOT_EQUAL_VEC4 : NOT_EQUAL;
break;
case BinaryOpType.POW:
doOpSnippet = useVec4 ? POW_VEC4 : POW;
break;
default:
continue;
}
return `
let isNaN = false;
let valueForNaN = uniforms.NAN;
{
${doOpSnippet}
${useVec4 ? CHECK_NAN_SNIPPET_VEC4 : CHECK_NAN_SNIPPET}
return resultTemp;
}
`;
} while (false);

// Ops without NaN check
switch (type) {
case BinaryOpType.ADD:
return ADD;
case BinaryOpType.ATAN2:
return getBinaryWithNanString('atan2', useVec4);
case BinaryOpType.COMPLEX_MULTIPLY_IMAG:
return COMPLEX_MULTIPLY_IMAG;
case BinaryOpType.COMPLEX_MULTIPLY_REAL:
Expand All @@ -245,18 +247,8 @@ export function getBinaryOpString(
return useVec4 ? LOGICAL_AND_VEC4 : LOGICAL_AND;
case BinaryOpType.LOGICAL_OR:
return useVec4 ? LOGICAL_OR_VEC4 : LOGICAL_OR;
case BinaryOpType.MAX:
return getBinaryWithNanString('max', useVec4);
case BinaryOpType.MIN:
return getBinaryWithNanString('min', useVec4);
case BinaryOpType.MOD:
return useVec4 ? MOD_VEC4 : MOD;
case BinaryOpType.MUL:
return MUL;
case BinaryOpType.NOT_EQUAL:
return useVec4 ? NOT_EQUAL_VEC4 : NOT_EQUAL;
case BinaryOpType.POW:
return useVec4 ? POW_VEC4 : POW;
case BinaryOpType.PRELU:
return useVec4 ? PRELU_VEC4 : PRELU;
case BinaryOpType.SQUARED_DIFFERENCE:
Expand Down
6 changes: 1 addition & 5 deletions tfjs-backend-webgpu/src/binary_op_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,7 @@ export class BinaryOpProgram implements WebGPUProgram {
const dType = this.isVec4 ? 'vec4<f32>' : 'f32';
const opFnStr = `
fn binaryOperation(a : ${dType}, b : ${dType}) -> ${dType} {
let isNaN = false;
let valueForNaN = uniforms.NAN;
{
${getBinaryOpString(this.op, this.isVec4)}
}
${getBinaryOpString(this.op, this.isVec4)}
};
`;

Expand Down