Skip to content

Commit

Permalink
MAINT pass n_samples instead of sample_indices in GBDT (#14017)
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug authored and glemaitre committed Jun 4, 2019
1 parent ccd3331 commit 6675c9e
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 12 deletions.
2 changes: 1 addition & 1 deletion sklearn/ensemble/_hist_gradient_boosting/grower.py
Expand Up @@ -276,7 +276,7 @@ def _compute_best_split_and_push(self, node):
"""

node.split_info = self.splitter.find_node_split(
node.sample_indices, node.histograms, node.sum_gradients,
node.n_samples, node.histograms, node.sum_gradients,
node.sum_hessians)

if node.split_info.gain <= 0: # no valid split
Expand Down
8 changes: 3 additions & 5 deletions sklearn/ensemble/_hist_gradient_boosting/splitting.pyx
Expand Up @@ -319,7 +319,7 @@ cdef class Splitter:

def find_node_split(
Splitter self,
const unsigned int [::1] sample_indices, # IN
unsigned int n_samples,
hist_struct [:, ::1] histograms, # IN
const Y_DTYPE_C sum_gradients,
const Y_DTYPE_C sum_hessians):
Expand All @@ -329,8 +329,8 @@ cdef class Splitter:
Parameters
----------
sample_indices : ndarray of unsigned int, shape (n_samples_at_node,)
The indices of the samples at the node to split.
n_samples : int
The number of samples at the node.
histograms : ndarray of HISTOGRAM_DTYPE of \
shape (n_features, max_bins)
The histograms of the current node.
Expand All @@ -345,15 +345,13 @@ cdef class Splitter:
The info about the best possible split among all features.
"""
cdef:
int n_samples
int feature_idx
int best_feature_idx
int n_features = self.n_features
split_info_struct split_info
split_info_struct * split_infos

with nogil:
n_samples = sample_indices.shape[0]

split_infos = <split_info_struct *> malloc(
self.n_features * sizeof(split_info_struct))
Expand Down
12 changes: 6 additions & 6 deletions sklearn/ensemble/_hist_gradient_boosting/tests/test_splitting.py
Expand Up @@ -49,7 +49,7 @@ def test_histogram_split(n_bins):

histograms = builder.compute_histograms_brute(sample_indices)
split_info = splitter.find_node_split(
sample_indices, histograms, sum_gradients,
sample_indices.shape[0], histograms, sum_gradients,
sum_hessians)

assert split_info.bin_idx == true_bin
Expand Down Expand Up @@ -103,17 +103,17 @@ def test_gradient_and_hessian_sanity(constant_hessian):
min_samples_leaf, min_gain_to_split, constant_hessian)

hists_parent = builder.compute_histograms_brute(sample_indices)
si_parent = splitter.find_node_split(sample_indices, hists_parent,
si_parent = splitter.find_node_split(n_samples, hists_parent,
sum_gradients, sum_hessians)
sample_indices_left, sample_indices_right, _ = splitter.split_indices(
si_parent, sample_indices)

hists_left = builder.compute_histograms_brute(sample_indices_left)
hists_right = builder.compute_histograms_brute(sample_indices_right)
si_left = splitter.find_node_split(sample_indices_left, hists_left,
si_left = splitter.find_node_split(n_samples, hists_left,
si_parent.sum_gradient_left,
si_parent.sum_hessian_left)
si_right = splitter.find_node_split(sample_indices_right, hists_right,
si_right = splitter.find_node_split(n_samples, hists_right,
si_parent.sum_gradient_right,
si_parent.sum_hessian_right)

Expand Down Expand Up @@ -203,7 +203,7 @@ def test_split_indices():
assert np.all(sample_indices == splitter.partition)

histograms = builder.compute_histograms_brute(sample_indices)
si_root = splitter.find_node_split(sample_indices, histograms,
si_root = splitter.find_node_split(n_samples, histograms,
sum_gradients, sum_hessians)

# sanity checks for best split
Expand Down Expand Up @@ -256,6 +256,6 @@ def test_min_gain_to_split():
hessians_are_constant)

histograms = builder.compute_histograms_brute(sample_indices)
split_info = splitter.find_node_split(sample_indices, histograms,
split_info = splitter.find_node_split(n_samples, histograms,
sum_gradients, sum_hessians)
assert split_info.gain == -1

0 comments on commit 6675c9e

Please sign in to comment.