Skip to content

Commit

Permalink
MNT Deprecate X_idx_sorted in tree module (#17614)
Browse files Browse the repository at this point in the history
  • Loading branch information
alfaro96 committed Jun 17, 2020
1 parent 03f8a2e commit e4ebcbc
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 59 deletions.
5 changes: 5 additions & 0 deletions doc/whats_new/v0.24.rst
Expand Up @@ -144,6 +144,11 @@ Changelog
- |Enhancement| :func:`tree.plot_tree` now uses colors from the matplotlib
configuration settings. :pr:`17187` by `Andreas Müller`_.

- |API|: The parameter ``X_idx_sorted`` is now deprecated in
:meth:`tree.DecisionTreeClassifier.fit` and
:meth:`tree.DecisionTreeRegressor.fit`, and has not effect.
:pr:`17614` by :user:`Juan Carlos Alfaro Jiménez <alfaro96>`.

:mod:`sklearn.neighbors`
.............................

Expand Down
12 changes: 5 additions & 7 deletions sklearn/ensemble/_gb.py
Expand Up @@ -166,7 +166,7 @@ def __init__(self, *, loss, learning_rate, n_estimators, criterion,
self.tol = tol

def _fit_stage(self, i, X, y, raw_predictions, sample_weight, sample_mask,
random_state, X_idx_sorted, X_csc=None, X_csr=None):
random_state, X_csc=None, X_csr=None):
"""Fit another stage of ``n_classes_`` trees to the boosting model. """

assert sample_mask.dtype == np.bool
Expand Down Expand Up @@ -207,7 +207,7 @@ def _fit_stage(self, i, X, y, raw_predictions, sample_weight, sample_mask,

X = X_csr if X_csr is not None else X
tree.fit(X, residual, sample_weight=sample_weight,
check_input=False, X_idx_sorted=X_idx_sorted)
check_input=False)

# update tree leaves
loss.update_terminal_regions(
Expand Down Expand Up @@ -482,12 +482,10 @@ def fit(self, X, y, sample_weight=None, monitor=None):
raw_predictions = self._raw_predict(X)
self._resize_state()

X_idx_sorted = None

# fit the boosting stages
n_stages = self._fit_stages(
X, y, raw_predictions, sample_weight, self._rng, X_val, y_val,
sample_weight_val, begin_at_stage, monitor, X_idx_sorted)
sample_weight_val, begin_at_stage, monitor)

# change shape of arrays after fit (early-stopping or additional ests)
if n_stages != self.estimators_.shape[0]:
Expand All @@ -501,7 +499,7 @@ def fit(self, X, y, sample_weight=None, monitor=None):

def _fit_stages(self, X, y, raw_predictions, sample_weight, random_state,
X_val, y_val, sample_weight_val,
begin_at_stage=0, monitor=None, X_idx_sorted=None):
begin_at_stage=0, monitor=None):
"""Iteratively fits the stages.
For each stage it computes the progress (OOB, train score)
Expand Down Expand Up @@ -544,7 +542,7 @@ def _fit_stages(self, X, y, raw_predictions, sample_weight, random_state,
# fit next stage of trees
raw_predictions = self._fit_stage(
i, X, y, raw_predictions, sample_weight, sample_mask,
random_state, X_idx_sorted, X_csc, X_csr)
random_state, X_csc, X_csr)

# track deviance (= loss)
if do_oob:
Expand Down
37 changes: 21 additions & 16 deletions sklearn/tree/_classes.py
Expand Up @@ -138,7 +138,7 @@ def get_n_leaves(self):
return self.tree_.n_leaves

def fit(self, X, y, sample_weight=None, check_input=True,
X_idx_sorted=None):
X_idx_sorted="deprecated"):

random_state = check_random_state(self.random_state)

Expand Down Expand Up @@ -317,6 +317,13 @@ def fit(self, X, y, sample_weight=None, check_input=True,
raise ValueError("min_impurity_decrease must be greater than "
"or equal to 0")

# TODO: Remove in v0.26
if X_idx_sorted != "deprecated":
warnings.warn("The parameter 'X_idx_sorted' is deprecated and has "
"no effect. It will be removed in v0.26. You can "
"suppress this warning by not passing any value to "
"the 'X_idx_sorted' parameter.", FutureWarning)

# Build tree
criterion = self.criterion
if not isinstance(criterion, Criterion):
Expand Down Expand Up @@ -363,7 +370,7 @@ def fit(self, X, y, sample_weight=None, check_input=True,
self.min_impurity_decrease,
min_impurity_split)

builder.build(self.tree_, X, y, sample_weight, X_idx_sorted)
builder.build(self.tree_, X, y, sample_weight)

if self.n_outputs_ == 1 and is_classifier(self):
self.n_classes_ = self.n_classes_[0]
Expand Down Expand Up @@ -834,7 +841,7 @@ def __init__(self, *,
ccp_alpha=ccp_alpha)

def fit(self, X, y, sample_weight=None, check_input=True,
X_idx_sorted=None):
X_idx_sorted="deprecated"):
"""Build a decision tree classifier from the training set (X, y).
Parameters
Expand All @@ -858,12 +865,11 @@ def fit(self, X, y, sample_weight=None, check_input=True,
Allow to bypass several input checking.
Don't use this parameter unless you know what you do.
X_idx_sorted : array-like of shape (n_samples, n_features), \
default=None
The indexes of the sorted training input samples. If many tree
are grown on the same dataset, this allows the ordering to be
cached between trees. If None, the data will be sorted here.
Don't use this parameter unless you know what to do.
X_idx_sorted : deprecated, default="deprecated"
This parameter is deprecated and has no effect.
It will be removed in v0.26.
.. deprecated :: 0.24
Returns
-------
Expand Down Expand Up @@ -1180,7 +1186,7 @@ def __init__(self, *,
ccp_alpha=ccp_alpha)

def fit(self, X, y, sample_weight=None, check_input=True,
X_idx_sorted=None):
X_idx_sorted="deprecated"):
"""Build a decision tree regressor from the training set (X, y).
Parameters
Expand All @@ -1203,12 +1209,11 @@ def fit(self, X, y, sample_weight=None, check_input=True,
Allow to bypass several input checking.
Don't use this parameter unless you know what you do.
X_idx_sorted : array-like of shape (n_samples, n_features), \
default=None
The indexes of the sorted training input samples. If many tree
are grown on the same dataset, this allows the ordering to be
cached between trees. If None, the data will be sorted here.
Don't use this parameter unless you know what to do.
X_idx_sorted : deprecated, default="deprecated"
This parameter is deprecated and has no effect.
It will be removed in v0.26.
.. deprecated :: 0.24
Returns
-------
Expand Down
3 changes: 1 addition & 2 deletions sklearn/tree/_splitter.pxd
Expand Up @@ -78,8 +78,7 @@ cdef class Splitter:

# Methods
cdef int init(self, object X, const DOUBLE_t[:, ::1] y,
DOUBLE_t* sample_weight,
np.ndarray X_idx_sorted=*) except -1
DOUBLE_t* sample_weight) except -1

cdef int node_reset(self, SIZE_t start, SIZE_t end,
double* weighted_n_node_samples) nogil except -1
Expand Down
27 changes: 3 additions & 24 deletions sklearn/tree/_splitter.pyx
Expand Up @@ -116,8 +116,7 @@ cdef class Splitter:
cdef int init(self,
object X,
const DOUBLE_t[:, ::1] y,
DOUBLE_t* sample_weight,
np.ndarray X_idx_sorted=None) except -1:
DOUBLE_t* sample_weight) except -1:
"""Initialize the splitter.
Take in the input data X, the target Y, and optional sample weights.
Expand All @@ -137,9 +136,6 @@ cdef class Splitter:
The weights of the samples, where higher weighted samples are fit
closer than lower weight samples. If not provided, all samples
are assumed to have uniform weight.
X_idx_sorted : ndarray, default=None
The indexes of the sorted training input samples
"""

self.rand_r_state = self.random_state.randint(0, RAND_R_MAX)
Expand Down Expand Up @@ -240,25 +236,12 @@ cdef class Splitter:
cdef class BaseDenseSplitter(Splitter):
cdef const DTYPE_t[:, :] X

cdef np.ndarray X_idx_sorted
cdef INT32_t* X_idx_sorted_ptr
cdef SIZE_t X_idx_sorted_stride
cdef SIZE_t n_total_samples
cdef SIZE_t* sample_mask

def __cinit__(self, Criterion criterion, SIZE_t max_features,
SIZE_t min_samples_leaf, double min_weight_leaf,
object random_state):

self.X_idx_sorted_ptr = NULL
self.X_idx_sorted_stride = 0
self.sample_mask = NULL

cdef int init(self,
object X,
const DOUBLE_t[:, ::1] y,
DOUBLE_t* sample_weight,
np.ndarray X_idx_sorted=None) except -1:
DOUBLE_t* sample_weight) except -1:
"""Initialize the splitter
Returns -1 in case of failure to allocate memory (and raise MemoryError)
Expand Down Expand Up @@ -303,9 +286,6 @@ cdef class BestSplitter(BaseDenseSplitter):
cdef double min_weight_leaf = self.min_weight_leaf
cdef UINT32_t* random_state = &self.rand_r_state

cdef INT32_t* X_idx_sorted = self.X_idx_sorted_ptr
cdef SIZE_t* sample_mask = self.sample_mask

cdef SplitRecord best, current
cdef double current_proxy_improvement = -INFINITY
cdef double best_proxy_improvement = -INFINITY
Expand Down Expand Up @@ -818,8 +798,7 @@ cdef class BaseSparseSplitter(Splitter):
cdef int init(self,
object X,
const DOUBLE_t[:, ::1] y,
DOUBLE_t* sample_weight,
np.ndarray X_idx_sorted=None) except -1:
DOUBLE_t* sample_weight) except -1:
"""Initialize the splitter
Returns -1 in case of failure to allocate memory (and raise MemoryError)
Expand Down
3 changes: 1 addition & 2 deletions sklearn/tree/_tree.pxd
Expand Up @@ -100,6 +100,5 @@ cdef class TreeBuilder:
cdef double min_impurity_decrease # Impurity threshold for early stopping

cpdef build(self, Tree tree, object X, np.ndarray y,
np.ndarray sample_weight=*,
np.ndarray X_idx_sorted=*)
np.ndarray sample_weight=*)
cdef _check_input(self, object X, np.ndarray y, np.ndarray sample_weight)
13 changes: 5 additions & 8 deletions sklearn/tree/_tree.pyx
Expand Up @@ -92,8 +92,7 @@ cdef class TreeBuilder:
"""Interface for different tree building strategies."""

cpdef build(self, Tree tree, object X, np.ndarray y,
np.ndarray sample_weight=None,
np.ndarray X_idx_sorted=None):
np.ndarray sample_weight=None):
"""Build a decision tree from the training set (X, y)."""
pass

Expand Down Expand Up @@ -144,8 +143,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
self.min_impurity_split = min_impurity_split

cpdef build(self, Tree tree, object X, np.ndarray y,
np.ndarray sample_weight=None,
np.ndarray X_idx_sorted=None):
np.ndarray sample_weight=None):
"""Build a decision tree from the training set (X, y)."""

# check input
Expand Down Expand Up @@ -175,7 +173,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
cdef double min_impurity_split = self.min_impurity_split

# Recursive partition (without actual recursion)
splitter.init(X, y, sample_weight_ptr, X_idx_sorted)
splitter.init(X, y, sample_weight_ptr)

cdef SIZE_t start
cdef SIZE_t end
Expand Down Expand Up @@ -314,8 +312,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder):
self.min_impurity_split = min_impurity_split

cpdef build(self, Tree tree, object X, np.ndarray y,
np.ndarray sample_weight=None,
np.ndarray X_idx_sorted=None):
np.ndarray sample_weight=None):
"""Build a decision tree from the training set (X, y)."""

# check input
Expand All @@ -333,7 +330,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder):
cdef SIZE_t min_samples_split = self.min_samples_split

# Recursive partition (without actual recursion)
splitter.init(X, y, sample_weight_ptr, X_idx_sorted)
splitter.init(X, y, sample_weight_ptr)

cdef PriorityHeap frontier = PriorityHeap(INITIAL_STACK_SIZE)
cdef PriorityHeapRecord record
Expand Down
13 changes: 13 additions & 0 deletions sklearn/tree/tests/test_tree.py
Expand Up @@ -1948,3 +1948,16 @@ def check_apply_path_readonly(name):
@pytest.mark.parametrize("name", ALL_TREES)
def test_apply_path_readonly_all_trees(name):
check_apply_path_readonly(name)


# TODO: Remove in v0.26
@pytest.mark.parametrize("TreeEstimator", [DecisionTreeClassifier,
DecisionTreeRegressor])
def test_X_idx_sorted_deprecated(TreeEstimator):
X_idx_sorted = np.argsort(X, axis=0)

tree = TreeEstimator()

with pytest.warns(FutureWarning,
match="The parameter 'X_idx_sorted' is deprecated"):
tree.fit(X, y, X_idx_sorted=X_idx_sorted)

0 comments on commit e4ebcbc

Please sign in to comment.