Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Fast PDPs for histogram-based GBDT #13769

Merged
merged 28 commits into from
Jul 11, 2019
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
290669d
fast partial dep cleaning
NicolasHug Apr 27, 2019
cd1e4a3
minor pep8
NicolasHug Apr 28, 2019
028beda
WIP
NicolasHug May 2, 2019
72269fc
Merge branch 'master' into fast_partial_dep_hist_gbdt
NicolasHug May 2, 2019
e54fe79
Merge branch 'master' into fast_partial_dep_hist_gbdt
NicolasHug May 2, 2019
444c4f6
tests and fixes
NicolasHug May 3, 2019
cd7d64d
docstrings
NicolasHug May 3, 2019
b60120c
docstrings
NicolasHug May 3, 2019
52ee07b
Merge branch 'fast_partial_dep_hist_gbdt' of github.com:NicolasHug/sc…
NicolasHug May 3, 2019
4ec6818
pep8
NicolasHug May 3, 2019
a16007b
more docs
NicolasHug May 3, 2019
ed06695
Use fast gradient boosting in PDP example
ogrisel May 9, 2019
dc5f944
Better MLP, better notebook layout
ogrisel May 9, 2019
03308bd
Fix shift in y in example
ogrisel May 9, 2019
341bcc3
Small example reorg
ogrisel May 9, 2019
18862a4
Make the example run faster without changing too much the plots
ogrisel May 9, 2019
369cd5d
Avoid oversubscription in Circle CI docker container
ogrisel May 9, 2019
879147f
One more tweak to the example
ogrisel May 9, 2019
f0f8641
Better colormap for the 2D interaction plot
ogrisel May 9, 2019
e7a700a
Revert "Better colormap for the 2D interaction plot"
ogrisel May 9, 2019
67b708d
Various nitpicks
ogrisel May 23, 2019
0319e9d
Merge branch 'master' of github.com:scikit-learn/scikit-learn into fa…
NicolasHug May 25, 2019
8a81d69
Merge branch 'fast_partial_dep_hist_gbdt' of github.com:NicolasHug/sc…
NicolasHug May 25, 2019
2e7a79d
Merge branch 'master' of github.com:scikit-learn/scikit-learn into fa…
NicolasHug Jun 21, 2019
6100123
Merge branch 'master' of github.com:scikit-learn/scikit-learn into fa…
NicolasHug Jul 4, 2019
fa5485a
bigger plots?
NicolasHug Jul 4, 2019
fcd0fc2
minor change in words
NicolasHug Jul 5, 2019
bc30c60
tight layout?
NicolasHug Jul 5, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions sklearn/ensemble/_hist_gradient_boosting/_predictor.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ cimport numpy as np

from .types cimport X_DTYPE_C
from .types cimport Y_DTYPE_C
from .types import Y_DTYPE
from .types cimport X_BINNED_DTYPE_C


Expand Down Expand Up @@ -98,3 +99,113 @@ cdef inline Y_DTYPE_C _predict_one_from_binned_data(
node = nodes[node.left]
else:
node = nodes[node.right]

def _compute_partial_dependence(
node_struct [:] nodes,
const X_DTYPE_C [:, ::1] X,
int [:] target_features,
Y_DTYPE_C [:] out):
"""Partial dependence of the response on the ``target_features`` set.

For each sample in ``X`` a tree traversal is performed.
Each traversal starts from the root with weight 1.0.

At each non-leaf node that splits on a target feature, either
the left child or the right child is visited based on the feature
value of the current sample, and the weight is not modified.
At each non-leaf node that splits on a complementary feature,
both children are visited and the weight is multiplied by the fraction
of training samples which went to each child.

At each leaf, the value of the node is multiplied by the current
weight (weights sum to 1 for all visited terminal nodes).

Parameters
----------
nodes : view on array of PREDICTOR_RECORD_DTYPE, shape (n_nodes)
The array representing the predictor tree.
X : view on 2d ndarray, shape (n_samples, n_target_features)
The grid points on which the partial dependence should be
evaluated.
target_features : view on 1d ndarray, shape (n_target_features)
The set of target features for which the partial dependence
should be evaluated.
out : view on 1d ndarray, shape (n_samples)
The value of the partial dependence function on each grid
point.
"""

cdef:

ogrisel marked this conversation as resolved.
Show resolved Hide resolved
unsigned int current_node_idx
unsigned int [:] node_idx_stack = np.zeros(shape=nodes.shape[0],
dtype=np.uint32)
Y_DTYPE_C [::1] weight_stack = np.zeros(shape=nodes.shape[0],
dtype=Y_DTYPE)
node_struct * current_node # pointer to avoid copying attributes

unsigned int sample_idx
unsigned feature_idx
unsigned stack_size
Y_DTYPE_C left_sample_frac
Y_DTYPE_C current_weight
Y_DTYPE_C total_weight # used for sanity check only
bint is_target_feature

for sample_idx in range(X.shape[0]):
# init stacks for current sample
stack_size = 1
node_idx_stack[0] = 0 # root node
weight_stack[0] = 1 # all the samples are in the root node
total_weight = 0

while stack_size > 0:

# pop the stack
stack_size -= 1
current_node_idx = node_idx_stack[stack_size]
current_node = &nodes[current_node_idx]

if current_node.is_leaf:
out[sample_idx] += (weight_stack[stack_size] *
current_node.value)
total_weight += weight_stack[stack_size]
else:
# determine if the split feature is a target feature
is_target_feature = False
for feature_idx in range(target_features.shape[0]):
if target_features[feature_idx] == current_node.feature_idx:
is_target_feature = True
break

if is_target_feature:
# In this case, we push left or right child on stack
if X[sample_idx, feature_idx] <= current_node.threshold:
node_idx_stack[stack_size] = current_node.left
else:
node_idx_stack[stack_size] = current_node.right
stack_size += 1
else:
# In this case, we push both children onto the stack,
# and give a weight proportional to the number of
# samples going through each branch.

# push left child
node_idx_stack[stack_size] = current_node.left
left_sample_frac = (
<Y_DTYPE_C> nodes[current_node.left].count /
current_node.count)
current_weight = weight_stack[stack_size]
weight_stack[stack_size] = current_weight * left_sample_frac
stack_size += 1

# push right child
node_idx_stack[stack_size] = current_node.right
weight_stack[stack_size] = (
current_weight * (1 - left_sample_frac))
stack_size += 1

# Sanity check. Should never happen.
if not (0.999 < total_weight < 1.001):
raise ValueError("Total weight should be 1.0 but was %.9f" %
total_weight)
31 changes: 31 additions & 0 deletions sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,37 @@ def _raw_predict(self, X):

return raw_predictions

def _compute_partial_dependence_recursion(self, grid, target_features):
"""Fast partial dependence computation.

Parameters
----------
grid : ndarray, shape (n_samples, n_target_features)
The grid points on which the partial dependence should be
evaluated.
target_features : ndarray, shape (n_target_features)
The set of target features for which the partial dependence
should be evaluated.

Returns
-------
averaged_predictions : ndarray, shape \
(n_trees_per_iteration, n_samples)
The value of the partial dependence function on each grid point.
"""
grid = np.asarray(grid, dtype=X_DTYPE, order='C')
averaged_predictions = np.zeros(
(self.n_trees_per_iteration_, grid.shape[0]), dtype=Y_DTYPE)

for predictors_of_ith_iteration in self._predictors:
for k, predictor in enumerate(predictors_of_ith_iteration):
predictor.compute_partial_dependence(grid, target_features,
averaged_predictions[k])
# Note that the learning rate is already accounted for in the leaves
# values.

return averaged_predictions

@abstractmethod
def _get_loss(self):
pass
Expand Down
18 changes: 18 additions & 0 deletions sklearn/ensemble/_hist_gradient_boosting/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .types import X_BINNED_DTYPE
from ._predictor import _predict_from_numeric_data
from ._predictor import _predict_from_binned_data
from ._predictor import _compute_partial_dependence


PREDICTOR_RECORD_DTYPE = np.dtype([
Expand Down Expand Up @@ -78,3 +79,20 @@ def predict_binned(self, X):
out = np.empty(X.shape[0], dtype=Y_DTYPE)
_predict_from_binned_data(self.nodes, X, out)
return out

def compute_partial_dependence(self, grid, target_features, out):
ogrisel marked this conversation as resolved.
Show resolved Hide resolved
"""Fast partial dependence computation.

Parameters
----------
grid : ndarray, shape (n_samples, n_target_features)
The grid points on which the partial dependence should be
evaluated.
target_features : ndarray, shape (n_target_features)
The set of target features for which the partial dependence
should be evaluated.
out : ndarray, shape (n_samples)
The value of the partial dependence function on each grid
point.
"""
_compute_partial_dependence(self.nodes, grid, target_features, out)
40 changes: 40 additions & 0 deletions sklearn/ensemble/gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from abc import ABCMeta
from abc import abstractmethod
import warnings

from .base import BaseEnsemble
from ..base import ClassifierMixin
Expand Down Expand Up @@ -1730,6 +1731,45 @@ def feature_importances_(self):
axis=0, dtype=np.float64)
return avg_feature_importances / np.sum(avg_feature_importances)

def _compute_partial_dependence_recursion(self, grid, target_features):
"""Fast partial dependence computation.

Parameters
----------
grid : ndarray, shape (n_samples, n_target_features)
The grid points on which the partial dependence should be
evaluated.
target_features : ndarray, shape (n_target_features)
The set of target features for which the partial dependence
should be evaluated.

Returns
-------
averaged_predictions : ndarray, shape \
(n_trees_per_iteration, n_samples)
The value of the partial dependence function on each grid point.
"""
check_is_fitted(self, 'estimators_',
msg="'estimator' parameter must be a fitted estimator")
if self.init is not None:
warnings.warn(
'Using recursion method with a non-constant init predictor '
'will lead to incorrect partial dependence values.',
ogrisel marked this conversation as resolved.
Show resolved Hide resolved
UserWarning
)
grid = np.asarray(grid, dtype=DTYPE, order='C')
n_estimators, n_trees_per_stage = self.estimators_.shape
averaged_predictions = np.zeros((n_trees_per_stage, grid.shape[0]),
dtype=np.float64, order='C')
for stage in range(n_estimators):
for k in range(n_trees_per_stage):
tree = self.estimators_[stage, k].tree_
tree.compute_partial_dependence(grid, target_features,
averaged_predictions[k])
averaged_predictions *= self.learning_rate

return averaged_predictions

def _validate_y(self, y, sample_weight):
# 'sample_weight' is not utilised but is used for
# consistency with similar method _validate_y of GBC
Expand Down
Loading