Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance optimization of RF split kernels by removing empty cycles #3818

Merged

Conversation

vinaydes
Copy link
Contributor

@vinaydes vinaydes commented May 3, 2021

The compute split kernels for classification and regression end up doing lot of work that is not required. This PR removes lot of these empty work cycles by doing following changes:

  1. For computing split for a node, launch number of thread blocks proportional to number of samples in that node. Before this PR the number of thread blocks was fixed for all the nodes
  2. Check if a node is leaf before launching the kernel and if it is leaf, do not launch any thread blocks for it
  3. Don't call update on split, if not valid split is found for a feature
  4. Skip round trip to global memory before evaluating best split, if only one thread block is operating on a node

Performance improvement observed

Classification problem on a synthetic dataset computeSplitClassificationKernel timings

branch-0.20: 22.91 seconds
This branch:  5.27 seconds
Gain: 4.35x

Regression problem on synthetic dataset computeSplitRegessionKernel timings

branch-0.20: 36.46 seconds
This branch: 34.03 seconds
Gain: 1.07x

Empty cycles is not the major performance issue in regression code, therefore we do not see large improvement currently.

@hcho3 hcho3 self-requested a review May 18, 2021 23:50

// variables
auto end = range_start + range_len;
auto len = nbins * 2;
// auto len = nbins * 2;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we simply remove this line instead of commenting it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I need to redo the regression part anyway after merging with #3845. I'll remove it then.

auto cdf_spred_len = 2 * nbins;
IdxT stride = blockDim.x * gridDim.x;
IdxT tid = threadIdx.x + blockIdx.x * blockDim.x;
// auto cdf_spred_len = 2 * nbins;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we simply remove this line instead of commenting it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

@@ -655,7 +708,7 @@ __global__ void computeSplitRegressionKernel(
__syncthreads();

/* Make a second pass over the data to compute gain */

auto coloffset = col * input.M;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is coloffset used anywhere in the kernel?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems to be used in L716 and L729

@vinaydes vinaydes requested a review from a team as a code owner May 25, 2021 17:00
@github-actions github-actions bot added the Cython / Python Cython or Python issue label May 25, 2021
@github-actions github-actions bot removed the Cython / Python Cython or Python issue label May 25, 2021
@codecov-commenter
Copy link

Codecov Report

❗ No coverage uploaded for pull request base (branch-21.06@29a8390). Click here to learn what that means.
The diff coverage is n/a.

Impacted file tree graph

@@               Coverage Diff               @@
##             branch-21.06    #3818   +/-   ##
===============================================
  Coverage                ?   85.43%           
===============================================
  Files                   ?      226           
  Lines                   ?    17281           
  Branches                ?        0           
===============================================
  Hits                    ?    14764           
  Misses                  ?     2517           
  Partials                ?        0           
Flag Coverage Δ
dask 48.96% <0.00%> (?)
non-dask 77.41% <0.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.


Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 29a8390...64c9cb5. Read the comment docs.

@vinaydes
Copy link
Contributor Author

gbm-bench results for this PR
Datasets benchmarked
For classification: airline, Fraud, Higgs, Covtye, Epsilon
For regression: airline_regression, year
code used: https://github.com/NVIDIA/gbm-bench
Number of estimators: For Fraud, Higgs and Epsilon 100 estimator trained, 50 estimators for rest of the datasets.
max_samples is set to 0.5 for all the experiments.

Accuracy remains unchanged for both classification and regression

image
image

Fit time improves for both classification and regression

image

Removing sklearn for zooming in on impact of this PR

image

Improvement in percentage term

image

Covtype and Fraud are relatively tiny datasets. Their fit time performance change is not dominated by computesplit kernels. Instead nodeSplit kernel becomes the dominant one (>80% gpu time) for that size. Therefore this PR has little to no impact on them.

@teju85
Copy link
Member

teju85 commented May 26, 2021

@JohnZed or @dantegd can we get python-side approval so that this PR can be merged?

@dantegd
Copy link
Member

dantegd commented May 26, 2021

@gpucibot merge

@rapids-bot rapids-bot bot merged commit f5a3483 into rapidsai:branch-21.06 May 26, 2021
vimarsh6739 pushed a commit to vimarsh6739/cuml that referenced this pull request Oct 9, 2023
…rapidsai#3818)

The compute split kernels for classification and regression end up doing lot of work that is not required. This PR removes lot of these empty work cycles by doing following changes:
1. For computing split for a node, launch number of thread blocks proportional to number of samples in that node. Before this PR the number of thread blocks was fixed for all the nodes
2. Check if a node is leaf before launching the kernel and if it is leaf, do not launch any thread blocks for it
3. Don't call update on split, if not valid split is found for a feature
4. Skip round trip to global memory before evaluating best split, if only one thread block is operating on a node

**Performance improvement observed**

Classification problem on a synthetic dataset `computeSplitClassificationKernel` timings
```
branch-0.20: 22.91 seconds
This branch:  5.27 seconds
Gain: 4.35x
```

Regression problem on synthetic dataset `computeSplitRegessionKernel` timings
```
branch-0.20: 36.46 seconds
This branch: 34.03 seconds
Gain: 1.07x
```
Empty cycles is not the major performance issue in regression code, therefore we do not see large improvement currently.

Authors:
  - Vinay Deshpande (https://github.com/vinaydes)
  - Rory Mitchell (https://github.com/RAMitchell)

Approvers:
  - Rory Mitchell (https://github.com/RAMitchell)
  - Thejaswi. N. S (https://github.com/teju85)
  - Philip Hyunsu Cho (https://github.com/hcho3)
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: rapidsai#3818
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
5 - Merge After Dependencies Depends on another PR: do not merge out of order CUDA/C++ improvement Improvement / enhancement to an existing function non-breaking Non-breaking change Perf Related to runtime performance of the underlying code
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants