Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WebGPU] support SparseSegmentSum and SparseSegmentMean #7746

Merged
merged 2 commits into from
Jun 12, 2023

Conversation

xhcao
Copy link
Collaborator

@xhcao xhcao commented Jun 8, 2023

To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.

@xhcao xhcao requested review from gyagp and qjia7 June 8, 2023 08:14
Copy link
Collaborator

@qjia7 qjia7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with two nits.

const $segmentIds = backend.readSync(segmentIds.dataId) as TypedArray;
const lastSegmentIdPlusOne =
numIndices > 0 ? $segmentIds[numIndices - 1] + 1 : 0;
const outputRows = lastSegmentIdPlusOne;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

outputRows -> numSegments?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

outputRows is not equal with numSegments here. For example, numSegments = [1, 2, 10], the output shape is [11, x, x, x], not [3, x, x, x] here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/sparse-segment-mean,

Output: Has same shape as data, except for dimension 0 which has size k, the number of segments.

Here numSegments, I mean the number of segments, which I referenced the name numSegments in tf.unsortedSegmentSum(x, segmentIds, numSegments). It's a number not an array. What do you mean numSegments = [1, 2, 10]?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const c = tf.tensor2d([[1,2,3,4], [-1,-2,-3,-4], [6,7,8,9]]);
const result = tf.sparse.sparseSegmentMean(c, tf.tensor1d([0, 1, 2], 'int32'), tf.tensor1d([0, 2, 10], 'int32'));
The result's shape is [11, 4], and outputRows = 11, if defined numSegments = $segmentIds[numIndices - 1] + 1 = 11here, I think it is confused and mostly we will thinknumSegments = 3here, only three rows are filled values fromdata`, other rows are initialized into zeros.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think 3 stands for numSegmentIds not numSegments, which are different. I just think outputRows is not a good name. But if you persist, I am ok you proceed without change.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The outputRows comes from cpu backend, let us keep them the same name. Thanks.

getUserCode(): string {
const userCode = `
${main('index')} {
if (index < uniforms.segmentIdsSize) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe directly use uniforms.segmentIdsShape which is the input's shape to avoid the extra uniform segmentIdsSize?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Collaborator

@gyagp gyagp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@xhcao xhcao merged commit 680101d into tensorflow:master Jun 12, 2023
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants