-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WebGPU] support SparseSegmentSum and SparseSegmentMean #7746
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with two nits.
const $segmentIds = backend.readSync(segmentIds.dataId) as TypedArray; | ||
const lastSegmentIdPlusOne = | ||
numIndices > 0 ? $segmentIds[numIndices - 1] + 1 : 0; | ||
const outputRows = lastSegmentIdPlusOne; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
outputRows
-> numSegments
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
outputRows
is not equal with numSegments
here. For example, numSegments = [1, 2, 10], the output shape is [11, x, x, x], not [3, x, x, x] here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/sparse-segment-mean,
Output: Has same shape as data, except for dimension 0 which has size k, the number of segments.
Here numSegments
, I mean the number of segments
, which I referenced the name numSegments
in tf.unsortedSegmentSum(x, segmentIds, numSegments)
. It's a number not an array. What do you mean numSegments = [1, 2, 10]
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const c = tf.tensor2d([[1,2,3,4], [-1,-2,-3,-4], [6,7,8,9]]);
const result = tf.sparse.sparseSegmentMean(c, tf.tensor1d([0, 1, 2], 'int32'), tf.tensor1d([0, 2, 10], 'int32'));
The result's shape is [11, 4], and outputRows =
11, if defined
numSegments = $segmentIds[numIndices - 1] + 1 = 11here, I think it is confused and mostly we will think
numSegments = 3here, only three rows are filled values from
data`, other rows are initialized into zeros.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think 3 stands for numSegmentIds
not numSegments
, which are different. I just think outputRows
is not a good name. But if you persist, I am ok you proceed without change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The outputRows
comes from cpu backend, let us keep them the same name. Thanks.
getUserCode(): string { | ||
const userCode = ` | ||
${main('index')} { | ||
if (index < uniforms.segmentIdsSize) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe directly use uniforms.segmentIdsShape
which is the input's shape to avoid the extra uniform segmentIdsSize
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.