-
Notifications
You must be signed in to change notification settings - Fork 2k
webgpu: add a naive implementation of conv3d #7016
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
3f7772d to
d69431f
Compare
gyagp
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with some nits. Overall, we should use more WGSL syntax sugar, like ++ and +=.
| let inputDepthVec4Remainder = uniforms.xShape.u % 4; | ||
| var dotProd = 0.0; | ||
| for (var wF = 0; wF < uniforms.filterDims[0]; wF = wF + 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| for (var wF = 0; wF < uniforms.filterDims[0]; wF = wF + 1) { | |
| for (var wF = 0; wF < uniforms.filterDims[0]; wF++) { |
| continue; | ||
| } | ||
| for (var wR = 0; wR < uniforms.filterDims[1]; wR = wR + 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use ++
| continue; | ||
| } | ||
| for (var wC = 0; wC < uniforms.filterDims[2]; wC = wC + 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use ++
| continue; | ||
| } | ||
| for (var d1 = 0; d1 < inputDepthNearestVec4; d1 = d1 + 4) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use += 4
| getW(wF, wR, wC, d1 + 3, d2) | ||
| ); | ||
| dotProd = dotProd + dot(xValues, wValues); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use +=
| getW(wF, wR, wC, inputDepthNearestVec4, d2), | ||
| getW(wF, wR, wC, inputDepthNearestVec4 + 1, d2) | ||
| ); | ||
| dotProd = dotProd + dot(xValues, wValues); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use +=
| getW(wF, wR, wC, inputDepthNearestVec4 + 1, d2), | ||
| getW(wF, wR, wC, inputDepthNearestVec4 + 2, d2) | ||
| ); | ||
| dotProd = dotProd + dot(xValues, wValues); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use +=
|
|
||
| const numCoords = indicesArr.length; | ||
| const shape = indicesArr.map(d => `${variableName}[${d}]`); | ||
| const indicesStr = ['.x', '.y', '.z', '.w', '.u', '.v']; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is the sequence wuv instead of uvw?
You may just define indicesStr as 'xyzuvw', then use it as ${variableName}.${indicesStr[d]}.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, @gyagp, xyzw is used for single component selection by WGSL https://gpuweb.github.io/gpuweb/wgsl/#:~:text=7.7.1.1.%20Vector%20Single%20Component%20Selection
TFJS webgpu backend extends the vec4 to vec5 and vec6
| struct vec5 {x: i32, y: i32, z: i32, w: i32, u: i32}; |
uv to select the fifth and sixth components.
All other comments are addressed, thank you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the info!
| @@ -0,0 +1,129 @@ | |||
| /** | |||
| * @license | |||
| * Copyright 2022 Google LLC. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: 2022 -> 2023 and similar for other places
| variableNames = ['x', 'W']; | ||
| uniforms = | ||
| 'filterDims: vec3<i32>, pad: vec3<i32>, strides: vec3<i32>, dilations: vec3<i32>,'; | ||
| workgroupSize: [number, number, number] = [4, 4, 8]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use the flat dispatch layout here? Previous experience shows that flat dispatch layout always has good performance if the algorithm is irrelevant with dispatch layout. You can have a micro-benchmark to verify it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.
This change is