Skip to content

Conversation

@xhcao
Copy link
Contributor

@xhcao xhcao commented Nov 7, 2022

To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.


This change is Reviewable

@xhcao xhcao force-pushed the conv3d branch 2 times, most recently from 3f7772d to d69431f Compare November 8, 2022 01:04
@xhcao
Copy link
Contributor Author

xhcao commented Jan 9, 2023

Hi, @qjia7 @gyagp Some other kernels also depend on conv3d kernel, so I firstly added a naive conv3d implementation here, in order to quickly implement the other kernels. I will to implement a shared conv3d version in the future.

@xhcao xhcao requested review from gyagp and qjia7 January 9, 2023 03:48
Copy link

@gyagp gyagp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with some nits. Overall, we should use more WGSL syntax sugar, like ++ and +=.

let inputDepthVec4Remainder = uniforms.xShape.u % 4;
var dotProd = 0.0;
for (var wF = 0; wF < uniforms.filterDims[0]; wF = wF + 1) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for (var wF = 0; wF < uniforms.filterDims[0]; wF = wF + 1) {
for (var wF = 0; wF < uniforms.filterDims[0]; wF++) {

continue;
}
for (var wR = 0; wR < uniforms.filterDims[1]; wR = wR + 1) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use ++

continue;
}
for (var wC = 0; wC < uniforms.filterDims[2]; wC = wC + 1) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use ++

continue;
}
for (var d1 = 0; d1 < inputDepthNearestVec4; d1 = d1 + 4) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use += 4

getW(wF, wR, wC, d1 + 3, d2)
);
dotProd = dotProd + dot(xValues, wValues);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use +=

getW(wF, wR, wC, inputDepthNearestVec4, d2),
getW(wF, wR, wC, inputDepthNearestVec4 + 1, d2)
);
dotProd = dotProd + dot(xValues, wValues);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use +=

getW(wF, wR, wC, inputDepthNearestVec4 + 1, d2),
getW(wF, wR, wC, inputDepthNearestVec4 + 2, d2)
);
dotProd = dotProd + dot(xValues, wValues);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use +=


const numCoords = indicesArr.length;
const shape = indicesArr.map(d => `${variableName}[${d}]`);
const indicesStr = ['.x', '.y', '.z', '.w', '.u', '.v'];
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the sequence wuv instead of uvw?
You may just define indicesStr as 'xyzuvw', then use it as ${variableName}.${indicesStr[d]}.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @gyagp, xyzw is used for single component selection by WGSL https://gpuweb.github.io/gpuweb/wgsl/#:~:text=7.7.1.1.%20Vector%20Single%20Component%20Selection
TFJS webgpu backend extends the vec4 to vec5 and vec6

struct vec5 {x: i32, y: i32, z: i32, w: i32, u: i32};
, and uses uv to select the fifth and sixth components.

All other comments are addressed, thank you.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the info!

@@ -0,0 +1,129 @@
/**
* @license
* Copyright 2022 Google LLC.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: 2022 -> 2023 and similar for other places

variableNames = ['x', 'W'];
uniforms =
'filterDims: vec3<i32>, pad: vec3<i32>, strides: vec3<i32>, dilations: vec3<i32>,';
workgroupSize: [number, number, number] = [4, 4, 8];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use the flat dispatch layout here? Previous experience shows that flat dispatch layout always has good performance if the algorithm is irrelevant with dispatch layout. You can have a micro-benchmark to verify it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@xhcao xhcao merged commit 94c0c2b into tensorflow:master Jan 12, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants