Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable probability output from RF binary classifier (alternative implementaton) #3869

Merged

Conversation

hcho3
Copy link
Contributor

@hcho3 hcho3 commented May 17, 2021

Alternative implementation of #3862 that does not depend on #3854
Closes #3764
Closes #2518

@hcho3 hcho3 requested a review from a team as a code owner May 17, 2021 16:06
@hcho3 hcho3 added 3 - Ready for Review Ready for review by team non-breaking Non-breaking change improvement Improvement / enhancement to an existing function labels May 17, 2021
Copy link
Contributor

@vinaydes vinaydes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apart from my comments the PR looks good to me.

int max_class_idx = 0;
int max_count = 0;
int total_count = 0;
for (int i = 0; i < input.nclasses; ++i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This loop was executed collaboratively by threads in the block. Now it is executed redundantly by all the threads in the block. Any particular reasons for that?

Copy link
Contributor Author

@hcho3 hcho3 May 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two reasons:

  1. I was following Refactor to extract random forest objectives #3854, which also redundantly computes the sum over the classes.
  2. This PR requires the computation of total_count, which is the sum of all elements of shist. If the loop were to run collaboratively, I'd need to define an extra data structure to perform reduction for total_count.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

total_count can be calculated with a call to cub::BlockReduce. Assuming nclasses is small this is small penalty. We could change it in future if that assumption breaks.
In that case, can the whole loop be moved inside the if (tid == 0) block on line 91?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vinaydes Thanks. What do you think of #3854, where the summing is performed redundantly in all threads? For example:

static DI LabelT LeafPrediction(BinT* shist, int nclasses) {
int class_idx = 0;
int count = 0;
for (int i = 0; i < nclasses; i++) {
auto current_count = shist[i].x;
if (current_count > count) {
class_idx = i;
count = current_count;
}
}
return class_idx;
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, in #3854, LeafPrediction() is called on line kernels.cuh#L141. It is already inside tid == 0 block. I could be wrong if it is called at some other place too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vinaydes Got it. In that case, I can move this loop inside the tid == 0 block?

info.prediction = pred;
info.colid = Leaf;
info.quesval = DataT(0); // don't care for leaf nodes
info.quesval = aux;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this reuse necessary? I understand quesval is unused in leaf node but may be better to use separate member variable for this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to avoid introducing a new member variable, since we want to introduce a new data structure to store the probability distribution for multi-class classifiers.

@codecov-commenter
Copy link

Codecov Report

❗ No coverage uploaded for pull request base (branch-21.06@4d06991). Click here to learn what that means.
The diff coverage is n/a.

Impacted file tree graph

@@               Coverage Diff               @@
##             branch-21.06    #3869   +/-   ##
===============================================
  Coverage                ?   85.41%           
===============================================
  Files                   ?      227           
  Lines                   ?    17315           
  Branches                ?        0           
===============================================
  Hits                    ?    14790           
  Misses                  ?     2525           
  Partials                ?        0           
Flag Coverage Δ
dask 48.93% <0.00%> (?)
non-dask 77.36% <0.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.


Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 4d06991...e2a0ed3. Read the comment docs.

@caryr35 caryr35 added this to PR-WIP in v21.06 Release via automation May 21, 2021
@dantegd dantegd added 4 - Waiting on Reviewer Waiting for reviewer to review or respond and removed 3 - Ready for Review Ready for review by team labels May 22, 2021
@dantegd
Copy link
Member

dantegd commented May 24, 2021

rerun tests

v21.06 Release automation moved this from PR-WIP to PR-Reviewer approved May 27, 2021
Copy link
Member

@dantegd dantegd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes lgtm, @vinaydes was wondering if you have any further comments or does this look good to merge?

@vinaydes
Copy link
Contributor

Good to merge 👍

@dantegd
Copy link
Member

dantegd commented May 27, 2021

@gpucibot merge

@rapids-bot rapids-bot bot merged commit 92484fb into rapidsai:branch-21.06 May 27, 2021
v21.06 Release automation moved this from PR-Reviewer approved to Done May 27, 2021
@hcho3 hcho3 deleted the rf_binary_prob_output_prototype_alt branch June 2, 2021 17:39
hcho3 added a commit to hcho3/cuml that referenced this pull request Jun 2, 2021
@hcho3 hcho3 mentioned this pull request Jun 2, 2021
rapids-bot bot pushed a commit that referenced this pull request Jun 2, 2021
Reverts #3869, as it was shown to reduce the test accuracy in some cases.

Closes #3910

Authors:
  - Philip Hyunsu Cho (https://github.com/hcho3)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: #3933
vimarsh6739 pushed a commit to vimarsh6739/cuml that referenced this pull request Oct 9, 2023
…ementaton) (rapidsai#3869)

Alternative implementation of rapidsai#3862 that does not depend on rapidsai#3854
Closes rapidsai#3764
Closes rapidsai#2518

Authors:
  - Philip Hyunsu Cho (https://github.com/hcho3)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)
  - Vinay Deshpande (https://github.com/vinaydes)

URL: rapidsai#3869
vimarsh6739 pushed a commit to vimarsh6739/cuml that referenced this pull request Oct 9, 2023
Reverts rapidsai#3869, as it was shown to reduce the test accuracy in some cases.

Closes rapidsai#3910

Authors:
  - Philip Hyunsu Cho (https://github.com/hcho3)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: rapidsai#3933
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
4 - Waiting on Reviewer Waiting for reviewer to review or respond CUDA/C++ improvement Improvement / enhancement to an existing function non-breaking Non-breaking change
Projects
No open projects
4 participants