Skip to content

Commit

Permalink
[ENH] BaseForecastingErrorMetric internal interface cleanup (#4305)
Browse files Browse the repository at this point in the history
This PR cleans up some internal interfaces of `BaseForecastingErrorMetric`, by putting output format conversion and coercion concerns into a single blocks before return, in the methods `evaluate` and `_evaluate_vectorized`.
  • Loading branch information
fkiraly committed Mar 14, 2023
1 parent 7ced9d5 commit b0abe56
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions sktime/performance_metrics/forecasting/_classes.py
Expand Up @@ -84,6 +84,13 @@ def _coerce_to_df(obj):
return pd.DataFrame(obj)


def _coerce_to_1d_numpy(obj):
"""Coerce to 1D np.ndarray, from pd.DataFrame or pd.Series."""
if isinstance(obj, (pd.DataFrame, pd.Series)):
obj = obj.values
return obj.flatten()


class BaseForecastingErrorMetric(BaseMetric):
"""Base class for defining forecasting error metrics in sktime.
Expand Down Expand Up @@ -207,12 +214,9 @@ def evaluate(self, y_true, y_pred, **kwargs):
out_df = self._evaluate_vectorized(
y_true=y_true_inner, y_pred=y_pred_inner, **kwargs
)
if multilevel == "uniform_average":
out_df = out_df.mean(axis=0)
# if level is averaged, but not variables, return numpy
if multioutput == "raw_values":
out_df = out_df.values

if multilevel == "uniform_average" and multioutput == "raw_values":
out_df = _coerce_to_1d_numpy(out_df)
if multilevel == "uniform_average" and multioutput == "uniform_average":
out_df = _coerce_to_scalar(out_df)
if multilevel == "raw_values":
Expand Down Expand Up @@ -278,13 +282,16 @@ def _evaluate_vectorized(self, y_true, y_pred, **kwargs):
)

if self.multioutput == "raw_values":
return pd.DataFrame(
eval_result = pd.DataFrame(
eval_result.iloc[:, 0].to_list(),
index=eval_result.index,
columns=y_true.X.columns,
)
else:
return eval_result

if self.multilevel == "uniform_average":
eval_result = eval_result.mean(axis=0)

return eval_result

def evaluate_by_index(self, y_true, y_pred, **kwargs):
"""Return the metric evaluated at each time point.
Expand Down

0 comments on commit b0abe56

Please sign in to comment.