-
Notifications
You must be signed in to change notification settings - Fork 74.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support SparseSegmentSum and SparseSegmentSumWithNumSegments on GPU #45085
Conversation
@Lifann Can you please resolve conflicts? Thanks! |
0ef4ada
to
69228f3
Compare
Thanks for your reminder! I have resolved the conflicts. |
segment_indices = [0, 1, 2, 2] | ||
num_indices = len(segment_indices) | ||
for tf_op in [math_ops.sparse_segment_sum_with_num_segments]: | ||
with self.cached_session(use_gpu=test_util.is_gpu_available()): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally you can do use_gpu=True
, it will use GPUs only when available.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay. It's very nice to have your review. I've modified the use_gpu
in session parameter to True
const TensorShape input_shape = input_data.shape(); | ||
Index element_size_muta = 1; | ||
if (input_dims > 1) { | ||
ACCUMULATE_MUL(input_shape, input_dims, accum); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd suggest just inlining the macro here. If you definitely need some abstraction, please use a function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"same length."), | ||
done); | ||
|
||
ScratchSpace<Index> output_rows_host(context, 1, /* on_host */ true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use the following spelling:
ScratchSpace<Index> output_rows_host(context, /*size=*/1, /*on_host=*/ true);
Our internal tooling recognizes this format and will check that the parameter names actually match up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const Tensor& num_segments = context->input(3); | ||
se::DeviceMemoryBase num_segments_device( | ||
const_cast<Tensor&>(num_segments).template flat<Index>().data() | ||
); | ||
OP_REQUIRES_ASYNC( | ||
context, | ||
stream | ||
->ThenMemcpy(output_rows_host.mutable_data(), num_segments_device, sizeof(Index)).ok(), | ||
errors::Internal( | ||
"SparseSegmentSumGpuOp: failed to copy num_segments to host."), | ||
done); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I'm reading the code correctly, the correct way to do this is to use the HostMemory
annotation. The runtime will then copy that input to the host before running the OpKernel
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great idea! It seems more concise!
registration
simpler impl
se::DeviceMemoryBase last_segment_id_on_device( | ||
const_cast<Tensor&>(segment_ids).template flat<Index>().data() + | ||
(num_indices - 1)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you sure the const_cast
is needed? I think you should be able to create and pass a const DeviceMemoryBase
to ThenMemcpy
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The pointer here is used to get the value of the last id in segment_ids
. So I think it might always need a conversion for the offset, maybe not const.
se::DeviceMemoryBase last_segment_id_on_device( | ||
const_cast<Tensor&>(segment_ids).template flat<Index>().data() + | ||
(num_indices - 1)); | ||
OP_REQUIRES_ASYNC( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use OP_REQUIRES_OK_ASYNC
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked the OP_REQUIRES_OK_ASYNC
macro. It seems that a Status
condition is needed. But ThenMemcpy
returns a Stream
, which only have an ok()
method with bool
return value. It's a little confusing with Status::ok()
.
} | ||
const Index element_size = element_size_muta; | ||
|
||
functor::SparseSegmentSumFunctor<T, Index> functor_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Local vars don't have a _
suffix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accepted
To discriminate from the namespace functor, I use executant
.
auto create_and_check_output = [context, num_indices, element_size, | ||
output_rows_host, &input_data, &indices, | ||
&segment_ids, &functor_, done]() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't we have a use-after-free on segment_ids
and functor_
?
I'd recommend creating a functor that has the captured values as member fields. This will make the types explicit and catch use-after-free issues.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (! has_num_segments) { | ||
output_rows++; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a short comment on why this is needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your advice. I've added the comment.
@Lifann Any update on this PR? Please. Thanks! |
1636b36
to
13c00d7
Compare
Sorry for the delay. And thanks for @sanjoy 's review. Besides, I also found and fixed an asynchronous problem in D2H copy, which may lead to the host gets |
13c00d7
to
3c1dffa
Compare
3c1dffa
to
6ac8ef4
Compare
324bd24
to
dcf8c57
Compare
@Lifann Is this ready for review or are you still working on it? |
This PR is ready for review. |
Hi @Lifann, Sorry for the delayed response, this PR somehow fell through the cracks. It seems like @benbarsdell has also implemented these ops in #47974. Would it be OK if we review & merged that PR (please feel free to leave a review there)? On first glance that implementation looks like it is vectorized so I'm expecting that to be faster. |
Thanks for the reply. I go check #47974. It seems to be much better than this PR. Good to know there is more professional developer aiming at the same goal. And also it's a nice chance for me to learn from a better experience. |
Then I close this PR. |
This is a PR from JIZHI Team & TaiJi AI platform in Tencent.
Currently, there is no GPU support for
sparse_segment_reduction
ops in Tensorflow. This commit offers GPU supports onSparseSegmentSum
andSparseSegmentSumWithNumSegments
. And we are looking forward to supporting more sparse reduction ops for training and inference in sparse scenarios.