From 6675c9e3429ec6a89c8e08c84ce24518a52f3236 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 4 Jun 2019 11:25:26 -0400 Subject: [PATCH] MAINT pass n_samples instead of sample_indices in GBDT (#14017) --- sklearn/ensemble/_hist_gradient_boosting/grower.py | 2 +- .../ensemble/_hist_gradient_boosting/splitting.pyx | 8 +++----- .../_hist_gradient_boosting/tests/test_splitting.py | 12 ++++++------ 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/sklearn/ensemble/_hist_gradient_boosting/grower.py b/sklearn/ensemble/_hist_gradient_boosting/grower.py index 691d10f16ac61..d6836c2bd4c75 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/grower.py +++ b/sklearn/ensemble/_hist_gradient_boosting/grower.py @@ -276,7 +276,7 @@ def _compute_best_split_and_push(self, node): """ node.split_info = self.splitter.find_node_split( - node.sample_indices, node.histograms, node.sum_gradients, + node.n_samples, node.histograms, node.sum_gradients, node.sum_hessians) if node.split_info.gain <= 0: # no valid split diff --git a/sklearn/ensemble/_hist_gradient_boosting/splitting.pyx b/sklearn/ensemble/_hist_gradient_boosting/splitting.pyx index 3beffff125af2..6dc6e58d9acff 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/splitting.pyx +++ b/sklearn/ensemble/_hist_gradient_boosting/splitting.pyx @@ -319,7 +319,7 @@ cdef class Splitter: def find_node_split( Splitter self, - const unsigned int [::1] sample_indices, # IN + unsigned int n_samples, hist_struct [:, ::1] histograms, # IN const Y_DTYPE_C sum_gradients, const Y_DTYPE_C sum_hessians): @@ -329,8 +329,8 @@ cdef class Splitter: Parameters ---------- - sample_indices : ndarray of unsigned int, shape (n_samples_at_node,) - The indices of the samples at the node to split. + n_samples : int + The number of samples at the node. histograms : ndarray of HISTOGRAM_DTYPE of \ shape (n_features, max_bins) The histograms of the current node. @@ -345,7 +345,6 @@ cdef class Splitter: The info about the best possible split among all features. """ cdef: - int n_samples int feature_idx int best_feature_idx int n_features = self.n_features @@ -353,7 +352,6 @@ cdef class Splitter: split_info_struct * split_infos with nogil: - n_samples = sample_indices.shape[0] split_infos = malloc( self.n_features * sizeof(split_info_struct)) diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_splitting.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_splitting.py index 1f72eac151bba..2fa94e06830ec 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_splitting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_splitting.py @@ -49,7 +49,7 @@ def test_histogram_split(n_bins): histograms = builder.compute_histograms_brute(sample_indices) split_info = splitter.find_node_split( - sample_indices, histograms, sum_gradients, + sample_indices.shape[0], histograms, sum_gradients, sum_hessians) assert split_info.bin_idx == true_bin @@ -103,17 +103,17 @@ def test_gradient_and_hessian_sanity(constant_hessian): min_samples_leaf, min_gain_to_split, constant_hessian) hists_parent = builder.compute_histograms_brute(sample_indices) - si_parent = splitter.find_node_split(sample_indices, hists_parent, + si_parent = splitter.find_node_split(n_samples, hists_parent, sum_gradients, sum_hessians) sample_indices_left, sample_indices_right, _ = splitter.split_indices( si_parent, sample_indices) hists_left = builder.compute_histograms_brute(sample_indices_left) hists_right = builder.compute_histograms_brute(sample_indices_right) - si_left = splitter.find_node_split(sample_indices_left, hists_left, + si_left = splitter.find_node_split(n_samples, hists_left, si_parent.sum_gradient_left, si_parent.sum_hessian_left) - si_right = splitter.find_node_split(sample_indices_right, hists_right, + si_right = splitter.find_node_split(n_samples, hists_right, si_parent.sum_gradient_right, si_parent.sum_hessian_right) @@ -203,7 +203,7 @@ def test_split_indices(): assert np.all(sample_indices == splitter.partition) histograms = builder.compute_histograms_brute(sample_indices) - si_root = splitter.find_node_split(sample_indices, histograms, + si_root = splitter.find_node_split(n_samples, histograms, sum_gradients, sum_hessians) # sanity checks for best split @@ -256,6 +256,6 @@ def test_min_gain_to_split(): hessians_are_constant) histograms = builder.compute_histograms_brute(sample_indices) - split_info = splitter.find_node_split(sample_indices, histograms, + split_info = splitter.find_node_split(n_samples, histograms, sum_gradients, sum_hessians) assert split_info.gain == -1