Skip to content

Commit

Permalink
Refact: array check in observables
Browse files Browse the repository at this point in the history
  • Loading branch information
nTrouvain committed May 9, 2022
1 parent 8f5b43e commit c8a6449
Showing 1 changed file with 16 additions and 26 deletions.
42 changes: 16 additions & 26 deletions reservoirpy/observables.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,19 @@
from .type import Weights


def _check_arrays(y_true, y_pred):
y_true_array = np.asarray(y_true)
y_pred_array = np.asarray(y_pred)

if not y_true_array.shape == y_pred_array.shape:
raise ValueError(
f"Shape mismatch between y_true and y_pred:
{y_true_array.shape} != {y_pred_array.shape}"
)

return y_true_array, y_pred_array


def spectral_radius(W: Weights, maxiter: int = None) -> float:
"""Compute the spectral radius of a matrix `W`.
Expand Down Expand Up @@ -98,14 +111,7 @@ def mse(y_true: np.ndarray, y_pred: np.ndarray) -> float:
float
Mean squared error.
"""
y_true_array = np.asarray(y_true)
y_pred_array = np.asarray(y_pred)

if not y_true_array.shape == y_pred_array.shape:
raise ValueError(
f"Shape mismatch between y_true and y_pred:
{y_true_array.shape} != {y_pred_array.shape}"
)
y_true_array, y_pred_array = _check_arrays(y_true, y_pred)
return float(np.mean((y_true_array - y_pred_array) ** 2))


Expand Down Expand Up @@ -165,16 +171,7 @@ def nrmse(
-------
float
Normalized mean squared error.
"""
y_true_array = np.asarray(y_true)
y_pred_array = np.asarray(y_pred)

if not y_true_array.shape == y_pred_array.shape:
raise ValueError(
f"Shape mismatch between y_true and y_pred:
{y_true_array.shape} != {y_pred_array.shape}"
)

"""
error = rmse(y_true, y_pred)
if norm_value is not None:
return error / norm_value
Expand Down Expand Up @@ -218,14 +215,7 @@ def rsquare(y_true: np.ndarray, y_pred: np.ndarray) -> float:
float
Coefficient of determination.
"""
y_true_array = np.asarray(y_true)
y_pred_array = np.asarray(y_pred)

if not y_true_array.shape == y_pred_array.shape:
raise ValueError(
f"Shape mismatch between y_true and y_pred:
{y_true_array.shape} != {y_pred_array.shape}"
)
y_true_array, y_pred_array = _check_arrays(y_true, y_pred)

d = (y_true_array - y_pred_array) ** 2
D = (y_true_array - y_pred_array.mean()) ** 2
Expand Down

0 comments on commit c8a6449

Please sign in to comment.