Skip to content

Commit

Permalink
[webgpu] More preparation for element-wise binary op restructuring (#…
Browse files Browse the repository at this point in the history
…7666)

The current code isn't great in that the vec4 shaders have diverged from
the scalar ones more than necessary. Here is the common preparation
work, so that following refactoring can be done on a per-op basis.
  • Loading branch information
hujiajie committed Jun 5, 2023
1 parent 04d4a86 commit 281d8f8
Showing 1 changed file with 30 additions and 13 deletions.
43 changes: 30 additions & 13 deletions tfjs-backend-webgpu/src/binary_op_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,6 @@ export enum BinaryOpType {
SUB
}

const CHECK_NAN_SNIPPET = `
resultTemp = select(resultTemp, valueForNaN, isNaN | isnan(a) | isnan(b));`;

const CHECK_NAN_SNIPPET_VEC4 = `
resultTemp = select(
resultTemp, vec4<f32>(valueForNaN),
vec4<bool>(isNaN) | isnanVec4(a) | isnanVec4(b));
`;

const ADD = 'return a + b;';
const ATAN2 = 'var resultTemp = atan2(a, b);';
// (Ar + Ai)(Br + Bi) =
Expand Down Expand Up @@ -183,9 +174,10 @@ const SUB = 'return a - b;';

export function getBinaryOpString(
type: BinaryOpType, useVec4?: boolean): string {
let doOpSnippet: string;

// Ops with NaN check
do {
let doOpSnippet: string;
switch (type) {
case BinaryOpType.ATAN2:
doOpSnippet = ATAN2;
Expand All @@ -208,13 +200,34 @@ export function getBinaryOpString(
default:
continue;
}

let isNaN: string;
let dTypeN: string;
let boolN: string;
if (useVec4) {
isNaN = 'isnanVec4';
dTypeN = 'vec4<f32>';
boolN = 'vec4<bool>';
} else {
isNaN = 'isnan';
dTypeN = 'f32';
boolN = 'bool';
}

return `
let aIsNaN = ${isNaN}(a);
let aPostLegalization = select(a, ${dTypeN}(42), aIsNaN);
let bIsNaN = ${isNaN}(b);
let bPostLegalization = select(b, ${dTypeN}(42), bIsNaN);
let isNaN = false;
let valueForNaN = uniforms.NAN;
{
let a = aPostLegalization;
let b = bPostLegalization;
${doOpSnippet}
${useVec4 ? CHECK_NAN_SNIPPET_VEC4 : CHECK_NAN_SNIPPET}
return resultTemp;
return select(
resultTemp, ${dTypeN}(valueForNaN),
${boolN}(isNaN) | aIsNaN | bIsNaN);
}
`;
} while (false);
Expand Down Expand Up @@ -256,6 +269,10 @@ export function getBinaryOpString(
case BinaryOpType.SUB:
return SUB;
default:
throw new Error(`BinaryType ${type} is not implemented!`);
// throw new Error(`BinaryType ${type} is not implemented!`);
}
return `
${doOpSnippet}
return resultTemp;
`;
}

0 comments on commit 281d8f8

Please sign in to comment.