diff --git a/benchmarks/bench_tree_nocats.py b/benchmarks/bench_tree_nocats.py new file mode 100644 index 0000000000000..803e85f87318c --- /dev/null +++ b/benchmarks/bench_tree_nocats.py @@ -0,0 +1,103 @@ +from timeit import timeit +from itertools import product +import numpy as np +import pandas as pd + +from sklearn.preprocessing import OneHotEncoder +from sklearn.model_selection import StratifiedKFold +from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier +from sklearn.metrics import roc_auc_score +from sklearn.datasets import fetch_openml + + +def get_data(trunc_ncat): + # the data is located here: https://www.openml.org/d/4135 + X, y = fetch_openml(data_id=4135, return_X_y=True) + X = pd.DataFrame(X) + + Xdicts = [] + for trunc in trunc_ncat: + X_trunc = X % trunc if trunc > 0 else X + keep_idx = np.array([idx[0] for idx in + X_trunc.groupby(list(X.columns)).groups.values()]) + X_trunc = X_trunc.values[keep_idx] + y_trunc = y[keep_idx] + + X_ohe = OneHotEncoder(categories='auto').fit_transform(X_trunc) + + Xdicts.append({'X': X_trunc, 'y': y_trunc, 'ohe': False, + 'trunc': trunc}) + Xdicts.append({'X': X_ohe, 'y': y_trunc, 'ohe': True, + 'trunc': trunc}) + + return Xdicts + + +# Training dataset +trunc_factor = [2, 3, 4, 5, 6, 8, 10, 12, 14, 16, 64, 0] +data = get_data(trunc_factor) +results = [] +# Loop over classifiers and datasets +for Xydict, clf_type in product( + data, [RandomForestClassifier, ExtraTreesClassifier]): + + # Can't use non-truncated categorical data with RandomForest + # and it becomes intractable with too many categories + if (clf_type is RandomForestClassifier and + not Xydict['ohe'] and + (not Xydict['trunc'] or Xydict['trunc'] > 16)): + continue + + X, y = Xydict['X'], Xydict['y'] + tech = 'One-hot' if Xydict['ohe'] else 'NOCATS' + trunc = ('truncated({})'.format(Xydict['trunc']) if Xydict['trunc'] > 0 + else 'full') + cat = 'none' if Xydict['ohe'] else 'all' + cv = StratifiedKFold(n_splits=5, shuffle=True, + random_state=17).split(X, y) + + traintimes = [] + testtimes = [] + aucs = [] + name = '({}, {}, {})'.format(clf_type.__name__, trunc, tech) + + for train, test in cv: + # Train + clf = clf_type(n_estimators=10, max_features=None, + min_samples_leaf=1, random_state=23, + bootstrap=False, max_depth=None, + categorical=cat) + + traintimes.append(timeit( + "clf.fit(X[train], y[train])".format(cat), + 'from __main__ import clf, X, y, train', number=1)) + + """ + # Check that all leaf nodes are pure + for est in clf.estimators_: + leaves = est.tree_.children_left < 0 + print(np.max(est.tree_.impurity[leaves])) + #assert(np.all(est.tree_.impurity[leaves] == 0)) + """ + + # Test + probs = [] + testtimes.append(timeit( + 'probs.append(clf.predict_proba(X[test]))', + 'from __main__ import probs, clf, X, test', number=1)) + + aucs.append(roc_auc_score(y[test], probs[0][:, 1])) + + traintimes = np.array(traintimes) + testtimes = np.array(testtimes) + aucs = np.array(aucs) + results.append([name, traintimes.mean(), traintimes.std(), + testtimes.mean(), testtimes.std(), + aucs.mean(), aucs.std()]) + + results_df = pd.DataFrame(results) + results_df.columns = ['name', 'train time mean', 'train time std', + 'test time mean', 'test time std', + 'auc mean', 'auc std'] + results_df = results_df.set_index('name') + print(results_df) diff --git a/sklearn/ensemble/_gradient_boosting.pyx b/sklearn/ensemble/_gradient_boosting.pyx index 2902502927cea..c1c3f4ffeb1cf 100644 --- a/sklearn/ensemble/_gradient_boosting.pyx +++ b/sklearn/ensemble/_gradient_boosting.pyx @@ -20,10 +20,14 @@ from scipy.sparse import csr_matrix from sklearn.tree._tree cimport Node from sklearn.tree._tree cimport Tree +from sklearn.tree._tree cimport CategoryCacheMgr from sklearn.tree._tree cimport DTYPE_t from sklearn.tree._tree cimport SIZE_t from sklearn.tree._tree cimport INT32_t +from sklearn.tree._tree cimport UINT32_t +from sklearn.tree._tree cimport BITSET_t from sklearn.tree._utils cimport safe_realloc +from sklearn.tree._utils cimport goes_left ctypedef np.int32_t int32 ctypedef np.float64_t float64 @@ -48,6 +52,8 @@ cdef void _predict_regression_tree_inplace_fast_dense(DTYPE_t *X, Py_ssize_t K, Py_ssize_t n_samples, Py_ssize_t n_features, + INT32_t* n_categories, + BITSET_t** cachebits, float64 *out): """Predicts output for regression tree and stores it in ``out[i, k]``. @@ -82,6 +88,12 @@ cdef void _predict_regression_tree_inplace_fast_dense(DTYPE_t *X, ``n_samples == X.shape[0]``. n_features : int The number of features; ``n_samples == X.shape[1]``. + n_categories : INT32_t pointer + Array of length n_features containing the number of categories + (for categorical features) or -1 (for non-categorical features) + cachebits : BITSET_t pointer pointer + Array of length node_count containing category cache buffers + for categorical features out : np.float64_t pointer The pointer to the data array where the predictions are stored. ``out`` is assumed to be a two-dimensional array of @@ -89,13 +101,19 @@ cdef void _predict_regression_tree_inplace_fast_dense(DTYPE_t *X, """ cdef Py_ssize_t i cdef Node *node + cdef BITSET_t* node_cache + for i in range(n_samples): node = root_node + node_cache = cachebits[0] # While node not a leaf while node.left_child != TREE_LEAF: - if X[i * n_features + node.feature] <= node.threshold: + if goes_left(X[i * n_features + node.feature], node.split_value, + n_categories[node.feature], node_cache): + node_cache = cachebits[node.left_child] node = root_node + node.left_child else: + node_cache = cachebits[node.right_child] node = root_node + node.right_child out[i * K + k] += scale * value[node - root_node] @@ -130,8 +148,8 @@ def _predict_regression_tree_stages_sparse(np.ndarray[object, ndim=2] estimators cdef Tree tree cdef Node** nodes = NULL cdef double** values = NULL - safe_realloc(&nodes, n_stages * n_outputs) - safe_realloc(&values, n_stages * n_outputs) + safe_realloc(&nodes, n_stages * n_outputs, sizeof(void*)) + safe_realloc(&values, n_stages * n_outputs, sizeof(void*)) for stage_i in range(n_stages): for output_i in range(n_outputs): tree = estimators[stage_i, output_i].tree_ @@ -147,8 +165,8 @@ def _predict_regression_tree_stages_sparse(np.ndarray[object, ndim=2] estimators # which features are nonzero in the present sample. cdef SIZE_t* feature_to_sample = NULL - safe_realloc(&X_sample, n_features) - safe_realloc(&feature_to_sample, n_features) + safe_realloc(&X_sample, n_features, sizeof(DTYPE_t)) + safe_realloc(&feature_to_sample, n_features, sizeof(SIZE_t)) memset(feature_to_sample, -1, n_features * sizeof(SIZE_t)) @@ -174,7 +192,7 @@ def _predict_regression_tree_stages_sparse(np.ndarray[object, ndim=2] estimators else: feature_value = 0. - if feature_value <= node.threshold: + if feature_value <= node.split_value.threshold: node = root_node + node.left_child else: node = root_node + node.right_child @@ -216,6 +234,10 @@ def predict_stages(np.ndarray[object, ndim=2] estimators, for k in range(K): tree = estimators[i, k].tree_ + # Make category cache buffers for this tree's nodes + cache_mgr = CategoryCacheMgr() + cache_mgr.populate(tree.nodes, tree.node_count, tree.n_categories) + # avoid buffer validation by casting to ndarray # and get data pointer # need brackets because of casting operator priority @@ -223,6 +245,7 @@ def predict_stages(np.ndarray[object, ndim=2] estimators, ( X).data, tree.nodes, tree.value, scale, k, K, X.shape[0], X.shape[1], + tree.n_categories, cache_mgr.bits, ( out).data) ## out += scale * tree.predict(X).reshape((X.shape[0], 1)) @@ -293,27 +316,34 @@ cpdef _partial_dependence_tree(Tree tree, DTYPE_t[:, ::1] X, cdef SIZE_t node_count = tree.node_count cdef SIZE_t stack_capacity = node_count * 2 - cdef Node **node_stack cdef double[::1] weight_stack = np_ones((stack_capacity,), dtype=np_float64) cdef SIZE_t stack_size = 1 cdef double left_sample_frac cdef double current_weight cdef double total_weight = 0.0 cdef Node *current_node - underlying_stack = np_zeros((stack_capacity,), dtype=np.intp) - node_stack = ( underlying_stack).data + cdef SIZE_t[::1] node_stack = np_zeros((stack_capacity,), dtype=np.intp) + cdef BITSET_t** cachebits + cdef BITSET_t* node_cache + + # Make category cache buffers for this tree's nodes + cache_mgr = CategoryCacheMgr() + cache_mgr.populate(root_node, node_count, tree.n_categories) + cachebits = cache_mgr.bits for i in range(X.shape[0]): # init stacks for new example stack_size = 1 - node_stack[0] = root_node + node_stack[0] = 0 + node_cache = cachebits[0] weight_stack[0] = 1.0 total_weight = 0.0 while stack_size > 0: # get top node on stack stack_size -= 1 - current_node = node_stack[stack_size] + current_node = root_node + node_stack[stack_size] + node_cache = cachebits[node_stack[stack_size]] if current_node.left_child == TREE_LEAF: out[i] += weight_stack[stack_size] * value[current_node - root_node] * \ @@ -325,21 +355,21 @@ cpdef _partial_dependence_tree(Tree tree, DTYPE_t[:, ::1] X, if feature_index != -1: # split feature in target set # push left or right child on stack - if X[i, feature_index] <= current_node.threshold: + if goes_left(X[i, feature_index], current_node.split_value, + tree.n_categories[current_node.feature], + node_cache): # left - node_stack[stack_size] = (root_node + - current_node.left_child) + node_stack[stack_size] = current_node.left_child else: # right - node_stack[stack_size] = (root_node + - current_node.right_child) + node_stack[stack_size] = current_node.right_child stack_size += 1 else: # split feature in complement set # push both children onto stack # push left child - node_stack[stack_size] = root_node + current_node.left_child + node_stack[stack_size] = current_node.left_child current_weight = weight_stack[stack_size] left_sample_frac = root_node[current_node.left_child].n_node_samples / \ current_node.n_node_samples @@ -354,7 +384,7 @@ cpdef _partial_dependence_tree(Tree tree, DTYPE_t[:, ::1] X, stack_size +=1 # push right child - node_stack[stack_size] = root_node + current_node.right_child + node_stack[stack_size] = current_node.right_child weight_stack[stack_size] = current_weight * \ (1.0 - left_sample_frac) stack_size +=1 diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index aae9dd8c72349..49dbaee0821ac 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -859,6 +859,19 @@ class RandomForestClassifier(ForestClassifier): will be removed in 0.25. Use ``min_impurity_decrease`` instead. + categorical : array-like or str + Array of feature indices, boolean array of length n_features, + ``'all'`` or ``'none'``. Indicates which features should be + considered as categorical rather than ordinal. For decision trees, + the maximum number of categories is 64. In practice, the limit will + often be lower because the process of searching for the best possible + split grows exponentially with the number of categories. However, a + shortcut due to Breiman (1984) is used when fitting data with binary + labels using the ``Gini`` or ``Entropy`` criteria. In this case, + the runtime is linear in the number of categories. + + .. versionadded:: 0.21 + bootstrap : boolean, optional (default=True) Whether bootstrap samples are used when building trees. If False, the whole datset is used to build each tree. @@ -953,18 +966,11 @@ class labels (multi-output problem). ... n_informative=2, n_redundant=0, ... random_state=0, shuffle=False) >>> clf = RandomForestClassifier(n_estimators=100, max_depth=2, - ... random_state=0) - >>> clf.fit(X, y) # doctest: +NORMALIZE_WHITESPACE - RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini', - max_depth=2, max_features='auto', max_leaf_nodes=None, - min_impurity_decrease=0.0, min_impurity_split=None, - min_samples_leaf=1, min_samples_split=2, - min_weight_fraction_leaf=0.0, n_estimators=100, n_jobs=None, - oob_score=False, random_state=0, verbose=0, warm_start=False) - >>> print(clf.feature_importances_) - [0.14205973 0.76664038 0.0282433 0.06305659] - >>> print(clf.predict([[0, 0, 0, 0]])) - [1] + ... random_state=0).fit(X, y) + >>> clf.feature_importances_ + array([0.14205973, 0.76664038, 0.0282433 , 0.06305659]) + >>> clf.predict([[0, 0, 0, 0]]) + array([1]) Notes ----- @@ -1001,6 +1007,7 @@ def __init__(self, max_leaf_nodes=None, min_impurity_decrease=0., min_impurity_split=None, + categorical='none', bootstrap=True, oob_score=False, n_jobs=None, @@ -1015,7 +1022,7 @@ def __init__(self, "min_samples_leaf", "min_weight_fraction_leaf", "max_features", "max_leaf_nodes", "min_impurity_decrease", "min_impurity_split", - "random_state"), + "random_state", "categorical"), bootstrap=bootstrap, oob_score=oob_score, n_jobs=n_jobs, @@ -1033,6 +1040,7 @@ def __init__(self, self.max_leaf_nodes = max_leaf_nodes self.min_impurity_decrease = min_impurity_decrease self.min_impurity_split = min_impurity_split + self.categorical = categorical class RandomForestRegressor(ForestRegressor): @@ -1150,6 +1158,19 @@ class RandomForestRegressor(ForestRegressor): ``min_impurity_split`` will change from 1e-7 to 0 in 0.23 and it will be removed in 0.25. Use ``min_impurity_decrease`` instead. + categorical : array-like or str + Array of feature indices, boolean array of length n_features, + ``'all'`` or ``'none'``. Indicates which features should be + considered as categorical rather than ordinal. For decision trees, + the maximum number of categories is 64. In practice, the limit will + often be lower because the process of searching for the best possible + split grows exponentially with the number of categories. However, a + shortcut due to Breiman (1984) is used when fitting data with binary + labels using the ``Gini`` or ``Entropy`` criteria. In this case, + the runtime is linear in the number of categories. + + .. versionadded:: 0.21 + bootstrap : boolean, optional (default=True) Whether bootstrap samples are used when building trees. If False, the whole datset is used to build each tree. @@ -1206,18 +1227,11 @@ class RandomForestRegressor(ForestRegressor): >>> X, y = make_regression(n_features=4, n_informative=2, ... random_state=0, shuffle=False) >>> regr = RandomForestRegressor(max_depth=2, random_state=0, - ... n_estimators=100) - >>> regr.fit(X, y) # doctest: +NORMALIZE_WHITESPACE - RandomForestRegressor(bootstrap=True, criterion='mse', max_depth=2, - max_features='auto', max_leaf_nodes=None, - min_impurity_decrease=0.0, min_impurity_split=None, - min_samples_leaf=1, min_samples_split=2, - min_weight_fraction_leaf=0.0, n_estimators=100, n_jobs=None, - oob_score=False, random_state=0, verbose=0, warm_start=False) - >>> print(regr.feature_importances_) - [0.18146984 0.81473937 0.00145312 0.00233767] - >>> print(regr.predict([[0, 0, 0, 0]])) - [-8.32987858] + ... n_estimators=100).fit(X, y) + >>> regr.feature_importances_ + array([0.18146984, 0.81473937, 0.00145312, 0.00233767]) + >>> regr.predict([[0, 0, 0, 0]]) + array([-8.32987858]) Notes ----- @@ -1261,6 +1275,7 @@ def __init__(self, max_leaf_nodes=None, min_impurity_decrease=0., min_impurity_split=None, + categorical='none', bootstrap=True, oob_score=False, n_jobs=None, @@ -1274,7 +1289,7 @@ def __init__(self, "min_samples_leaf", "min_weight_fraction_leaf", "max_features", "max_leaf_nodes", "min_impurity_decrease", "min_impurity_split", - "random_state"), + "random_state", "categorical"), bootstrap=bootstrap, oob_score=oob_score, n_jobs=n_jobs, @@ -1291,6 +1306,7 @@ def __init__(self, self.max_leaf_nodes = max_leaf_nodes self.min_impurity_decrease = min_impurity_decrease self.min_impurity_split = min_impurity_split + self.categorical = categorical class ExtraTreesClassifier(ForestClassifier): @@ -1401,6 +1417,15 @@ class ExtraTreesClassifier(ForestClassifier): ``min_impurity_split`` will change from 1e-7 to 0 in 0.23 and it will be removed in 0.25. Use ``min_impurity_decrease`` instead. + categorical : array-like or str + Array of feature indices, boolean array of length n_features, + ``'all'`` or ``'none'``. Indicates which features should be + considered as categorical rather than ordinal. Extra-random trees + have an upper limit of :math:`2^{31}` categories, and runtimes + linear in the number of categories. + + .. versionadded:: 0.21 + bootstrap : boolean, optional (default=False) Whether bootstrap samples are used when building trees. If False, the whole datset is used to build each tree. @@ -1516,6 +1541,7 @@ def __init__(self, max_leaf_nodes=None, min_impurity_decrease=0., min_impurity_split=None, + categorical='none', bootstrap=False, oob_score=False, n_jobs=None, @@ -1530,7 +1556,7 @@ def __init__(self, "min_samples_leaf", "min_weight_fraction_leaf", "max_features", "max_leaf_nodes", "min_impurity_decrease", "min_impurity_split", - "random_state"), + "random_state", "categorical"), bootstrap=bootstrap, oob_score=oob_score, n_jobs=n_jobs, @@ -1548,6 +1574,7 @@ def __init__(self, self.max_leaf_nodes = max_leaf_nodes self.min_impurity_decrease = min_impurity_decrease self.min_impurity_split = min_impurity_split + self.categorical = categorical class ExtraTreesRegressor(ForestRegressor): @@ -1663,6 +1690,15 @@ class ExtraTreesRegressor(ForestRegressor): ``min_impurity_split`` will change from 1e-7 to 0 in 0.23 and it will be removed in 0.25. Use ``min_impurity_decrease`` instead. + categorical : array-like or str + Array of feature indices, boolean array of length n_features, + ``'all'`` or ``'none'``. Indicates which features should be + considered as categorical rather than ordinal. Extra-random trees + have an upper limit of :math:`2^{31}` categories, and runtimes + linear in the number of categories. + + .. versionadded:: 0.21 + bootstrap : boolean, optional (default=False) Whether bootstrap samples are used when building trees. If False, the whole datset is used to build each tree. @@ -1740,6 +1776,7 @@ def __init__(self, max_leaf_nodes=None, min_impurity_decrease=0., min_impurity_split=None, + categorical='none', bootstrap=False, oob_score=False, n_jobs=None, @@ -1753,7 +1790,7 @@ def __init__(self, "min_samples_leaf", "min_weight_fraction_leaf", "max_features", "max_leaf_nodes", "min_impurity_decrease", "min_impurity_split", - "random_state"), + "random_state", "categorical"), bootstrap=bootstrap, oob_score=oob_score, n_jobs=n_jobs, @@ -1770,6 +1807,7 @@ def __init__(self, self.max_leaf_nodes = max_leaf_nodes self.min_impurity_decrease = min_impurity_decrease self.min_impurity_split = min_impurity_split + self.categorical = categorical class RandomTreesEmbedding(BaseForest): @@ -1865,6 +1903,13 @@ class RandomTreesEmbedding(BaseForest): ``min_impurity_split`` will change from 1e-7 to 0 in 0.23 and it will be removed in 0.25. Use ``min_impurity_decrease`` instead. + categorical : array-like or str + Array of feature indices, boolean array of length n_features, + ``'all'`` or ``'none'``. Indicates which features should be + considered as categorical rather than ordinal. Extra-random trees + have an upper limit of :math:`2^{31}` categories, and runtimes + linear in the number of categories. + sparse_output : bool, optional (default=True) Whether or not to return a sparse CSR matrix, as default behavior, or to return a dense array compatible with dense pipeline operators. @@ -1916,6 +1961,7 @@ def __init__(self, max_leaf_nodes=None, min_impurity_decrease=0., min_impurity_split=None, + categorical="none", sparse_output=True, n_jobs=None, random_state=None, @@ -1928,7 +1974,7 @@ def __init__(self, "min_samples_leaf", "min_weight_fraction_leaf", "max_features", "max_leaf_nodes", "min_impurity_decrease", "min_impurity_split", - "random_state"), + "random_state", "categorical"), bootstrap=False, oob_score=False, n_jobs=n_jobs, @@ -1944,6 +1990,7 @@ def __init__(self, self.min_impurity_decrease = min_impurity_decrease self.min_impurity_split = min_impurity_split self.sparse_output = sparse_output + self.categorical = categorical def _set_oob_score(self, X, y): raise NotImplementedError("OOB score not supported by tree embedding") diff --git a/sklearn/ensemble/gradient_boosting.py b/sklearn/ensemble/gradient_boosting.py index 413cc8a5ad3fd..d35e1afc0ec18 100644 --- a/sklearn/ensemble/gradient_boosting.py +++ b/sklearn/ensemble/gradient_boosting.py @@ -1126,7 +1126,7 @@ def __init__(self, loss, learning_rate, n_estimators, criterion, random_state, alpha=0.9, verbose=0, max_leaf_nodes=None, warm_start=False, presort='auto', validation_fraction=0.1, n_iter_no_change=None, - tol=1e-4): + tol=1e-4, categorical='none'): self.n_estimators = n_estimators self.learning_rate = learning_rate @@ -1150,6 +1150,7 @@ def __init__(self, loss, learning_rate, n_estimators, criterion, self.validation_fraction = validation_fraction self.n_iter_no_change = n_iter_no_change self.tol = tol + self.categorical = categorical def _fit_stage(self, i, X, y, y_pred, sample_weight, sample_mask, random_state, X_idx_sorted, X_csc=None, X_csr=None): @@ -1185,7 +1186,8 @@ def _fit_stage(self, i, X, y, y_pred, sample_weight, sample_mask, max_features=self.max_features, max_leaf_nodes=self.max_leaf_nodes, random_state=random_state, - presort=self.presort) + presort=self.presort, + categorical=self.categorical) if self.subsample < 1.0: # no inplace multiplication! @@ -1872,6 +1874,19 @@ class GradientBoostingClassifier(BaseGradientBoosting, ClassifierMixin): .. versionadded:: 0.20 + categorical : array-like or str + Array of feature indices, boolean array of length n_features, + ``'all'`` or ``'none'``. Indicates which features should be + considered as categorical rather than ordinal. For decision trees, + the maximum number of categories is 64. In practice, the limit will + often be lower because the process of searching for the best possible + split grows exponentially with the number of categories. However, a + shortcut due to Breiman (1984) is used when fitting data with binary + labels using the ``Gini`` or ``Entropy`` criteria. In this case, + the runtime is linear in the number of categories. + + .. versionadded:: 0.21 + Attributes ---------- n_estimators_ : int @@ -1942,7 +1957,7 @@ def __init__(self, loss='deviance', learning_rate=0.1, n_estimators=100, random_state=None, max_features=None, verbose=0, max_leaf_nodes=None, warm_start=False, presort='auto', validation_fraction=0.1, - n_iter_no_change=None, tol=1e-4): + n_iter_no_change=None, tol=1e-4, categorical='none'): super().__init__( loss=loss, learning_rate=learning_rate, n_estimators=n_estimators, @@ -1957,7 +1972,8 @@ def __init__(self, loss='deviance', learning_rate=0.1, n_estimators=100, min_impurity_split=min_impurity_split, warm_start=warm_start, presort=presort, validation_fraction=validation_fraction, - n_iter_no_change=n_iter_no_change, tol=tol) + n_iter_no_change=n_iter_no_change, tol=tol, + categorical=categorical) def _validate_y(self, y, sample_weight): check_classification_targets(y) @@ -2334,6 +2350,18 @@ class GradientBoostingRegressor(BaseGradientBoosting, RegressorMixin): .. versionadded:: 0.20 + categorical : array-like or str + Array of feature indices, boolean array of length n_features, + ``'all'`` or ``'none'``. Indicates which features should be + considered as categorical rather than ordinal. For decision trees, + the maximum number of categories is 64. In practice, the limit will + often be lower because the process of searching for the best possible + split grows exponentially with the number of categories. However, a + shortcut due to Breiman (1984) is used when fitting data with binary + labels using the ``Gini`` or ``Entropy`` criteria. In this case, + the runtime is linear in the number of categories. + + .. versionadded:: 0.21 Attributes ---------- @@ -2394,7 +2422,7 @@ def __init__(self, loss='ls', learning_rate=0.1, n_estimators=100, min_impurity_split=None, init=None, random_state=None, max_features=None, alpha=0.9, verbose=0, max_leaf_nodes=None, warm_start=False, presort='auto', validation_fraction=0.1, - n_iter_no_change=None, tol=1e-4): + n_iter_no_change=None, tol=1e-4, categorical='none'): super().__init__( loss=loss, learning_rate=learning_rate, n_estimators=n_estimators, @@ -2408,7 +2436,7 @@ def __init__(self, loss='ls', learning_rate=0.1, n_estimators=100, random_state=random_state, alpha=alpha, verbose=verbose, max_leaf_nodes=max_leaf_nodes, warm_start=warm_start, presort=presort, validation_fraction=validation_fraction, - n_iter_no_change=n_iter_no_change, tol=tol) + n_iter_no_change=n_iter_no_change, tol=tol, categorical='none') def predict(self, X): """Predict regression target for X. diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index aa02f46db68f3..063917dcd1431 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -53,6 +53,7 @@ from sklearn.utils.fixes import comb from sklearn.tree.tree import SPARSE_SPLITTERS +from sklearn.tree.tests.test_tree import _make_categorical # toy sample @@ -1339,6 +1340,62 @@ def test_backend_respected(): assert ba.count == 0 +@pytest.mark.parametrize('model', FOREST_CLASSIFIERS_REGRESSORS) +@pytest.mark.parametrize('data_params', [ + {'n_rows': 10000, + 'n_numerical': 10, + 'n_categorical': 5, + 'cat_size': 3, + 'n_num_meaningful': 1, + 'n_cat_meaningful': 2}, + {'n_rows': 1000, + 'n_numerical': 0, + 'n_categorical': 5, + 'cat_size': 3, + 'n_num_meaningful': 0, + 'n_cat_meaningful': 3}, + {'n_rows': 1000, + 'n_numerical': 5, + 'n_categorical': 5, + 'cat_size': 64, + 'n_num_meaningful': 0, + 'n_cat_meaningful': 2}, + {'n_rows': 1000, + 'n_numerical': 5, + 'n_categorical': 5, + 'cat_size': 3, + 'n_num_meaningful': 0, + 'n_cat_meaningful': 3}]) +def test_categorical_data(model, data_params): + # DecisionTrees are too slow for large category sizes. + if data_params['cat_size'] > 8 and 'RandomForest' in model: + pass + + X, y, meaningful_features = _make_categorical( + **data_params, + regression=model in FOREST_REGRESSORS, + return_tuple=True, + random_state=42) + rows, cols = X.shape + categorical_features = (np.arange(data_params['n_categorical']) + + data_params['n_numerical']) + + model = FOREST_CLASSIFIERS_REGRESSORS[model]( + random_state=42, categorical=categorical_features, + n_estimators=100).fit(X, y) + fi = model.feature_importances_ + bad_features = np.array([True]*cols) + bad_features[meaningful_features] = False + + good_ones = fi[meaningful_features] + print(good_ones) + bad_ones = fi[bad_features] + print(bad_ones) + + # all good features should be more important than all bad features. + assert np.all([np.all(x > bad_ones) for x in good_ones]) + + @pytest.mark.filterwarnings('ignore:The default value of n_estimators') @pytest.mark.parametrize('name', FOREST_CLASSIFIERS) @pytest.mark.parametrize('oob_score', (True, False)) diff --git a/sklearn/neighbors/quad_tree.pyx b/sklearn/neighbors/quad_tree.pyx index fbe736636c89d..26491e924e0f6 100644 --- a/sklearn/neighbors/quad_tree.pyx +++ b/sklearn/neighbors/quad_tree.pyx @@ -605,7 +605,7 @@ cdef class _QuadTree: else: capacity = 2 * self.capacity - safe_realloc(&self.cells, capacity) + safe_realloc(&self.cells, capacity, sizeof(Cell)) # if capacity smaller than cell_count, adjust the counter if capacity < self.cell_count: diff --git a/sklearn/tree/_criterion.pxd b/sklearn/tree/_criterion.pxd index 1cbd395af8e37..7d2802487f416 100644 --- a/sklearn/tree/_criterion.pxd +++ b/sklearn/tree/_criterion.pxd @@ -1,3 +1,5 @@ +# cython: language_level=3 + # Authors: Gilles Louppe # Peter Prettenhofer # Brian Holt diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index a2b362334de54..cceb358e94f2b 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -1,3 +1,4 @@ +# cython: language_level=3 # cython: cdivision=True # cython: boundscheck=False # cython: wraparound=False @@ -246,7 +247,7 @@ cdef class ClassificationCriterion(Criterion): self.sum_right = NULL self.n_classes = NULL - safe_realloc(&self.n_classes, n_outputs) + safe_realloc(&self.n_classes, n_outputs, sizeof(SIZE_t)) cdef SIZE_t k = 0 cdef SIZE_t sum_stride = 0 @@ -1035,7 +1036,7 @@ cdef class MAE(RegressionCriterion): self.node_medians = NULL # Allocate memory for the accumulators - safe_realloc(&self.node_medians, n_outputs) + safe_realloc(&self.node_medians, n_outputs, sizeof(DOUBLE_t)) self.left_child = np.empty(n_outputs, dtype='object') self.right_child = np.empty(n_outputs, dtype='object') diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index 4d5c5ae46bceb..7aaf4f455d1fd 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -1,3 +1,5 @@ +# cython: language_level=3 + # Authors: Gilles Louppe # Peter Prettenhofer # Brian Holt @@ -12,6 +14,8 @@ import numpy as np cimport numpy as np +from ._utils cimport SplitValue, SplitRecord, BITSET_t + from ._criterion cimport Criterion ctypedef np.npy_float32 DTYPE_t # Type of X @@ -19,17 +23,7 @@ ctypedef np.npy_float64 DOUBLE_t # Type of y, sample_weight ctypedef np.npy_intp SIZE_t # Type for indices and counters ctypedef np.npy_int32 INT32_t # Signed 32 bit integer ctypedef np.npy_uint32 UINT32_t # Unsigned 32 bit integer - -cdef struct SplitRecord: - # Data to track sample split - SIZE_t feature # Which feature to split on. - SIZE_t pos # Split samples array at the given position, - # i.e. count of samples below threshold for feature. - # pos is >= end if the node is a leaf. - double threshold # Threshold to split at. - double improvement # Impurity improvement given parent node. - double impurity_left # Impurity of the left split. - double impurity_right # Impurity of the right split. +ctypedef np.npy_uint64 UINT64_t # Unsigned 64 bit integer cdef class Splitter: # The splitter searches in the input space for a feature and a threshold @@ -59,10 +53,15 @@ cdef class Splitter: cdef bint presort # Whether to use presorting, only # allowed on dense data + cdef bint breiman_shortcut # Whether decision trees are allowed to use the + # Breiman shortcut for categorical features cdef DOUBLE_t* y cdef SIZE_t y_stride cdef DOUBLE_t* sample_weight + cdef INT32_t[:] n_categories # (n_features,) array giving number of + # categories (<0 for non-categorical) + cdef BITSET_t* cat_cache # Cache buffer for fast categorical split evaluation # The samples vector `samples` is maintained by the Splitter object such # that the samples contained in a node are contiguous. With this setting, @@ -83,6 +82,7 @@ cdef class Splitter: # Methods cdef int init(self, object X, np.ndarray y, DOUBLE_t* sample_weight, + INT32_t[:] n_categories, np.ndarray X_idx_sorted=*) except -1 cdef int node_reset(self, SIZE_t start, SIZE_t end, diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 3f5a176d9171a..de49219abc19a 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -1,3 +1,4 @@ +# cython: language_level=3 # cython: cdivision=True # cython: boundscheck=False # cython: wraparound=False @@ -33,6 +34,11 @@ from ._utils cimport rand_int from ._utils cimport rand_uniform from ._utils cimport RAND_R_MAX from ._utils cimport safe_realloc +from ._utils cimport setup_cat_cache +from ._utils cimport goes_left +from ._utils cimport (BITSET_t, bs_get, bs_set, bs_flip_all, + bs_from_template) + cdef double INFINITY = np.inf @@ -48,7 +54,7 @@ cdef inline void _init_split(SplitRecord* self, SIZE_t start_pos) nogil: self.impurity_right = INFINITY self.pos = start_pos self.feature = 0 - self.threshold = 0. + self.split_value.threshold = 0. self.improvement = -INFINITY cdef class Splitter: @@ -60,7 +66,7 @@ cdef class Splitter: def __cinit__(self, Criterion criterion, SIZE_t max_features, SIZE_t min_samples_leaf, double min_weight_leaf, - object random_state, bint presort): + object random_state, bint presort, bint breiman_shortcut): """ Parameters ---------- @@ -95,12 +101,14 @@ cdef class Splitter: self.y = NULL self.y_stride = 0 self.sample_weight = NULL + self.cat_cache = NULL self.max_features = max_features self.min_samples_leaf = min_samples_leaf self.min_weight_leaf = min_weight_leaf self.random_state = random_state self.presort = presort + self.breiman_shortcut = breiman_shortcut def __dealloc__(self): """Destructor.""" @@ -109,6 +117,7 @@ cdef class Splitter: free(self.features) free(self.constant_features) free(self.feature_values) + free(self.cat_cache) def __getstate__(self): return {} @@ -120,6 +129,7 @@ cdef class Splitter: object X, np.ndarray[DOUBLE_t, ndim=2, mode="c"] y, DOUBLE_t* sample_weight, + INT32_t[:] n_categories, np.ndarray X_idx_sorted=None) except -1: """Initialize the splitter. @@ -140,6 +150,10 @@ cdef class Splitter: The weights of the samples, where higher weighted samples are fit closer than lower weight samples. If not provided, all samples are assumed to have uniform weight. + + n_categories : array of INT32_t, shape=(n_features,) + Number of categories for categorical features, or -1 for + non-categorical features """ self.rand_r_state = self.random_state.randint(0, RAND_R_MAX) @@ -147,7 +161,7 @@ cdef class Splitter: # Create a new array which will be used to store nonzero # samples from the feature of interest - cdef SIZE_t* samples = safe_realloc(&self.samples, n_samples) + cdef SIZE_t* samples = safe_realloc(&self.samples, n_samples, sizeof(SIZE_t)) cdef SIZE_t i, j cdef double weighted_n_samples = 0.0 @@ -169,20 +183,36 @@ cdef class Splitter: self.weighted_n_samples = weighted_n_samples cdef SIZE_t n_features = X.shape[1] - cdef SIZE_t* features = safe_realloc(&self.features, n_features) + cdef SIZE_t* features = safe_realloc(&self.features, n_features, + sizeof(SIZE_t)) for i in range(n_features): features[i] = i self.n_features = n_features - safe_realloc(&self.feature_values, n_samples) - safe_realloc(&self.constant_features, n_features) + safe_realloc(&self.feature_values, n_samples, sizeof(DTYPE_t)) + safe_realloc(&self.constant_features, n_features, sizeof(SIZE_t)) self.y = y.data self.y_stride = y.strides[0] / y.itemsize self.sample_weight = sample_weight + + # Initialize the number of categories for each feature + # A value of -1 indicates a non-categorical feature + if n_categories is None: + self.n_categories = np.array([-1] * n_features, dtype=np.int32) + else: + self.n_categories = np.empty_like(n_categories, dtype=np.int32) + self.n_categories[:] = n_categories + + # If needed, allocate cache space for categorical splits + cdef INT32_t max_n_categories = max(self.n_categories) + if max_n_categories > 0: + cache_size = (max_n_categories + 63) // 64 + safe_realloc(&self.cat_cache, cache_size, sizeof(BITSET_t)) + return 0 cdef int node_reset(self, SIZE_t start, SIZE_t end, @@ -252,7 +282,7 @@ cdef class BaseDenseSplitter(Splitter): def __cinit__(self, Criterion criterion, SIZE_t max_features, SIZE_t min_samples_leaf, double min_weight_leaf, - object random_state, bint presort): + object random_state, bint presort, bint breiman_shortcut): self.X = NULL self.X_sample_stride = 0 @@ -271,6 +301,7 @@ cdef class BaseDenseSplitter(Splitter): object X, np.ndarray[DOUBLE_t, ndim=2, mode="c"] y, DOUBLE_t* sample_weight, + INT32_t[:] n_categories, np.ndarray X_idx_sorted=None) except -1: """Initialize the splitter @@ -279,7 +310,7 @@ cdef class BaseDenseSplitter(Splitter): """ # Call parent init - Splitter.init(self, X, y, sample_weight) + Splitter.init(self, X, y, sample_weight, n_categories) # Initialize X cdef np.ndarray X_ndarray = X @@ -295,7 +326,7 @@ cdef class BaseDenseSplitter(Splitter): self.X_idx_sorted.itemsize) self.n_total_samples = X.shape[0] - safe_realloc(&self.sample_mask, self.n_total_samples) + safe_realloc(&self.sample_mask, self.n_total_samples, sizeof(SIZE_t)) memset(self.sample_mask, 0, self.n_total_samples*sizeof(SIZE_t)) return 0 @@ -311,6 +342,51 @@ cdef class BestSplitter(BaseDenseSplitter): self.random_state, self.presort), self.__getstate__()) + + cdef void _breiman_sort_categories(self, SIZE_t start, SIZE_t end, + INT32_t ncat, SIZE_t ncat_present, + const INT32_t *cat_offset, + SIZE_t *sorted_cat) nogil: + """The Breiman shortcut for finding the best split involves a + preprocessing step wherein we sort the categories by + increasing (weighted) mean of the outcome y (whether 0/1 + binary for classification or quantitative for + regression). This function implements this preprocessing step + and produces a sorted list of category values. + """ + cdef: + SIZE_t *samples = self.samples + DTYPE_t *Xf = self.feature_values + DOUBLE_t *y = self.y + SIZE_t y_stride = self.y_stride + DOUBLE_t *sample_weight = self.sample_weight + DOUBLE_t w + SIZE_t cat, localcat + SIZE_t q, partition_end + DTYPE_t sort_value[64] + DTYPE_t sort_density[64] + + # categorical features with more than 64 categories are not supported + # here. + memset(sort_value, 0, 64 * sizeof(DTYPE_t)) + memset(sort_density, 0, 64 * sizeof(DTYPE_t)) + + for q in range(start, end): + cat = Xf[q] + w = sample_weight[samples[q]] if sample_weight else 1.0 + sort_value[cat] += w * (y[y_stride * samples[q]]) + sort_density[cat] += w + + for localcat in range(ncat_present): + cat = localcat + cat_offset[localcat] + if sort_density[cat] == 0: # Avoid dividing by zero + sort_density[cat] = 1 + sort_value[localcat] = sort_value[cat] / sort_density[cat] + sorted_cat[localcat] = cat + + sort(&sort_value[0], sorted_cat, ncat_present) + + cdef int node_split(self, double impurity, SplitRecord* split, SIZE_t* n_constant_features) nogil except -1: """Find the best split on node samples[start:end] @@ -345,12 +421,12 @@ cdef class BestSplitter(BaseDenseSplitter): cdef SIZE_t f_i = n_features cdef SIZE_t f_j - cdef SIZE_t tmp cdef SIZE_t p cdef SIZE_t feature_idx_offset cdef SIZE_t feature_offset cdef SIZE_t i cdef SIZE_t j + cdef UINT64_t ui # unsigned long int i cdef SIZE_t n_visited_features = 0 # Number of features discovered to be constant during the split search @@ -362,6 +438,12 @@ cdef class BestSplitter(BaseDenseSplitter): cdef SIZE_t n_total_constants = n_known_constants cdef DTYPE_t current_feature_value cdef SIZE_t partition_end + cdef bint is_categorical + cdef UINT64_t cat_idx, ncat_present + cdef INT32_t cat_offs[64] + cdef bint breiman_shortcut = self.breiman_shortcut + cdef SIZE_t sorted_cat[64] + cdef BITSET_t cat_split = 0 _init_split(&best, end) @@ -403,9 +485,8 @@ cdef class BestSplitter(BaseDenseSplitter): if f_j < n_known_constants: # f_j in the interval [n_drawn_constants, n_known_constants[ - tmp = features[f_j] - features[f_j] = features[n_drawn_constants] - features[n_drawn_constants] = tmp + features[f_j], features[n_drawn_constants] = ( + features[n_drawn_constants], features[f_j]) n_drawn_constants += 1 @@ -447,66 +528,138 @@ cdef class BestSplitter(BaseDenseSplitter): f_i -= 1 features[f_i], features[f_j] = features[f_j], features[f_i] + # Identify the number of categories present in this node + is_categorical = self.n_categories[current.feature] > 0 + if is_categorical: + cat_split = 0 + ncat_present = 0 + for i in range(start, end): + # Xf[i] < 64 already verified in tree.py + cat_split = bs_set(cat_split, Xf[i]) + for i in range(self.n_categories[current.feature]): + if bs_get(cat_split, i): + cat_offs[ncat_present] = i - ncat_present + ncat_present += 1 + if ncat_present <= 3: + breiman_shortcut = False # No benefit for small N + if breiman_shortcut: + self._breiman_sort_categories( + start, end, self.n_categories[current.feature], + ncat_present, cat_offs, &sorted_cat[0]) + # Evaluate all splits self.criterion.reset() p = start + cat_idx = 0 + + while True: + if is_categorical: + cat_idx += 1 + if breiman_shortcut: + if cat_idx >= ncat_present: + break + + cat_split = 0 + for ui in range(cat_idx): + cat_split = bs_set(cat_split, + sorted_cat[ui]) + # check if the first bit is 1, if yes, flip all + if bs_get(cat_split, 0): + cat_split = bs_flip_all( + cat_split, + self.n_categories[current.feature]) + else: + if cat_idx >= ( 1) << (ncat_present - 1): + break + + # Expand the bits of (2 * cat_idx) out into + # cat_split. We double cat_idx to avoid + # double-counting equivalent splits. This also + # ensures that cat_split & 1 == 0 as required + cat_split = bs_from_template( + cat_idx << 1, + cat_offs, ncat_present) + + # Partition + j = start + partition_end = end + while j < partition_end: + if bs_get(cat_split, Xf[j]): + j += 1 + else: + partition_end -= 1 + Xf[j], Xf[partition_end] = ( + Xf[partition_end], Xf[j]) + samples[j], samples[partition_end] = ( + samples[partition_end], samples[j]) + current.pos = j + + # Must reset criterion since we've reordered the + # samples + self.criterion.reset() + else: + # Non-categorical feature + while (p + 1 < end and + Xf[p + 1] <= Xf[p] + FEATURE_THRESHOLD): + p += 1 - while p < end: - while (p + 1 < end and - Xf[p + 1] <= Xf[p] + FEATURE_THRESHOLD): + # (p + 1 >= end) or (X[samples[p + 1], current.feature] > + # X[samples[p], current.feature]) p += 1 + # (p >= end) or (X[samples[p], current.feature] > + # X[samples[p - 1], current.feature]) - # (p + 1 >= end) or (X[samples[p + 1], current.feature] > - # X[samples[p], current.feature]) - p += 1 - # (p >= end) or (X[samples[p], current.feature] > - # X[samples[p - 1], current.feature]) + if p >= end: + break - if p < end: current.pos = p - # Reject if min_samples_leaf is not guaranteed - if (((current.pos - start) < min_samples_leaf) or - ((end - current.pos) < min_samples_leaf)): - continue - - self.criterion.update(current.pos) + # Reject if min_samples_leaf is not guaranteed + if (((current.pos - start) < min_samples_leaf) or + ((end - current.pos) < min_samples_leaf)): + continue - # Reject if min_weight_leaf is not satisfied - if ((self.criterion.weighted_n_left < min_weight_leaf) or - (self.criterion.weighted_n_right < min_weight_leaf)): - continue + self.criterion.update(current.pos) - current_proxy_improvement = self.criterion.proxy_impurity_improvement() + # Reject if min_weight_leaf is not satisfied + if ((self.criterion.weighted_n_left < min_weight_leaf) or + (self.criterion.weighted_n_right < min_weight_leaf)): + continue - if current_proxy_improvement > best_proxy_improvement: - best_proxy_improvement = current_proxy_improvement - # sum of halves is used to avoid infinite value - current.threshold = Xf[p - 1] / 2.0 + Xf[p] / 2.0 + current_proxy_improvement = self.criterion.proxy_impurity_improvement() - if ((current.threshold == Xf[p]) or - (current.threshold == INFINITY) or - (current.threshold == -INFINITY)): - current.threshold = Xf[p - 1] + if current_proxy_improvement > best_proxy_improvement: + best_proxy_improvement = current_proxy_improvement + if is_categorical: + current.split_value.cat_split = cat_split + else: + current.split_value.threshold = (Xf[p - 1] + Xf[p]) / 2.0 + if (current.split_value.threshold == Xf[p] + or current.split_value.threshold == INFINITY + or current.split_value.threshold == -INFINITY): + current.split_value.threshold = Xf[p - 1] - best = current # copy + best = current # copy # Reorganize into samples[start:best.pos] + samples[best.pos:end] if best.pos < end: + setup_cat_cache(self.cat_cache, best.split_value.cat_split, + self.n_categories[best.feature]) feature_offset = X_feature_stride * best.feature partition_end = end p = start while p < partition_end: - if X[X_sample_stride * samples[p] + feature_offset] <= best.threshold: + if goes_left(X[X_sample_stride * samples[p] + feature_offset], + best.split_value, self.n_categories[best.feature], + self.cat_cache): p += 1 else: partition_end -= 1 - tmp = samples[partition_end] - samples[partition_end] = samples[p] - samples[p] = tmp + samples[partition_end], samples[p] = ( + samples[p], samples[partition_end]) self.criterion.reset() self.criterion.update(best.pos) @@ -690,8 +843,7 @@ cdef class RandomSplitter(BaseDenseSplitter): cdef SIZE_t f_i = n_features cdef SIZE_t f_j - cdef SIZE_t p - cdef SIZE_t tmp + cdef SIZE_t p, q cdef SIZE_t feature_stride # Number of features discovered to be constant during the split search cdef SIZE_t n_found_constants = 0 @@ -705,6 +857,8 @@ cdef class RandomSplitter(BaseDenseSplitter): cdef DTYPE_t max_feature_value cdef DTYPE_t current_feature_value cdef SIZE_t partition_end + cdef bint is_categorical + cdef UINT64_t split_seed _init_split(&best, end) @@ -741,9 +895,8 @@ cdef class RandomSplitter(BaseDenseSplitter): if f_j < n_known_constants: # f_j in the interval [n_drawn_constants, n_known_constants[ - tmp = features[f_j] - features[f_j] = features[n_drawn_constants] - features[n_drawn_constants] = tmp + features[f_j], features[n_drawn_constants] = ( + features[n_drawn_constants], features[f_j]) n_drawn_constants += 1 @@ -780,32 +933,45 @@ cdef class RandomSplitter(BaseDenseSplitter): f_i -= 1 features[f_i], features[f_j] = features[f_j], features[f_i] - # Draw a random threshold - current.threshold = rand_uniform(min_feature_value, - max_feature_value, - random_state) - - if current.threshold == max_feature_value: - current.threshold = min_feature_value - - # Partition - partition_end = end - p = start - while p < partition_end: - current_feature_value = Xf[p] - if current_feature_value <= current.threshold: - p += 1 + # Repeat split & partition if split is trivial, up to 60 times + # (Can only happen with categorical features) + for q in range(60): + # Construct a random split + is_categorical = self.n_categories[current.feature] > 0 + if is_categorical: + split_seed = rand_int(0, RAND_R_MAX + 1, + random_state) + current.split_value.cat_split = (split_seed << 32) | 1 else: - partition_end -= 1 + current.split_value.threshold = rand_uniform( + min_feature_value, max_feature_value, random_state) + if current.split_value.threshold == max_feature_value: + current.split_value.threshold = min_feature_value + + # Partition + setup_cat_cache(self.cat_cache, current.split_value.cat_split, + self.n_categories[current.feature]) + partition_end = end + p = start + while p < partition_end: + current_feature_value = Xf[p] + if goes_left(current_feature_value, current.split_value, + self.n_categories[current.feature], self.cat_cache): + p += 1 + else: + partition_end -= 1 + + Xf[p] = Xf[partition_end] + Xf[partition_end] = current_feature_value - Xf[p] = Xf[partition_end] - Xf[partition_end] = current_feature_value + samples[partition_end], samples[p] = ( + samples[p], samples[partition_end]) - tmp = samples[partition_end] - samples[partition_end] = samples[p] - samples[p] = tmp + current.pos = partition_end - current.pos = partition_end + # Break early if the split is non-trivial + if current.pos != start and current.pos != end: + break # Reject if min_samples_leaf is not guaranteed if (((current.pos - start) < min_samples_leaf) or @@ -830,20 +996,23 @@ cdef class RandomSplitter(BaseDenseSplitter): # Reorganize into samples[start:best.pos] + samples[best.pos:end] feature_stride = X_feature_stride * best.feature if best.pos < end: + setup_cat_cache(self.cat_cache, best.split_value.cat_split, + self.n_categories[best.feature]) if current.feature != best.feature: partition_end = end p = start while p < partition_end: - if X[X_sample_stride * samples[p] + feature_stride] <= best.threshold: + if goes_left(X[X_sample_stride * samples[p] + feature_stride], + best.split_value, self.n_categories[best.feature], + self.cat_cache): p += 1 else: partition_end -= 1 - tmp = samples[partition_end] - samples[partition_end] = samples[p] - samples[p] = tmp + samples[partition_end], samples[p] = ( + samples[p], samples[partition_end]) self.criterion.reset() @@ -881,7 +1050,7 @@ cdef class BaseSparseSplitter(Splitter): def __cinit__(self, Criterion criterion, SIZE_t max_features, SIZE_t min_samples_leaf, double min_weight_leaf, - object random_state, bint presort): + object random_state, bint presort, bint breiman_shortcut): # Parent __cinit__ is automatically called self.X_data = NULL @@ -902,6 +1071,7 @@ cdef class BaseSparseSplitter(Splitter): object X, np.ndarray[DOUBLE_t, ndim=2, mode="c"] y, DOUBLE_t* sample_weight, + INT32_t[:] n_categories, np.ndarray X_idx_sorted=None) except -1: """Initialize the splitter @@ -909,7 +1079,7 @@ cdef class BaseSparseSplitter(Splitter): or 0 otherwise. """ # Call parent init - Splitter.init(self, X, y, sample_weight) + Splitter.init(self, X, y, sample_weight, n_categories) if not isinstance(X, csc_matrix): raise ValueError("X should be in csc format") @@ -929,8 +1099,8 @@ cdef class BaseSparseSplitter(Splitter): self.n_total_samples = n_total_samples # Initialize auxiliary array used to perform split - safe_realloc(&self.index_to_samples, n_total_samples) - safe_realloc(&self.sorted_samples, n_samples) + safe_realloc(&self.index_to_samples, n_total_samples, sizeof(SIZE_t)) + safe_realloc(&self.sorted_samples, n_samples, sizeof(SIZE_t)) cdef SIZE_t* index_to_samples = self.index_to_samples cdef SIZE_t p @@ -1237,7 +1407,7 @@ cdef class BestSparseSplitter(BaseSparseSplitter): cdef double best_proxy_improvement = - INFINITY cdef SIZE_t f_i = n_features - cdef SIZE_t f_j, p, tmp + cdef SIZE_t f_j, p cdef SIZE_t n_visited_features = 0 # Number of features discovered to be constant during the split search cdef SIZE_t n_found_constants = 0 @@ -1292,9 +1462,8 @@ cdef class BestSparseSplitter(BaseSparseSplitter): if f_j < n_known_constants: # f_j in the interval [n_drawn_constants, n_known_constants[ - tmp = features[f_j] - features[f_j] = features[n_drawn_constants] - features[n_drawn_constants] = tmp + features[f_j], features[n_drawn_constants] = ( + features[n_drawn_constants], features[f_j]) n_drawn_constants += 1 @@ -1386,12 +1555,12 @@ cdef class BestSparseSplitter(BaseSparseSplitter): if current_proxy_improvement > best_proxy_improvement: best_proxy_improvement = current_proxy_improvement # sum of halves used to avoid infinite values - current.threshold = Xf[p_prev] / 2.0 + Xf[p] / 2.0 + current.split_value.threshold = Xf[p_prev] / 2.0 + Xf[p] / 2.0 - if ((current.threshold == Xf[p]) or - (current.threshold == INFINITY) or - (current.threshold == -INFINITY)): - current.threshold = Xf[p_prev] + if ((current.split_value.threshold == Xf[p]) or + (current.split_value.threshold == INFINITY) or + (current.split_value.threshold == -INFINITY)): + current.split_value.threshold = Xf[p_prev] best = current @@ -1400,7 +1569,7 @@ cdef class BestSparseSplitter(BaseSparseSplitter): self.extract_nnz(best.feature, &end_negative, &start_positive, &is_samples_sorted) - self._partition(best.threshold, end_negative, start_positive, + self._partition(best.split_value.threshold, end_negative, start_positive, best.pos) self.criterion.reset() @@ -1472,7 +1641,7 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): cdef DTYPE_t current_feature_value cdef SIZE_t f_i = n_features - cdef SIZE_t f_j, p, tmp + cdef SIZE_t f_j, p cdef SIZE_t n_visited_features = 0 # Number of features discovered to be constant during the split search cdef SIZE_t n_found_constants = 0 @@ -1528,9 +1697,8 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): if f_j < n_known_constants: # f_j in the interval [n_drawn_constants, n_known_constants[ - tmp = features[f_j] - features[f_j] = features[n_drawn_constants] - features[n_drawn_constants] = tmp + features[f_j], features[n_drawn_constants] = ( + features[n_drawn_constants], features[f_j]) n_drawn_constants += 1 @@ -1587,15 +1755,14 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): features[f_i], features[f_j] = features[f_j], features[f_i] # Draw a random threshold - current.threshold = rand_uniform(min_feature_value, - max_feature_value, - random_state) + current.split_value.threshold = rand_uniform( + min_feature_value, max_feature_value, random_state) - if current.threshold == max_feature_value: - current.threshold = min_feature_value + if current.split_value.threshold == max_feature_value: + current.split_value.threshold = min_feature_value # Partition - current.pos = self._partition(current.threshold, + current.pos = self._partition(current.split_value.threshold, end_negative, start_positive, start_positive + @@ -1631,7 +1798,7 @@ cdef class RandomSparseSplitter(BaseSparseSplitter): self.extract_nnz(best.feature, &end_negative, &start_positive, &is_samples_sorted) - self._partition(best.threshold, end_negative, start_positive, + self._partition(best.split_value.threshold, end_negative, start_positive, best.pos) self.criterion.reset() diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index 14b03103deff0..3839f837ce2d3 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -1,3 +1,5 @@ +# cython: language_level=3 + # Authors: Gilles Louppe # Peter Prettenhofer # Brian Holt @@ -19,19 +21,20 @@ ctypedef np.npy_intp SIZE_t # Type for indices and counters ctypedef np.npy_int32 INT32_t # Signed 32 bit integer ctypedef np.npy_uint32 UINT32_t # Unsigned 32 bit integer +from ._utils cimport SplitValue +from ._utils cimport SplitRecord +from ._utils cimport Node +from ._utils cimport BITSET_t from ._splitter cimport Splitter -from ._splitter cimport SplitRecord -cdef struct Node: - # Base storage structure for the nodes in a Tree object - SIZE_t left_child # id of the left child of the node - SIZE_t right_child # id of the right child of the node - SIZE_t feature # Feature used for splitting the node - DOUBLE_t threshold # Threshold value at the node - DOUBLE_t impurity # Impurity of the node (i.e., the value of the criterion) - SIZE_t n_node_samples # Number of samples at the node - DOUBLE_t weighted_n_node_samples # Weighted number of samples at the node +cdef class CategoryCacheMgr: + # Class to manage the category cache memory during Tree.apply() + + cdef SIZE_t n_nodes + cdef BITSET_t **bits + + cdef void populate(self, Node *nodes, SIZE_t n_nodes, INT32_t *n_categories) cdef class Tree: @@ -53,10 +56,12 @@ cdef class Tree: cdef Node* nodes # Array of nodes cdef double* value # (capacity, n_outputs, max_n_classes) array of values cdef SIZE_t value_stride # = n_outputs * max_n_classes + cdef INT32_t *n_categories # (n_features,) array giving number of + # categories (<0 for non-categorical) # Methods cdef SIZE_t _add_node(self, SIZE_t parent, bint is_left, bint is_leaf, - SIZE_t feature, double threshold, double impurity, + SIZE_t feature, SplitValue split_value, double impurity, SIZE_t n_node_samples, double weighted_n_samples) nogil except -1 cdef int _resize(self, SIZE_t capacity) nogil except -1 @@ -101,5 +106,6 @@ cdef class TreeBuilder: cpdef build(self, Tree tree, object X, np.ndarray y, np.ndarray sample_weight=*, + np.ndarray n_categories=*, np.ndarray X_idx_sorted=*) cdef _check_input(self, object X, np.ndarray y, np.ndarray sample_weight) diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 2aa67b0f62a17..926aeb35e1758 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -1,3 +1,4 @@ +# cython: language_level=3 # cython: cdivision=True # cython: boundscheck=False # cython: wraparound=False @@ -37,6 +38,9 @@ from ._utils cimport PriorityHeap from ._utils cimport PriorityHeapRecord from ._utils cimport safe_realloc from ._utils cimport sizet_ptr_to_ndarray +from ._utils cimport int32_ptr_to_ndarray +from ._utils cimport setup_cat_cache +from ._utils cimport goes_left cdef extern from "numpy/arrayobject.h": object PyArray_NewFromDescr(object subtype, np.dtype descr, @@ -66,17 +70,37 @@ cdef SIZE_t _TREE_LEAF = TREE_LEAF cdef SIZE_t _TREE_UNDEFINED = TREE_UNDEFINED cdef SIZE_t INITIAL_STACK_SIZE = 10 -# Repeat struct definition for numpy +""" +this includes cat_split, but it breaks joblib.hash + +NODE_DTYPE = np.dtype({ + 'names': ['left_child', 'right_child', 'feature', 'threshold', 'cat_split', + 'impurity', 'n_node_samples', 'weighted_n_node_samples'], + 'formats': [np.intp, np.intp, np.intp, np.float64, np.uint64, np.float64, + np.intp, np.float64], + 'offsets': [ + &( NULL).left_child, + &( NULL).right_child, + &( NULL).feature, + &( NULL).split_value, + &( NULL).split_value, + &( NULL).impurity, + &( NULL).n_node_samples, + &( NULL).weighted_n_node_samples + ] +}) +""" + NODE_DTYPE = np.dtype({ - 'names': ['left_child', 'right_child', 'feature', 'threshold', 'impurity', - 'n_node_samples', 'weighted_n_node_samples'], - 'formats': [np.intp, np.intp, np.intp, np.float64, np.float64, np.intp, - np.float64], + 'names': ['left_child', 'right_child', 'feature', 'threshold', + 'impurity', 'n_node_samples', 'weighted_n_node_samples'], + 'formats': [np.intp, np.intp, np.intp, np.float64, np.float64, + np.intp, np.float64], 'offsets': [ &( NULL).left_child, &( NULL).right_child, &( NULL).feature, - &( NULL).threshold, + &( NULL).split_value, &( NULL).impurity, &( NULL).n_node_samples, &( NULL).weighted_n_node_samples @@ -92,6 +116,7 @@ cdef class TreeBuilder: cpdef build(self, Tree tree, object X, np.ndarray y, np.ndarray sample_weight=None, + np.ndarray n_categories=None, np.ndarray X_idx_sorted=None): """Build a decision tree from the training set (X, y).""" pass @@ -144,6 +169,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cpdef build(self, Tree tree, object X, np.ndarray y, np.ndarray sample_weight=None, + np.ndarray n_categories=None, np.ndarray X_idx_sorted=None): """Build a decision tree from the training set (X, y).""" @@ -154,6 +180,9 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): if sample_weight is not None: sample_weight_ptr = sample_weight.data + if n_categories is not None: + n_categories = np.asarray(n_categories, dtype=np.int32, order='C') + # Initial capacity cdef int init_capacity @@ -174,7 +203,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cdef double min_impurity_split = self.min_impurity_split # Recursive partition (without actual recursion) - splitter.init(X, y, sample_weight_ptr, X_idx_sorted) + splitter.init(X, y, sample_weight_ptr, n_categories, X_idx_sorted) cdef SIZE_t start cdef SIZE_t end @@ -241,7 +270,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): min_impurity_decrease)) node_id = tree._add_node(parent, is_left, is_leaf, split.feature, - split.threshold, impurity, n_node_samples, + split.split_value, impurity, n_node_samples, weighted_n_node_samples) if node_id == (-1): @@ -314,6 +343,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): cpdef build(self, Tree tree, object X, np.ndarray y, np.ndarray sample_weight=None, + np.ndarray n_categories=None, np.ndarray X_idx_sorted=None): """Build a decision tree from the training set (X, y).""" @@ -324,6 +354,9 @@ cdef class BestFirstTreeBuilder(TreeBuilder): if sample_weight is not None: sample_weight_ptr = sample_weight.data + if n_categories is not None: + n_categories = np.asarray(n_categories, dtype=np.int32, order='C') + # Parameters cdef Splitter splitter = self.splitter cdef SIZE_t max_leaf_nodes = self.max_leaf_nodes @@ -332,7 +365,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): cdef SIZE_t min_samples_split = self.min_samples_split # Recursive partition (without actual recursion) - splitter.init(X, y, sample_weight_ptr, X_idx_sorted) + splitter.init(X, y, sample_weight_ptr, n_categories, X_idx_sorted) cdef PriorityHeap frontier = PriorityHeap(INITIAL_STACK_SIZE) cdef PriorityHeapRecord record @@ -373,7 +406,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): node.left_child = _TREE_LEAF node.right_child = _TREE_LEAF node.feature = _TREE_UNDEFINED - node.threshold = _TREE_UNDEFINED + node.split_value.threshold = _TREE_UNDEFINED else: # Node is expandable @@ -466,7 +499,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): if parent != NULL else _TREE_UNDEFINED, is_left, is_leaf, - split.feature, split.threshold, impurity, n_node_samples, + split.feature, split.split_value, impurity, n_node_samples, weighted_n_node_samples) if node_id == (-1): return -1 @@ -499,6 +532,46 @@ cdef class BestFirstTreeBuilder(TreeBuilder): return 0 +cdef class CategoryCacheMgr: + """Class to manage the category cache memory during Tree.apply() + """ + + def __cinit__(self): + self.n_nodes = 0 + self.bits = NULL + + def _dealloc__(self): + cdef int i + + if self.bits != NULL: + for i in range(self.n_nodes): + free(self.bits[i]) + free(self.bits) + + cdef void populate(self, Node *nodes, SIZE_t n_nodes, + INT32_t *n_categories): + cdef SIZE_t i + cdef INT32_t ncat + + if nodes == NULL or n_categories == NULL: + return + + self.n_nodes = n_nodes + safe_realloc( &self.bits, n_nodes, sizeof(void *)) + for i in range(n_nodes): + self.bits[i] = NULL + if nodes[i].left_child != _TREE_LEAF: + ncat = n_categories[nodes[i].feature] + if ncat > 0: + cache_size = (ncat + 63) // 64 + safe_realloc(&self.bits[i], + cache_size, + sizeof(BITSET_t)) + setup_cat_cache(self.bits[i], + nodes[i].split_value.cat_split, + ncat) + + # ============================================================================= # Tree # ============================================================================= @@ -546,6 +619,10 @@ cdef class Tree: value : array of double, shape [node_count, n_outputs, max_n_classes] Contains the constant prediction value of each node. + n_categories : array of int32, shape [n_features] + Number of expected category values for categorical features, or + -1 for non-categorical features. + impurity : array of double, shape [node_count] impurity[i] holds the impurity (i.e., the value of the splitting criterion) at node i. @@ -603,14 +680,20 @@ cdef class Tree: def __get__(self): return self._get_value_ndarray()[:self.node_count] + property n_categories: + def __get__(self): + return int32_ptr_to_ndarray(self.n_categories, self.n_features).copy() + def __cinit__(self, int n_features, np.ndarray[SIZE_t, ndim=1] n_classes, - int n_outputs): + int n_outputs, np.ndarray[INT32_t, ndim=1] n_categories): """Constructor.""" # Input/Output layout self.n_features = n_features self.n_outputs = n_outputs self.n_classes = NULL - safe_realloc(&self.n_classes, n_outputs) + self.n_categories = NULL + safe_realloc(&self.n_classes, n_outputs, sizeof(SIZE_t)) + safe_realloc(&self.n_categories, n_features, sizeof(INT32_t)) self.max_n_classes = np.max(n_classes) self.value_stride = n_outputs * self.max_n_classes @@ -618,6 +701,16 @@ cdef class Tree: cdef SIZE_t k for k in range(n_outputs): self.n_classes[k] = n_classes[k] + for k in range(n_features): + self.n_categories[k] = n_categories[k] + + # Ensure cython and numpy node sizes match up + np_node_size = ( NODE_DTYPE).itemsize + node_size = sizeof(Node) + if (np_node_size != node_size): + raise TypeError('Size of numpy NODE_DTYPE ({} bytes) does not' + ' match size of Node ({} bytes)'.format( + np_node_size, node_size)) # Inner structures self.max_depth = 0 @@ -632,12 +725,15 @@ cdef class Tree: free(self.n_classes) free(self.value) free(self.nodes) + free(self.n_categories) def __reduce__(self): """Reduce re-implementation, for pickling.""" return (Tree, (self.n_features, sizet_ptr_to_ndarray(self.n_classes, self.n_outputs), - self.n_outputs), self.__getstate__()) + self.n_outputs, + int32_ptr_to_ndarray(self.n_categories, self.n_features)), + self.__getstate__()) def __getstate__(self): """Getstate re-implementation, for pickling.""" @@ -708,8 +804,8 @@ cdef class Tree: else: capacity = 2 * self.capacity - safe_realloc(&self.nodes, capacity) - safe_realloc(&self.value, capacity * self.value_stride) + safe_realloc(&self.nodes, capacity, sizeof(Node)) + safe_realloc(&self.value, capacity * self.value_stride, sizeof(double)) # value memory is initialised to 0 to enable classifier argmax if capacity > self.capacity: @@ -725,7 +821,7 @@ cdef class Tree: return 0 cdef SIZE_t _add_node(self, SIZE_t parent, bint is_left, bint is_leaf, - SIZE_t feature, double threshold, double impurity, + SIZE_t feature, SplitValue split_value, double impurity, SIZE_t n_node_samples, double weighted_n_node_samples) nogil except -1: """Add a node to the tree. @@ -755,12 +851,12 @@ cdef class Tree: node.left_child = _TREE_LEAF node.right_child = _TREE_LEAF node.feature = _TREE_UNDEFINED - node.threshold = _TREE_UNDEFINED + node.split_value.threshold = _TREE_UNDEFINED else: # left_child and right_child will be set later node.feature = feature - node.threshold = threshold + node.split_value = split_value self.node_count += 1 @@ -806,17 +902,24 @@ cdef class Tree: # Initialize auxiliary data-structure cdef Node* node = NULL cdef SIZE_t i = 0 + cache_mgr = CategoryCacheMgr() + cache_mgr.populate(self.nodes, self.node_count, self.n_categories) + cdef BITSET_t** cat_caches = cache_mgr.bits + cdef BITSET_t* cache = NULL with nogil: for i in range(n_samples): node = self.nodes + cache = cat_caches[0] # While node not a leaf while node.left_child != _TREE_LEAF: # ... and node.right_child != _TREE_LEAF: - if X_ptr[X_sample_stride * i + - X_fx_stride * node.feature] <= node.threshold: + if goes_left(X_ptr[X_sample_stride * i + X_fx_stride * node.feature], + node.split_value, self.n_categories[node.feature], cache): + cache = cat_caches[node.left_child] node = &self.nodes[node.left_child] else: + cache = cat_caches[node.right_child] node = &self.nodes[node.right_child] out_ptr[i] = (node - self.nodes) # node offset @@ -857,20 +960,25 @@ cdef class Tree: cdef DTYPE_t* X_sample = NULL cdef SIZE_t i = 0 cdef INT32_t k = 0 + cache_mgr = CategoryCacheMgr() + cache_mgr.populate(self.nodes, self.node_count, self.n_categories) + cdef BITSET_t** cat_caches = cache_mgr.bits + cdef BITSET_t* cache = NULL # feature_to_sample as a data structure records the last seen sample # for each feature; functionally, it is an efficient way to identify # which features are nonzero in the present sample. cdef SIZE_t* feature_to_sample = NULL - safe_realloc(&X_sample, n_features) - safe_realloc(&feature_to_sample, n_features) + safe_realloc(&X_sample, n_features, sizeof(DTYPE_t)) + safe_realloc(&feature_to_sample, n_features, sizeof(SIZE_t)) with nogil: memset(feature_to_sample, -1, n_features * sizeof(SIZE_t)) for i in range(n_samples): node = self.nodes + cache = cat_caches[0] for k in range(X_indptr[i], X_indptr[i + 1]): feature_to_sample[X_indices[k]] = i @@ -885,9 +993,12 @@ cdef class Tree: else: feature_value = 0. - if feature_value <= node.threshold: + if goes_left(feature_value, node.split_value, + self.n_categories[node.feature], cache): + cache = cat_caches[node.left_child] node = &self.nodes[node.left_child] else: + cache = cat_caches[node.right_child] node = &self.nodes[node.right_child] out_ptr[i] = (node - self.nodes) # node offset @@ -935,10 +1046,15 @@ cdef class Tree: # Initialize auxiliary data-structure cdef Node* node = NULL cdef SIZE_t i = 0 + cache_mgr = CategoryCacheMgr() + cache_mgr.populate(self.nodes, self.node_count, self.n_categories) + cdef BITSET_t** cat_caches = cache_mgr.bits + cdef BITSET_t* cache = NULL with nogil: for i in range(n_samples): node = self.nodes + cache = cat_caches[0] indptr_ptr[i + 1] = indptr_ptr[i] # Add all external nodes @@ -947,10 +1063,12 @@ cdef class Tree: indices_ptr[indptr_ptr[i + 1]] = (node - self.nodes) indptr_ptr[i + 1] += 1 - if X_ptr[X_sample_stride * i + - X_fx_stride * node.feature] <= node.threshold: + if goes_left(X_ptr[X_sample_stride * i + X_fx_stride * node.feature], + node.split_value, self.n_categories[node.feature], cache): + cache = cat_caches[node.left_child] node = &self.nodes[node.left_child] else: + cache = cat_caches[node.right_child] node = &self.nodes[node.right_child] # Add the leave node @@ -1003,20 +1121,25 @@ cdef class Tree: cdef DTYPE_t* X_sample = NULL cdef SIZE_t i = 0 cdef INT32_t k = 0 + cache_mgr = CategoryCacheMgr() + cache_mgr.populate(self.nodes, self.node_count, self.n_categories) + cdef BITSET_t** cat_caches = cache_mgr.bits + cdef BITSET_t* cache = NULL # feature_to_sample as a data structure records the last seen sample # for each feature; functionally, it is an efficient way to identify # which features are nonzero in the present sample. cdef SIZE_t* feature_to_sample = NULL - safe_realloc(&X_sample, n_features) - safe_realloc(&feature_to_sample, n_features) + safe_realloc(&X_sample, n_features, sizeof(DTYPE_t)) + safe_realloc(&feature_to_sample, n_features, sizeof(SIZE_t)) with nogil: memset(feature_to_sample, -1, n_features * sizeof(SIZE_t)) for i in range(n_samples): node = self.nodes + cache = cat_caches[0] indptr_ptr[i + 1] = indptr_ptr[i] for k in range(X_indptr[i], X_indptr[i + 1]): @@ -1036,9 +1159,12 @@ cdef class Tree: else: feature_value = 0. - if feature_value <= node.threshold: + if goes_left(feature_value, node.split_value, + self.n_categories[node.feature], cache): + cache = cat_caches[node.left_child] node = &self.nodes[node.left_child] else: + cache = cat_caches[node.right_child] node = &self.nodes[node.right_child] # Add the leave node @@ -1057,7 +1183,6 @@ cdef class Tree: return out - cpdef compute_feature_importances(self, normalize=True): """Computes the importance of each feature (aka variable).""" cdef Node* left diff --git a/sklearn/tree/_utils.pxd b/sklearn/tree/_utils.pxd index 04806ade180c2..7d2cf332be241 100644 --- a/sklearn/tree/_utils.pxd +++ b/sklearn/tree/_utils.pxd @@ -1,3 +1,5 @@ +# cython: language_level=3 + # Authors: Gilles Louppe # Peter Prettenhofer # Arnaud Joly @@ -10,14 +12,63 @@ import numpy as np cimport numpy as np -from _tree cimport Node -from sklearn.neighbors.quad_tree cimport Cell + +from ..neighbors.quad_tree cimport Cell ctypedef np.npy_float32 DTYPE_t # Type of X ctypedef np.npy_float64 DOUBLE_t # Type of y, sample_weight ctypedef np.npy_intp SIZE_t # Type for indices and counters ctypedef np.npy_int32 INT32_t # Signed 32 bit integer ctypedef np.npy_uint32 UINT32_t # Unsigned 32 bit integer +ctypedef np.npy_uint64 UINT64_t # Unsigned 64 bit integer +ctypedef UINT64_t BITSET_t + +ctypedef union SplitValue: + # Union type to generalize the concept of a threshold to categorical + # features. The floating point view, i.e. ``SplitValue.threshold`` is used + # for numerical features, where feature values less than or equal to the + # threshold go left, and values greater than the threshold go right. + # + # For categorical features, the BITSET_t view (`SplitValue.cat_split``) is + # used. It works in one of two ways, indicated by the value of its least + # significant bit (LSB). If the LSB is 0, then cat_split acts as a bitfield + # for up to 64 categories, sending samples left if the bit corresponding to + # their category is 1 or right if it is 0. If the LSB is 1, then the most + # significant 32 bits of cat_split make a random seed. To evaluate a + # sample, use the random seed to flip a coin (category_value + 1) times and + # send it left if the last flip gives 1; otherwise right. This second + # method allows up to 2**31 category values, but can only be used for + # RandomSplitter. + DOUBLE_t threshold + BITSET_t cat_split + + +ctypedef struct SplitRecord: + # Data to track sample split + SIZE_t feature # Which feature to split on. + SIZE_t pos # Split samples array at the given position, + # i.e. count of samples below threshold for feature. + # pos is >= end if the node is a leaf. + SplitValue split_value # Generalized threshold for categorical and + # non-categorical features + double improvement # Impurity improvement given parent node. + double impurity_left # Impurity of the left split. + double impurity_right # Impurity of the right split. + + +cdef struct Node: + # Base storage structure for the nodes in a Tree object + + SIZE_t left_child # id of the left child of the node + SIZE_t right_child # id of the right child of the node + SIZE_t feature # Feature used for splitting the node + SplitValue split_value # Generalized threshold for categorical and + # non-categorical features + DOUBLE_t impurity # Impurity of the node (i.e., the value of the criterion) + SIZE_t n_node_samples # Number of samples at the node + DOUBLE_t weighted_n_node_samples # Weighted number of samples at the node + +# cdef struct Node # Forward declaration cdef enum: # Max value for our rand_r replacement (near the bottom). @@ -38,18 +89,22 @@ ctypedef fused realloc_ptr: (unsigned char*) (WeightedPQueueRecord*) (DOUBLE_t*) - (DOUBLE_t**) (Node*) (Cell*) (Node**) (StackRecord*) (PriorityHeapRecord*) + (void**) + (INT32_t*) + (UINT32_t*) + (BITSET_t*) -cdef realloc_ptr safe_realloc(realloc_ptr* p, size_t nelems) nogil except * +cdef realloc_ptr safe_realloc(realloc_ptr* p, size_t nelems, size_t elem_bytes) nogil except * cdef np.ndarray sizet_ptr_to_ndarray(SIZE_t* data, SIZE_t size) +cdef np.ndarray int32_ptr_to_ndarray(INT32_t* data, SIZE_t size) cdef SIZE_t rand_int(SIZE_t low, SIZE_t high, UINT32_t* random_state) nogil @@ -61,6 +116,15 @@ cdef double rand_uniform(double low, double high, cdef double log(double x) nogil + +cdef void setup_cat_cache(BITSET_t* cachebits, UINT64_t cat_split, + INT32_t n_categories) nogil + + +cdef bint goes_left(DTYPE_t feature_value, SplitValue split, + INT32_t n_categories, BITSET_t* cachebits) nogil + + # ============================================================================= # Stack data structure # ============================================================================= @@ -167,3 +231,13 @@ cdef class WeightedMedianCalculator: self, DOUBLE_t data, DOUBLE_t weight, DOUBLE_t original_median) nogil cdef DOUBLE_t get_median(self) nogil + + +cdef BITSET_t bs_set(BITSET_t value, SIZE_t i) nogil +cdef BITSET_t bs_reset(BITSET_t value, SIZE_t i) nogil +cdef BITSET_t bs_flip(BITSET_t value, SIZE_t i) nogil +cdef BITSET_t bs_flip_all(BITSET_t value, SIZE_t n_low_bits) nogil +cdef bint bs_get(BITSET_t value, SIZE_t i) nogil +cdef BITSET_t bs_from_template(UINT64_t template, + INT32_t *cat_offs, + SIZE_t ncats_present) nogil diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index 9c646730d170b..ecfcbdf308c52 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -1,3 +1,4 @@ +# cython: language_level=3 # cython: cdivision=True # cython: boundscheck=False # cython: wraparound=False @@ -24,15 +25,15 @@ np.import_array() # Helper functions # ============================================================================= -cdef realloc_ptr safe_realloc(realloc_ptr* p, size_t nelems) nogil except *: +cdef realloc_ptr safe_realloc(realloc_ptr* p, size_t nelems, size_t nbytes_elem) nogil except *: # sizeof(realloc_ptr[0]) would be more like idiomatic C, but causes Cython # 0.20.1 to crash. - cdef size_t nbytes = nelems * sizeof(p[0][0]) - if nbytes / sizeof(p[0][0]) != nelems: + cdef size_t nbytes = nelems * nbytes_elem + if nbytes / nbytes_elem != nelems: # Overflow in the multiplication with gil: raise MemoryError("could not allocate (%d * %d) bytes" - % (nelems, sizeof(p[0][0]))) + % (nelems, nbytes_elem)) cdef realloc_ptr tmp = realloc(p[0], nbytes) if tmp == NULL: with gil: @@ -46,7 +47,7 @@ def _realloc_test(): # Helper for tests. Tries to allocate (-1) / 2 * sizeof(size_t) # bytes, which will always overflow. cdef SIZE_t* p = NULL - safe_realloc(&p, (-1) / 2) + safe_realloc(&p, (-1) / 2, sizeof(SIZE_t)) if p != NULL: free(p) assert False @@ -69,6 +70,13 @@ cdef inline np.ndarray sizet_ptr_to_ndarray(SIZE_t* data, SIZE_t size): return np.PyArray_SimpleNewFromData(1, shape, np.NPY_INTP, data).copy() +cdef inline np.ndarray int32_ptr_to_ndarray(INT32_t* data, SIZE_t size): + """Encapsulate data into a 1D numpy array of int32's.""" + cdef np.npy_intp shape[1] + shape[0] = size + return np.PyArray_SimpleNewFromData(1, shape, np.NPY_INT32, data) + + cdef inline SIZE_t rand_int(SIZE_t low, SIZE_t high, UINT32_t* random_state) nogil: """Generate a random integer in [low; end).""" @@ -86,6 +94,103 @@ cdef inline double log(double x) nogil: return ln(x) / ln(2.0) +cdef inline void setup_cat_cache(BITSET_t* cachebits, BITSET_t cat_split, + INT32_t n_categories) nogil: + """Populate the bits of the category cache from a split. + + Attributes + ---------- + cachebits : BITSET_t* + This is a pointer to the output array. The size of the array should be + ``ceil(n_categories / 64)``. This function assumes the required + memory is allocated for the array by the caller. + + cat_split : BITSET_t + If ``least significant bit == 0``: + It stores the split of the maximum 64 categories in its bits. + This is used in `BestSplitter`, and without loss of generality it + is assumed to be even, i.e. for any odd value there is an + equivalent even ``cat_split``. + If ``least significant bit == 1``: + It is a random split, and the 32 most significant bits of + ``cat_split`` contain the random seed of the split. The + ``n_categories`` lowest bits of ``cachebits`` are then filled with + random zeros and ones given the random seed. + + n_categories : INT32_t + The number of categories. + """ + cdef INT32_t j + cdef UINT32_t rng_seed, val + cdef SIZE_t cache_size = (n_categories + 63) // 64 + if n_categories > 0: + if cat_split & 1: + # RandomSplitter + for j in range(cache_size): + cachebits[j] = 0 + rng_seed = cat_split >> 32 + for j in range(n_categories): + val = rand_int(0, 2, &rng_seed) + if not val: + continue + cachebits[j // 64] = bs_set(cachebits[j // 64], j % 64) + else: + # BestSplitter + # In practice, cache_size here should ALWAYS be 1 + # XXX TODO: check cache_size == 1? + cachebits[0] = cat_split + + +cdef inline bint goes_left(DTYPE_t feature_value, SplitValue split, + INT32_t n_categories, BITSET_t *cachebits) nogil: + """Determine whether a sample goes to the left or right child node. + + For numerical features, ``(-inf, split.threshold]`` is the left child, and + ``(split.threshold, inf)`` the right child. + + For categorical features, if the corresponding bit for the category is set + in cachebits, the left child isused, and if not set, the right child. If + the given input category is larger than the ``n_categories``, the right + child is assumed. + + Attributes + ---------- + feature_value : DTYPE_t + The value of the feature for which the decision needs to be made. + + split : SplitValue + The union (of DOUBLE_t and BITSET_t) indicating the split. However, it + is used (as a DOUBLE_t) only for numerical features. + + n_categories : INT32_t + The number of categories present in the feature in question. The + feature is considered a numerical one and not a categorical one if + n_categories is negative. + + cachebits : BITSET_t* + The array containing the expantion of split.cat_split. The function + setup_cat_cache is the one filling it. + + Returns + ------- + result : bint + Indicating whether the left branch should be used. + """ + cdef SIZE_t idx, shift + + if n_categories < 0: + # Non-categorical feature + return feature_value <= split.threshold + else: + # Categorical feature, using bit cache + if ( feature_value) < n_categories: + idx = ( feature_value) // 64 + offset = ( feature_value) % 64 + return bs_get(cachebits[idx], offset) + else: + return 0 + + # ============================================================================= # Stack data structure # ============================================================================= @@ -132,7 +237,7 @@ cdef class Stack: if top >= self.capacity: self.capacity *= 2 # Since safe_realloc can raise MemoryError, use `except -1` - safe_realloc(&self.stack_, self.capacity) + safe_realloc(&self.stack_, self.capacity, sizeof(StackRecord)) stack = self.stack_ stack[top].start = start @@ -192,7 +297,7 @@ cdef class PriorityHeap: def __cinit__(self, SIZE_t capacity): self.capacity = capacity self.heap_ptr = 0 - safe_realloc(&self.heap_, capacity) + safe_realloc(&self.heap_, capacity, sizeof(PriorityHeapRecord)) def __dealloc__(self): free(self.heap_) @@ -248,7 +353,7 @@ cdef class PriorityHeap: if heap_ptr >= self.capacity: self.capacity *= 2 # Since safe_realloc can raise MemoryError, use `except -1` - safe_realloc(&self.heap_, self.capacity) + safe_realloc(&self.heap_, self.capacity, sizeof(PriorityHeapRecord)) # Put element as last element of heap heap = self.heap_ @@ -318,7 +423,7 @@ cdef class WeightedPQueue: def __cinit__(self, SIZE_t capacity): self.capacity = capacity self.array_ptr = 0 - safe_realloc(&self.array_, capacity) + safe_realloc(&self.array_, capacity, sizeof(WeightedPQueueRecord)) def __dealloc__(self): free(self.array_) @@ -331,7 +436,7 @@ cdef class WeightedPQueue: """ self.array_ptr = 0 # Since safe_realloc can raise MemoryError, use `except *` - safe_realloc(&self.array_, self.capacity) + safe_realloc(&self.array_, self.capacity, sizeof(WeightedPQueueRecord)) return 0 cdef bint is_empty(self) nogil: @@ -354,7 +459,7 @@ cdef class WeightedPQueue: if array_ptr >= self.capacity: self.capacity *= 2 # Since safe_realloc can raise MemoryError, use `except -1` - safe_realloc(&self.array_, self.capacity) + safe_realloc(&self.array_, self.capacity, sizeof(WeightedPQueueRecord)) # Put element as last element of array array = self.array_ @@ -666,3 +771,29 @@ cdef class WeightedMedianCalculator: if self.sum_w_0_k > (self.total_weight / 2.0): # whole median return self.samples.get_value_from_index(self.k-1) + + +cdef inline BITSET_t bs_set(BITSET_t value, SIZE_t i) nogil: + return value | ( 1) << i + +cdef inline BITSET_t bs_reset(BITSET_t value, SIZE_t i) nogil: + return value & ~(( 1) << i) + +cdef inline BITSET_t bs_flip(BITSET_t value, SIZE_t i) nogil: + return value ^ ( 1) << i + +cdef inline BITSET_t bs_flip_all(BITSET_t value, SIZE_t n_low_bits) nogil: + return (~value) & ((~( 0)) >> (64 - n_low_bits)) + +cdef inline bint bs_get(BITSET_t value, SIZE_t i) nogil: + return (value >> i) & ( 1) + +cdef inline BITSET_t bs_from_template(UINT64_t template, + INT32_t *cat_offs, + SIZE_t ncats_present) nogil: + cdef SIZE_t i + cdef BITSET_t value = 0 + for i in range(ncats_present): + value |= (template & + (( 1) << i)) << cat_offs[i] + return value diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 4949bfe72a92b..7554ba22acf80 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -177,8 +177,8 @@ def assert_tree_equal(d, s, message): assert_array_equal(d.feature[internal], s.feature[internal], message + ": inequal features") - assert_array_equal(d.threshold[internal], s.threshold[internal], - message + ": inequal threshold") + assert_array_almost_equal(d.threshold[internal], s.threshold[internal], + err_msg=message + ": inequal threshold") assert_array_equal(d.n_node_samples.sum(), s.n_node_samples.sum(), message + ": inequal sum(n_node_samples)") assert_array_equal(d.n_node_samples, s.n_node_samples, @@ -1813,6 +1813,31 @@ def _pickle_copy(obj): assert_equal(n_samples, n_samples_) +@pytest.mark.parametrize('name', ALL_TREES) +@pytest.mark.parametrize('categorical', ['invalid string', [[0]], + [False, False, False], [1, 2], [-3], + [0, 0, 1]]) +def test_invalid_categorical(name, categorical): + Tree = ALL_TREES[name] + with pytest.raises(ValueError, match="Invalid value for categorical"): + Tree(categorical=categorical).fit(X, y) + + +@pytest.mark.parametrize('name', ALL_TREES) +def test_no_sparse_with_categorical(name): + # Currently we do not support sparse categorical features + X, y, X_sparse = [DATASETS['clf_small'][z] + for z in ['X', 'y', 'X_sparse']] + Tree = ALL_TREES[name] + with pytest.raises(NotImplementedError, + match="Categorical features not supported with sparse"): + Tree(categorical=[6, 10]).fit(X_sparse, y) + + with pytest.raises(NotImplementedError, + match="Categorical features not supported with sparse"): + Tree(categorical=[6, 10]).fit(X, y).predict(X_sparse) + + def test_empty_leaf_infinite_threshold(): # try to make empty leaf by using near infinite value. data = np.random.RandomState(0).randn(100, 11) * 2e38 @@ -1830,6 +1855,97 @@ def test_empty_leaf_infinite_threshold(): assert len(empty_leaf) == 0 +def _make_categorical(n_rows: int, n_numerical: int, n_categorical: int, + cat_size: int, n_num_meaningful: int, + n_cat_meaningful: int, regression: bool, + return_tuple: bool, random_state: int): + + from sklearn.preprocessing import OneHotEncoder + np.random.seed(random_state) + numeric = np.random.standard_normal((n_rows, n_numerical)) + categorical = np.random.randint(0, cat_size, (n_rows, n_categorical)) + categorical_ohe = OneHotEncoder(categories='auto').fit_transform( + categorical[:, :n_cat_meaningful]) + + data_meaningful = np.hstack((numeric[:, :n_num_meaningful], + categorical_ohe.todense())) + _, cols = data_meaningful.shape + coefs = np.random.standard_normal(cols) + y = np.dot(data_meaningful, coefs) + y = np.asarray(y).reshape(-1) + X = np.hstack((numeric, categorical)) + + if not regression: + y = (y < y.mean()).astype(int) + + meaningful_features = np.r_[np.arange(n_num_meaningful), + np.arange(n_cat_meaningful) + + n_numerical] + + if return_tuple: + return X, y, meaningful_features + else: + return {'X': X, + 'y': y, + 'meaningful_features': meaningful_features} + + +@pytest.mark.parametrize('model', ALL_TREES) +@pytest.mark.parametrize('data_params', [ + {'n_rows': 1000, + 'n_numerical': 5, + 'n_categorical': 5, + 'cat_size': 3, + 'n_num_meaningful': 2, + 'n_cat_meaningful': 3}, + {'n_rows': 1000, + 'n_numerical': 0, + 'n_categorical': 5, + 'cat_size': 3, + 'n_num_meaningful': 0, + 'n_cat_meaningful': 3}, + {'n_rows': 1000, + 'n_numerical': 5, + 'n_categorical': 5, + 'cat_size': 64, + 'n_num_meaningful': 0, + 'n_cat_meaningful': 2}, + {'n_rows': 1000, + 'n_numerical': 5, + 'n_categorical': 5, + 'cat_size': 3, + 'n_num_meaningful': 0, + 'n_cat_meaningful': 3}]) +def test_categorical_data(model, data_params): + # DecisionTrees are too slow for large category sizes. + if data_params['cat_size'] > 8 and 'DecisionTree' in model: + pass + + X, y, meaningful_features = _make_categorical( + **data_params, + regression=model in REG_TREES, + return_tuple=True, + random_state=42) + rows, cols = X.shape + categorical_features = (np.arange(data_params['n_categorical']) + + data_params['n_numerical']) + + model = ALL_TREES[model](random_state=42, + categorical=categorical_features).fit(X, y) + fi = model.feature_importances_ + bad_features = np.array([True]*cols) + bad_features[meaningful_features] = False + + good_ones = fi[meaningful_features] + bad_ones = fi[bad_features] + + # all good features should be more important than all bad features. + assert np.all([np.all(x > bad_ones) for x in good_ones]) + + leaves = model.tree_.children_left < 0 + assert(np.all(model.tree_.impurity[leaves] < 1e-6)) + + @pytest.mark.parametrize('name', CLF_TREES) def test_multi_target(name): Tree = CLF_TREES[name] diff --git a/sklearn/tree/tree.py b/sklearn/tree/tree.py index a07e6a0ca5d9a..25d813b306490 100644 --- a/sklearn/tree/tree.py +++ b/sklearn/tree/tree.py @@ -92,7 +92,8 @@ def __init__(self, min_impurity_decrease, min_impurity_split, class_weight=None, - presort=False): + presort=False, + categorical="none"): self.criterion = criterion self.splitter = splitter self.max_depth = max_depth @@ -106,6 +107,7 @@ def __init__(self, self.min_impurity_split = min_impurity_split self.class_weight = class_weight self.presort = presort + self.categorical = categorical def get_depth(self): """Returns the depth of the decision tree. @@ -281,6 +283,56 @@ def fit(self, X, y, sample_weight=None, check_input=True, else: sample_weight = expanded_class_weight + # Validate categorical features + if isinstance(self.categorical, str): + if self.categorical == 'none': + categorical = np.array([], dtype=np.int) + elif self.categorical == 'all': + categorical = np.arange(self.n_features_) + else: + raise ValueError("Invalid value for categorical: {}. Allowed" + " strings are 'all' or 'none'" + "".format(self.categorical)) + else: + categorical = np.atleast_1d(self.categorical).flatten() + if categorical.dtype == np.bool: + if categorical.size != self.n_features_: + raise ValueError("Invalid value for categorical: Shape of " + "boolean parameter categorical must " + "be (n_features,)") + categorical = np.nonzero(categorical)[0] + if (np.size(categorical) > self.n_features_ or + (categorical.size > 0 and + (categorical.min() < 0 or + categorical.max() >= self.n_features_))): + raise ValueError("Invalid value for categorical: Invalid shape or " + "feature index for parameter categorical " + "invalid.") + if issparse(X): + if categorical.size > 0: + raise NotImplementedError("Categorical features not supported" + " with sparse inputs") + else: + if np.any(X[:, categorical].astype(np.int) < 0): + raise ValueError("Invalid value for categorical: given values " + "for categorical features must be " + "non-negative.") + + # Calculate n_categories and verify they are all at least 1% populated + n_categories = np.array([np.int(X[:, i].max()) + 1 if i in categorical + else -1 for i in range(self.n_features_)], + dtype=np.int32) + n_cat_present = np.array([np.unique(X[:, i].astype(np.int)).size + if i in categorical else -1 + for i in range(self.n_features_)], + dtype=np.int32) + if np.any((n_cat_present < 0.01 * n_cat_present)[categorical]): + warnings.warn("At least one categorical feature has less than 1%" + " of its categories present in the sample. Runtime" + " and memory usage will be much smaller if you" + " represent the categories as sequential integers.", + UserWarning) + # Set min_weight_leaf from min_weight_fraction_leaf if sample_weight is None: min_weight_leaf = (self.min_weight_fraction_leaf * @@ -346,6 +398,12 @@ def fit(self, X, y, sample_weight=None, check_input=True, else: criterion = CRITERIA_REG[self.criterion](self.n_outputs_, n_samples) + if is_classification: + breiman_shortcut = (self.n_classes_.tolist() == [2] and + (isinstance(criterion, _criterion.Gini) or + isinstance(criterion, _criterion.Entropy))) + else: + breiman_shortcut = isinstance(criterion, _criterion.MSE) SPLITTERS = SPARSE_SPLITTERS if issparse(X) else DENSE_SPLITTERS @@ -356,9 +414,17 @@ def fit(self, X, y, sample_weight=None, check_input=True, min_samples_leaf, min_weight_leaf, random_state, - self.presort) + self.presort, + breiman_shortcut) - self.tree_ = Tree(self.n_features_, self.n_classes_, self.n_outputs_) + if (not isinstance(splitter, _splitter.RandomSplitter) and + np.max(n_categories) > 64): + raise ValueError("Categorical features with greater than 64" + " categories not supported with DecisionTree;" + " try ExtraTree.") + + self.tree_ = Tree(self.n_features_, self.n_classes_, self.n_outputs_, + n_categories) # Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise if max_leaf_nodes < 0: @@ -377,7 +443,8 @@ def fit(self, X, y, sample_weight=None, check_input=True, self.min_impurity_decrease, min_impurity_split) - builder.build(self.tree_, X, y, sample_weight, X_idx_sorted) + builder.build(self.tree_, X, y, sample_weight, n_categories, + X_idx_sorted) if self.n_outputs_ == 1: self.n_classes_ = self.n_classes_[0] @@ -393,6 +460,9 @@ def _validate_X_predict(self, X, check_input): X.indptr.dtype != np.intc): raise ValueError("No support for np.int64 index based " "sparse matrices") + if issparse(X) and np.any(self.tree_.n_categories > 0): + raise NotImplementedError("Categorical features not supported" + " with sparse inputs") n_features = X.shape[1] if self.n_features_ != n_features: @@ -666,6 +736,19 @@ class DecisionTreeClassifier(BaseDecisionTree, ClassifierMixin): When using either a smaller dataset or a restricted depth, this may speed up the training. + categorical : array-like or str + Array of feature indices, boolean array of length n_features, + ``'all'`` or ``'none'``. Indicates which features should be + considered as categorical rather than ordinal. For decision trees, + the maximum number of categories is 64. In practice, the limit will + often be lower because the process of searching for the best possible + split grows exponentially with the number of categories. However, a + shortcut due to Breiman (1984) is used when fitting data with binary + labels using the ``Gini`` or ``Entropy`` criteria. In this case, + the runtime is linear in the number of categories. Extra-random trees + have an upper limit of :math:`2^{31}` categories, and runtimes + linear in the number of categories. + Attributes ---------- classes_ : array of shape = [n_classes] or a list of such arrays @@ -757,7 +840,8 @@ def __init__(self, min_impurity_decrease=0., min_impurity_split=None, class_weight=None, - presort=False): + presort=False, + categorical="none"): super().__init__( criterion=criterion, splitter=splitter, @@ -771,7 +855,8 @@ def __init__(self, random_state=random_state, min_impurity_decrease=min_impurity_decrease, min_impurity_split=min_impurity_split, - presort=presort) + presort=presort, + categorical=categorical) def fit(self, X, y, sample_weight=None, check_input=True, X_idx_sorted=None): @@ -1018,6 +1103,18 @@ class DecisionTreeRegressor(BaseDecisionTree, RegressorMixin): When using either a smaller dataset or a restricted depth, this may speed up the training. + categorical : array-like or str + Array of feature indices, boolean array of length n_features, + ``'all'`` or ``'none'``. Indicates which features should be + considered as categorical rather than ordinal. For decision trees, + the maximum number of categories is 64. In practice, the limit will + often be lower because the process of searching for the best possible + split grows exponentially with the number of categories. However, a + shortcut due to Breiman (1984) is used when fitting data using the + ``MSE`` criterion. In this case, the runtime is linear in the number + of categories. Extra-random trees have an upper limit of :math:`2^{31}` + categories, and runtimes linear in the number of categories. + Attributes ---------- feature_importances_ : array of shape = [n_features] @@ -1100,7 +1197,8 @@ def __init__(self, max_leaf_nodes=None, min_impurity_decrease=0., min_impurity_split=None, - presort=False): + presort=False, + categorical='none'): super().__init__( criterion=criterion, splitter=splitter, @@ -1113,7 +1211,8 @@ def __init__(self, random_state=random_state, min_impurity_decrease=min_impurity_decrease, min_impurity_split=min_impurity_split, - presort=presort) + presort=presort, + categorical=categorical) def fit(self, X, y, sample_weight=None, check_input=True, X_idx_sorted=None): @@ -1295,6 +1394,15 @@ class ExtraTreeClassifier(DecisionTreeClassifier): Note that these weights will be multiplied with sample_weight (passed through the fit method) if sample_weight is specified. + categorical : array-like or str + Array of feature indices, boolean array of length n_features, + ``'all'`` or ``'none'``. Indicates which features should be + considered as categorical rather than ordinal. Extra-random trees + have an upper limit of :math:`2^{31}` categories, and runtimes + linear in the number of categories. + + .. versionadded:: 0.21 + See also -------- ExtraTreeRegressor, sklearn.ensemble.ExtraTreesClassifier, @@ -1326,7 +1434,8 @@ def __init__(self, max_leaf_nodes=None, min_impurity_decrease=0., min_impurity_split=None, - class_weight=None): + class_weight=None, + categorical='none'): super().__init__( criterion=criterion, splitter=splitter, @@ -1339,7 +1448,8 @@ def __init__(self, class_weight=class_weight, min_impurity_decrease=min_impurity_decrease, min_impurity_split=min_impurity_split, - random_state=random_state) + random_state=random_state, + categorical=categorical) class ExtraTreeRegressor(DecisionTreeRegressor): @@ -1463,6 +1573,14 @@ class ExtraTreeRegressor(DecisionTreeRegressor): Best nodes are defined as relative reduction in impurity. If None then unlimited number of leaf nodes. + categorical : array-like or str + Array of feature indices, boolean array of length n_features, + ``'all'`` or ``'none'``. Indicates which features should be + considered as categorical rather than ordinal. Extra-random trees + have an upper limit of :math:`2^{31}` categories, and runtimes + linear in the number of categories. + + .. versionadded:: 0.21 See also -------- @@ -1494,7 +1612,8 @@ def __init__(self, random_state=None, min_impurity_decrease=0., min_impurity_split=None, - max_leaf_nodes=None): + max_leaf_nodes=None, + categorical='none'): super().__init__( criterion=criterion, splitter=splitter, @@ -1506,4 +1625,5 @@ def __init__(self, max_leaf_nodes=max_leaf_nodes, min_impurity_decrease=min_impurity_decrease, min_impurity_split=min_impurity_split, - random_state=random_state) + random_state=random_state, + categorical=categorical) diff --git a/sklearn/utils/fast_dict.pxd b/sklearn/utils/fast_dict.pxd index 62e0a08739b14..56b98198ad5a1 100644 --- a/sklearn/utils/fast_dict.pxd +++ b/sklearn/utils/fast_dict.pxd @@ -8,6 +8,7 @@ integers, and values float. from libcpp.map cimport map as cpp_map # Import the C-level symbols of numpy +import numpy as np cimport numpy as np ctypedef np.float64_t DTYPE_t