Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: [python-package] preserve params when copying Booster (fixes #5539) #6101

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
more changes
  • Loading branch information
jameslamb committed Sep 16, 2023
commit c0e00c0b9a0aff057a8c3bef6e155b6d1c203828
5 changes: 5 additions & 0 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
@@ -3246,8 +3246,13 @@ def __init__(
elif model_str is not None:
self.model_from_string(model_str)
# ensure params are updated on the C++ side
# NOTE: models loaded from file are initially set to "boosting: GBDT", so "boosting"
# shouldn't be passed through here
self.params = params
boosting_type = params.pop("boosting", None)
self.reset_parameter(params)
if boosting_type is not None:
params["boosting"] = boosting_type
else:
raise TypeError('Need at least one training dataset or model file or model string '
'to create Booster instance')
30 changes: 25 additions & 5 deletions tests/python_package_test/test_basic.py
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@
import lightgbm as lgb
from lightgbm.compat import PANDAS_INSTALLED, pd_DataFrame, pd_Series

from .utils import dummy_obj, load_breast_cancer, mse_obj
from .utils import BOOSTING_TYPES, dummy_obj, load_breast_cancer, mse_obj


def test_basic(tmp_path):
@@ -825,8 +825,11 @@ def test_feature_names_are_set_correctly_when_no_feature_names_passed_into_Datas
assert ds.construct().feature_name == ["Column_0", "Column_1", "Column_2"]


def test_booster_deepcopy_preserves_parameters():
@pytest.mark.parametrize('boosting_type', BOOSTING_TYPES)
def test_booster_deepcopy_preserves_parameters(boosting_type):
orig_params = {
'boosting': boosting_type,
'feature_fraction': 0.708,
'num_leaves': 5,
'verbosity': -1
}
@@ -841,11 +844,19 @@ def test_booster_deepcopy_preserves_parameters():
assert bst.params["verbosity"] == -1

# passed-in params shouldn't have been modified outside of lightgbm
assert orig_params == {'num_leaves': 5, 'verbosity': -1}
assert orig_params == {
'boosting': boosting_type,
'feature_fraction': 0.708,
'num_leaves': 5,
'verbosity': -1
}


def test_booster_params_kwarg_overrides_params_from_model_string():
@pytest.mark.parametrize('boosting_type', BOOSTING_TYPES)
def test_booster_params_kwarg_overrides_params_from_model_string(boosting_type):
orig_params = {
'boosting': boosting_type,
'feature_fraction': 0.708,
'num_leaves': 5,
'verbosity': -1
}
@@ -863,5 +874,14 @@ def test_booster_params_kwarg_overrides_params_from_model_string():
assert bst2.params["num_leaves"] == 7
assert "[num_leaves: 7]" in bst2.model_to_string()

# boosting type should have been preserved in the new model
if boosting_type != "gbdt":
raise RuntimeError

# passed-in params shouldn't have been modified outside of lightgbm
assert orig_params == {'num_leaves': 5, 'verbosity': -1}
assert orig_params == {
'boosting': boosting_type,
'feature_fraction': 0.708,
'num_leaves': 5,
'verbosity': -1
}
1 change: 1 addition & 0 deletions tests/python_package_test/utils.py
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@

import lightgbm as lgb

BOOSTING_TYPES = ['gbdt', 'dart', 'goss', 'rf']
SERIALIZERS = ["pickle", "joblib", "cloudpickle"]