-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Create preprocessing.py * add comments to preprocessing.py * [pre-commit.ci] auto fixes from pre-commit.com hooks * use ndscaler in plot_financial_data * typo * missing import * [pre-commit.ci] auto fixes from pre-commit.com hooks * tests * [pre-commit.ci] auto fixes from pre-commit.com hooks * missing import * pytest import not required * api doc * correct docstring * Change NDRobustScaler -> NdRobustScaler * Update pyriemann_qiskit/utils/preprocessing.py Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com> * Update pyriemann_qiskit/utils/preprocessing.py Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com> * fix: syntax error * [pre-commit.ci] auto fixes from pre-commit.com hooks * Update pyriemann_qiskit/utils/preprocessing.py Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com> * Update pyriemann_qiskit/utils/preprocessing.py Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com> * Update pyriemann_qiskit/utils/preprocessing.py Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com> * Update tests/test_preprocessing.py Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>
- Loading branch information
1 parent
402bbd2
commit 6a27692
Showing
5 changed files
with
117 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from sklearn.base import TransformerMixin | ||
from sklearn.preprocessing import RobustScaler | ||
|
||
|
||
class NdRobustScaler(TransformerMixin): | ||
"""Apply one robust scaler by feature. | ||
RobustScaler of scikit-learn [1]_ is adapted to 3d inputs [2]_. | ||
References | ||
---------- | ||
.. [1] \ | ||
https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.RobustScaler.html | ||
.. [2] \ | ||
https://stackoverflow.com/questions/50125844/how-to-standard-scale-a-3d-matrix | ||
Notes | ||
----- | ||
.. versionadded:: 0.2.0 | ||
""" | ||
|
||
def __init__(self): | ||
self._scalers = [] | ||
|
||
"""Fits one robust scaler on each feature of the training data. | ||
Parameters | ||
---------- | ||
X : ndarray, shape (n_matrices, n_features, n_samples) | ||
Training matrices. | ||
_y : ndarray, shape (n_samples,) | ||
Unused. Kept for scikit-learn compatibility. | ||
Returns | ||
------- | ||
self : NdRobustScaler instance | ||
The NdRobustScaler instance. | ||
""" | ||
|
||
def fit(self, X, _y=None, **kwargs): | ||
_, n_features, _ = X.shape | ||
self._scalers = [] | ||
for i in range(n_features): | ||
scaler = RobustScaler().fit(X[:, i, :]) | ||
self._scalers.append(scaler) | ||
return self | ||
|
||
"""Apply the previously trained robust scalers (on scaler by feature) | ||
Parameters | ||
---------- | ||
X : ndarray, shape (n_matrices, n_features, n_samples) | ||
Matrices to scale. | ||
_y : ndarray, shape (n_samples,) | ||
Unused. Kept for scikit-learn compatibility. | ||
Returns | ||
------- | ||
self : NdRobustScaler instance | ||
The NdRobustScaler instance. | ||
""" | ||
|
||
def transform(self, X, **kwargs): | ||
_, n_features, _ = X.shape | ||
if n_features != len(self._scalers): | ||
raise ValueError( | ||
"Input has not the same number of features as the fitted scaler" | ||
) | ||
for i in range(n_features): | ||
X[:, i, :] = self._scalers[i].transform(X[:, i, :]) | ||
return X |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import numpy as np | ||
from pyriemann_qiskit.utils.preprocessing import NdRobustScaler | ||
from sklearn.preprocessing import RobustScaler | ||
|
||
|
||
def test_ndrobustscaler(get_covmats): | ||
n_matrices, n_features = 5, 3 | ||
|
||
X = get_covmats(n_matrices, n_features) | ||
|
||
scaler = NdRobustScaler() | ||
transformed_X = scaler.fit_transform(X) | ||
|
||
assert transformed_X.shape == X.shape | ||
|
||
# Check that each feature is scaled using RobustScaler | ||
for i in range(n_features): | ||
feature_before_scaling = X[:, i, :] | ||
feature_after_scaling = transformed_X[:, i, :] | ||
|
||
# Use RobustScaler to manually scale the feature and compare | ||
manual_scaler = RobustScaler() | ||
manual_scaler.fit(feature_before_scaling) | ||
manual_scaled_feature = manual_scaler.transform(feature_before_scaling) | ||
|
||
np.testing.assert_allclose( | ||
feature_after_scaling, manual_scaled_feature, rtol=1e-5 | ||
) |