Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions tfjs-backend-webgpu/src/binary_op_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ const MOD = `
`;
const MOD_VEC4 = `
let isNaN = !vec4<bool>(b);
let valueForNaN = uniforms.NAN;
var resultTemp = vec4<f32>(a % b);
${CHECK_NAN_SNIPPET_VEC4}

Expand Down Expand Up @@ -186,7 +185,6 @@ const POW_VEC4 = `
resultTemp.a = 1.0;
}
let isNaN = (a < vec4<f32>(0.0)) & (floor(b) < b);
let valueForNaN = uniforms.NAN;
${CHECK_NAN_SNIPPET_VEC4}
return resultTemp;
`;
Expand All @@ -199,11 +197,9 @@ const PRELU_VEC4 = `
const SQUARED_DIFFERENCE = 'return (a - b) * (a - b);';
const SUB = 'return a - b;';

function getBinaryWithNanString(
op: string, useVec4: boolean, valueForNaN = 'uniforms.NAN') {
function getBinaryWithNanString(op: string, useVec4: boolean) {
const checkNanSnippet = useVec4 ? CHECK_NAN_SNIPPET_VEC4 : CHECK_NAN_SNIPPET;
return useVec4 ? `
let valueForNaN = ${valueForNaN};
var resultTemp = vec4<f32>(${op}(a, b));
` + checkNanSnippet +
`
Expand Down
1 change: 1 addition & 0 deletions tfjs-backend-webgpu/src/binary_op_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ export class BinaryOpProgram implements WebGPUProgram {
const opFnStr = `
fn binaryOperation(a : ${dType}, b : ${dType}) -> ${dType} {
let isNaN = false;
let valueForNaN = uniforms.NAN;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently for NOT_EQUAL_VEC4, valueForNaN is 1.0 instead of uniforms.NAN. Any specific reason? @hujiajie @qjia7

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The op returns booleans, even for NaN operands. I think that's by design.

(Well, the doc says nothing about this, it's just explicitly checked in the unit tests. Another mess is that currently this is defined in NOT_EQUAL, but not in EQUAL!)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So do you think if uniforms.NAN is good enough for a default value? I'm not sure if we have more cases like (NOT_)EQUAL.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think so.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't see any reason that we use let valueForNaN = 1.0; for NOT_EQUAL_VEC4. For WebGL, there is no NaN check for NotEqual.
So my opinion is we directly use uniforms.NAN in CHECK_NAN_SNIPPET_VEC4. And don't need to define an additional variable valueForNaN.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. WGSL implementations may assume that NaNs and infinities are not present at runtime. In such an implementation, when an expression evaluation would produce an infinity or a NaN, an indeterminate value of the target type is produced instead. So, I feel explicit checks are necessary.
  2. We want i32(valueForNaN) to evaluate to 1/true later, but i32(uniforms.NAN) is 0/false.

{
${getBinaryOpString(this.op, this.isVec4)}
}
Expand Down