Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] Introduce refit_tree_manual to Booster class. #6617

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Prev Previous commit
Next Next commit
Add tests
  • Loading branch information
Atanas Dimitrov committed Aug 16, 2024
commit b26a21cbc43aa2b90e68c2a9d62e1290230b7c6f
7 changes: 1 addition & 6 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
@@ -4912,11 +4912,7 @@ def refit(
new_booster._network = self._network
return new_booster

def refit_tree_manual(
self,
tree_id: int,
values: np.ndarray
) -> None:
def refit_tree_manual(self, tree_id: int, values: np.ndarray) -> "Booster":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a user it would not yet be clear to me how this function is different to set_leaf_output, i.e. why is this not just called set_leaf_outputs?

Copy link
Contributor Author

@neNasko1 neNasko1 Sep 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you propose changing the name of the function or just writing better docs? I am a bit unsure if calling this function set_leaf_outputs would be a bit strange as it does additional things than to just update the leaf values.

"""Set all the outputs of a tree and recalculate the dataset scores.

.. versionadded:: 4.6.0
@@ -4947,7 +4943,6 @@ def refit_tree_manual(
self.__is_predicted_cur_iter = [False for _ in range(self.__num_dataset)]
return self


def get_leaf_output(self, tree_id: int, leaf_id: int) -> float:
"""Get the output of a leaf.

39 changes: 39 additions & 0 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
@@ -2332,6 +2332,45 @@ def test_refit_dataset_params(rng):
np.testing.assert_allclose(stored_weights, refit_weight)


def test_refit_tree_manual():
def retrieve_leaves_from_tree(tree):
if "leaf_index" in tree:
return {tree["leaf_index"]: tree["leaf_value"]}

left_child = retrieve_leaves_from_tree(tree["left_child"])
right_child = retrieve_leaves_from_tree(tree["right_child"])

return left_child | right_child

def retrieve_leaves_from_booster(booster, iteration):
tree = booster.dump_model(0, iteration)["tree_info"][0]["tree_structure"]
return retrieve_leaves_from_tree(tree)

def debias_callback(env):
booster = env.model
curr_values = retrieve_leaves_from_booster(booster, env.iteration)
eval_pred = booster.predict(df)
delta = np.log(np.mean(y) / np.mean(eval_pred))
refitted_values = [curr_values[ix] + delta for ix in range(len(curr_values))]
booster.refit_tree_manual(env.iteration, refitted_values)

X, y = make_synthetic_regression()
y = np.abs(y)
df = pd_DataFrame(X, columns=["x1", "x2", "x3", "x4"])
ds = lgb.Dataset(df, y)

params = {
"verbose": -1,
"n_estimators": 5,
"num_leaves": 5,
"objective": "gamma",
}
bst = lgb.train(params, ds, callbacks=[debias_callback])

# Check if debiasing worked
np.testing.assert_allclose(bst.predict(df).mean(), y.mean())


@pytest.mark.parametrize("boosting_type", ["rf", "dart"])
def test_mape_for_specific_boosting_types(boosting_type):
X, y = make_synthetic_regression()