-
Notifications
You must be signed in to change notification settings - Fork 875
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FIX] Fix incorrect nullCount
in get_json_object
#11633
Conversation
Fixes an issue where `get_json_object` returns an incorrect null count on large inputs. The source of the bug is that each thread was resetting `out_valid_count` to zero, then only the threads that execute in the final block contribute to the value of `out_valid_count`.
Would it be possible to add a gtest that creates the issue? |
@davidwendt this is difficult to replicate in a gtest because it'll only manifest when there are more rows than threads (which also makes the bug GPU-dependent?). Do we have a way to generate large amounts of test data? For reference, here's the valid counts before/after this change on a dataset with ~2M rows:
|
Seems that it would occur when there are more rows than a block of threads if I'm reading the code right. cudf/cpp/src/strings/json/json_path.cu Line 987 in 1aa4ec8
So perhaps a test with more than 512 rows would work? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I no longer think a specific gtest is necessary here.
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## branch-22.10 #11633 +/- ##
===============================================
Coverage ? 86.41%
===============================================
Files ? 145
Lines ? 22993
Branches ? 0
===============================================
Hits ? 19870
Misses ? 3123
Partials ? 0 Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
@@ -909,7 +909,6 @@ __launch_bounds__(block_size) __global__ | |||
size_type tid = threadIdx.x + (blockDim.x * blockIdx.x); | |||
size_type stride = blockDim.x * gridDim.x; | |||
|
|||
if (out_valid_count.has_value()) { *(out_valid_count.value()) = 0; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this reset is no longer needed? I see that out_valid_count
is then updated by atomicAdd
, but what if out_valid_count
was never been initialized before being passed into this function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is initialized here:
cudf/cpp/src/strings/json/json_path.cu
Line 1020 in 1aa4ec8
rmm::device_scalar<size_type> d_valid_count{0, stream}; |
I don't know why this if-statement was here. Maybe it was for some debug purpose.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If so then this looks somewhat unsafe IMO because this needs to rely on the caller to initialize the variable without any guarantee from anywhere.
If we decided to not have the initialization here, it is better to have a comment line in the function doxygen clearly stressing/clarifying that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved the initialization @ttnghia @davidwendt please re-review.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If so then this looks somewhat unsafe IMO because this needs to rely on the caller to initialize the variable without any guarantee from anywhere.
Passing a pointer to an uninitialized device_scalar seems anti-RAII. If this were regular host code passing a pointer to a function and expecting that function to initialize it, without a very good reason (like that function has knowledge that the caller doesn't about how it should be initialized) I'd probably flag it as a code smell.
This proposes a method to ensure that unsafe initialization of `out_valid_field` are handled. I tested it locally with the sample code that @trxcllnt and had neither failure or incorrect null count.
@gpucibot merge |
Fixes an issue where
get_json_object
returns an incorrect null count on large inputs.The source of the bug is that each thread was resetting
out_valid_count
to zero, then only the threads that execute in the final block contribute to the value ofout_valid_count
.cc: @thomcom