Skip to content

Commit 5cc55c3

Browse files
mergify[bot]Alexsandrussicfaust
authored
Fixes for sklearn 1.7 pre-release support (#2451) (#2527)
* Fixes for sklearn 1.7 pre-release support * Fix ensemble probabilities interval * Change scaling method to `clip` --------- (cherry picked from commit 95c73bd) Co-authored-by: Alexander Andreev <alexander.andreev@intel.com> Co-authored-by: icfaust <icfaust@gmail.com>
1 parent 21a8267 commit 5cc55c3

File tree

5 files changed

+19
-6
lines changed

5 files changed

+19
-6
lines changed

daal4py/sklearn/ensemble/_forest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -679,8 +679,8 @@ def _daal_predict_proba(self, X):
679679
dfc_predictionResult = dfc_algorithm.compute(X, self.daal_model_)
680680

681681
pred = dfc_predictionResult.probabilities
682-
683-
return pred
682+
# TODO: fix probabilities out of [0, 1] interval on oneDAL side
683+
return pred.clip(0.0, 1.0)
684684

685685
def _daal_fit_classifier(self, X, y, sample_weight=None):
686686
y = check_array(y, ensure_2d=False, dtype=None)

daal4py/sklearn/manifold/_t_sne.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,13 @@ def _daal_tsne(self, P, n_samples, X_embedded):
6666
[n_samples],
6767
[P.nnz],
6868
[self.n_iter_without_progress],
69-
[self._max_iter if sklearn_check_version("1.5") else self.n_iter],
69+
[
70+
(
71+
self.max_iter
72+
if sklearn_check_version("1.7")
73+
else (self._max_iter if sklearn_check_version("1.5") else self.n_iter)
74+
)
75+
],
7076
]
7177

7278
# Pass params to daal4py backend

daal4py/sklearn/metrics/_pairwise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from functools import partial
1919

2020
import numpy as np
21+
from joblib import effective_n_jobs
2122
from sklearn.exceptions import DataConversionWarning
2223
from sklearn.metrics import pairwise_distances as pairwise_distances_original
2324
from sklearn.metrics.pairwise import (
@@ -28,7 +29,6 @@
2829
_parallel_pairwise,
2930
check_pairwise_arrays,
3031
)
31-
from sklearn.utils._joblib import effective_n_jobs
3232
from sklearn.utils.validation import check_non_negative
3333

3434
try:

onedal/ensemble/forest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,9 @@ def _predict_proba(self, X, hparams=None):
424424
else:
425425
result = self.infer(params, model, X)
426426

427-
return from_table(result.probabilities)
427+
# TODO: fix probabilities out of [0, 1] interval on oneDAL side
428+
pred = from_table(result.probabilities)
429+
return pred.clip(0.0, 1.0)
428430

429431

430432
class RandomForestClassifier(ClassifierMixin, BaseForest, metaclass=ABCMeta):

onedal/utils/_array_api.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,12 @@ def _asarray(data, xp, *args, **kwargs):
6868

6969
def _is_numpy_namespace(xp):
7070
"""Return True if xp is backed by NumPy."""
71-
return xp.__name__ in {"numpy", "array_api_compat.numpy", "numpy.array_api"}
71+
return xp.__name__ in {
72+
"numpy",
73+
"array_api_compat.numpy",
74+
"numpy.array_api",
75+
"sklearn.externals.array_api_compat.numpy",
76+
}
7277

7378

7479
def _get_sycl_namespace(*arrays):

0 commit comments

Comments
 (0)