Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explainable boosting parameters #6335

Draft
wants to merge 27 commits into
base: master
Choose a base branch
from
Draft
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
f47f823
First version of the new parameter "tree_interaction_constraints""
Apr 7, 2022
5730198
readme update
Apr 7, 2022
5d69338
First version of the new parameter "tree_interaction_constraints""
Apr 7, 2022
ec9ed61
readme update
Apr 7, 2022
bfac4e1
Merge branch 'master' into microsoft-master
veneres Feb 14, 2024
f5b391e
Merge pull request #2 from veneres/microsoft-master
veneres Feb 14, 2024
d1966c2
Updated readme
veneres Feb 14, 2024
6438f0e
Merge remote-tracking branch 'upstream/master'
veneres Feb 14, 2024
848fd58
Fix missing parenthesis
veneres Feb 14, 2024
d32b7f6
Temporarly remove a new test
veneres Feb 14, 2024
d216823
Merge with private repository edits
veneres Feb 15, 2024
8dabbb2
Merge remote-tracking branch 'upstream/master'
veneres Feb 15, 2024
137bc6d
Resolved lint errors identified by github actions
veneres Feb 15, 2024
9b3fb5e
Fix docs
veneres Feb 15, 2024
997e06b
Fix docs
veneres Feb 15, 2024
64ff80c
Fix docs and linting
veneres Feb 15, 2024
ee8d6e6
Fix docs
veneres Feb 15, 2024
09acfcf
Fix docs
veneres Feb 15, 2024
0d66bea
Boolean guards added for constrained learning
veneres Feb 16, 2024
84287f1
test and small fix added
veneres Feb 16, 2024
61727ca
Merge branch 'microsoft:master' into master
veneres Feb 21, 2024
227ec1b
Param name refactor
veneres Feb 21, 2024
ca3dac5
Interaction constraints test added
veneres Feb 21, 2024
ab04352
Addressed: Unnecessary `list` comprehension (rewrite using `list()`)
veneres Feb 21, 2024
2fc53c0
Merge remote-tracking branch 'upstream/master'
veneres Feb 22, 2024
c0a4591
Skip constraint test on CUDA for the moment
veneres Feb 22, 2024
8165317
Reformat file for ruff check
veneres Feb 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
First version of the new parameter "tree_interaction_constraints""
  • Loading branch information
Alberto Veneri committed Apr 7, 2022
commit f47f82301601e234536dcb6592253dce7ad84be6
10 changes: 10 additions & 0 deletions include/LightGBM/config.h
Original file line number Diff line number Diff line change
@@ -543,6 +543,14 @@ struct Config {
// desc = any two features can only appear in the same branch only if there exists a constraint containing both features
std::string interaction_constraints = "";

// desc = controls which features can appear in the same tree
// desc = by default interaction constraints are disabled, to enable them you can specify
// descl2 = for CLI, lists separated by commas, e.g. ``[0,1,2],[2,3]``
// descl2 = for Python-package, list of lists, e.g. ``[[0, 1, 2], [2, 3]]``
// descl2 = for R-package, list of character or numeric vectors, e.g. ``list(c("var1", "var2", "var3"), c("var3", "var4"))`` or ``list(c(1L, 2L, 3L), c(3L, 4L))``. Numeric vectors should use 1-based indexing, where ``1L`` is the first feature, ``2L`` is the second feature, etc
// desc = any two features can only appear in the same tree only if there exists a constraint containing both features
std::string tree_interaction_constraints = "";

// alias = verbose
// desc = controls the level of LightGBM's verbosity
// desc = ``< 0``: Fatal, ``= 0``: Error (Warning), ``= 1``: Info, ``> 1``: Debug
@@ -1065,6 +1073,7 @@ struct Config {
static const std::unordered_set<std::string>& parameter_set();
std::vector<std::vector<double>> auc_mu_weights_matrix;
std::vector<std::vector<int>> interaction_constraints_vector;
std::vector<std::vector<int>> tree_interaction_constraints_vector;
static const std::string DumpAliases();

private:
@@ -1073,6 +1082,7 @@ struct Config {
std::string SaveMembersToString() const;
void GetAucMuWeights();
void GetInteractionConstraints();
void GetTreeInteractionConstraints();
};

inline bool Config::GetString(
10 changes: 10 additions & 0 deletions include/LightGBM/tree.h
Original file line number Diff line number Diff line change
@@ -158,6 +158,10 @@ class Tree {
/*! \brief Get features on leaf's branch*/
inline std::vector<int> branch_features(int leaf) const { return branch_features_[leaf]; }

std::set<int> tree_features() const {
return tree_features_;
}

inline double split_gain(int split_idx) const { return split_gain_[split_idx]; }

inline double internal_value(int node_idx) const {
@@ -520,6 +524,10 @@ class Tree {
bool track_branch_features_;
/*! \brief Features on leaf's branch, original index */
std::vector<std::vector<int>> branch_features_;

/*! \brief Features used by the tree, original index */
std::set<int> tree_features_;

double shrinkage_;
int max_depth_;
/*! \brief Tree has linear model at each leaf */
@@ -579,7 +587,9 @@ inline void Tree::Split(int leaf, int feature, int real_feature,
branch_features_[num_leaves_] = branch_features_[leaf];
branch_features_[num_leaves_].push_back(split_feature_[new_node_idx]);
branch_features_[leaf].push_back(split_feature_[new_node_idx]);
tree_features_.insert(split_feature_[new_node_idx]);
}

}

inline double Tree::Predict(const double* feature_values) const {
12 changes: 11 additions & 1 deletion src/io/config.cpp
Original file line number Diff line number Diff line change
@@ -185,13 +185,21 @@ void Config::GetAucMuWeights() {
}

void Config::GetInteractionConstraints() {
if (interaction_constraints == "") {
if (interaction_constraints.empty()) {
interaction_constraints_vector = std::vector<std::vector<int>>();
} else {
interaction_constraints_vector = Common::StringToArrayofArrays<int>(interaction_constraints, '[', ']', ',');
}
}

void Config::GetTreeInteractionConstraints() {
if (tree_interaction_constraints.empty()) {
tree_interaction_constraints_vector = std::vector<std::vector<int>>();
} else {
tree_interaction_constraints_vector = Common::StringToArrayofArrays<int>(tree_interaction_constraints, '[', ']', ',');
}
}

void Config::Set(const std::unordered_map<std::string, std::string>& params) {
// generate seeds by seed.
if (GetInt(params, "seed", &seed)) {
@@ -221,6 +229,8 @@ void Config::Set(const std::unordered_map<std::string, std::string>& params) {

GetInteractionConstraints();

GetTreeInteractionConstraints();

// sort eval_at
std::sort(eval_at.begin(), eval_at.end());

5 changes: 5 additions & 0 deletions src/io/config_auto.cpp
Original file line number Diff line number Diff line change
@@ -245,6 +245,7 @@ const std::unordered_set<std::string>& Config::parameter_set() {
"cegb_penalty_feature_coupled",
"path_smooth",
"interaction_constraints",
"tree_interaction_constraints",
"verbosity",
"input_model",
"output_model",
@@ -482,6 +483,8 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str

GetString(params, "interaction_constraints", &interaction_constraints);

GetString(params, "tree_interaction_constraints", &tree_interaction_constraints);

GetInt(params, "verbosity", &verbosity);

GetString(params, "input_model", &input_model);
@@ -703,6 +706,7 @@ std::string Config::SaveMembersToString() const {
str_buf << "[cegb_penalty_feature_coupled: " << Common::Join(cegb_penalty_feature_coupled, ",") << "]\n";
str_buf << "[path_smooth: " << path_smooth << "]\n";
str_buf << "[interaction_constraints: " << interaction_constraints << "]\n";
str_buf << "[tree_interaction_constraints: " << tree_interaction_constraints << "]\n";
str_buf << "[verbosity: " << verbosity << "]\n";
str_buf << "[saved_feature_importance_type: " << saved_feature_importance_type << "]\n";
str_buf << "[linear_tree: " << linear_tree << "]\n";
@@ -822,6 +826,7 @@ const std::string Config::DumpAliases() {
str_buf << "\"cegb_penalty_feature_coupled\": [], ";
str_buf << "\"path_smooth\": [], ";
str_buf << "\"interaction_constraints\": [], ";
str_buf << "\"tree_interaction_constraints\": [], ";
str_buf << "\"verbosity\": [\"verbose\"], ";
str_buf << "\"input_model\": [\"model_input\", \"model_in\"], ";
str_buf << "\"output_model\": [\"model_output\", \"model_out\"], ";
59 changes: 51 additions & 8 deletions src/treelearner/col_sampler.hpp
Original file line number Diff line number Diff line change
@@ -28,6 +28,10 @@ class ColSampler {
std::unordered_set<int> constraint_set(constraint.begin(), constraint.end());
interaction_constraints_.push_back(constraint_set);
}
for (auto constraint : config->tree_interaction_constraints_vector) {
std::unordered_set<int> constraint_set(constraint.begin(), constraint.end());
tree_interaction_constraints_.push_back(constraint_set);
}
}

static int GetCnt(size_t total_cnt, double fraction) {
@@ -89,30 +93,67 @@ class ColSampler {
}

std::vector<int8_t> GetByNode(const Tree* tree, int leaf) {
std::unordered_set<int> tree_allowed_features;
if (!tree_interaction_constraints_.empty()) {
std::set<int> tree_features = tree->tree_features();
tree_allowed_features.insert(tree_features.begin(), tree_features.end());
for (auto constraint : tree_interaction_constraints_) {
int num_feat_found = 0;

if (tree_features.empty()) {
tree_allowed_features.insert(constraint.begin(), constraint.end());
}

for (int feat : tree_features) {
if (constraint.count(feat) == 0) { break; }
++num_feat_found;
if (num_feat_found == static_cast<int>(tree_features.size())) {
tree_allowed_features.insert(constraint.begin(), constraint.end());
break;
}
}
}
}

// get interaction constraints for current branch
std::unordered_set<int> allowed_features;
std::unordered_set<int> branch_allowed_features;
if (!interaction_constraints_.empty()) {
std::vector<int> branch_features = tree->branch_features(leaf);
allowed_features.insert(branch_features.begin(), branch_features.end());
for (auto constraint : interaction_constraints_) {
int num_feat_found = 0;
if (branch_features.size() == 0) {
allowed_features.insert(constraint.begin(), constraint.end());
if (branch_features.empty()) {
branch_allowed_features.insert(constraint.begin(), constraint.end());
}
for (int feat : branch_features) {
if (constraint.count(feat) == 0) { break; }
++num_feat_found;
if (num_feat_found == static_cast<int>(branch_features.size())) {
allowed_features.insert(constraint.begin(), constraint.end());
branch_allowed_features.insert(constraint.begin(), constraint.end());
break;
}
}
}
}

// intersect allowed features for branch and tree
std::unordered_set<int> allowed_features;

if(tree_interaction_constraints_.empty() && !interaction_constraints_.empty()) {
allowed_features.insert(branch_allowed_features.begin(), branch_allowed_features.end());
} else if(!tree_interaction_constraints_.empty() && interaction_constraints_.empty()){
allowed_features.insert(tree_allowed_features.begin(), tree_allowed_features.end());
} else {
for (int element : tree_allowed_features) {
if (branch_allowed_features.count(element) > 0) {
allowed_features.insert(element);
}
}
}


std::vector<int8_t> ret(train_data_->num_features(), 0);
if (fraction_bynode_ >= 1.0f) {
if (interaction_constraints_.empty()) {
if (interaction_constraints_.empty() && tree_interaction_constraints_.empty()) {
return std::vector<int8_t>(train_data_->num_features(), 1);
} else {
for (int feat : allowed_features) {
@@ -128,7 +169,7 @@ class ColSampler {
auto used_feature_cnt = GetCnt(used_feature_indices_.size(), fraction_bynode_);
std::vector<int>* allowed_used_feature_indices;
std::vector<int> filtered_feature_indices;
if (interaction_constraints_.empty()) {
if (interaction_constraints_.empty() && tree_interaction_constraints_.empty()) {
allowed_used_feature_indices = &used_feature_indices_;
} else {
for (int feat_ind : used_feature_indices_) {
@@ -154,7 +195,7 @@ class ColSampler {
GetCnt(valid_feature_indices_.size(), fraction_bynode_);
std::vector<int>* allowed_valid_feature_indices;
std::vector<int> filtered_feature_indices;
if (interaction_constraints_.empty()) {
if (interaction_constraints_.empty() && tree_interaction_constraints_.empty()) {
allowed_valid_feature_indices = &valid_feature_indices_;
} else {
for (int feat : valid_feature_indices_) {
@@ -199,6 +240,8 @@ class ColSampler {
std::vector<int> valid_feature_indices_;
/*! \brief interaction constraints index in original (raw data) features */
std::vector<std::unordered_set<int>> interaction_constraints_;
/*! \brief tree nteraction constraints index in original (raw data) features */
std::vector<std::unordered_set<int>> tree_interaction_constraints_;
};

} // namespace LightGBM
18 changes: 16 additions & 2 deletions src/treelearner/serial_tree_learner.cpp
Original file line number Diff line number Diff line change
@@ -172,7 +172,8 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians
// some initial works before training
BeforeTrain();

bool track_branch_features = !(config_->interaction_constraints_vector.empty());
bool track_branch_features = !(config_->interaction_constraints_vector.empty()
&& config_->tree_interaction_constraints_vector.empty());
auto tree = std::unique_ptr<Tree>(new Tree(config_->num_leaves, track_branch_features, false));
auto tree_ptr = tree.get();
constraints_->ShareTreePointer(tree_ptr);
@@ -282,6 +283,19 @@ void SerialTreeLearner::BeforeTrain() {

bool SerialTreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) {
Common::FunctionTimer fun_timer("SerialTreeLearner::BeforeFindBestSplit", global_timer);

#pragma omp parallel for schedule(static)
for (int i = 0; i < config_->num_leaves; ++i) {
int feat_index = best_split_per_leaf_[i].feature;
if(feat_index == -1) continue;

int inner_feat_index = train_data_->InnerFeatureIndex(feat_index);
auto allowed_feature = col_sampler_.GetByNode(tree, i);
if(!allowed_feature[inner_feat_index]){
RecomputeBestSplitForLeaf(tree, i, &best_split_per_leaf_[i]);
}
}

// check depth of current leaf
if (config_->max_depth > 0) {
// only need to check left leaf, since right leaf is in same level of left leaf
@@ -801,7 +815,7 @@ double SerialTreeLearner::GetParentOutput(const Tree* tree, const LeafSplits* le
return parent_output;
}

void SerialTreeLearner::RecomputeBestSplitForLeaf(Tree* tree, int leaf, SplitInfo* split) {
void SerialTreeLearner::RecomputeBestSplitForLeaf(const Tree* tree, int leaf, SplitInfo* split) {
FeatureHistogram* histogram_array_;
if (!histogram_pool_.Get(leaf, &histogram_array_)) {
Log::Warning(
2 changes: 1 addition & 1 deletion src/treelearner/serial_tree_learner.h
Original file line number Diff line number Diff line change
@@ -129,7 +129,7 @@ class SerialTreeLearner: public TreeLearner {

void GetShareStates(const Dataset* dataset, bool is_constant_hessian, bool is_first_time);

void RecomputeBestSplitForLeaf(Tree* tree, int leaf, SplitInfo* split);
void RecomputeBestSplitForLeaf(const Tree* tree, int leaf, SplitInfo* split);

/*!
* \brief Some initial works before training
52 changes: 51 additions & 1 deletion tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
@@ -2600,7 +2600,7 @@ def metrics_combination_cv_regression(metric_list, assumed_iteration,
feval=lambda preds, train_data: [constant_metric(preds, train_data),
decreasing_metric(preds, train_data)])


#TODO investigate why this test fails
def test_node_level_subcol():
X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
@@ -2918,6 +2918,56 @@ def test_interaction_constraints():
[1] + list(range(2, num_features))]),
train_data, num_boost_round=10)

@pytest.mark.skipif(getenv('TASK', '') == 'cuda_exp', reason='Interaction constraints are not yet supported by CUDA Experimental version')
def test_tree_interaction_constraints():
def check_consistency(est, tree_interaction_constraints):
feat_to_index = {feat: i for i, feat in enumerate(est.feature_name())}
tree_df = est.trees_to_dataframe()
inter_found = set()
for tree_index in tree_df["tree_index"].unique():
tree_df_per_index = tree_df[tree_df["tree_index"] == tree_index]
feat_used = [feat_to_index[feat] for feat in tree_df_per_index["split_feature"].unique() if feat is not None]
inter_found.add(tuple(sorted(feat_used)))
print(inter_found)
for feats_found in inter_found:
found = False
for real_contraints in tree_interaction_constraints:
if set(feats_found) <= set(real_contraints):
found = True
break
assert found is True
X, y = load_boston(return_X_y=True)
num_features = X.shape[1]
train_data = lgb.Dataset(X, label=y)
# check that tree constraint containing all features is equivalent to no constraint
params = {'verbose': -1,
'seed': 0}
est = lgb.train(params, train_data, num_boost_round=10)
pred1 = est.predict(X)
est = lgb.train(dict(params, tree_interaction_constraints=[list(range(num_features))]), train_data,
num_boost_round=10)
pred2 = est.predict(X)
np.testing.assert_allclose(pred1, pred2)
# check that each tree is composed exactly of 2 features contained in the contrained set
tree_interaction_constraints = [[i, i + 1] for i in range(0, num_features - 1, 2)]
print(tree_interaction_constraints)
new_params = dict(params, tree_interaction_constraints=tree_interaction_constraints)
est = lgb.train(new_params, train_data, num_boost_round=100)
check_consistency(est, tree_interaction_constraints)
# check if tree features interaction constraints works with multiple set of features
tree_interaction_constraints = [[i for i in range(i, i + 5)] for i in range(0, num_features - 5, 5)]
print(tree_interaction_constraints)
new_params = dict(params, tree_interaction_constraints=tree_interaction_constraints)
est = lgb.train(new_params, train_data, num_boost_round=100)
check_consistency(est, tree_interaction_constraints)
# check if tree features interaction constraints works with multiple set of features
tree_interaction_constraints = [[i] for i in range(num_features)]
print(tree_interaction_constraints)
new_params = dict(params, tree_interaction_constraints=tree_interaction_constraints)
est = lgb.train(new_params, train_data, num_boost_round=100)
check_consistency(est, tree_interaction_constraints)



def test_linear_trees(tmp_path):
# check that setting linear_tree=True fits better than ordinary trees when data has linear relationship