Skip to content

Commit 18a4af6

Browse files
authored
Update documents and tests. (dmlc#7659)
* Revise documents after recent refactoring and cat support. * Add tests for behavior of max_depth and max_leaves.
1 parent 5eed299 commit 18a4af6

File tree

7 files changed

+142
-44
lines changed

7 files changed

+142
-44
lines changed

doc/parameter.rst

+6-8
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ Parameters for Tree Booster
7474

7575
* ``max_depth`` [default=6]
7676

77-
- Maximum depth of a tree. Increasing this value will make the model more complex and more likely to overfit. 0 is only accepted in ``lossguide`` growing policy when ``tree_method`` is set as ``hist`` or ``gpu_hist`` and it indicates no limit on depth. Beware that XGBoost aggressively consumes memory when training a deep tree.
78-
- range: [0,∞] (0 is only accepted in ``lossguide`` growing policy when ``tree_method`` is set as ``hist`` or ``gpu_hist``)
77+
- Maximum depth of a tree. Increasing this value will make the model more complex and more likely to overfit. 0 indicates no limit on depth. Beware that XGBoost aggressively consumes memory when training a deep tree. ``exact`` tree method requires non-zero value.
78+
- range: [0,∞]
7979

8080
* ``min_child_weight`` [default=1]
8181

@@ -164,7 +164,7 @@ Parameters for Tree Booster
164164

165165
- Control the balance of positive and negative weights, useful for unbalanced classes. A typical value to consider: ``sum(negative instances) / sum(positive instances)``. See :doc:`Parameters Tuning </tutorials/param_tuning>` for more discussion. Also, see Higgs Kaggle competition demo for examples: `R <https://github.com/dmlc/xgboost/blob/master/demo/kaggle-higgs/higgs-train.R>`_, `py1 <https://github.com/dmlc/xgboost/blob/master/demo/kaggle-higgs/higgs-numpy.py>`_, `py2 <https://github.com/dmlc/xgboost/blob/master/demo/kaggle-higgs/higgs-cv.py>`_, `py3 <https://github.com/dmlc/xgboost/blob/master/demo/guide-python/cross_validation.py>`_.
166166

167-
* ``updater`` [default= ``grow_colmaker,prune``]
167+
* ``updater``
168168

169169
- A comma separated string defining the sequence of tree updaters to run, providing a modular way to construct and to modify the trees. This is an advanced parameter that is usually set automatically, depending on some other parameters. However, it could be also set explicitly by a user. The following updaters exist:
170170

@@ -177,8 +177,6 @@ Parameters for Tree Booster
177177
- ``refresh``: refreshes tree's statistics and/or leaf values based on the current data. Note that no random subsampling of data rows is performed.
178178
- ``prune``: prunes the splits where loss < min_split_loss (or gamma) and nodes that have depth greater than ``max_depth``.
179179

180-
- In a distributed setting, the implicit updater sequence value would be adjusted to ``grow_histmaker,prune`` by default, and you can set ``tree_method`` as ``hist`` to use ``grow_histmaker``.
181-
182180
* ``refresh_leaf`` [default=1]
183181

184182
- This is a parameter of the ``refresh`` updater. When this flag is 1, tree leafs as well as tree nodes' stats are updated. When it is 0, only node stats are updated.
@@ -194,19 +192,19 @@ Parameters for Tree Booster
194192
* ``grow_policy`` [default= ``depthwise``]
195193

196194
- Controls a way new nodes are added to the tree.
197-
- Currently supported only if ``tree_method`` is set to ``hist`` or ``gpu_hist``.
195+
- Currently supported only if ``tree_method`` is set to ``hist``, ``approx`` or ``gpu_hist``.
198196
- Choices: ``depthwise``, ``lossguide``
199197

200198
- ``depthwise``: split at nodes closest to the root.
201199
- ``lossguide``: split at nodes with highest loss change.
202200

203201
* ``max_leaves`` [default=0]
204202

205-
- Maximum number of nodes to be added. Only relevant when ``grow_policy=lossguide`` is set.
203+
- Maximum number of nodes to be added. Not used by ``exact`` tree method.
206204

207205
* ``max_bin``, [default=256]
208206

209-
- Only used if ``tree_method`` is set to ``hist`` or ``gpu_hist``.
207+
- Only used if ``tree_method`` is set to ``hist``, ``approx`` or ``gpu_hist``.
210208
- Maximum number of discrete bins to bucket continuous features.
211209
- Increasing this number improves the optimality of splits at the cost of higher computation time.
212210

doc/treemethod.rst

+29
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,32 @@ was never tested and contained some unknown bugs, we decided to remove it and fo
114114
resources on more promising algorithms instead. For accuracy, most of the time
115115
``approx``, ``hist`` and ``gpu_hist`` are enough with some parameters tuning, so removing
116116
them don't have any real practical impact.
117+
118+
119+
**************
120+
Feature Matrix
121+
**************
122+
123+
Following table summarizes some differences in supported features between 4 tree methods,
124+
`T` means supported while `F` means unsupported.
125+
126+
+------------------+-----------+---------------------+---------------------+------------------------+
127+
| | Exact | Approx | Hist | GPU Hist |
128+
+==================+===========+=====================+=====================+========================+
129+
| grow_policy | Depthwise | depthwise/lossguide | depthwise/lossguide | depthwise/lossguide |
130+
+------------------+-----------+---------------------+---------------------+------------------------+
131+
| max_leaves | F | T | T | T |
132+
+------------------+-----------+---------------------+---------------------+------------------------+
133+
| sampling method | uniform | uniform | uniform | gradient_based/uniform |
134+
+------------------+-----------+---------------------+---------------------+------------------------+
135+
| categorical data | F | T | T | T |
136+
+------------------+-----------+---------------------+---------------------+------------------------+
137+
| External memory | F | T | P | P |
138+
+------------------+-----------+---------------------+---------------------+------------------------+
139+
| Distributed | F | T | T | T |
140+
+------------------+-----------+---------------------+---------------------+------------------------+
141+
142+
Features/parameters that are not mentioned here are universally supported for all 4 tree
143+
methods (for instance, column sampling and constraints). The `P` in external memory means
144+
partially supported. Please note that both categorical data and external memory are
145+
experimental.

doc/tutorials/feature_interaction_constraint.rst

+1-6
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ first and second constraints (``[0, 1]``, ``[2, 3, 4]``).
129129

130130
.. |fig1| image:: ../_static/feature_interaction_illustration2.svg
131131
:scale: 7%
132-
:align: middle
132+
:align: middle
133133

134134
.. |fig2| image:: ../_static/feature_interaction_illustration3.svg
135135
:scale: 7%
@@ -174,11 +174,6 @@ parameter:
174174
num_boost_round = 1000, evals = evallist,
175175
early_stopping_rounds = 10)
176176
177-
**Choice of tree construction algorithm**. To use feature interaction constraints, be sure
178-
to set the ``tree_method`` parameter to one of the following: ``exact``, ``hist``,
179-
``approx`` or ``gpu_hist``. Support for ``gpu_hist`` and ``approx`` is added only in
180-
1.0.0.
181-
182177
**************
183178
Advanced topic
184179
**************

doc/tutorials/monotonic.rst

+12-15
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,23 @@
22
Monotonic Constraints
33
#####################
44

5-
It is often the case in a modeling problem or project that the functional form of an acceptable model is constrained in some way. This may happen due to business considerations, or because of the type of scientific question being investigated. In some cases, where there is a very strong prior belief that the true relationship has some quality, constraints can be used to improve the predictive performance of the model.
5+
It is often the case in a modeling problem or project that the functional form of an acceptable model is constrained in some way. This may happen due to business considerations, or because of the type of scientific question being investigated. In some cases, where there is a very strong prior belief that the true relationship has some quality, constraints can be used to improve the predictive performance of the model.
66

77
A common type of constraint in this situation is that certain features bear a **monotonic** relationship to the predicted response:
88

99
.. math::
1010
1111
f(x_1, x_2, \ldots, x, \ldots, x_{n-1}, x_n) \leq f(x_1, x_2, \ldots, x', \ldots, x_{n-1}, x_n)
1212
13-
whenever :math:`x \leq x'` is an **increasing constraint**; or
13+
whenever :math:`x \leq x'` is an **increasing constraint**; or
1414

1515
.. math::
1616
1717
f(x_1, x_2, \ldots, x, \ldots, x_{n-1}, x_n) \geq f(x_1, x_2, \ldots, x', \ldots, x_{n-1}, x_n)
1818
1919
whenever :math:`x \leq x'` is a **decreasing constraint**.
2020

21-
XGBoost has the ability to enforce monotonicity constraints on any features used in a boosted model.
21+
XGBoost has the ability to enforce monotonicity constraints on any features used in a boosted model.
2222

2323
****************
2424
A Simple Example
@@ -60,8 +60,8 @@ Suppose the following code fits your model without monotonicity constraints
6060

6161
.. code-block:: python
6262
63-
model_no_constraints = xgb.train(params, dtrain,
64-
num_boost_round = 1000, evals = evallist,
63+
model_no_constraints = xgb.train(params, dtrain,
64+
num_boost_round = 1000, evals = evallist,
6565
early_stopping_rounds = 10)
6666
6767
Then fitting with monotonicity constraints only requires adding a single parameter
@@ -71,8 +71,8 @@ Then fitting with monotonicity constraints only requires adding a single paramet
7171
params_constrained = params.copy()
7272
params_constrained['monotone_constraints'] = "(1,-1)"
7373
74-
model_with_constraints = xgb.train(params_constrained, dtrain,
75-
num_boost_round = 1000, evals = evallist,
74+
model_with_constraints = xgb.train(params_constrained, dtrain,
75+
num_boost_round = 1000, evals = evallist,
7676
early_stopping_rounds = 10)
7777
7878
In this example the training data ``X`` has two columns, and by using the parameter values ``(1,-1)`` we are telling XGBoost to impose an increasing constraint on the first predictor and a decreasing constraint on the second.
@@ -82,14 +82,11 @@ Some other examples:
8282
- ``(1,0)``: An increasing constraint on the first predictor and no constraint on the second.
8383
- ``(0,-1)``: No constraint on the first predictor and a decreasing constraint on the second.
8484

85-
**Choice of tree construction algorithm**. To use monotonic constraints, be
86-
sure to set the ``tree_method`` parameter to one of ``exact``, ``hist``, and
87-
``gpu_hist``.
8885

8986
**Note for the 'hist' tree construction algorithm**.
90-
If ``tree_method`` is set to either ``hist`` or ``gpu_hist``, enabling monotonic
91-
constraints may produce unnecessarily shallow trees. This is because the
87+
If ``tree_method`` is set to either ``hist``, ``approx`` or ``gpu_hist``, enabling
88+
monotonic constraints may produce unnecessarily shallow trees. This is because the
9289
``hist`` method reduces the number of candidate splits to be considered at each
93-
split. Monotonic constraints may wipe out all available split candidates, in
94-
which case no split is made. To reduce the effect, you may want to increase
95-
the ``max_bin`` parameter to consider more split candidates.
90+
split. Monotonic constraints may wipe out all available split candidates, in which case no
91+
split is made. To reduce the effect, you may want to increase the ``max_bin`` parameter to
92+
consider more split candidates.

src/tree/updater_colmaker.cc

+2
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,8 @@ class ColMaker: public TreeUpdater {
174174
std::vector<int> newnodes;
175175
this->InitData(gpair, *p_fmat);
176176
this->InitNewNode(qexpand_, gpair, *p_fmat, *p_tree);
177+
// We can check max_leaves too, but might break some grid searching pipelines.
178+
CHECK_GT(param_.max_depth, 0) << "exact tree method doesn't support unlimited depth.";
177179
for (int depth = 0; depth < param_.max_depth; ++depth) {
178180
this->FindSplit(depth, qexpand_, gpair, p_fmat, p_tree);
179181
this->ResetPosition(qexpand_, p_fmat, *p_tree);

tests/cpp/tree/test_tree_policy.cc

+91-14
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,89 @@ class TestGrowPolicy : public ::testing::Test {
2020
true);
2121
}
2222

23-
void TestTreeGrowPolicy(std::string tree_method, std::string policy) {
24-
{
25-
std::unique_ptr<Learner> learner{Learner::Create({this->Xy_})};
26-
learner->SetParam("tree_method", tree_method);
27-
learner->SetParam("max_leaves", "16");
28-
learner->SetParam("grow_policy", policy);
29-
learner->Configure();
23+
std::unique_ptr<Learner> TrainOneIter(std::string tree_method, std::string policy,
24+
int32_t max_leaves, int32_t max_depth) {
25+
std::unique_ptr<Learner> learner{Learner::Create({this->Xy_})};
26+
learner->SetParam("tree_method", tree_method);
27+
if (max_leaves >= 0) {
28+
learner->SetParam("max_leaves", std::to_string(max_leaves));
29+
}
30+
if (max_depth >= 0) {
31+
learner->SetParam("max_depth", std::to_string(max_depth));
32+
}
33+
learner->SetParam("grow_policy", policy);
34+
35+
auto check_max_leave = [&]() {
36+
Json model{Object{}};
37+
learner->SaveModel(&model);
38+
auto j_tree = model["learner"]["gradient_booster"]["model"]["trees"][0];
39+
RegTree tree;
40+
tree.LoadModel(j_tree);
41+
CHECK_LE(tree.GetNumLeaves(), max_leaves);
42+
};
43+
44+
auto check_max_depth = [&](int32_t sol) {
45+
Json model{Object{}};
46+
learner->SaveModel(&model);
47+
48+
auto j_tree = model["learner"]["gradient_booster"]["model"]["trees"][0];
49+
RegTree tree;
50+
tree.LoadModel(j_tree);
51+
bst_node_t depth = 0;
52+
tree.WalkTree([&](bst_node_t nidx) {
53+
depth = std::max(tree.GetDepth(nidx), depth);
54+
return true;
55+
});
56+
if (sol > -1) {
57+
CHECK_EQ(depth, sol);
58+
} else {
59+
CHECK_EQ(depth, max_depth) << "tree method: " << tree_method << " policy: " << policy
60+
<< " leaves:" << max_leaves << ", depth:" << max_depth;
61+
}
62+
};
3063

64+
if (max_leaves == 0 && max_depth == 0) {
65+
// unconstrainted
66+
if (tree_method != "gpu_hist") {
67+
// GPU pre-allocates for all nodes.
68+
learner->UpdateOneIter(0, Xy_);
69+
}
70+
} else if (max_leaves > 0 && max_depth == 0) {
71+
learner->UpdateOneIter(0, Xy_);
72+
check_max_leave();
73+
} else if (max_leaves == 0 && max_depth > 0) {
74+
learner->UpdateOneIter(0, Xy_);
75+
check_max_depth(-1);
76+
} else if (max_leaves > 0 && max_depth > 0) {
3177
learner->UpdateOneIter(0, Xy_);
78+
check_max_leave();
79+
check_max_depth(2);
80+
} else if (max_leaves == -1 && max_depth == 0) {
81+
// default max_leaves is 0, so both of them are now 0
82+
} else {
83+
// default parameters
84+
learner->UpdateOneIter(0, Xy_);
85+
}
86+
return learner;
87+
}
88+
89+
void TestCombination(std::string tree_method) {
90+
for (auto policy : {"depthwise", "lossguide"}) {
91+
// -1 means default
92+
for (auto leaves : {-1, 0, 3}) {
93+
for (auto depth : {-1, 0, 3}) {
94+
this->TrainOneIter(tree_method, policy, leaves, depth);
95+
}
96+
}
97+
}
98+
}
99+
100+
void TestTreeGrowPolicy(std::string tree_method, std::string policy) {
101+
{
102+
/**
103+
* max_leaves
104+
*/
105+
auto learner = this->TrainOneIter(tree_method, policy, 16, -1);
32106
Json model{Object{}};
33107
learner->SaveModel(&model);
34108

@@ -38,13 +112,10 @@ class TestGrowPolicy : public ::testing::Test {
38112
ASSERT_EQ(tree.GetNumLeaves(), 16);
39113
}
40114
{
41-
std::unique_ptr<Learner> learner{Learner::Create({this->Xy_})};
42-
learner->SetParam("tree_method", tree_method);
43-
learner->SetParam("max_depth", "3");
44-
learner->SetParam("grow_policy", policy);
45-
learner->Configure();
46-
47-
learner->UpdateOneIter(0, Xy_);
115+
/**
116+
* max_depth
117+
*/
118+
auto learner = this->TrainOneIter(tree_method, policy, -1, 3);
48119
Json model{Object{}};
49120
learner->SaveModel(&model);
50121

@@ -64,17 +135,23 @@ class TestGrowPolicy : public ::testing::Test {
64135
TEST_F(TestGrowPolicy, Approx) {
65136
this->TestTreeGrowPolicy("approx", "depthwise");
66137
this->TestTreeGrowPolicy("approx", "lossguide");
138+
139+
this->TestCombination("approx");
67140
}
68141

69142
TEST_F(TestGrowPolicy, Hist) {
70143
this->TestTreeGrowPolicy("hist", "depthwise");
71144
this->TestTreeGrowPolicy("hist", "lossguide");
145+
146+
this->TestCombination("hist");
72147
}
73148

74149
#if defined(XGBOOST_USE_CUDA)
75150
TEST_F(TestGrowPolicy, GpuHist) {
76151
this->TestTreeGrowPolicy("gpu_hist", "depthwise");
77152
this->TestTreeGrowPolicy("gpu_hist", "lossguide");
153+
154+
this->TestCombination("gpu_hist");
78155
}
79156
#endif // defined(XGBOOST_USE_CUDA)
80157
} // namespace xgboost

tests/python-gpu/test_gpu_prediction.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def noop(*args, **kwargs):
2222
rng = np.random.RandomState(1994)
2323

2424
shap_parameter_strategy = strategies.fixed_dictionaries({
25-
'max_depth': strategies.integers(0, 11),
25+
'max_depth': strategies.integers(1, 11),
2626
'max_leaves': strategies.integers(0, 256),
2727
'num_parallel_tree': strategies.sampled_from([1, 10]),
2828
}).filter(lambda x: x['max_depth'] > 0 or x['max_leaves'] > 0)

0 commit comments

Comments
 (0)