Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] Expose ObjectiveFunction class #6586

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Prev Previous commit
Next Next commit
Fix debug prints
  • Loading branch information
Atanas Dimitrov committed Aug 1, 2024
commit 58e400378781926ad5b24c1b4d9bd2da5a41e28c
6 changes: 1 addition & 5 deletions src/c_api.cpp
Original file line number Diff line number Diff line change
@@ -27,7 +27,6 @@
#include <mutex>
#include <stdexcept>
#include <vector>
#include <fstream>

#include "application/predictor.hpp"
#include <LightGBM/utils/yamc/alternate_shared_mutex.hpp>
@@ -44,8 +43,7 @@ inline int LGBM_APIHandleException(const std::string& ex) {
return -1;
}

#define API_BEGIN() std::ofstream outf("logs.txt", std::ios_base::app); \
try {
#define API_BEGIN() try {
#define API_END() } \
catch(std::exception& ex) { return LGBM_APIHandleException(ex); } \
catch(std::string& ex) { return LGBM_APIHandleException(ex); } \
@@ -2590,7 +2588,6 @@ int LGBM_BoosterPredictForMats(BoosterHandle handle,
int64_t* out_len,
double* out_result) {
API_BEGIN();
outf << parameter << std::endl;
auto param = Config::Str2Map(parameter);
Config config;
config.Set(param);
@@ -2758,7 +2755,6 @@ LIGHTGBM_C_EXPORT int LGBM_ObjectiveFunctionCreate(const char *typ,
auto param = Config::Str2Map(parameter);
Config config(param);
*out = ObjectiveFunction::CreateObjectiveFunction(std::string(typ), config);
outf << parameter << std::endl;
API_END();
}

1 change: 0 additions & 1 deletion src/objective/regression_objective.hpp
Original file line number Diff line number Diff line change
@@ -111,7 +111,6 @@ class RegressionL2loss: public ObjectiveFunction {
}

void Init(const Metadata& metadata, data_size_t num_data) override {
Log::Debug("We are here");
num_data_ = num_data;
label_ = metadata.label();
if (sqrt_) {
2 changes: 1 addition & 1 deletion tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
@@ -2937,7 +2937,7 @@ def test_multiclass_custom_objective(use_weight):
builtin_obj_bst = lgb.train(params, ds, num_boost_round=10)
builtin_obj_preds = builtin_obj_bst.predict(X)

params["objective"] = custom_obj
params["objective"] = multiclass_custom_objective
custom_obj_bst = lgb.train(params, ds, num_boost_round=10)
custom_obj_preds = softmax(custom_obj_bst.predict(X))