Skip to content

Commit

Permalink
feat: better xgb support
Browse files Browse the repository at this point in the history
  • Loading branch information
Christopher Hua committed Jan 17, 2023
1 parent d076935 commit 0a2fc96
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
10 changes: 10 additions & 0 deletions hummingbird/ml/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@
except ImportError:
StackingClassifier = None

try:
from xgboost.sklearn import XGBClassifier
except ImportError:
XGBClassifier = None

do_not_merge_columns = tuple(filter(lambda op: op is not None, [OneHotEncoder, ColumnTransformer]))


Expand Down Expand Up @@ -257,6 +262,11 @@ def _parse_sklearn_single_model(topology, model, inputs):
variable = topology.declare_logical_variable("variable")
this_operator.outputs.append(variable)

# ... unless they don't.
if type(model) == XGBClassifier:
predictions = topology.declare_logical_variable("predictions")
this_operator.outputs.append(predictions)

return this_operator.outputs


Expand Down
10 changes: 10 additions & 0 deletions hummingbird/ml/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ def _convert_xgboost(model, backend, test_input, device, extra_config={}):
booster = model.get_booster() if hasattr(model, "get_booster") else model
if hasattr(booster, "num_features"):
extra_config[constants.N_FEATURES] = booster.num_features()
elif hasattr(booster, "feature_names"):
extra_config[constants.N_FEATURES] = len(booster.feature_names)
elif "_features_count" in dir(model):
extra_config[constants.N_FEATURES] = model._features_count
elif test_input is not None:
Expand All @@ -150,6 +152,14 @@ def _convert_xgboost(model, backend, test_input, device, extra_config={}):
"XGBoost converter is not able to infer the number of input features.\
Please pass some test_input to the converter."
)

# Add an extra field to output (labels, predictions) if this is a classifier
if hasattr(model, "get_xgb_params"):
params = model.get_xgb_params()
if params.get("objective") == "binary:logistic":
extra_config[constants.OUTPUT_NAMES] = ['labels', 'predictions']


return _convert_sklearn(model, backend, test_input, device, extra_config)


Expand Down

0 comments on commit 0a2fc96

Please sign in to comment.