Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] Introduce refit_tree_manual to Booster class. #6617

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Prev Previous commit
Next Next commit
Add validation checks
  • Loading branch information
Atanas Dimitrov committed Aug 9, 2024
commit 75a4ec5b9878206dc2c671a8466aadd3d6f5bac0
2 changes: 1 addition & 1 deletion include/LightGBM/boosting.h
Original file line number Diff line number Diff line change
@@ -79,7 +79,7 @@ class LIGHTGBM_EXPORT Boosting {
/*!
* \brief Change the leaf values of a tree and update the scores
*/
virtual void RefitTreeManual(int tree_idx, const double *vals) = 0;
virtual void RefitTreeManual(int tree_idx, const double *vals, const int vals_size) = 0;

/*!
* \brief Training logic
3 changes: 2 additions & 1 deletion include/LightGBM/c_api.h
Original file line number Diff line number Diff line change
@@ -782,7 +782,8 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterRefit(BoosterHandle handle,
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterRefitTreeManual(BoosterHandle handle,
int32_t tree_idx,
const double *values);
const double *vals,
const int vals_size);

/*!
* \brief Update the model by specifying gradient and Hessian directly
6 changes: 2 additions & 4 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
@@ -4936,14 +4936,12 @@ def refit_tree_manual(
"""
values = _list_to_1d_numpy(values, dtype=np.float64, name="leaf_values")

if len(values) != self.num_leaves(tree_id):
raise ValueError("Length of values should be equal to the number of leaves in the tree")

_safe_call(
_LIB.LGBM_BoosterRefitTreeManual(
self._handle,
ctypes.c_int(tree_id),
values.ctypes.data_as(ctypes.POINTER(ctypes.c_double))
values.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
ctypes.c_int(len(values)),
)
)
self.__is_predicted_cur_iter = [False for _ in range(self.__num_dataset)]
4 changes: 2 additions & 2 deletions src/boosting/gbdt.cpp
Original file line number Diff line number Diff line change
@@ -289,8 +289,8 @@ void GBDT::RefitTree(const int* tree_leaf_prediction, const size_t nrow, const s
}
}

void GBDT::RefitTreeManual(int tree_idx, const double *vals) {
CHECK(tree_idx >= 0 && static_cast<size_t>(tree_idx) < models_.size());
void GBDT::RefitTreeManual(int tree_idx, const double *vals, const int vals_size) {
CHECK(tree_idx >= 0 && static_cast<size_t>(tree_idx) < models_.size() && vals_size == models_[tree_idx]->num_leaves());
// reset score
for (int leaf_id = 0; leaf_id < models_[tree_idx]->num_leaves(); ++leaf_id) {
models_[tree_idx]->SetLeafOutput(leaf_id, vals[leaf_id] - models_[tree_idx]->LeafOutput(leaf_id));
2 changes: 1 addition & 1 deletion src/boosting/gbdt.h
Original file line number Diff line number Diff line change
@@ -145,7 +145,7 @@ class GBDT : public GBDTBase {

void RefitTree(const int* tree_leaf_prediction, const size_t nrow, const size_t ncol) override;

void RefitTreeManual(int tree_idx, const double *vals) override;
void RefitTreeManual(int tree_idx, const double *vals, const int vals_size) override;

/*!
* \brief Training logic
9 changes: 5 additions & 4 deletions src/c_api.cpp
Original file line number Diff line number Diff line change
@@ -412,9 +412,9 @@ class Booster {
boosting_->RefitTree(leaf_preds, nrow, ncol);
}

void RefitTreeManual(int tree_idx, const double *vals) {
void RefitTreeManual(int tree_idx, const double *vals, const int vals_size) {
UNIQUE_LOCK(mutex_)
boosting_->RefitTreeManual(tree_idx, vals);
boosting_->RefitTreeManual(tree_idx, vals, vals_size);
}

bool TrainOneIter(const score_t* gradients, const score_t* hessians) {
@@ -2065,10 +2065,11 @@ int LGBM_BoosterRefit(BoosterHandle handle, const int32_t* leaf_preds, int32_t n

int LGBM_BoosterRefitTreeManual(BoosterHandle handle,
int tree_idx,
const double *val) {
const double *vals,
const int vals_size) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->RefitTreeManual(tree_idx, val);
ref_booster->RefitTreeManual(tree_idx, vals, vals_size);
API_END();
}