Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,23 @@ def get_basescore(model):
Copy-pasted from XGBoost unit test code.

See also:
* https://github.com/dmlc/xgboost/blob/a99bb38bd2762e35e6a1673a0c11e09eddd8e723/python-package/xgboost/testing/updater.py#L13
* https://github.com/dmlc/xgboost/blob/2463938/python-package/xgboost/testing/updater.py#L43
* https://github.com/dmlc/xgboost/issues/9347
* https://discuss.xgboost.ai/t/how-to-get-base-score-from-trained-booster/3192
"""
base_score = float(json.loads(model.get_booster().save_config())["learner"]["learner_model_param"]["base_score"])
return base_score
jintercept = json.loads(model.get_booster().save_config())["learner"]["learner_model_param"]["base_score"]
out = json.loads(jintercept)
if isinstance(out, float):
return out
# For XGBoost 3.1.0 and after, the value is itself a list.
# However, we don't support multiple base scores yet.
if len(out) > 1:
raise ValueError(
f"Model contains multiple base scores ({out}). "
"This typically occurs with XGBoost ≥ 3.1.0, which supports multi-target base scores. "
"This function only supports a single base score. "
)
return out[0]


def SaveXGBoost(xgb_model, key_name, output_path, num_inputs):
Expand Down
2 changes: 2 additions & 0 deletions tmva/tmva/test/rbdt_xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def test_XGBMulticlass_default(self):
"""
Test model trained with multiclass XGBClassifier.
"""
if xgboost.__version__ >= "3.1.0":
self.skipTest("We don't support multiclassification with xgboost>=3.1.0 yet")
_test_XGBMulticlass("default")

def test_XGBRegression_default(self):
Expand Down
Loading