-
Notifications
You must be signed in to change notification settings - Fork 117
/
xgb.py
337 lines (288 loc) · 12.9 KB
/
xgb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
"""Implements XGBoost models."""
import platform
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy
import xgboost.sklearn
from concrete.ml.quantization.quantizers import UniformQuantizer
from ..common.debugging.custom_assert import assert_true
# The sigmoid and softmax functions are already defined in the ONNX module and thus are imported
# here in order to avoid duplicating them.
from ..onnx.ops_impl import numpy_sigmoid, numpy_softmax
from ..quantization import QuantizedArray
from .base import BaseTreeClassifierMixin, BaseTreeRegressorMixin
# Disabling invalid-name to use uppercase X
# pylint: disable=invalid-name,too-many-instance-attributes
class XGBClassifier(BaseTreeClassifierMixin):
"""Implements the XGBoost classifier."""
sklearn_alg = xgboost.sklearn.XGBClassifier
q_x_byfeatures: List[QuantizedArray]
n_bits: int
output_quantizers: List[UniformQuantizer]
_tensor_tree_predict: Optional[Callable]
n_classes_: int
sklearn_model: Any
framework: str = "xgboost"
# pylint: disable=too-many-arguments,too-many-locals
def __init__(
self,
n_bits: int = 6,
max_depth: Optional[int] = 3,
learning_rate: Optional[float] = 0.1,
n_estimators: Optional[int] = 20,
objective: Optional[str] = "binary:logistic",
booster: Optional[str] = None,
tree_method: Optional[str] = None,
n_jobs: Optional[int] = None,
gamma: Optional[float] = None,
min_child_weight: Optional[float] = None,
max_delta_step: Optional[float] = None,
subsample: Optional[float] = None,
colsample_bytree: Optional[float] = None,
colsample_bylevel: Optional[float] = None,
colsample_bynode: Optional[float] = None,
reg_alpha: Optional[float] = None,
reg_lambda: Optional[float] = None,
scale_pos_weight: Optional[float] = None,
base_score: Optional[float] = None,
missing: float = numpy.nan,
num_parallel_tree: Optional[int] = None,
monotone_constraints: Optional[Union[Dict[str, int], str]] = None,
interaction_constraints: Optional[Union[str, List[Tuple[str]]]] = None,
importance_type: Optional[str] = None,
gpu_id: Optional[int] = None,
validate_parameters: Optional[bool] = None,
predictor: Optional[str] = None,
enable_categorical: bool = False,
use_label_encoder: bool = False,
random_state: Optional[
Union[numpy.random.RandomState, int] # pylint: disable=no-member
] = None,
verbosity: Optional[int] = None,
):
# See https://xgboost.readthedocs.io/en/stable/python/python_api.html#module-xgboost.sklearn
# for more information about the parameters used.
# base_score != 0.5 or None seems to not pass our tests (see #474)
assert_true(
base_score in [0.5, None],
f"Currently, only 0.5 or None are supported for base_score. Got {base_score}",
)
# FIXME: see https://github.com/zama-ai/concrete-ml-internal/issues/503, there is currently
# an issue with n_jobs != 1 on macOS
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/506, remove this workaround
# once https://github.com/zama-ai/concrete-ml-internal/issues/503 is fixed
if platform.system() == "Darwin":
if n_jobs != 1: # pragma: no cover
warnings.warn("forcing n_jobs = 1 on mac for segfault issue") # pragma: no cover
n_jobs = 1 # pragma: no cover
BaseTreeClassifierMixin.__init__(self, n_bits=n_bits)
self.max_depth = max_depth
self.learning_rate = learning_rate
self.n_estimators = n_estimators
self.objective = objective
self.booster = booster
self.tree_method = tree_method
self.n_jobs = n_jobs
self.gamma = gamma
self.min_child_weight = min_child_weight
self.max_delta_step = max_delta_step
self.subsample = subsample
self.colsample_bytree = colsample_bytree
self.colsample_bylevel = colsample_bylevel
self.colsample_bynode = colsample_bynode
self.reg_alpha = reg_alpha
self.reg_lambda = reg_lambda
self.scale_pos_weight = scale_pos_weight
self.base_score = base_score
self.missing = missing
self.num_parallel_tree = num_parallel_tree
self.monotone_constraints = monotone_constraints
self.interaction_constraints = interaction_constraints
self.importance_type = importance_type
self.gpu_id = gpu_id
self.validate_parameters = validate_parameters
self.predictor = predictor
self.enable_categorical = enable_categorical
self.use_label_encoder = use_label_encoder
self.random_state = random_state
self.verbosity = verbosity
self.post_processing_params: Dict[str, Any] = {}
def _update_post_processing_params(self):
"""Update the post processing params."""
self.post_processing_params = {
"n_classes_": self.n_classes_,
"n_estimators": self.n_estimators,
}
def post_processing(self, y_preds: numpy.ndarray) -> numpy.ndarray:
"""Apply post-processing to the predictions.
Args:
y_preds (numpy.ndarray): The predictions.
Returns:
numpy.ndarray: The post-processed predictions.
"""
assert self.output_quantizers is not None
# Update post-processing params with their current values
self.__dict__.update(self.post_processing_params)
y_preds = self.output_quantizers[0].dequant(y_preds)
# Apply transpose
y_preds = numpy.transpose(y_preds, axes=(2, 1, 0))
# XGBoost returns a shape (n_examples, n_classes, n_trees) when self.n_classes_ >= 3
# otherwise it returns a shape (n_examples, 1, n_trees)
# Reshape to (-1, n_classes, n_trees)
# No need to reshape if n_classes = 2
if self.n_classes_ > 2:
y_preds = y_preds.reshape((-1, self.n_classes_, self.n_estimators)) # type: ignore
# Sum all tree outputs.
y_preds = numpy.sum(y_preds, axis=2)
assert_true(y_preds.ndim == 2, "y_preds should be a 2D array")
# If this binary classification problem
if self.n_classes_ == 2:
# Apply sigmoid
y_preds = numpy_sigmoid(y_preds)[0]
# Transform in a 2d array where [1-p, p] is the output as XGBoost only outputs 1 value
# when considering 2 classes
y_preds = numpy.concatenate((1 - y_preds, y_preds), axis=1)
# Else, it's a multi-class classification problem
else:
# Apply softmax
y_preds = numpy_softmax(y_preds)[0]
return y_preds
# Disabling invalid-name to use uppercase X
# pylint: disable=invalid-name,too-many-instance-attributes
class XGBRegressor(BaseTreeRegressorMixin):
"""Implements the XGBoost regressor."""
sklearn_alg = xgboost.sklearn.XGBRegressor
q_x_byfeatures: List[QuantizedArray]
n_bits: int
output_quantizers: List[UniformQuantizer]
_tensor_tree_predict: Optional[Callable]
sklearn_model: Any
framework: str = "xgboost"
# pylint: disable=too-many-arguments,too-many-locals
def __init__(
self,
n_bits: int = 6,
max_depth: Optional[int] = 3,
learning_rate: Optional[float] = 0.1,
n_estimators: Optional[int] = 20,
objective: Optional[str] = "reg:squarederror",
booster: Optional[str] = None,
tree_method: Optional[str] = None,
n_jobs: Optional[int] = None,
gamma: Optional[float] = None,
min_child_weight: Optional[float] = None,
max_delta_step: Optional[float] = None,
subsample: Optional[float] = None,
colsample_bytree: Optional[float] = None,
colsample_bylevel: Optional[float] = None,
colsample_bynode: Optional[float] = None,
reg_alpha: Optional[float] = None,
reg_lambda: Optional[float] = None,
scale_pos_weight: Optional[float] = None,
base_score: Optional[float] = None,
missing: float = numpy.nan,
num_parallel_tree: Optional[int] = None,
monotone_constraints: Optional[Union[Dict[str, int], str]] = None,
interaction_constraints: Optional[Union[str, List[Tuple[str]]]] = None,
importance_type: Optional[str] = None,
gpu_id: Optional[int] = None,
validate_parameters: Optional[bool] = None,
predictor: Optional[str] = None,
enable_categorical: bool = False,
use_label_encoder: bool = False,
random_state: Optional[
Union[numpy.random.RandomState, int] # pylint: disable=no-member
] = None,
verbosity: Optional[int] = None,
):
# See https://xgboost.readthedocs.io/en/stable/python/python_api.html#module-xgboost.sklearn
# for more information about the parameters used.
# base_score != 0.5 or None seems to not pass our tests (see #474)
assert_true(
base_score in [0.5, None],
f"Currently, only 0.5 or None are supported for base_score. Got {base_score}",
)
# FIXME: see https://github.com/zama-ai/concrete-ml-internal/issues/503, there is currently
# an issue with n_jobs != 1 on macOS
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/506, remove this workaround
# once https://github.com/zama-ai/concrete-ml-internal/issues/503 is fixed
if platform.system() == "Darwin":
if n_jobs != 1: # pragma: no cover
warnings.warn("forcing n_jobs = 1 on mac for segfault issue") # pragma: no cover
n_jobs = 1 # pragma: no cover
BaseTreeRegressorMixin.__init__(self, n_bits=n_bits)
self.max_depth = max_depth
self.learning_rate = learning_rate
self.n_estimators = n_estimators
self.objective = objective
self.booster = booster
self.tree_method = tree_method
self.n_jobs = n_jobs
self.gamma = gamma
self.min_child_weight = min_child_weight
self.max_delta_step = max_delta_step
self.subsample = subsample
self.colsample_bytree = colsample_bytree
self.colsample_bylevel = colsample_bylevel
self.colsample_bynode = colsample_bynode
self.reg_alpha = reg_alpha
self.reg_lambda = reg_lambda
self.scale_pos_weight = scale_pos_weight
self.base_score = base_score
self.missing = missing
self.num_parallel_tree = num_parallel_tree
self.monotone_constraints = monotone_constraints
self.interaction_constraints = interaction_constraints
self.importance_type = importance_type
self.gpu_id = gpu_id
self.validate_parameters = validate_parameters
self.predictor = predictor
self.enable_categorical = enable_categorical
self.use_label_encoder = use_label_encoder
self.random_state = random_state
self.verbosity = verbosity
self.post_processing_params: Dict[str, Any] = {}
def _update_post_processing_params(self):
"""Update the post processing params."""
self.post_processing_params = {
"n_estimators": self.n_estimators,
}
def post_processing(self, y_preds: numpy.ndarray) -> numpy.ndarray:
"""Apply post-processing to the predictions.
Args:
y_preds (numpy.ndarray): The predictions.
Returns:
numpy.ndarray: The post-processed predictions.
"""
assert self.output_quantizers is not None
# Update post-processing params with their current values
self.__dict__.update(self.post_processing_params)
y_preds = self.output_quantizers[0].dequant(y_preds)
# Apply transpose
y_preds = numpy.transpose(y_preds, axes=(2, 1, 0))
# XGBoost returns a shape (n_examples, n_classes, n_trees) when self.n_classes_ >= 3
# otherwise it returns a shape (n_examples, 1, n_trees)
# Sum all tree outputs.
y_preds = numpy.sum(y_preds, axis=2)
assert_true(y_preds.ndim == 2, "y_preds should be a 2D array")
return y_preds
def fit(self, X, y, **kwargs) -> Any:
"""Fit the tree-based estimator.
Args:
X : training data
By default, you should be able to pass:
* numpy arrays
* torch tensors
* pandas DataFrame or Series
y (numpy.ndarray): The target data.
**kwargs: args for super().fit
Returns:
Any: The fitted model.
"""
# HummingBird doesn't manage correctly n_targets > 1
assert_true(
len(y.shape) == 1 or y.shape[1] == 1, "n_targets = 1 is the only supported case"
)
# Call super's fit that will train the network
super().fit(X, y, **kwargs)
return self