Skip to content

Commit 9254c58

Browse files
BlueTea88hcho3
authored andcommitted
[TREE] add interaction constraints (dmlc#3466)
* add interaction constraints * enable both interaction and monotonic constraints at the same time * fix lint * add R test, fix lint, update demo * Use dmlc::JSONReader to express interaction constraints as nested lists; Use sparse arrays for bookkeeping * Add Python test for interaction constraints * make R interaction constraints parameter based on feature index instead of column names, fix R coding style * Fix lint * Add BlueTea88 to CONTRIBUTORS.md * Short circuit when no constraint is specified; address review comments * Add tutorial for feature interaction constraints * allow interaction constraints to be passed as string, remove redundant column_names argument * Fix typo * Address review comments * Add comments to Python test
1 parent dee0b69 commit 9254c58

12 files changed

+581
-3
lines changed

CONTRIBUTORS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,5 @@ List of Contributors
7878
* [Pierre de Sahb](https://github.com/pdesahb)
7979
* [liuliang01](https://github.com/liuliang01)
8080
- liuliang01 added support for the qid column for LibSVM input format. This makes ranking task easier in distributed setting.
81+
* [Andrew Thia](https://github.com/BlueTea88)
82+
- Andrew Thia implemented feature interaction constraints

R-package/R/utils.R

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,19 @@ check.booster.params <- function(params, ...) {
7474
params[['monotone_constraints']] = vec2str
7575
}
7676

77+
# interaction constraints parser (convert from list of column indices to string)
78+
if (!is.null(params[['interaction_constraints']]) &&
79+
typeof(params[['interaction_constraints']]) != "character"){
80+
# check input class
81+
if (class(params[['interaction_constraints']]) != 'list') stop('interaction_constraints should be class list')
82+
if (!all(unique(sapply(params[['interaction_constraints']], class)) %in% c('numeric','integer'))) {
83+
stop('interaction_constraints should be a list of numeric/integer vectors')
84+
}
85+
86+
# recast parameter as string
87+
interaction_constraints <- sapply(params[['interaction_constraints']], function(x) paste0('[', paste(x, collapse=','), ']'))
88+
params[['interaction_constraints']] <- paste0('[', paste(interaction_constraints, collapse=','), ']')
89+
}
7790
return(params)
7891
}
7992

R-package/R/xgb.train.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#' \item \code{colsample_bytree} subsample ratio of columns when constructing each tree. Default: 1
2727
#' \item \code{num_parallel_tree} Experimental parameter. number of trees to grow per round. Useful to test Random Forest through Xgboost (set \code{colsample_bytree < 1}, \code{subsample < 1} and \code{round = 1}) accordingly. Default: 1
2828
#' \item \code{monotone_constraints} A numerical vector consists of \code{1}, \code{0} and \code{-1} with its length equals to the number of features in the training data. \code{1} is increasing, \code{-1} is decreasing and \code{0} is no constraint.
29+
#' \item \code{interaction_constraints} A list of vectors specifying feature indices of permitted interactions. Each item of the list represents one permitted interaction where specified features are allowed to interact with each other. Feature index values should start from \code{0} (\code{0} references the first column). Leave argument unspecified for no interaction constraints.
2930
#' }
3031
#'
3132
#' 2.2. Parameter for Linear Booster
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
library(xgboost)
2+
library(data.table)
3+
4+
set.seed(1024)
5+
6+
# Function to obtain a list of interactions fitted in trees, requires input of maximum depth
7+
treeInteractions <- function(input_tree, input_max_depth){
8+
trees <- copy(input_tree) # copy tree input to prevent overwriting
9+
if (input_max_depth < 2) return(list()) # no interactions if max depth < 2
10+
if (nrow(input_tree) == 1) return(list())
11+
12+
# Attach parent nodes
13+
for (i in 2:input_max_depth){
14+
if (i == 2) trees[, ID_merge:=ID] else trees[, ID_merge:=get(paste0('parent_',i-2))]
15+
parents_left <- trees[!is.na(Split), list(i.id=ID, i.feature=Feature, ID_merge=Yes)]
16+
parents_right <- trees[!is.na(Split), list(i.id=ID, i.feature=Feature, ID_merge=No)]
17+
18+
setorderv(trees, 'ID_merge')
19+
setorderv(parents_left, 'ID_merge')
20+
setorderv(parents_right, 'ID_merge')
21+
22+
trees <- merge(trees, parents_left, by='ID_merge', all.x=T)
23+
trees[!is.na(i.id), c(paste0('parent_', i-1), paste0('parent_feat_', i-1)):=list(i.id, i.feature)]
24+
trees[, c('i.id','i.feature'):=NULL]
25+
26+
trees <- merge(trees, parents_right, by='ID_merge', all.x=T)
27+
trees[!is.na(i.id), c(paste0('parent_', i-1), paste0('parent_feat_', i-1)):=list(i.id, i.feature)]
28+
trees[, c('i.id','i.feature'):=NULL]
29+
}
30+
31+
# Extract nodes with interactions
32+
interaction_trees <- trees[!is.na(Split) & !is.na(parent_1),
33+
c('Feature',paste0('parent_feat_',1:(input_max_depth-1))), with=F]
34+
interaction_trees_split <- split(interaction_trees, 1:nrow(interaction_trees))
35+
interaction_list <- lapply(interaction_trees_split, as.character)
36+
37+
# Remove NAs (no parent interaction)
38+
interaction_list <- lapply(interaction_list, function(x) x[!is.na(x)])
39+
40+
# Remove non-interactions (same variable)
41+
interaction_list <- lapply(interaction_list, unique) # remove same variables
42+
interaction_length <- sapply(interaction_list, length)
43+
interaction_list <- interaction_list[interaction_length > 1]
44+
interaction_list <- unique(lapply(interaction_list, sort))
45+
return(interaction_list)
46+
}
47+
48+
# Generate sample data
49+
x <- list()
50+
for (i in 1:10){
51+
x[[i]] = i*rnorm(1000, 10)
52+
}
53+
x <- as.data.table(x)
54+
55+
y = -1*x[, rowSums(.SD)] + x[['V1']]*x[['V2']] + x[['V3']]*x[['V4']]*x[['V5']] + rnorm(1000, 0.001) + 3*sin(x[['V7']])
56+
57+
train = as.matrix(x)
58+
59+
# Interaction constraint list (column names form)
60+
interaction_list <- list(c('V1','V2'),c('V3','V4','V5'))
61+
62+
# Convert interaction constraint list into feature index form
63+
cols2ids <- function(object, col_names) {
64+
LUT <- seq_along(col_names) - 1
65+
names(LUT) <- col_names
66+
rapply(object, function(x) LUT[x], classes="character", how="replace")
67+
}
68+
interaction_list_fid = cols2ids(interaction_list, colnames(train))
69+
70+
# Fit model with interaction constraints
71+
bst = xgboost(data = train, label = y, max_depth = 4,
72+
eta = 0.1, nthread = 2, nrounds = 1000,
73+
interaction_constraints = interaction_list_fid)
74+
75+
bst_tree <- xgb.model.dt.tree(colnames(train), bst)
76+
bst_interactions <- treeInteractions(bst_tree, 4) # interactions constrained to combinations of V1*V2 and V3*V4*V5
77+
78+
# Fit model without interaction constraints
79+
bst2 = xgboost(data = train, label = y, max_depth = 4,
80+
eta = 0.1, nthread = 2, nrounds = 1000)
81+
82+
bst2_tree <- xgb.model.dt.tree(colnames(train), bst2)
83+
bst2_interactions <- treeInteractions(bst2_tree, 4) # much more interactions
84+
85+
# Fit model with both interaction and monotonicity constraints
86+
bst3 = xgboost(data = train, label = y, max_depth = 4,
87+
eta = 0.1, nthread = 2, nrounds = 1000,
88+
interaction_constraints = interaction_list_fid,
89+
monotone_constraints = c(-1,0,0,0,0,0,0,0,0,0))
90+
91+
bst3_tree <- xgb.model.dt.tree(colnames(train), bst3)
92+
bst3_interactions <- treeInteractions(bst3_tree, 4) # interactions still constrained to combinations of V1*V2 and V3*V4*V5
93+
94+
# Show monotonic constraints still apply by checking scores after incrementing V1
95+
x1 <- sort(unique(x[['V1']]))
96+
for (i in 1:length(x1)){
97+
testdata <- copy(x[, -c('V1')])
98+
testdata[['V1']] <- x1[i]
99+
testdata <- testdata[, paste0('V',1:10), with=F]
100+
pred <- predict(bst3, as.matrix(testdata))
101+
102+
# Should not print out anything due to monotonic constraints
103+
if (i > 1) if (any(pred > prev_pred)) print(i)
104+
prev_pred <- pred
105+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
require(xgboost)
2+
3+
context("interaction constraints")
4+
5+
set.seed(1024)
6+
x1 <- rnorm(1000, 1)
7+
x2 <- rnorm(1000, 1)
8+
x3 <- sample(c(1,2,3), size=1000, replace=TRUE)
9+
y <- x1 + x2 + x3 + x1*x2*x3 + rnorm(1000, 0.001) + 3*sin(x1)
10+
train <- matrix(c(x1,x2,x3), ncol = 3)
11+
12+
test_that("interaction constraints for regression", {
13+
# Fit a model that only allows interaction between x1 and x2
14+
bst <- xgboost(data = train, label = y, max_depth = 3,
15+
eta = 0.1, nthread = 2, nrounds = 100, verbose = 0,
16+
interaction_constraints = list(c(0,1)))
17+
18+
# Set all observations to have the same x3 values then increment
19+
# by the same amount
20+
preds <- lapply(c(1,2,3), function(x){
21+
tmat <- matrix(c(x1,x2,rep(x,1000)), ncol=3)
22+
return(predict(bst, tmat))
23+
})
24+
25+
# Check incrementing x3 has the same effect on all observations
26+
# since x3 is constrained to be independent of x1 and x2
27+
# and all observations start off from the same x3 value
28+
diff1 <- preds[[2]] - preds[[1]]
29+
test1 <- all(abs(diff1 - diff1[1]) < 1e-4)
30+
31+
diff2 <- preds[[3]] - preds[[2]]
32+
test2 <- all(abs(diff2 - diff2[1]) < 1e-4)
33+
34+
expect_true({
35+
test1 & test2
36+
}, "Interaction Contraint Satisfied")
37+
38+
})

doc/conf.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141

4242
# -- mock out modules
4343
import mock
44-
MOCK_MODULES = ['numpy', 'scipy', 'scipy.sparse', 'sklearn', 'matplotlib', 'pandas', 'graphviz']
44+
MOCK_MODULES = ['scipy', 'scipy.sparse', 'sklearn', 'pandas']
4545
for mod_name in MOCK_MODULES:
4646
sys.modules[mod_name] = mock.Mock()
4747

@@ -62,13 +62,20 @@
6262
# Add any Sphinx extension module names here, as strings. They can be
6363
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones
6464
extensions = [
65+
'matplotlib.sphinxext.only_directives',
66+
'matplotlib.sphinxext.plot_directive',
6567
'sphinx.ext.autodoc',
6668
'sphinx.ext.napoleon',
6769
'sphinx.ext.mathjax',
6870
'sphinx.ext.intersphinx',
6971
'breathe'
7072
]
7173

74+
graphviz_output_format = 'png'
75+
plot_formats = [('svg', 300), ('png', 100), ('hires.png', 300)]
76+
plot_html_show_source_link = False
77+
plot_html_show_formats = False
78+
7279
# Breathe extension variables
7380
breathe_projects = {"xgboost": "doxyxml/"}
7481
breathe_default_project = "xgboost"
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
###############################
2+
Feature Interaction Constraints
3+
###############################
4+
5+
The decision tree is a powerful tool to discover interaction among independent
6+
variables (features). Variables that appear together in a traversal path
7+
are interacting with one another, since the condition of a child node is
8+
predicated on the condition of the parent node. For example, the highlighted
9+
red path in the diagram below contains three variables: :math:`x_1`, :math:`x_7`,
10+
and :math:`x_{10}`, so the highlighted prediction (at the highlighted leaf node)
11+
is the product of interaction between :math:`x_1`, :math:`x_7`, and
12+
:math:`x_{10}`.
13+
14+
.. plot::
15+
:nofigs:
16+
17+
from graphviz import Source
18+
source = r"""
19+
digraph feature_interaction_illustration1 {
20+
graph [fontname = "helvetica"];
21+
node [fontname = "helvetica"];
22+
edge [fontname = "helvetica"];
23+
0 [label=<x<SUB><FONT POINT-SIZE="11">10</FONT></SUB> &lt; -1.5 ?>, shape=box, color=red, fontcolor=red];
24+
1 [label=<x<SUB><FONT POINT-SIZE="11">2</FONT></SUB> &lt; 2 ?>, shape=box];
25+
2 [label=<x<SUB><FONT POINT-SIZE="11">7</FONT></SUB> &lt; 0.3 ?>, shape=box, color=red, fontcolor=red];
26+
3 [label="...", shape=none];
27+
4 [label="...", shape=none];
28+
5 [label=<x<SUB><FONT POINT-SIZE="11">1</FONT></SUB> &lt; 0.5 ?>, shape=box, color=red, fontcolor=red];
29+
6 [label="...", shape=none];
30+
7 [label="...", shape=none];
31+
8 [label="Predict +1.3", color=red, fontcolor=red];
32+
0 -> 1 [labeldistance=2.0, labelangle=45, headlabel="Yes/Missing "];
33+
0 -> 2 [labeldistance=2.0, labelangle=-45,
34+
headlabel="No", color=red, fontcolor=red];
35+
1 -> 3 [labeldistance=2.0, labelangle=45, headlabel="Yes"];
36+
1 -> 4 [labeldistance=2.0, labelangle=-45, headlabel=" No/Missing"];
37+
2 -> 5 [labeldistance=2.0, labelangle=-45, headlabel="Yes",
38+
color=red, fontcolor=red];
39+
2 -> 6 [labeldistance=2.0, labelangle=-45, headlabel=" No/Missing"];
40+
5 -> 7;
41+
5 -> 8 [color=red];
42+
}
43+
"""
44+
Source(source, format='png').render('../_static/feature_interaction_illustration1', view=False)
45+
Source(source, format='svg').render('../_static/feature_interaction_illustration1', view=False)
46+
47+
.. raw:: html
48+
49+
<p>
50+
<img src="../_static/feature_interaction_illustration1.svg"
51+
onerror="this.src='../_static/feature_interaction_illustration1.png'; this.onerror=null;">
52+
</p>
53+
54+
When the tree depth is larger than one, many variables interact on
55+
the sole basis of minimizing training loss, and the resulting decision tree may
56+
capture a spurious relationship (noise) rather than a legitimate relationship
57+
that generalizes across different datasets. **Feature interaction constraints**
58+
allow users to decide which variables are allowed to interact and which are not.
59+
60+
Potential benefits include:
61+
62+
* Better predictive performance from focusing on interactions that work --
63+
whether through domain specific knowledge or algorithms that rank interactions
64+
* Less noise in predictions; better generalization
65+
* More control to the user on what the model can fit. For example, the user may
66+
want to exclude some interactions even if they perform well due to regulatory
67+
constraints
68+
69+
****************
70+
A Simple Example
71+
****************
72+
73+
Feature interaction constraints are expressed in terms of groups of variables
74+
that are allowed to interact. For example, the constraint
75+
``[0, 1]`` indicates that variables :math:`x_0` and :math:`x_1` are allowed to
76+
interact with each other but with no other variable. Similarly, ``[2, 3, 4]``
77+
indicates that :math:`x_2`, :math:`x_3`, and :math:`x_4` are allowed to
78+
interact with one another but with no other variable. A set of feature
79+
interaction constraints is expressed as a nested list, e.g.
80+
``[[0, 1], [2, 3, 4]]``, where each inner list is a group of indices of features
81+
that are allowed to interact with each other.
82+
83+
In the following diagram, the left decision tree is in violation of the first
84+
constraint (``[0, 1]``), whereas the right decision tree complies with both the
85+
first and second constraints (``[0, 1]``, ``[2, 3, 4]``).
86+
87+
.. plot::
88+
:nofigs:
89+
90+
from graphviz import Source
91+
source = r"""
92+
digraph feature_interaction_illustration2 {
93+
graph [fontname = "helvetica"];
94+
node [fontname = "helvetica"];
95+
edge [fontname = "helvetica"];
96+
0 [label=<x<SUB><FONT POINT-SIZE="11">0</FONT></SUB> &lt; 5.0 ?>, shape=box];
97+
1 [label=<x<SUB><FONT POINT-SIZE="11">2</FONT></SUB> &lt; -3.0 ?>, shape=box];
98+
2 [label="+0.6"];
99+
3 [label="-0.4"];
100+
4 [label="+1.2"];
101+
0 -> 1 [labeldistance=2.0, labelangle=45, headlabel="Yes/Missing "];
102+
0 -> 2 [labeldistance=2.0, labelangle=-45, headlabel="No"];
103+
1 -> 3 [labeldistance=2.0, labelangle=45, headlabel="Yes"];
104+
1 -> 4 [labeldistance=2.0, labelangle=-45, headlabel=" No/Missing"];
105+
}
106+
"""
107+
Source(source, format='png').render('../_static/feature_interaction_illustration2', view=False)
108+
Source(source, format='svg').render('../_static/feature_interaction_illustration2', view=False)
109+
110+
.. plot::
111+
:nofigs:
112+
113+
from graphviz import Source
114+
source = r"""
115+
digraph feature_interaction_illustration3 {
116+
graph [fontname = "helvetica"];
117+
node [fontname = "helvetica"];
118+
edge [fontname = "helvetica"];
119+
0 [label=<x<SUB><FONT POINT-SIZE="11">3</FONT></SUB> &lt; 2.5 ?>, shape=box];
120+
1 [label="+1.6"];
121+
2 [label=<x<SUB><FONT POINT-SIZE="11">2</FONT></SUB> &lt; -1.2 ?>, shape=box];
122+
3 [label="+0.1"];
123+
4 [label="-0.3"];
124+
0 -> 1 [labeldistance=2.0, labelangle=45, headlabel="Yes"];
125+
0 -> 2 [labeldistance=2.0, labelangle=-45, headlabel=" No/Missing"];
126+
2 -> 3 [labeldistance=2.0, labelangle=45, headlabel="Yes/Missing "];
127+
2 -> 4 [labeldistance=2.0, labelangle=-45, headlabel="No"];
128+
}
129+
"""
130+
Source(source, format='png').render('../_static/feature_interaction_illustration3', view=False)
131+
Source(source, format='svg').render('../_static/feature_interaction_illustration3', view=False)
132+
133+
.. raw:: html
134+
135+
<p>
136+
<img src="../_static/feature_interaction_illustration2.svg"
137+
onerror="this.src='../_static/feature_interaction_illustration2.png'; this.onerror=null;">
138+
<img src="../_static/feature_interaction_illustration3.svg"
139+
onerror="this.src='../_static/feature_interaction_illustration3.png'; this.onerror=null;">
140+
</p>
141+
142+
****************************************************
143+
Enforcing Feature Interaction Constraints in XGBoost
144+
****************************************************
145+
146+
It is very simple to enforce monotonicity constraints in XGBoost. Here we will
147+
give an example using Python, but the same general idea generalizes to other
148+
platforms.
149+
150+
Suppose the following code fits your model without monotonicity constraints:
151+
152+
.. code-block:: python
153+
154+
model_no_constraints = xgb.train(params, dtrain,
155+
num_boost_round = 1000, evals = evallist,
156+
early_stopping_rounds = 10)
157+
158+
Then fitting with monotonicity constraints only requires adding a single
159+
parameter:
160+
161+
.. code-block:: python
162+
163+
params_constrained = params.copy()
164+
# Use nested list to define feature interaction constraints
165+
params_constrained['interaction_constraints'] = '[[0, 2], [1, 3, 4], [5, 6]]'
166+
# Features 0 and 2 are allowed to interact with each other but with no other feature
167+
# Features 1, 3, 4 are allowed to interact with one another but with no other feature
168+
# Features 5 and 6 are allowed to interact with each other but with no other feature
169+
170+
model_with_constraints = xgb.train(params_constrained, dtrain,
171+
num_boost_round = 1000, evals = evallist,
172+
early_stopping_rounds = 10)
173+
174+
**Choice of tree construction algorithm**. To use feature interaction
175+
constraints, be sure to set the ``tree_method`` parameter to either ``exact``
176+
or ``hist``. Currently, GPU algorithms (``gpu_hist``, ``gpu_exact``) do not
177+
support feature interaction constraints.

doc/tutorials/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ See `Awesome XGBoost <https://github.com/dmlc/xgboost/tree/master/demo>`_ for mo
1414
Distributed XGBoost with XGBoost4J-Spark <https://xgboost.readthedocs.io/en/latest/jvm/xgboost4j_spark_tutorial.html>
1515
dart
1616
monotonic
17+
feature_interaction_constraint
1718
input_format
1819
param_tuning
1920
external_memory

0 commit comments

Comments
 (0)