Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] Introduce refit_tree_manual to Booster class. #6617

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Prev Previous commit
Next Next commit
Init
  • Loading branch information
Atanas Dimitrov committed Aug 8, 2024
commit 954e6a949b826a65b57ddfeabfe7357f0350f61d
5 changes: 5 additions & 0 deletions include/LightGBM/boosting.h
Original file line number Diff line number Diff line change
@@ -76,6 +76,11 @@ class LIGHTGBM_EXPORT Boosting {
*/
virtual void RefitTree(const int* tree_leaf_prediction, const size_t nrow, const size_t ncol) = 0;

/*!
* \brief Change the leaf values of a tree and update the scores
*/
virtual void RefitTreeManual(int tree_idx, const double *vals) = 0;

/*!
* \brief Training logic
* \param gradients nullptr for using default objective, otherwise use self-defined boosting
6 changes: 6 additions & 0 deletions include/LightGBM/c_api.h
Original file line number Diff line number Diff line change
@@ -778,6 +778,12 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterRefit(BoosterHandle handle,
int32_t nrow,
int32_t ncol);

/*!
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterRefitTreeManual(BoosterHandle handle,
int32_t tree_idx,
const double *values);

/*!
* \brief Update the model by specifying gradient and Hessian directly
* (this can be used to support customized loss functions).
35 changes: 35 additions & 0 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
@@ -4912,6 +4912,41 @@ def refit(
new_booster._network = self._network
return new_booster

def refit_tree_manual(
self,
tree_id: int,
values: np.ndarray
) -> None:
"""Set all the outputs of a tree and recalculate the dataset scores.

.. versionadded:: 4.6.0

Parameters
----------
tree_id : int
The index of the tree.
values : numpy 1-D array
Value to set as the outputs of the tree.
The number of elements should be equal to the number of leaves in the tree.

Returns
-------
self : Booster
Booster with the leaf outputs set.
"""
values = _list_to_1d_numpy(values, dtype=np.float64, name="leaf_values")

_safe_call(
_LIB.LGBM_BoosterRefitTreeManual(
self._handle,
ctypes.c_int(tree_id),
values.ctypes.data_as(ctypes.POINTER(ctypes.c_double))
)
)
self.__is_predicted_cur_iter = [False for _ in range(self.__num_dataset)]
return self


def get_leaf_output(self, tree_id: int, leaf_id: int) -> float:
"""Get the output of a leaf.

17 changes: 17 additions & 0 deletions src/boosting/gbdt.cpp
Original file line number Diff line number Diff line change
@@ -289,6 +289,23 @@ void GBDT::RefitTree(const int* tree_leaf_prediction, const size_t nrow, const s
}
}

void GBDT::RefitTreeManual(int tree_idx, const double *vals) {
CHECK(tree_idx >= 0 && static_cast<size_t>(tree_idx) < models_.size());
// reset score
for (int leaf_id = 0; leaf_id < models_[tree_idx]->num_leaves(); ++leaf_id) {
models_[tree_idx]->SetLeafOutput(leaf_id, vals[leaf_id] - models_[tree_idx]->LeafOutput(leaf_id));
}
// add the delta
train_score_updater_->AddScore(models_[tree_idx].get(), tree_idx % num_tree_per_iteration_);
for (auto& score_updater : valid_score_updater_) {
score_updater->AddScore(models_[tree_idx].get(), tree_idx % num_tree_per_iteration_);
}
// update the model
for (int leaf_id = 0; leaf_id < models_[tree_idx]->num_leaves(); ++leaf_id) {
models_[tree_idx]->SetLeafOutput(leaf_id, vals[leaf_id]);
}
}

/* If the custom "average" is implemented it will be used in place of the label average (if enabled)
*
* An improvement to this is to have options to explicitly choose
2 changes: 2 additions & 0 deletions src/boosting/gbdt.h
Original file line number Diff line number Diff line change
@@ -145,6 +145,8 @@ class GBDT : public GBDTBase {

void RefitTree(const int* tree_leaf_prediction, const size_t nrow, const size_t ncol) override;

void RefitTreeManual(int tree_idx, const double *vals) override;

/*!
* \brief Training logic
* \param gradients nullptr for using default objective, otherwise use self-defined boosting
15 changes: 15 additions & 0 deletions src/c_api.cpp
Original file line number Diff line number Diff line change
@@ -412,6 +412,11 @@ class Booster {
boosting_->RefitTree(leaf_preds, nrow, ncol);
}

void RefitTreeManual(int tree_idx, const double *vals) {
UNIQUE_LOCK(mutex_)
boosting_->RefitTreeManual(tree_idx, vals);
}

bool TrainOneIter(const score_t* gradients, const score_t* hessians) {
UNIQUE_LOCK(mutex_)
return boosting_->TrainOneIter(gradients, hessians);
@@ -2058,6 +2063,16 @@ int LGBM_BoosterRefit(BoosterHandle handle, const int32_t* leaf_preds, int32_t n
API_END();
}

int LGBM_BoosterRefitTreeManual(BoosterHandle handle,
int tree_idx,
const double *val) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
ref_booster->RefitTreeManual(tree_idx, val);
API_END();
}


int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);