Skip to content

Commit

Permalink
Update sklearn api (#60)
Browse files Browse the repository at this point in the history
* update interface in air_pls

* update airpls

* update arpls api

* fix documentation in index_selector

* fix typo in index_selector

* fix api in constant baseline correction

* update cubic splines

* update api in linear correction

* update init for linear correction

* fix api in non negative

* fix api in polynomial and in subtract reference

* fix norris williams

* fix api in savitzky golay

* fix minmax scaler

* fix norm scaler

* fix point scaler

* fix emsc

* update msc

* update rnv

* update snv

* update mean filter

* update median filter

* update savgol filter

* update whittaker
  • Loading branch information
paucablop committed Nov 22, 2023
1 parent 84a6b62 commit 5cb0ff0
Show file tree
Hide file tree
Showing 27 changed files with 122 additions and 376 deletions.
16 changes: 8 additions & 8 deletions chemotools/baseline/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from .air_pls import AirPls
from .ar_pls import ArPls
from .constant_baseline_correction import ConstantBaselineCorrection
from .cubic_spline_correction import CubicSplineCorrection
from .linear_correction import LinearCorrection
from .non_negative import NonNegative
from .polynomial_correction import PolynomialCorrection
from .subtract_reference import SubtractReference
from ._air_pls import AirPls
from ._ar_pls import ArPls
from ._constant_baseline_correction import ConstantBaselineCorrection
from ._cubic_spline_correction import CubicSplineCorrection
from ._linear_correction import LinearCorrection
from ._non_negative import NonNegative
from ._polynomial_correction import PolynomialCorrection
from ._subtract_reference import SubtractReference
18 changes: 2 additions & 16 deletions chemotools/baseline/air_pls.py → chemotools/baseline/_air_pls.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,6 @@ class AirPls(OneToOneFeatureMixin, BaseEstimator, TransformerMixin):
The number of iterations used to calculate the baseline. Increasing the number of iterations can improve the
accuracy of the baseline correction, but also increases the computation time.
Attributes
----------
n_features_in_ : int
The number of features in the input data.
_is_fitted : bool
A flag indicating whether the estimator has been fitted to data.
Methods
-------
fit(X, y=None)
Expand Down Expand Up @@ -85,13 +77,7 @@ def fit(self, X: np.ndarray, y=None) -> "AirPls":
Returns the instance itself.
"""
# Check that X is a 2D array and has only finite values
X = check_input(X)

# Set the number of features
self.n_features_in_ = X.shape[1]

# Set the fitted attribute to True
self._is_fitted = True
X = self._validate_data(X)

return self

Expand All @@ -113,7 +99,7 @@ def transform(self, X: np.ndarray, y=None) -> np.ndarray:
"""

# Check that the estimator is fitted
check_is_fitted(self, "_is_fitted")
check_is_fitted(self, "n_features_in_")

# Check that X is a 2D array and has only finite values
X = check_input(X)
Expand Down
17 changes: 2 additions & 15 deletions chemotools/baseline/ar_pls.py → chemotools/baseline/_ar_pls.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,6 @@ class ArPls(OneToOneFeatureMixin, BaseEstimator, TransformerMixin):
nr_iterations : int, optional (default=100)
The maximum number of iterations for the weight updating scheme.
Attributes
----------
n_features_in_ : int
The number of input features.
_is_fitted : bool
Whether the estimator has been fitted.
Methods
-------
Expand Down Expand Up @@ -86,13 +79,7 @@ def fit(self, X: np.ndarray, y=None) -> "ArPls":
"""

# Check that X is a 2D array and has only finite values
X = check_input(X)

# Set the number of features
self.n_features_in_ = X.shape[1]

# Set the fitted attribute to True
self._is_fitted = True
X = self._validate_data(X)

return self

Expand All @@ -114,7 +101,7 @@ def transform(self, X: np.ndarray, y=None) -> np.ndarray:
"""

# Check that the estimator is fitted
check_is_fitted(self, "_is_fitted")
check_is_fitted(self, "n_features_in_")

# Check that X is a 2D array and has only finite values
X = check_input(X)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,6 @@ class ConstantBaselineCorrection(OneToOneFeatureMixin, BaseEstimator, Transforme
end_index_ : int
The index of the end of the range. It is 1 if the wavenumbers are not provided.
n_features_in_ : int
The number of features in the input data.
_is_fitted : bool
Whether the transformer has been fitted to data.
Methods
-------
fit(X, y=None)
Expand All @@ -46,7 +40,10 @@ class ConstantBaselineCorrection(OneToOneFeatureMixin, BaseEstimator, Transforme
"""

def __init__(
self, start: int = 0, end: int = 1, wavenumbers: np.ndarray = None,
self,
start: int = 0,
end: int = 1,
wavenumbers: np.ndarray = None,
) -> None:
self.start = start
self.end = end
Expand All @@ -70,13 +67,7 @@ def fit(self, X: np.ndarray, y=None) -> "ConstantBaselineCorrection":
The fitted transformer.
"""
# Check that X is a 2D array and has only finite values
X = check_input(X)

# Set the number of features
self.n_features_in_ = X.shape[1]

# Set the fitted attribute to True
self._is_fitted = True
X = self._validate_data(X)

# Set the start and end indices
if self.wavenumbers is None:
Expand Down Expand Up @@ -109,7 +100,7 @@ def transform(self, X: np.ndarray, y=0, copy=True) -> np.ndarray:
The transformed input data.
"""
# Check that the estimator is fitted
check_is_fitted(self, "_is_fitted")
check_is_fitted(self, ["start_index_", "end_index_"])

# Check that X is a 2D array and has only finite values
X = check_input(X)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

from chemotools.utils.check_inputs import check_input


class CubicSplineCorrection(OneToOneFeatureMixin, BaseEstimator, TransformerMixin):
"""
A transformer that corrects a baseline by subtracting a cubic spline through the
A transformer that corrects a baseline by subtracting a cubic spline through the
points defined by the indices.
Parameters
Expand All @@ -32,6 +33,7 @@ class CubicSplineCorrection(OneToOneFeatureMixin, BaseEstimator, TransformerMixi
Transform the input data by subtracting the constant baseline value.
"""

def __init__(self, indices: list = None) -> None:
self.indices = indices

Expand All @@ -53,13 +55,7 @@ def fit(self, X: np.ndarray, y=None) -> "CubicSplineCorrection":
The fitted transformer.
"""
# Check that X is a 2D array and has only finite values
X = check_input(X)

# Set the number of features
self.n_features_in_ = X.shape[1]

# Set the fitted attribute to True
self._is_fitted = True
X = self._validate_data(X)

if self.indices is None:
self.indices_ = [0, len(X[0]) - 1]
Expand Down Expand Up @@ -89,15 +85,17 @@ def transform(self, X: np.ndarray, y=None, copy=True):
The transformed data.
"""
# Check that the estimator is fitted
check_is_fitted(self, "_is_fitted")
check_is_fitted(self, "indices_")

# Check that X is a 2D array and has only finite values
X = check_input(X)
X_ = X.copy()

# Check that the number of features is the same as the fitted data
if X_.shape[1] != self.n_features_in_:
raise ValueError(f"Expected {self.n_features_in_} features but got {X_.shape[1]}")
raise ValueError(
f"Expected {self.n_features_in_} features but got {X_.shape[1]}"
)

# Calculate spline baseline correction
for i, x in enumerate(X_):
Expand All @@ -106,7 +104,7 @@ def transform(self, X: np.ndarray, y=None, copy=True):

def _spline_baseline_correct(self, x: np.ndarray) -> np.ndarray:
indices = self.indices_
intensity = x[indices]
intensity = x[indices]
spl = CubicSpline(indices, intensity)
baseline = spl(range(len(x)))
return x - baseline
baseline = spl(range(len(x)))
return x - baseline
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,6 @@ class LinearCorrection(OneToOneFeatureMixin, BaseEstimator, TransformerMixin):
A transformer that corrects a baseline by subtracting a linear baseline through the
initial and final points of the spectrum.
Parameters
----------
Attributes
----------
n_features_in_ : int
The number of features in the input data.
_is_fitted : bool
Whether the transformer has been fitted to data.
Methods
-------
fit(X, y=None)
Expand Down Expand Up @@ -68,13 +57,7 @@ def fit(self, X: np.ndarray, y=None) -> "LinearCorrection":
The fitted transformer.
"""
# Check that X is a 2D array and has only finite values
X = check_input(X)

# Set the number of features
self.n_features_in_ = X.shape[1]

# Set the fitted attribute to True
self._is_fitted = True
X = self._validate_data(X)

return self

Expand All @@ -99,7 +82,7 @@ def transform(self, X: np.ndarray, y=0, copy=True) -> np.ndarray:
The transformed data.
"""
# Check that the estimator is fitted
check_is_fitted(self, "_is_fitted")
check_is_fitted(self, "n_features_in_")

# Check that X is a 2D array and has only finite values
X = check_input(X)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,6 @@ class NonNegative(OneToOneFeatureMixin, BaseEstimator, TransformerMixin):
mode : str, optional
The mode to use for the non-negative values. Can be "zero" or "abs".
Attributes
----------
n_features_in_ : int
The number of features in the input data.
_is_fitted : bool
Whether the transformer has been fitted to data.
Methods
-------
fit(X, y=None)
Expand Down Expand Up @@ -52,13 +44,7 @@ def fit(self, X: np.ndarray, y=None) -> "NonNegative":
The fitted transformer.
"""
# Check that X is a 2D array and has only finite values
X = check_input(X)

# Set the number of features
self.n_features_in_ = X.shape[1]

# Set the fitted attribute to True
self._is_fitted = True
X = self._validate_data(X)

return self

Expand All @@ -80,7 +66,7 @@ def transform(self, X: np.ndarray, y=None) -> np.ndarray:
The transformed data.
"""
# Check that the estimator is fitted
check_is_fitted(self, "_is_fitted")
check_is_fitted(self, "n_features_in_")

# Check that X is a 2D array and has only finite values
X = check_input(X)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

from chemotools.utils.check_inputs import check_input


class PolynomialCorrection(OneToOneFeatureMixin, BaseEstimator, TransformerMixin):
"""
A transformer that subtracts a polynomial baseline from the input data. The polynomial is
A transformer that subtracts a polynomial baseline from the input data. The polynomial is
fitted to the points in the spectrum specified by the indices parameter.
Parameters
Expand All @@ -18,14 +19,6 @@ class PolynomialCorrection(OneToOneFeatureMixin, BaseEstimator, TransformerMixin
The indices of the points in the spectrum to fit the polynomial to. Defaults to None,
which fits the polynomial to all points in the spectrum (equivalent to detrend).
Attributes
----------
n_features_in_ : int
The number of features in the input data.
_is_fitted : bool
Whether the transformer has been fitted to data.
Methods
-------
fit(X, y=None)
Expand All @@ -37,6 +30,7 @@ class PolynomialCorrection(OneToOneFeatureMixin, BaseEstimator, TransformerMixin
_baseline_correct_spectrum(x)
Subtract the polynomial baseline from a single spectrum.
"""

def __init__(self, order: int = 1, indices: list = None) -> None:
self.order = order
self.indices = indices
Expand All @@ -59,22 +53,16 @@ def fit(self, X: np.ndarray, y=None) -> "PolynomialCorrection":
The fitted transformer.
"""
# Check that X is a 2D array and has only finite values
X = check_input(X)

# Set the number of features
self.n_features_in_ = X.shape[1]

# Set the fitted attribute to True
self._is_fitted = True
X = self._validate_data(X)

if self.indices is None:
self.indices_ = range(0, len(X[0]))
else:
self.indices_ = self.indices

return self
def transform(self, X: np.ndarray, y:int=0, copy:bool=True) -> np.ndarray:

def transform(self, X: np.ndarray, y: int = 0, copy: bool = True) -> np.ndarray:
"""
Transform the input data by subtracting the polynomial baseline.
Expand All @@ -95,21 +83,23 @@ def transform(self, X: np.ndarray, y:int=0, copy:bool=True) -> np.ndarray:
The transformed data.
"""
# Check that the estimator is fitted
check_is_fitted(self, "_is_fitted")
check_is_fitted(self, "indices_")

# Check that X is a 2D array and has only finite values
X = check_input(X)
X_ = X.copy()

# Check that the number of features is the same as the fitted data
if X_.shape[1] != self.n_features_in_:
raise ValueError(f"Expected {self.n_features_in_} features but got {X_.shape[1]}")
raise ValueError(
f"Expected {self.n_features_in_} features but got {X_.shape[1]}"
)

# Calculate polynomial baseline correction
for i, x in enumerate(X_):
X_[i] = self._baseline_correct_spectrum(x)
return X_.reshape(-1, 1) if X_.ndim == 1 else X_

def _baseline_correct_spectrum(self, x: np.ndarray) -> np.ndarray:
"""
Subtract the polynomial baseline from a single spectrum.
Expand All @@ -126,5 +116,5 @@ def _baseline_correct_spectrum(self, x: np.ndarray) -> np.ndarray:
"""
intensity = x[self.indices_]
poly = np.polyfit(self.indices_, intensity, self.order)
baseline = [np.polyval(poly, i) for i in range(0, len(x))]
return x - baseline
baseline = [np.polyval(poly, i) for i in range(0, len(x))]
return x - baseline
Loading

0 comments on commit 5cb0ff0

Please sign in to comment.