Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] Expose ObjectiveFunction class #6586

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Prev Previous commit
Next Next commit
Run pre-commit hooks
  • Loading branch information
Atanas Dimitrov committed Aug 1, 2024
commit caa5a499756ac7819c8f3b96c4cc6276c373d30b
2 changes: 1 addition & 1 deletion python-package/lightgbm/__init__.py
Original file line number Diff line number Diff line change
@@ -6,7 +6,7 @@

from pathlib import Path

from .basic import Booster, Dataset, Sequence, ObjectiveFunction, register_logger
from .basic import Booster, Dataset, ObjectiveFunction, Sequence, register_logger
from .callback import EarlyStopException, early_stopping, log_evaluation, record_evaluation, reset_parameter
from .engine import CVBooster, cv, train

3 changes: 1 addition & 2 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
@@ -5311,7 +5311,7 @@ def __init__(self, name: str, params: Dict[str, Any]):
self.num_class = params.get("num_class", 1)

if "objective" in params and params["objective"] != self.name:
raise ValueError("The name should be consistent with the params[\"objective\"] field.")
raise ValueError('The name should be consistent with the params["objective"] field.')

self.__create()

@@ -5383,7 +5383,6 @@ def __init_from_dataset(self, dataset: Dataset) -> "ObjectiveFunction":
if self._handle is None:
raise ValueError("Dealocated ObjectiveFunction cannot be initialized")

ref_dataset = dataset._handle
tmp_num_data = ctypes.c_int(0)
_safe_call(
_LIB.LGBM_ObjectiveFunctionInit(
37 changes: 21 additions & 16 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
@@ -32,8 +32,8 @@
logistic_sigmoid,
make_synthetic_regression,
mse_obj,
pickle_and_unpickle_object,
multiclass_custom_objective,
pickle_and_unpickle_object,
softmax,
)

@@ -4393,25 +4393,30 @@ def test_quantized_training():
quant_rmse = np.sqrt(np.mean((quant_bst.predict(X) - y) ** 2))
assert quant_rmse < rmse + 6.0


@pytest.mark.parametrize("use_weight", [False, True])
@pytest.mark.parametrize("test_data", [
{
"custom_objective": mse_obj,
"objective_name": "regression",
"df": make_synthetic_regression(),
"num_class": 1
},
{
"custom_objective": multiclass_custom_objective,
"objective_name": "multiclass",
"df": make_blobs(n_samples=100, centers=[[-4, -4], [4, 4], [-4, 4]], random_state=42),
"num_class": 3
},
])
@pytest.mark.parametrize(
"test_data",
[
{
"custom_objective": mse_obj,
"objective_name": "regression",
"df": make_synthetic_regression(),
"num_class": 1,
},
{
"custom_objective": multiclass_custom_objective,
"objective_name": "multiclass",
"df": make_blobs(n_samples=100, centers=[[-4, -4], [4, 4], [-4, 4]], random_state=42),
"num_class": 3,
},
],
)
@pytest.mark.parametrize("num_boost_round", [5, 15])
def test_objective_function_class(use_weight, test_data, num_boost_round):
X, y = test_data["df"]
weight = np.random.choice([1, 2], y.shape) if use_weight else None
rng = np.random.default_rng()
weight = rng.choice([1, 2], y.shape) if use_weight else None
lgb_train = lgb.Dataset(X, y, weight=weight, init_score=np.zeros((len(y), test_data["num_class"])))

params = {"verbose": -1, "objective": test_data["objective_name"], "num_class": test_data["num_class"]}
5 changes: 3 additions & 2 deletions tests/python_package_test/utils.py
Original file line number Diff line number Diff line change
@@ -169,8 +169,8 @@ def multiclass_custom_objective(y_pred, ds):


def builtin_objective(name, params):
"""Mimics the builtin objective functions to mock training.
"""
"""Mimics the builtin objective functions to mock training."""

def wrapper(y_pred, dtrain):
fobj = lgb.ObjectiveFunction(name, params)
fobj.init(dtrain)
@@ -181,6 +181,7 @@ def wrapper(y_pred, dtrain):
hess = hess.reshape((fobj.num_class, -1)).transpose()
print(grad, hess)
return (grad, hess)

return wrapper