forked from dmlc/xgboost
-
Notifications
You must be signed in to change notification settings - Fork 0
/
sklearn.py
1986 lines (1748 loc) · 74.4 KB
/
sklearn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme, too-many-lines
"""Scikit-Learn Wrapper interface for XGBoost."""
import copy
import json
import os
import warnings
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
cast,
)
import numpy as np
from scipy.special import softmax
from ._typing import ArrayLike, FeatureNames, FeatureTypes
from .callback import TrainingCallback
# Do not use class names on scikit-learn directly. Re-define the classes on
# .compat to guarantee the behavior without scikit-learn
from .compat import (
SKLEARN_INSTALLED,
XGBClassifierBase,
XGBModelBase,
XGBoostLabelEncoder,
XGBRegressorBase,
)
from .config import config_context
from .core import (
Booster,
DMatrix,
Metric,
QuantileDMatrix,
XGBoostError,
_convert_ntree_limit,
_deprecate_positional_args,
)
from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_array
from .training import train
class XGBRankerMixIn: # pylint: disable=too-few-public-methods
"""MixIn for ranking, defines the _estimator_type usually defined in scikit-learn base
classes."""
_estimator_type = "ranker"
def _check_rf_callback(
early_stopping_rounds: Optional[int],
callbacks: Optional[Sequence[TrainingCallback]],
) -> None:
if early_stopping_rounds is not None or callbacks is not None:
raise NotImplementedError(
"`early_stopping_rounds` and `callbacks` are not implemented for"
" random forest."
)
_SklObjective = Optional[
Union[str, Callable[[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]]
]
def _objective_decorator(
func: Callable[[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]
) -> Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]]:
"""Decorate an objective function
Converts an objective function using the typical sklearn metrics
signature so that it is usable with ``xgboost.training.train``
Parameters
----------
func:
Expects a callable with signature ``func(y_true, y_pred)``:
y_true: array_like of shape [n_samples]
The target values
y_pred: array_like of shape [n_samples]
The predicted values
Returns
-------
new_func:
The new objective function as expected by ``xgboost.training.train``.
The signature is ``new_func(preds, dmatrix)``:
preds: array_like, shape [n_samples]
The predicted values
dmatrix: ``DMatrix``
The training set from which the labels will be extracted using
``dmatrix.get_label()``
"""
def inner(preds: np.ndarray, dmatrix: DMatrix) -> Tuple[np.ndarray, np.ndarray]:
"""internal function"""
labels = dmatrix.get_label()
return func(labels, preds)
return inner
def _metric_decorator(func: Callable) -> Metric:
"""Decorate a metric function from sklearn.
Converts an metric function that uses the typical sklearn metric signature so that it
is compatible with :py:func:`train`
"""
def inner(y_score: np.ndarray, dmatrix: DMatrix) -> Tuple[str, float]:
y_true = dmatrix.get_label()
return func.__name__, func(y_true, y_score)
return inner
__estimator_doc = """
n_estimators : int
Number of gradient boosted trees. Equivalent to number of boosting
rounds.
"""
__model_doc = f"""
max_depth : Optional[int]
Maximum tree depth for base learners.
max_leaves :
Maximum number of leaves; 0 indicates no limit.
max_bin :
If using histogram-based algorithm, maximum number of bins per feature
grow_policy :
Tree growing policy. 0: favor splitting at nodes closest to the node, i.e. grow
depth-wise. 1: favor splitting at nodes with highest loss change.
learning_rate : Optional[float]
Boosting learning rate (xgb's "eta")
verbosity : Optional[int]
The degree of verbosity. Valid values are 0 (silent) - 3 (debug).
objective : {_SklObjective}
Specify the learning task and the corresponding learning objective or
a custom objective function to be used (see note below).
booster: Optional[str]
Specify which booster to use: gbtree, gblinear or dart.
tree_method: Optional[str]
Specify which tree method to use. Default to auto. If this parameter is set to
default, XGBoost will choose the most conservative option available. It's
recommended to study this option from the parameters document :doc:`tree method
</treemethod>`
n_jobs : Optional[int]
Number of parallel threads used to run xgboost. When used with other
Scikit-Learn algorithms like grid search, you may choose which algorithm to
parallelize and balance the threads. Creating thread contention will
significantly slow down both algorithms.
gamma : Optional[float]
(min_split_loss) Minimum loss reduction required to make a further partition on a
leaf node of the tree.
min_child_weight : Optional[float]
Minimum sum of instance weight(hessian) needed in a child.
max_delta_step : Optional[float]
Maximum delta step we allow each tree's weight estimation to be.
subsample : Optional[float]
Subsample ratio of the training instance.
sampling_method :
Sampling method. Used only by `gpu_hist` tree method.
- `uniform`: select random training instances uniformly.
- `gradient_based` select random training instances with higher probability when
the gradient and hessian are larger. (cf. CatBoost)
colsample_bytree : Optional[float]
Subsample ratio of columns when constructing each tree.
colsample_bylevel : Optional[float]
Subsample ratio of columns for each level.
colsample_bynode : Optional[float]
Subsample ratio of columns for each split.
reg_alpha : Optional[float]
L1 regularization term on weights (xgb's alpha).
reg_lambda : Optional[float]
L2 regularization term on weights (xgb's lambda).
scale_pos_weight : Optional[float]
Balancing of positive and negative weights.
base_score : Optional[float]
The initial prediction score of all instances, global bias.
random_state : Optional[Union[numpy.random.RandomState, int]]
Random number seed.
.. note::
Using gblinear booster with shotgun updater is nondeterministic as
it uses Hogwild algorithm.
missing : float, default np.nan
Value in the data which needs to be present as a missing value.
num_parallel_tree: Optional[int]
Used for boosting random forest.
monotone_constraints : Optional[Union[Dict[str, int], str]]
Constraint of variable monotonicity. See :doc:`tutorial </tutorials/monotonic>`
for more information.
interaction_constraints : Optional[Union[str, List[Tuple[str]]]]
Constraints for interaction representing permitted interactions. The
constraints must be specified in the form of a nested list, e.g. ``[[0, 1], [2,
3, 4]]``, where each inner list is a group of indices of features that are
allowed to interact with each other. See :doc:`tutorial
</tutorials/feature_interaction_constraint>` for more information
importance_type: Optional[str]
The feature importance type for the feature_importances\\_ property:
* For tree model, it's either "gain", "weight", "cover", "total_gain" or
"total_cover".
* For linear model, only "weight" is defined and it's the normalized coefficients
without bias.
gpu_id : Optional[int]
Device ordinal.
validate_parameters : Optional[bool]
Give warnings for unknown parameter.
predictor : Optional[str]
Force XGBoost to use specific predictor, available choices are [cpu_predictor,
gpu_predictor].
enable_categorical : bool
.. versionadded:: 1.5.0
.. note:: This parameter is experimental
Experimental support for categorical data. When enabled, cudf/pandas.DataFrame
should be used to specify categorical data type. Also, JSON/UBJSON
serialization format is required.
feature_types : FeatureTypes
.. versionadded:: 2.0.0
Used for specifying feature types without constructing a dataframe. See
:py:class:`DMatrix` for details.
max_cat_to_onehot : Optional[int]
.. versionadded:: 1.6.0
.. note:: This parameter is experimental
A threshold for deciding whether XGBoost should use one-hot encoding based split
for categorical data. When number of categories is lesser than the threshold
then one-hot encoding is chosen, otherwise the categories will be partitioned
into children nodes. Only relevant for regression and binary classification.
See :doc:`Categorical Data </tutorials/categorical>` for details.
eval_metric : Optional[Union[str, List[str], Callable]]
.. versionadded:: 1.6.0
Metric used for monitoring the training result and early stopping. It can be a
string or list of strings as names of predefined metric in XGBoost (See
doc/parameter.rst), one of the metrics in :py:mod:`sklearn.metrics`, or any other
user defined metric that looks like `sklearn.metrics`.
If custom objective is also provided, then custom metric should implement the
corresponding reverse link function.
Unlike the `scoring` parameter commonly used in scikit-learn, when a callable
object is provided, it's assumed to be a cost function and by default XGBoost will
minimize the result during early stopping.
For advanced usage on Early stopping like directly choosing to maximize instead of
minimize, see :py:obj:`xgboost.callback.EarlyStopping`.
See :doc:`Custom Objective and Evaluation Metric </tutorials/custom_metric_obj>`
for more.
.. note::
This parameter replaces `eval_metric` in :py:meth:`fit` method. The old one
receives un-transformed prediction regardless of whether custom objective is
being used.
.. code-block:: python
from sklearn.datasets import load_diabetes
from sklearn.metrics import mean_absolute_error
X, y = load_diabetes(return_X_y=True)
reg = xgb.XGBRegressor(
tree_method="hist",
eval_metric=mean_absolute_error,
)
reg.fit(X, y, eval_set=[(X, y)])
early_stopping_rounds : Optional[int]
.. versionadded:: 1.6.0
Activates early stopping. Validation metric needs to improve at least once in
every **early_stopping_rounds** round(s) to continue training. Requires at least
one item in **eval_set** in :py:meth:`fit`.
The method returns the model from the last iteration (not the best one). If
there's more than one item in **eval_set**, the last entry will be used for early
stopping. If there's more than one metric in **eval_metric**, the last metric
will be used for early stopping.
If early stopping occurs, the model will have three additional fields:
:py:attr:`best_score`, :py:attr:`best_iteration` and
:py:attr:`best_ntree_limit`.
.. note::
This parameter replaces `early_stopping_rounds` in :py:meth:`fit` method.
callbacks : Optional[List[TrainingCallback]]
List of callback functions that are applied at end of each iteration.
It is possible to use predefined callbacks by using
:ref:`Callback API <callback_api>`.
.. note::
States in callback are not preserved during training, which means callback
objects can not be reused for multiple training sessions without
reinitialization or deepcopy.
.. code-block:: python
for params in parameters_grid:
# be sure to (re)initialize the callbacks before each run
callbacks = [xgb.callback.LearningRateScheduler(custom_rates)]
xgboost.train(params, Xy, callbacks=callbacks)
kwargs : dict, optional
Keyword arguments for XGBoost Booster object. Full documentation of parameters
can be found :doc:`here </parameter>`.
Attempting to set a parameter via the constructor args and \\*\\*kwargs
dict simultaneously will result in a TypeError.
.. note:: \\*\\*kwargs unsupported by scikit-learn
\\*\\*kwargs is unsupported by scikit-learn. We do not guarantee
that parameters passed via this argument will interact properly
with scikit-learn.
"""
__custom_obj_note = """
.. note:: Custom objective function
A custom objective function can be provided for the ``objective``
parameter. In this case, it should have the signature
``objective(y_true, y_pred) -> grad, hess``:
y_true: array_like of shape [n_samples]
The target values
y_pred: array_like of shape [n_samples]
The predicted values
grad: array_like of shape [n_samples]
The value of the gradient for each sample point.
hess: array_like of shape [n_samples]
The value of the second derivative for each sample point
"""
def xgboost_model_doc(
header: str,
items: List[str],
extra_parameters: Optional[str] = None,
end_note: Optional[str] = None,
) -> Callable[[Type], Type]:
"""Obtain documentation for Scikit-Learn wrappers
Parameters
----------
header: str
An introducion to the class.
items : list
A list of common doc items. Available items are:
- estimators: the meaning of n_estimators
- model: All the other parameters
- objective: note for customized objective
extra_parameters: str
Document for class specific parameters, placed at the head.
end_note: str
Extra notes put to the end."""
def get_doc(item: str) -> str:
"""Return selected item"""
__doc = {
"estimators": __estimator_doc,
"model": __model_doc,
"objective": __custom_obj_note,
}
return __doc[item]
def adddoc(cls: Type) -> Type:
doc = [
"""
Parameters
----------
"""
]
if extra_parameters:
doc.append(extra_parameters)
doc.extend([get_doc(i) for i in items])
if end_note:
doc.append(end_note)
full_doc = [header + "\n\n"]
full_doc.extend(doc)
cls.__doc__ = "".join(full_doc)
return cls
return adddoc
def _wrap_evaluation_matrices(
missing: float,
X: Any,
y: Any,
group: Optional[Any],
qid: Optional[Any],
sample_weight: Optional[Any],
base_margin: Optional[Any],
feature_weights: Optional[Any],
eval_set: Optional[Sequence[Tuple[Any, Any]]],
sample_weight_eval_set: Optional[Sequence[Any]],
base_margin_eval_set: Optional[Sequence[Any]],
eval_group: Optional[Sequence[Any]],
eval_qid: Optional[Sequence[Any]],
create_dmatrix: Callable,
enable_categorical: bool,
feature_types: Optional[FeatureTypes],
) -> Tuple[Any, List[Tuple[Any, str]]]:
"""Convert array_like evaluation matrices into DMatrix. Perform validation on the
way."""
train_dmatrix = create_dmatrix(
data=X,
label=y,
group=group,
qid=qid,
weight=sample_weight,
base_margin=base_margin,
feature_weights=feature_weights,
missing=missing,
enable_categorical=enable_categorical,
feature_types=feature_types,
ref=None,
)
n_validation = 0 if eval_set is None else len(eval_set)
def validate_or_none(meta: Optional[Sequence], name: str) -> Sequence:
if meta is None:
return [None] * n_validation
if len(meta) != n_validation:
raise ValueError(
f"{name}'s length does not equal `eval_set`'s length, "
+ f"expecting {n_validation}, got {len(meta)}"
)
return meta
if eval_set is not None:
sample_weight_eval_set = validate_or_none(
sample_weight_eval_set, "sample_weight_eval_set"
)
base_margin_eval_set = validate_or_none(
base_margin_eval_set, "base_margin_eval_set"
)
eval_group = validate_or_none(eval_group, "eval_group")
eval_qid = validate_or_none(eval_qid, "eval_qid")
evals = []
for i, (valid_X, valid_y) in enumerate(eval_set):
# Skip the duplicated entry.
if all(
(
valid_X is X,
valid_y is y,
sample_weight_eval_set[i] is sample_weight,
base_margin_eval_set[i] is base_margin,
eval_group[i] is group,
eval_qid[i] is qid,
)
):
evals.append(train_dmatrix)
else:
m = create_dmatrix(
data=valid_X,
label=valid_y,
weight=sample_weight_eval_set[i],
group=eval_group[i],
qid=eval_qid[i],
base_margin=base_margin_eval_set[i],
missing=missing,
enable_categorical=enable_categorical,
feature_types=feature_types,
ref=train_dmatrix,
)
evals.append(m)
nevals = len(evals)
eval_names = [f"validation_{i}" for i in range(nevals)]
evals = list(zip(evals, eval_names))
else:
if any(
meta is not None
for meta in [
sample_weight_eval_set,
base_margin_eval_set,
eval_group,
eval_qid,
]
):
raise ValueError(
"`eval_set` is not set but one of the other evaluation meta info is "
"not None."
)
evals = []
return train_dmatrix, evals
@xgboost_model_doc(
"""Implementation of the Scikit-Learn API for XGBoost.""",
["estimators", "model", "objective"],
)
class XGBModel(XGBModelBase):
# pylint: disable=too-many-arguments, too-many-instance-attributes, missing-docstring
def __init__(
self,
max_depth: Optional[int] = None,
max_leaves: Optional[int] = None,
max_bin: Optional[int] = None,
grow_policy: Optional[str] = None,
learning_rate: Optional[float] = None,
n_estimators: int = 100,
verbosity: Optional[int] = None,
objective: _SklObjective = None,
booster: Optional[str] = None,
tree_method: Optional[str] = None,
n_jobs: Optional[int] = None,
gamma: Optional[float] = None,
min_child_weight: Optional[float] = None,
max_delta_step: Optional[float] = None,
subsample: Optional[float] = None,
sampling_method: Optional[str] = None,
colsample_bytree: Optional[float] = None,
colsample_bylevel: Optional[float] = None,
colsample_bynode: Optional[float] = None,
reg_alpha: Optional[float] = None,
reg_lambda: Optional[float] = None,
scale_pos_weight: Optional[float] = None,
base_score: Optional[float] = None,
random_state: Optional[Union[np.random.RandomState, int]] = None,
missing: float = np.nan,
num_parallel_tree: Optional[int] = None,
monotone_constraints: Optional[Union[Dict[str, int], str]] = None,
interaction_constraints: Optional[Union[str, Sequence[Sequence[str]]]] = None,
importance_type: Optional[str] = None,
gpu_id: Optional[int] = None,
validate_parameters: Optional[bool] = None,
predictor: Optional[str] = None,
enable_categorical: bool = False,
feature_types: FeatureTypes = None,
max_cat_to_onehot: Optional[int] = None,
eval_metric: Optional[Union[str, List[str], Callable]] = None,
early_stopping_rounds: Optional[int] = None,
callbacks: Optional[List[TrainingCallback]] = None,
**kwargs: Any,
) -> None:
if not SKLEARN_INSTALLED:
raise ImportError(
"sklearn needs to be installed in order to use this module"
)
self.n_estimators = n_estimators
self.objective = objective
self.max_depth = max_depth
self.max_leaves = max_leaves
self.max_bin = max_bin
self.grow_policy = grow_policy
self.learning_rate = learning_rate
self.verbosity = verbosity
self.booster = booster
self.tree_method = tree_method
self.gamma = gamma
self.min_child_weight = min_child_weight
self.max_delta_step = max_delta_step
self.subsample = subsample
self.sampling_method = sampling_method
self.colsample_bytree = colsample_bytree
self.colsample_bylevel = colsample_bylevel
self.colsample_bynode = colsample_bynode
self.reg_alpha = reg_alpha
self.reg_lambda = reg_lambda
self.scale_pos_weight = scale_pos_weight
self.base_score = base_score
self.missing = missing
self.num_parallel_tree = num_parallel_tree
self.random_state = random_state
self.n_jobs = n_jobs
self.monotone_constraints = monotone_constraints
self.interaction_constraints = interaction_constraints
self.importance_type = importance_type
self.gpu_id = gpu_id
self.validate_parameters = validate_parameters
self.predictor = predictor
self.enable_categorical = enable_categorical
self.feature_types = feature_types
self.max_cat_to_onehot = max_cat_to_onehot
self.eval_metric = eval_metric
self.early_stopping_rounds = early_stopping_rounds
self.callbacks = callbacks
if kwargs:
self.kwargs = kwargs
def _more_tags(self) -> Dict[str, bool]:
"""Tags used for scikit-learn data validation."""
return {"allow_nan": True, "no_validation": True}
def __sklearn_is_fitted__(self) -> bool:
return hasattr(self, "_Booster")
def get_booster(self) -> Booster:
"""Get the underlying xgboost Booster of this model.
This will raise an exception when fit was not called
Returns
-------
booster : a xgboost booster of underlying model
"""
if not self.__sklearn_is_fitted__():
from sklearn.exceptions import NotFittedError
raise NotFittedError("need to call fit or load_model beforehand")
return self._Booster
def set_params(self, **params: Any) -> "XGBModel":
"""Set the parameters of this estimator. Modification of the sklearn method to
allow unknown kwargs. This allows using the full range of xgboost
parameters that are not defined as member variables in sklearn grid
search.
Returns
-------
self
"""
if not params:
# Simple optimization to gain speed (inspect is slow)
return self
# this concatenates kwargs into parameters, enabling `get_params` for
# obtaining parameters from keyword parameters.
for key, value in params.items():
if hasattr(self, key):
setattr(self, key, value)
else:
if not hasattr(self, "kwargs"):
self.kwargs = {}
self.kwargs[key] = value
if hasattr(self, "_Booster"):
parameters = self.get_xgb_params()
self.get_booster().set_param(parameters)
return self
def get_params(self, deep: bool = True) -> Dict[str, Any]:
# pylint: disable=attribute-defined-outside-init
"""Get parameters."""
# Based on: https://stackoverflow.com/questions/59248211
# The basic flow in `get_params` is:
# 0. Return parameters in subclass first, by using inspect.
# 1. Return parameters in `XGBModel` (the base class).
# 2. Return whatever in `**kwargs`.
# 3. Merge them.
params = super().get_params(deep)
cp = copy.copy(self)
cp.__class__ = cp.__class__.__bases__[0]
params.update(cp.__class__.get_params(cp, deep))
# if kwargs is a dict, update params accordingly
if hasattr(self, "kwargs") and isinstance(self.kwargs, dict):
params.update(self.kwargs)
if isinstance(params["random_state"], np.random.RandomState):
params["random_state"] = params["random_state"].randint(
np.iinfo(np.int32).max
)
def parse_parameter(value: Any) -> Optional[Union[int, float, str]]:
for t in (int, float, str):
try:
ret = t(value)
return ret
except ValueError:
continue
return None
# Get internal parameter values
try:
config = json.loads(self.get_booster().save_config())
stack = [config]
internal = {}
while stack:
obj = stack.pop()
for k, v in obj.items():
if k.endswith("_param"):
for p_k, p_v in v.items():
internal[p_k] = p_v
elif isinstance(v, dict):
stack.append(v)
for k, v in internal.items():
if k in params and params[k] is None:
params[k] = parse_parameter(v)
except ValueError:
pass
return params
def get_xgb_params(self) -> Dict[str, Any]:
"""Get xgboost specific parameters."""
params = self.get_params()
# Parameters that should not go into native learner.
wrapper_specific = {
"importance_type",
"kwargs",
"missing",
"n_estimators",
"use_label_encoder",
"enable_categorical",
"early_stopping_rounds",
"callbacks",
"feature_types",
}
filtered = {}
for k, v in params.items():
if k not in wrapper_specific and not callable(v):
filtered[k] = v
return filtered
def get_num_boosting_rounds(self) -> int:
"""Gets the number of xgboost boosting rounds."""
return self.n_estimators
def _get_type(self) -> str:
if not hasattr(self, "_estimator_type"):
raise TypeError(
"`_estimator_type` undefined. "
"Please use appropriate mixin to define estimator type."
)
return self._estimator_type # pylint: disable=no-member
def save_model(self, fname: Union[str, os.PathLike]) -> None:
meta: Dict[str, Any] = {}
for k, v in self.__dict__.items():
if k == "_le":
meta["_le"] = self._le.to_json()
continue
if k == "_Booster":
continue
if k == "classes_":
# numpy array is not JSON serializable
meta["classes_"] = self.classes_.tolist()
continue
if k == "feature_types":
# Use the `feature_types` attribute from booster instead.
meta["feature_types"] = None
continue
try:
json.dumps({k: v})
meta[k] = v
except TypeError:
warnings.warn(
str(k) + " is not saved in Scikit-Learn meta.", UserWarning
)
meta["_estimator_type"] = self._get_type()
meta_str = json.dumps(meta)
self.get_booster().set_attr(scikit_learn=meta_str)
self.get_booster().save_model(fname)
# Delete the attribute after save
self.get_booster().set_attr(scikit_learn=None)
save_model.__doc__ = f"""{Booster.save_model.__doc__}"""
def load_model(self, fname: Union[str, bytearray, os.PathLike]) -> None:
# pylint: disable=attribute-defined-outside-init
if not hasattr(self, "_Booster"):
self._Booster = Booster({"n_jobs": self.n_jobs})
self.get_booster().load_model(fname)
meta_str = self.get_booster().attr("scikit_learn")
if meta_str is None:
# FIXME(jiaming): This doesn't have to be a problem as most of the needed
# information like num_class and objective is in Learner class.
warnings.warn("Loading a native XGBoost model with Scikit-Learn interface.")
return
meta = json.loads(meta_str)
states = {}
for k, v in meta.items():
if k == "_le":
self._le = XGBoostLabelEncoder()
self._le.from_json(v)
continue
# FIXME(jiaming): This can be removed once label encoder is gone since we can
# generate it from `np.arange(self.n_classes_)`
if k == "classes_":
self.classes_ = np.array(v)
continue
if k == "feature_types":
self.feature_types = self.get_booster().feature_types
continue
if k == "_estimator_type":
if self._get_type() != v:
raise TypeError(
"Loading an estimator with different type. "
f"Expecting: {self._get_type()}, got: {v}"
)
continue
states[k] = v
self.__dict__.update(states)
# Delete the attribute after load
self.get_booster().set_attr(scikit_learn=None)
load_model.__doc__ = f"""{Booster.load_model.__doc__}"""
# pylint: disable=too-many-branches
def _configure_fit(
self,
booster: Optional[Union[Booster, "XGBModel", str]],
eval_metric: Optional[Union[Callable, str, Sequence[str]]],
params: Dict[str, Any],
early_stopping_rounds: Optional[int],
callbacks: Optional[Sequence[TrainingCallback]],
) -> Tuple[
Optional[Union[Booster, str, "XGBModel"]],
Optional[Metric],
Dict[str, Any],
Optional[int],
Optional[Sequence[TrainingCallback]],
]:
"""Configure parameters for :py:meth:`fit`."""
if isinstance(booster, XGBModel):
model: Optional[Union[Booster, str]] = booster.get_booster()
else:
model = booster
def _deprecated(parameter: str) -> None:
warnings.warn(
f"`{parameter}` in `fit` method is deprecated for better compatibility "
f"with scikit-learn, use `{parameter}` in constructor or`set_params` "
"instead.",
UserWarning,
)
def _duplicated(parameter: str) -> None:
raise ValueError(
f"2 different `{parameter}` are provided. Use the one in constructor "
"or `set_params` instead."
)
# Configure evaluation metric.
if eval_metric is not None:
_deprecated("eval_metric")
if self.eval_metric is not None and eval_metric is not None:
_duplicated("eval_metric")
# - track where does the evaluation metric come from
if self.eval_metric is not None:
from_fit = False
eval_metric = self.eval_metric
else:
from_fit = True
# - configure callable evaluation metric
metric: Optional[Metric] = None
if eval_metric is not None:
if callable(eval_metric) and from_fit:
# No need to wrap the evaluation function for old parameter.
metric = eval_metric
elif callable(eval_metric):
# Parameter from constructor or set_params
metric = _metric_decorator(eval_metric)
else:
params.update({"eval_metric": eval_metric})
# Configure early_stopping_rounds
if early_stopping_rounds is not None:
_deprecated("early_stopping_rounds")
if early_stopping_rounds is not None and self.early_stopping_rounds is not None:
_duplicated("early_stopping_rounds")
early_stopping_rounds = (
self.early_stopping_rounds
if self.early_stopping_rounds is not None
else early_stopping_rounds
)
# Configure callbacks
if callbacks is not None:
_deprecated("callbacks")
if callbacks is not None and self.callbacks is not None:
_duplicated("callbacks")
callbacks = self.callbacks if self.callbacks is not None else callbacks
tree_method = params.get("tree_method", None)
cat_support = {"gpu_hist", "approx", "hist"}
if self.enable_categorical and tree_method not in cat_support:
raise ValueError(
"Experimental support for categorical data is not implemented for"
" current tree method yet."
)
return model, metric, params, early_stopping_rounds, callbacks
def _create_dmatrix(self, ref: Optional[DMatrix], **kwargs: Any) -> DMatrix:
# Use `QuantileDMatrix` to save memory.
if self.tree_method in ("hist", "gpu_hist"):
try:
return QuantileDMatrix(
**kwargs, ref=ref, nthread=self.n_jobs, max_bin=self.max_bin
)
except TypeError: # `QuantileDMatrix` supports lesser types than DMatrix
pass
return DMatrix(**kwargs, nthread=self.n_jobs)
def _set_evaluation_result(self, evals_result: TrainingCallback.EvalsLog) -> None:
if evals_result:
self.evals_result_ = cast(Dict[str, Dict[str, List[float]]], evals_result)
@_deprecate_positional_args
def fit(
self,
X: ArrayLike,
y: ArrayLike,
*,
sample_weight: Optional[ArrayLike] = None,
base_margin: Optional[ArrayLike] = None,
eval_set: Optional[Sequence[Tuple[ArrayLike, ArrayLike]]] = None,
eval_metric: Optional[Union[str, Sequence[str], Metric]] = None,
early_stopping_rounds: Optional[int] = None,
verbose: Optional[Union[bool, int]] = True,
xgb_model: Optional[Union[Booster, str, "XGBModel"]] = None,
sample_weight_eval_set: Optional[Sequence[ArrayLike]] = None,
base_margin_eval_set: Optional[Sequence[ArrayLike]] = None,
feature_weights: Optional[ArrayLike] = None,
callbacks: Optional[Sequence[TrainingCallback]] = None,
) -> "XGBModel":
# pylint: disable=invalid-name,attribute-defined-outside-init
"""Fit gradient boosting model.
Note that calling ``fit()`` multiple times will cause the model object to be
re-fit from scratch. To resume training from a previous checkpoint, explicitly
pass ``xgb_model`` argument.
Parameters
----------
X :
Feature matrix
y :
Labels
sample_weight :
instance weights
base_margin :
global bias for each instance.
eval_set :
A list of (X, y) tuple pairs to use as validation sets, for which
metrics will be computed.
Validation metrics will help us track the performance of the model.
eval_metric : str, list of str, or callable, optional
.. deprecated:: 1.6.0
Use `eval_metric` in :py:meth:`__init__` or :py:meth:`set_params` instead.
early_stopping_rounds : int
.. deprecated:: 1.6.0
Use `early_stopping_rounds` in :py:meth:`__init__` or
:py:meth:`set_params` instead.
verbose :
If `verbose` is True and an evaluation set is used, the evaluation metric
measured on the validation set is printed to stdout at each boosting stage.
If `verbose` is an integer, the evaluation metric is printed at each `verbose`
boosting stage. The last boosting stage / the boosting stage found by using
`early_stopping_rounds` is also printed.
xgb_model :
file name of stored XGBoost model or 'Booster' instance XGBoost model to be
loaded before training (allows training continuation).
sample_weight_eval_set :
A list of the form [L_1, L_2, ..., L_n], where each L_i is an array like
object storing instance weights for the i-th validation set.
base_margin_eval_set :
A list of the form [M_1, M_2, ..., M_n], where each M_i is an array like
object storing base margin for the i-th validation set.
feature_weights :
Weight for each feature, defines the probability of each feature being
selected when colsample is being used. All values must be greater than 0,
otherwise a `ValueError` is thrown.
callbacks :
.. deprecated:: 1.6.0
Use `callbacks` in :py:meth:`__init__` or :py:meth:`set_params` instead.
"""
with config_context(verbosity=self.verbosity):
evals_result: TrainingCallback.EvalsLog = {}
train_dmatrix, evals = _wrap_evaluation_matrices(