-
Notifications
You must be signed in to change notification settings - Fork 2k
webgpu: support EluGrad kernel #7197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| const COMPLEX_MULTIPLY_IMAG = 'return areal * bimag + aimag * breal;'; | ||
| const DIV = 'return a / b;'; | ||
| const ELU_DER = ` | ||
| let bGTEZero = f32(b > 0.); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| let bGTEZero = f32(b > 0.); | |
| let bGTEZero = f32(b >= 0.); |
| const DIV = 'return a / b;'; | ||
| const ELU_DER = ` | ||
| let bGTEZero = f32(b > 0.); | ||
| return (bGTEZero * a) + (1.0 - bGTEZero) * (a * (b + 1.0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use select can be simpler
| `; | ||
| const ELU_DER_VEC4 = ` | ||
| let bGTEZero = vec4<f32>(b >= vec4<f32>(0.)); | ||
| return (bGTEZero * a) + (vec4<f32>(1.0) - bGTEZero) * (a * (b + vec4<f32>(1.0))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here, please use select
gyagp
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
| const COMPLEX_MULTIPLY_REAL = 'return areal * breal - aimag * bimag;'; | ||
| const COMPLEX_MULTIPLY_IMAG = 'return areal * bimag + aimag * breal;'; | ||
| const DIV = 'return a / b;'; | ||
| const ELU_DER = 'return select(a * (b + 1.0), a, b >= 0.);'; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see webgl is using b >= 1.0 https://github.com/tensorflow/tfjs/blob/master/tfjs-backend-webgl/src/kernels/EluGrad.ts#L24 . However, you are using b >= 0.0. Which one is correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On webgl backend, the non PACKED and PACKED versions are different,
| const ELU_DER_PACKED = ` |
I think the PACKED version is correct, but is not 100% sure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to http://bytedeco.org/javacpp-presets/tensorflow/apidocs/org/bytedeco/tensorflow/EluGrad.html, b >= 0. is correct.
Please also fix the WebGL unpacked version, and I think we also need some test cases to cover.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed them and added case in #7251
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Please hold this a bit until #7251 is reviewed.
qjia7
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with one nit. Thanks!
| @@ -0,0 +1,38 @@ | |||
| /** | |||
| * @license | |||
| * Copyright 2022 Google LLC. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: 2022 -> 2023
To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.
This change is