Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] Introduce refit_tree_manual to Booster class. #6617

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Prev Previous commit
Next Next commit
Comments after code review
  • Loading branch information
Atanas Dimitrov committed Sep 2, 2024
commit 1d73e4e6a6ba92098c35de90e9a95d01b6482476
6 changes: 4 additions & 2 deletions src/boosting/gbdt.cpp
Original file line number Diff line number Diff line change
@@ -290,8 +290,10 @@ void GBDT::RefitTree(const int* tree_leaf_prediction, const size_t nrow, const s
}

void GBDT::RefitTreeManual(int tree_idx, const double *vals, const int vals_size) {
CHECK(tree_idx >= 0 && static_cast<size_t>(tree_idx) < models_.size() && vals_size == models_[tree_idx]->num_leaves());
// reset score
CHECK(tree_idx >= 0);
CHECK(static_cast<size_t>(tree_idx) < models_.size());
CHECK(vals_size == models_[tree_idx]->num_leaves());
// reset score by adding the difference
for (int leaf_id = 0; leaf_id < models_[tree_idx]->num_leaves(); ++leaf_id) {
models_[tree_idx]->SetLeafOutput(leaf_id, vals[leaf_id] - models_[tree_idx]->LeafOutput(leaf_id));
}
6 changes: 5 additions & 1 deletion tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
@@ -2366,9 +2366,13 @@ def debias_callback(env):
"num_leaves": 5,
"objective": "gamma",
}
bst = lgb.train(params, ds, callbacks=[debias_callback])

# Check that the model is biased when no callback is provided
bst = lgb.train(params, ds)
np.testing.assert_raises(AssertionError, np.testing.assert_array_equal, bst.predict(df).mean(), y.mean())

# Check if debiasing worked
bst = lgb.train(params, ds, callbacks=[debias_callback])
np.testing.assert_allclose(bst.predict(df).mean(), y.mean())


Loading
Oops, something went wrong.