Skip to content

Commit b64569e

Browse files
committed
Revert changes in covariance; remove unneeded imports
1 parent 277ac31 commit b64569e

File tree

4 files changed

+7
-26
lines changed

4 files changed

+7
-26
lines changed

sklearnex/decomposition/pca.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444

4545
from sklearn.decomposition import PCA as _sklearn_PCA
4646

47-
from onedal.common.hyperparameters import get_hyperparameters
4847
from onedal.decomposition import PCA as onedal_PCA
4948
from onedal.utils._array_api import _is_numpy_namespace
5049

sklearnex/ensemble/_forest.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@
5858
from onedal.ensemble import RandomForestRegressor as onedal_RandomForestRegressor
5959
from onedal.primitives import get_tree_state_cls, get_tree_state_reg
6060
from onedal.utils.validation import _num_features, _num_samples
61-
from sklearnex import get_hyperparameters
6261
from sklearnex._utils import register_hyperparameters
6362

6463
from .._config import get_config

sklearnex/linear_model/incremental_linear.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@
3636
if sklearn_check_version("1.2"):
3737
from sklearn.utils._param_validation import Interval
3838

39-
from onedal.common.hyperparameters import get_hyperparameters
40-
4139
from .._device_offload import dispatch, wrap_output_data
4240
from .._utils import (
4341
PatchingConditionsChain,

sklearnex/preview/covariance/covariance.py

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,9 @@
2020
from scipy import sparse as sp
2121
from sklearn.covariance import EmpiricalCovariance as _sklearn_EmpiricalCovariance
2222
from sklearn.utils import check_array
23-
from sklearn.utils.validation import _num_features
2423

2524
from daal4py.sklearn._n_jobs_support import control_n_jobs
2625
from daal4py.sklearn._utils import daal_check_version, sklearn_check_version
27-
from onedal.common.hyperparameters import get_hyperparameters
2826
from onedal.covariance import EmpiricalCovariance as onedal_EmpiricalCovariance
2927
from sklearnex import config_context
3028
from sklearnex.metrics import pairwise_distances
@@ -91,15 +89,11 @@ def _onedal_supported(self, method_name, *data):
9189
_onedal_gpu_supported = _onedal_supported
9290

9391
def fit(self, X, y=None):
94-
self._fit(X)
95-
return self
96-
97-
@wrap_output_data
98-
def _fit(self, X):
9992
if sklearn_check_version("1.2"):
10093
self._validate_params()
94+
X = validate_data(self, X, ensure_all_finite=False)
10195

102-
return dispatch(
96+
dispatch(
10397
self,
10498
"fit",
10599
{
@@ -109,6 +103,8 @@ def _fit(self, X):
109103
X,
110104
)
111105

106+
return self
107+
112108
# expose sklearnex pairwise_distances if mahalanobis distance eventually supported
113109
@wrap_output_data
114110
def mahalanobis(self, X):
@@ -117,20 +113,9 @@ def mahalanobis(self, X):
117113
precision = self.get_precision()
118114
with config_context(assume_finite=True):
119115
# compute mahalanobis distances
120-
try:
121-
dist = pairwise_distances(
122-
X, self.location_[np.newaxis, :], metric="mahalanobis", VI=precision
123-
)
124-
125-
except ValueError as e:
126-
# Throw the expected sklearn error in an n_feature length violation
127-
if "Incompatible dimension for X and Y matrices:" in str(e):
128-
raise ValueError(
129-
f"X has {_num_features(X)} features, but {self.__class__.__name__} "
130-
f"is expecting {self.n_features_in_} features as input."
131-
)
132-
else:
133-
raise e
116+
dist = pairwise_distances(
117+
X, self.location_[np.newaxis, :], metric="mahalanobis", VI=precision
118+
)
134119

135120
return np.reshape(dist, (len(X),)) ** 2
136121

0 commit comments

Comments
 (0)