Skip to content

Commit

Permalink
&benheid [BUG] allow alpha and coverage to be passed again via me…
Browse files Browse the repository at this point in the history
…trics to `evaluate` (#5354)

This PR ensures pre-existing syntax to pass `alpha` and `coverage` via
metrics to `evaluate` works again, fixing
#5336.

Not commenting here on whether the status quo is a good idea or not (I
think it was cleaner to remove it, or is, in the long run), but such a
change should not happen without deprecation.

Depends on #5337, so this change
should trigger the test that is failing on `main`.
  • Loading branch information
fkiraly committed Oct 5, 2023
1 parent aadc8a6 commit c552caf
Showing 1 changed file with 68 additions and 14 deletions.
82 changes: 68 additions & 14 deletions sktime/forecasting/model_evaluation/_functions.py
Expand Up @@ -73,7 +73,7 @@ def _check_scores(metrics) -> Dict:


def _get_column_order_and_datatype(
metric_types: Dict, return_data: bool = True, cutoff_dtype=None
metric_types: Dict, return_data: bool = True, cutoff_dtype=None, old_naming=True
) -> Dict:
"""Get the ordered column name and input datatype of results."""
others_metadata = {
Expand All @@ -86,11 +86,21 @@ def _get_column_order_and_datatype(
}
fit_metadata, metrics_metadata = {"fit_time": "float"}, {}
for scitype in metric_types:
fit_metadata[f"{scitype}_time"] = "float"
if return_data:
y_metadata[f"y_{scitype}"] = "object"
for metric in metric_types.get(scitype):
metrics_metadata[f"test_{metric.name}"] = "float"
pred_args = _get_pred_args_from_metric(scitype, metric)
if pred_args == {} or old_naming:
time_key = f"{scitype}_time"
result_key = f"test_{metric.name}"
y_pred_key = f"y_{scitype}"
else:
argval = list(pred_args.values())[0]
time_key = f"{scitype}_{argval}_time"
result_key = f"test_{metric.name}_{argval}"
y_pred_key = f"y_{scitype}_{argval}"
fit_metadata[time_key] = "float"
metrics_metadata[result_key] = "float"
if return_data:
y_metadata[y_pred_key] = "object"
fit_metadata.update(others_metadata)
if return_data:
fit_metadata.update(y_metadata)
Expand Down Expand Up @@ -160,6 +170,18 @@ def _select_fh_from_y(y):
return fh


def _get_pred_args_from_metric(scitype, metric):
pred_args = {
"pred_quantiles": "alpha",
"pred_interval": "coverage",
}
if scitype in pred_args.keys():
val = getattr(metric, pred_args[scitype], None)
if val is not None:
return {pred_args[scitype]: val}
return {}


def _evaluate_window(
y_train,
y_test,
Expand All @@ -181,7 +203,9 @@ def _evaluate_window(
cutoff = pd.Period(pd.NaT) if cutoff_dtype.startswith("period") else pd.NA
y_pred = pd.NA
temp_result = dict()

y_preds_cache = dict()
old_naming = True
old_name_mapping = {}
if fh is None:
fh = _select_fh_from_y(y_test)

Expand All @@ -206,15 +230,40 @@ def _evaluate_window(
# cache prediction from the first scitype and reuse it to compute other metrics
for scitype in scoring:
method = getattr(forecaster, pred_type[scitype])
start_pred = time.perf_counter()
y_pred = method(fh, X_test)
pred_time = time.perf_counter() - start_pred
temp_result[f"{scitype}_time"] = [pred_time]
if len(set(map(lambda metric: metric.name, scoring.get(scitype)))) != len(
scoring.get(scitype)
):
old_naming = False
for metric in scoring.get(scitype):
pred_args = _get_pred_args_from_metric(scitype, metric)
if pred_args == {}:
time_key = f"{scitype}_time"
result_key = f"test_{metric.name}"
y_pred_key = f"y_{scitype}"
else:
argval = list(pred_args.values())[0]
time_key = f"{scitype}_{argval}_time"
result_key = f"test_{metric.name}_{argval}"
y_pred_key = f"y_{scitype}_{argval}"
old_name_mapping[f"{scitype}_{argval}_time"] = f"{scitype}_time"
old_name_mapping[
f"test_{metric.name}_{argval}"
] = f"test_{metric.name}"
old_name_mapping[f"y_{scitype}_{argval}"] = f"y_{scitype}"

# make prediction
if y_pred_key not in y_preds_cache.keys():
start_pred = time.perf_counter()
y_pred = method(fh, X_test, **pred_args)
pred_time = time.perf_counter() - start_pred
temp_result[time_key] = [pred_time]
y_preds_cache[y_pred_key] = [y_pred]
else:
y_pred = y_preds_cache[y_pred_key][0]

score = metric(y_test, y_pred, y_train=y_train)
temp_result[f"test_{metric.name}"] = [score]
if return_data:
temp_result[f"y_{scitype}"] = [y_pred]
temp_result[result_key] = [score]

# get cutoff
cutoff = forecaster.cutoff

Expand Down Expand Up @@ -254,9 +303,14 @@ def _evaluate_window(
if return_data:
temp_result["y_train"] = [y_train]
temp_result["y_test"] = [y_test]
temp_result.update(y_preds_cache)
result = pd.DataFrame(temp_result)
result = result.astype({"len_train_window": int, "cutoff": cutoff_dtype})
column_order = _get_column_order_and_datatype(scoring, return_data, cutoff_dtype)
if old_naming:
result = result.rename(columns=old_name_mapping)
column_order = _get_column_order_and_datatype(
scoring, return_data, cutoff_dtype, old_naming=old_naming
)
result = result.reindex(columns=column_order.keys())

# Return forecaster if "update"
Expand Down

0 comments on commit c552caf

Please sign in to comment.