Skip to content

Commit

Permalink
Update scikit-learn to 1.4
Browse files Browse the repository at this point in the history
Mostly deals with parameters that were deprecated in scikit-learn 1.2
and are now no longer available.

Update end version of deprecation

Update version in which deprecation expires

Change argument name to deal with deprecation
  • Loading branch information
betatim committed May 8, 2024
1 parent 68d4336 commit 5deed40
Show file tree
Hide file tree
Showing 18 changed files with 262 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def fit(self, X, y=None) -> "KBinsDiscretizer":
if 'onehot' in self.encode:
self._encoder = OneHotEncoder(
categories=np.array([np.arange(i) for i in self.n_bins_]),
sparse=self.encode == 'onehot', output_type='cupy')
sparse_output=self.encode == 'onehot', output_type='cupy')
# Fit the OneHotEncoder with toy datasets
# so that it's ready for use after the KBinsDiscretizer is fitted
self._encoder.fit(np.zeros((1, len(self.n_bins_)), dtype=int))
Expand Down
53 changes: 44 additions & 9 deletions python/cuml/cluster/agglomerative.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2019-2022, NVIDIA CORPORATION.
# Copyright (c) 2019-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -16,6 +16,8 @@

# distutils: language = c++

import warnings

from libc.stdint cimport uintptr_t

from cuml.internals.safe_imports import cpu_only_import
Expand Down Expand Up @@ -103,6 +105,17 @@ class AgglomerativeClustering(Base, ClusterMixin, CMajorInputTagMixin):
Metric used to compute the linkage. Can be "euclidean", "l1",
"l2", "manhattan", or "cosine". If connectivity is "knn" only
"euclidean" is accepted.
.. deprecated:: 24.06
`affinity` was deprecated in version 24.06 and will be renamed to
`metric` in 25.08.
metric : str, default=None
Metric used to compute the linkage. Can be "euclidean", "l1",
"l2", "manhattan", or "cosine". If set to `None` then "euclidean"
is used. If connectivity is "knn" only "euclidean" is accepted.
.. versionadded:: 24.06
linkage : {"single"}, default="single"
Which linkage criterion to use. The linkage criterion determines
which distance to use between sets of observations. The algorithm
Expand Down Expand Up @@ -136,9 +149,9 @@ class AgglomerativeClustering(Base, ClusterMixin, CMajorInputTagMixin):
labels_ = CumlArrayDescriptor()
children_ = CumlArrayDescriptor()

def __init__(self, *, n_clusters=2, affinity="euclidean", linkage="single",
handle=None, verbose=False, connectivity='knn',
n_neighbors=10, output_type=None):
def __init__(self, *, n_clusters=2, affinity="deprecated", metric=None,
linkage="single", handle=None, verbose=False,
connectivity='knn', n_neighbors=10, output_type=None):

super().__init__(handle=handle,
verbose=verbose,
Expand All @@ -159,11 +172,12 @@ class AgglomerativeClustering(Base, ClusterMixin, CMajorInputTagMixin):
raise ValueError("'n_neighbors' must be a positive number "
"between 2 and 1023")

if affinity not in _metrics_mapping:
raise ValueError("'affinity' %s is not supported." % affinity)
if metric is not None and metric not in _metrics_mapping:
raise ValueError("Metric '%s' is not supported." % affinity)

self.n_clusters = n_clusters
self.affinity = affinity
self.metric = metric
self.linkage = linkage
self.n_neighbors = n_neighbors
self.connectivity = connectivity
Expand All @@ -178,6 +192,26 @@ class AgglomerativeClustering(Base, ClusterMixin, CMajorInputTagMixin):
"""
Fit the hierarchical clustering from features.
"""
if self.affinity != "deprecated":
if self.metric is not None:
raise ValueError(
"Both `affinity` and `metric` attributes were set. Attribute"
" `affinity` was deprecated in version 24.06 and will be removed in"
" 25.08. To avoid this error, only set the `metric` attribute."
)
warnings.warn(
(
"Attribute `affinity` was deprecated in version 24.06 and will be"
" removed in 25.08. Use `metric` instead."
),
FutureWarning,
)
metric_name = self.affinity
else:
if self.metric is None:
metric_name = "euclidean"
else:
metric_name = self.metric

X_m, n_rows, n_cols, self.dtype = \
input_to_cuml_array(X, order='C',
Expand Down Expand Up @@ -209,10 +243,10 @@ class AgglomerativeClustering(Base, ClusterMixin, CMajorInputTagMixin):
linkage_output.labels = <int*>labels_ptr

cdef DistanceType metric
if self.affinity in _metrics_mapping:
metric = _metrics_mapping[self.affinity]
if metric_name in _metrics_mapping:
metric = _metrics_mapping[metric_name]
else:
raise ValueError("'affinity' %s not supported." % self.affinity)
raise ValueError("Metric '%s' not supported." % metric_name)

if self.connectivity == 'knn':
single_linkage_neighbors(
Expand Down Expand Up @@ -249,6 +283,7 @@ class AgglomerativeClustering(Base, ClusterMixin, CMajorInputTagMixin):
return super().get_param_names() + [
"n_clusters",
"affinity",
"metric",
"linkage",
"connectivity",
"n_neighbors"
Expand Down
18 changes: 16 additions & 2 deletions python/cuml/ensemble/randomforest_common.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -68,7 +68,7 @@ class BaseRandomForestModel(Base):
classes_ = CumlArrayDescriptor()

def __init__(self, *, split_criterion, n_streams=4, n_estimators=100,
max_depth=16, handle=None, max_features='auto', n_bins=128,
max_depth=16, handle=None, max_features='sqrt', n_bins=128,
bootstrap=True,
verbose=False, min_samples_leaf=1, min_samples_split=2,
max_samples=1.0, max_leaves=-1, accuracy_metric=None,
Expand Down Expand Up @@ -166,8 +166,22 @@ class BaseRandomForestModel(Base):
return math.log2(self.n_cols)/self.n_cols
elif self.max_features == 'auto':
if self.RF_type == CLASSIFICATION:
warnings.warn(
"`max_features='auto'` has been deprecated in 24.06 "
"and will be removed in 25.08. To keep the past behaviour "
"and silence this warning, explicitly set "
"`max_features='sqrt'`.",
FutureWarning
)
return 1/np.sqrt(self.n_cols)
else:
warnings.warn(
"`max_features='auto'` has been deprecated in 24.06 "
"and will be removed in 25.08. To keep the past behaviour "
"and silence this warning, explicitly set "
"`max_features=1.0`.",
FutureWarning
)
return 1.0
else:
raise ValueError(
Expand Down
9 changes: 6 additions & 3 deletions python/cuml/ensemble/randomforestclassifier.pyx
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

#
# Copyright (c) 2019-2023, NVIDIA CORPORATION.
# Copyright (c) 2019-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -172,15 +172,18 @@ class RandomForestClassifier(BaseRandomForestModel,
max_leaves : int (default = -1)
Maximum leaf nodes per tree. Soft constraint. Unlimited,
If ``-1``.
max_features : int, float, or string (default = 'auto')
max_features : int, float, or string (default = 'sqrt')
Ratio of number of features (columns) to consider per node
split.\n
* If type ``int`` then ``max_features`` is the absolute count of
features to be used
* If type ``float`` then ``max_features`` is used as a fraction.
* If ``'auto'`` then ``max_features=1/sqrt(n_features)``.
* If ``'sqrt'`` then ``max_features=1/sqrt(n_features)``.
* If ``'log2'`` then ``max_features=log2(n_features)/n_features``.
.. versionchanged:: 24.06
The default of `max_features` changed from `"auto"` to `"sqrt"`.
n_bins : int (default = 128)
Maximum number of bins used by the split algorithm per feature.
For large problems, particularly those with highly-skewed input data,
Expand Down
7 changes: 4 additions & 3 deletions python/cuml/ensemble/randomforestregressor.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2019-2023, NVIDIA CORPORATION.
# Copyright (c) 2019-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -168,15 +168,16 @@ class RandomForestRegressor(BaseRandomForestModel,
max_leaves : int (default = -1)
Maximum leaf nodes per tree. Soft constraint. Unlimited,
If ``-1``.
max_features : int, float, or string (default = 'auto')
max_features : int, float, or string (default = 1.0)
Ratio of number of features (columns) to consider
per node split.\n
* If type ``int`` then ``max_features`` is the absolute count of
features to be used.
* If type ``float`` then ``max_features`` is used as a fraction.
* If ``'auto'`` then ``max_features=1.0``.
* If ``'sqrt'`` then ``max_features=1/sqrt(n_features)``.
* If ``'log2'`` then ``max_features=log2(n_features)/n_features``.
.. versionchanged:: 24.06
The default of `max_features` changed from `"auto"` to 1.0.
n_bins : int (default = 128)
Maximum number of bins used by the split algorithm per feature.
For large problems, particularly those with highly-skewed input data,
Expand Down
6 changes: 4 additions & 2 deletions python/cuml/experimental/linear_model/lars.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -85,11 +85,13 @@ class Lars(Base, RegressorMixin):
fit_intercept : boolean (default = True)
If True, Lars tries to correct for the global mean of y.
If False, the model expects that you have centered the data.
normalize : boolean (default = True)
normalize : boolean (default = False)
This parameter is ignored when `fit_intercept` is set to False.
If True, the predictors in X will be normalized by removing its mean
and dividing by it's variance. If False, then the solver expects that
the data is already normalized.
.. versionchanged:: 24.06
The default of `normalize` changed from `True` to `False`.
copy_X : boolean (default = True)
The solver permutes the columns of X. Set `copy_X` to True to prevent
changing the input data.
Expand Down
19 changes: 15 additions & 4 deletions python/cuml/linear_model/logistic_regression.pyx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2019-2022, NVIDIA CORPORATION.
# Copyright (c) 2019-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -16,6 +16,8 @@

# distutils: language = c++

import warnings

from cuml.internals.safe_imports import cpu_only_import
from cuml.internals.safe_imports import gpu_only_import
import pprint
Expand All @@ -36,7 +38,7 @@ cp = gpu_only_import('cupy')
np = cpu_only_import('numpy')


supported_penalties = ["l1", "l2", "none", "elasticnet"]
supported_penalties = ["l1", "l2", None, "none", "elasticnet"]

supported_solvers = ["qn"]

Expand Down Expand Up @@ -210,15 +212,24 @@ class LogisticRegression(UniversalBase,
output_type=output_type)

if penalty not in supported_penalties:
raise ValueError("`penalty` " + str(penalty) + "not supported.")
raise ValueError("`penalty` " + str(penalty) + " not supported.")

if solver not in supported_solvers:
raise ValueError("Only quasi-newton `qn` solver is "
" supported, not %s" % solver)
self.solver = solver

self.C = C

if penalty == "none":
warnings.warn(
"The 'none' option was deprecated in version 24.06, and will "
"be removed in 25.08. Use None instead.",
FutureWarning
)
penalty = None
self.penalty = penalty

self.tol = tol
self.fit_intercept = fit_intercept
self.max_iter = max_iter
Expand Down Expand Up @@ -452,7 +463,7 @@ class LogisticRegression(UniversalBase,
return proba

def _get_qn_params(self):
if self.penalty == "none":
if self.penalty is None:
l1_strength = 0.0
l2_strength = 0.0

Expand Down
40 changes: 35 additions & 5 deletions python/cuml/preprocessing/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,21 @@ class OneHotEncoder(BaseEncoder):
- dict/list : ``drop[col]`` is the category in feature col that
should be dropped.
sparse : bool, default=True
sparse_output : bool, default=True
This feature is not fully supported by cupy
yet, causing incorrect values when computing one hot encodings.
See https://github.com/cupy/cupy/issues/3223
.. versionadded:: 24.06
`sparse` was renamed to `sparse_output`
sparse : bool, default=True
Will return sparse matrix if set True else will return an array.
.. deprecated:: 24.06
`sparse` is deprecated in 24.06 and will be removed in 25.08. Use
`sparse_output` instead.
dtype : number type, default=np.float
Desired datatype of transform's output.
handle_unknown : {'error', 'ignore'}, default='error'
Expand Down Expand Up @@ -246,7 +257,8 @@ def __init__(
*,
categories="auto",
drop=None,
sparse=True,
sparse="deprecated",
sparse_output=True,
dtype=np.float32,
handle_unknown="error",
handle=None,
Expand All @@ -257,7 +269,9 @@ def __init__(
handle=handle, verbose=verbose, output_type=output_type
)
self.categories = categories
# TODO(24.08): Remove self.sparse
self.sparse = sparse
self.sparse_output = sparse_output
self.dtype = dtype
self.handle_unknown = handle_unknown
self.drop = drop
Expand All @@ -266,10 +280,14 @@ def __init__(
self._features = None
self._encoders = None
self.input_type = None
if sparse and np.dtype(dtype) not in ["f", "d", "F", "D"]:
# This parameter validation should be performed in `fit` instead
# of in the constructor. Hence the awkwark `if` clause
if ((sparse != "deprecated" and sparse) or sparse_output) and np.dtype(
dtype
) not in ["f", "d", "F", "D"]:
raise ValueError(
"Only float32, float64, complex64 and complex128 "
"are supported when using sparse"
"are supported when using sparse_output"
)

def _validate_keywords(self):
Expand All @@ -289,6 +307,17 @@ def _validate_keywords(self):
"zero."
)

if self.sparse != "deprecated":
warnings.warn(
(
"`sparse` was renamed to `sparse_output` in version 24.06"
" and will be removed in 25.08. `sparse_output` is ignored"
" unless you leave `sparse` set to its default value."
),
FutureWarning,
)
self.sparse_output = self.sparse

def _check_is_fitted(self):
if not self._fitted:
msg = (
Expand Down Expand Up @@ -440,7 +469,7 @@ def transform(self, X):
(val, (rows, cols)), shape=(len(X), j), dtype=self.dtype
)

if not self.sparse:
if not self.sparse_output:
ohe = ohe.toarray()

return ohe
Expand Down Expand Up @@ -578,6 +607,7 @@ def get_param_names(self):
"categories",
"drop",
"sparse",
"sparse_output",
"dtype",
"handle_unknown",
]
Expand Down
Loading

0 comments on commit 5deed40

Please sign in to comment.