Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] Expose ObjectiveFunction class #6586

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Prev Previous commit
Next Next commit
Fix multiclass
  • Loading branch information
Atanas Dimitrov committed Aug 1, 2024
commit e600a25fac50c21ec6970c890b25664596ad45d6
6 changes: 3 additions & 3 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
@@ -5341,11 +5341,11 @@ def __call__(self, y_pred: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
raise ValueError("Objective function seems uninitialized")

if self.num_data is None or self.num_class is None:
# TODO: Be more descriptive
raise ValueError("ObjectiveFunction was not created properly")

grad = np.zeros(dtype=np.float32, shape=self.num_data * self.num_class)
hess = np.zeros(dtype=np.float32, shape=self.num_data * self.num_class)
data_shape = self.num_data * self.num_class
grad = np.zeros(dtype=np.float32, shape=data_shape)
hess = np.zeros(dtype=np.float32, shape=data_shape)

_safe_call(
_LIB.LGBM_ObjectiveFunctionEval(
1 change: 0 additions & 1 deletion src/objective/multiclass_objective.hpp
Original file line number Diff line number Diff line change
@@ -25,7 +25,6 @@ class MulticlassSoftmax: public ObjectiveFunction {
public:
explicit MulticlassSoftmax(const Config& config) {
num_class_ = config.num_class;
std::cout << "We have set " << num_class_ << std::endl;
// This factor is to rescale the redundant form of K-classification, to the non-redundant form.
// In the traditional settings of K-classification, there is one redundant class, whose output is set to 0 (like the class 0 in binary classification).
// This is from the Friedman GBDT paper.
73 changes: 29 additions & 44 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
@@ -33,7 +33,7 @@
make_synthetic_regression,
mse_obj,
pickle_and_unpickle_object,
sklearn_multiclass_custom_objective,
multiclass_custom_objective,
softmax,
)

@@ -2927,12 +2927,6 @@ def test_default_objective_and_metric():

@pytest.mark.parametrize("use_weight", [True, False])
def test_multiclass_custom_objective(use_weight):
def custom_obj(y_pred, ds):
y_true = ds.get_label()
weight = ds.get_weight()
grad, hess = sklearn_multiclass_custom_objective(y_true, y_pred, weight)
return grad, hess

centers = [[-4, -4], [4, 4], [-4, 4]]
X, y = make_blobs(n_samples=1_000, centers=centers, random_state=42)
weight = np.full_like(y, 2)
@@ -4400,47 +4394,38 @@ def test_quantized_training():
assert quant_rmse < rmse + 6.0

@pytest.mark.parametrize("use_weight", [False, True])
def test_objective_function_regression(use_weight):
X, y = make_synthetic_regression()
weight = np.random.choice([1, 2], len(X)) if use_weight else None
lgb_train = lgb.Dataset(X, y, weight=weight, init_score=np.zeros(len(X)))

params = {"verbose": -1, "objective": "regression"}
builtin_loss = builtin_objective("multiclass", copy.deepcopy(params))

booster = lgb.train(params, lgb_train, num_boost_round=20)
params["objective"] = mse_obj
booster_custom = lgb.train(params, lgb_train, num_boost_round=20)
params["objective"] = builtin_loss
booster_exposed = lgb.train(params, lgb_train, num_boost_round=20)
np.testing.assert_allclose(booster_exposed.predict(X), booster.predict(X))
np.testing.assert_allclose(booster_exposed.predict(X), booster_custom.predict(X))

y_pred = booster.predict(X)
np.testing.assert_allclose(builtin_loss(y_pred, lgb_train), mse_obj(y_pred, lgb_train))

@pytest.mark.parametrize("use_weight", [False, True])
def test_objective_function_multiclass(use_weight):
def custom_obj(y_pred, ds):
y_true = ds.get_label()
weight = ds.get_weight()
grad, hess = sklearn_multiclass_custom_objective(y_true, y_pred, weight)
return grad, hess

X, y = make_blobs(n_samples=1_000, centers=[[-4, -4], [4, 4], [-4, 4]], random_state=42)
@pytest.mark.parametrize("test_data", [
{
"custom_objective": mse_obj,
"objective_name": "regression",
"df": make_synthetic_regression(),
"num_class": 1
},
{
"custom_objective": multiclass_custom_objective,
"objective_name": "multiclass",
"df": make_blobs(n_samples=100, centers=[[-4, -4], [4, 4], [-4, 4]], random_state=42),
"num_class": 3
},
])
@pytest.mark.parametrize("num_boost_round", [5, 15])
def test_objective_function_multiclass(use_weight, test_data, num_boost_round):
X, y = test_data["df"]
weight = np.random.choice([1, 2], y.shape) if use_weight else None
lgb_train = lgb.Dataset(X, y, weight=weight, init_score=np.zeros((len(y), 3)))
lgb_train = lgb.Dataset(X, y, weight=weight, init_score=np.zeros((len(y), test_data["num_class"])))

params = {"verbose": -1, "objective": test_data["objective_name"], "num_class": test_data["num_class"]}
builtin_loss = builtin_objective(test_data["objective_name"], copy.deepcopy(params))

params = {"verbose": -1, "objective": "multiclass", "num_class": 3}
builtin_loss = builtin_objective("multiclass", copy.deepcopy(params))
booster = lgb.train(params, lgb_train, num_boost_round=20)
params["objective"] = custom_obj
booster_custom = lgb.train(params, lgb_train, num_boost_round=20)
params["objective"] = builtin_loss
booster_exposed = lgb.train(params, lgb_train, num_boost_round=20)
booster_exposed = lgb.train(params, lgb_train, num_boost_round=num_boost_round)
params["objective"] = test_data["objective_name"]
booster = lgb.train(params, lgb_train, num_boost_round=num_boost_round)
params["objective"] = test_data["custom_objective"]
booster_custom = lgb.train(params, lgb_train, num_boost_round=num_boost_round)

np.testing.assert_allclose(booster_exposed.predict(X), booster.predict(X, raw_score=True))
np.testing.assert_allclose(booster_exposed.predict(X), booster_custom.predict(X))

y_pred = booster.predict(X, raw_score=True)
np.testing.assert_allclose(builtin_loss(y_pred, lgb_train), mse_obj(y_pred, lgb_train))
y_pred = np.zeros_like(booster.predict(X, raw_score=True))
np.testing.assert_allclose(builtin_loss(y_pred, lgb_train), test_data["custom_objective"](y_pred, lgb_train))
31 changes: 23 additions & 8 deletions tests/python_package_test/utils.py
Original file line number Diff line number Diff line change
@@ -161,6 +161,29 @@ def sklearn_multiclass_custom_objective(y_true, y_pred, weight=None):
return grad, hess


def multiclass_custom_objective(y_pred, ds):
y_true = ds.get_label()
weight = ds.get_weight()
grad, hess = sklearn_multiclass_custom_objective(y_true, y_pred, weight)
return grad, hess


def builtin_objective(name, params):
"""Mimics the builtin objective functions to mock training.
"""
def wrapper(y_pred, dtrain):
fobj = lgb.ObjectiveFunction(name, params)
fobj.init(dtrain)
(grad, hess) = fobj(y_pred)
print(grad, hess)
if fobj.num_class != 1:
grad = grad.reshape((fobj.num_class, -1)).transpose()
hess = hess.reshape((fobj.num_class, -1)).transpose()
print(grad, hess)
return (grad, hess)
return wrapper


def pickle_obj(obj, filepath, serializer):
if serializer == "pickle":
with open(filepath, "wb") as f:
@@ -194,14 +217,6 @@ def pickle_and_unpickle_object(obj, serializer):
return obj_from_disk # noqa: RET504


def builtin_objective(name, params):
def wrapper(y_pred, dtrain):
fobj = lgb.ObjectiveFunction(name, params)
fobj.init(dtrain)
return fobj(y_pred)
return wrapper


# doing this here, at import time, to ensure it only runs once_per import
# instead of once per assertion
_numpy_testing_supports_strict_kwarg = "strict" in getfullargspec(np.testing.assert_array_equal).kwonlyargs