Skip to content

Commit

Permalink
[BUG] fix defaulting logic for _predict_interval and `_predict_quan…
Browse files Browse the repository at this point in the history
…tiles` when only `_predict_var` is implemented (#191)

This PR fixes a unreported bug in `BaseProbaRegressor` where
`_predict_interval` and `_predict_quantiles` would raise an
`NotImplementedError` when only `_predict_var` is implemented - despite
current defaulting logic being sufficient to obtain (normal) quantiles
and symmetric (normal) intervals from `_predict` and `_predict_var`.

This issue is present with the `sklearn` adapter and delegate
descendants, but was not uncovered due to masking by the test bug in
#189.

The failure is present in test logs of the fix
#189, and #189 is used for testing
(by merge into).
  • Loading branch information
fkiraly committed Jan 30, 2024
1 parent 3d56f69 commit a5bb1cf
Showing 1 changed file with 44 additions and 5 deletions.
49 changes: 44 additions & 5 deletions skpro/regression/base/_base.py
Expand Up @@ -228,8 +228,16 @@ def _predict_proba(self, X):
if not can_do_proba:
raise NotImplementedError

# if any of the above are implemented, predict_var will have a default
# we use predict_var to get scale, and predict to get location
# defaulting logic is as follows:
# var direct deputies are proba, then interval
# proba direct deputy is var (via Normal dist)
# quantiles direct deputies are interval, then proba
# interval direct deputy is quantiles
#
# so, conditions for defaulting for proba is:
# default to var if any of the other three are implemented

# we use predict_var to get scale, and predict to get location
pred_var = self.predict_var(X=X)
pred_std = np.sqrt(pred_var)
pred_mean = self.predict(X=X)
Expand Down Expand Up @@ -316,11 +324,21 @@ def _predict_interval(self, X, coverage):
"""
implements_quantiles = self._has_implementation_of("_predict_quantiles")
implements_proba = self._has_implementation_of("_predict_proba")
can_do_proba = implements_quantiles or implements_proba
implements_var = self._has_implementation_of("_predict_var")
can_do_proba = implements_quantiles or implements_proba or implements_var

if not can_do_proba:
raise NotImplementedError

# defaulting logic is as follows:
# var direct deputies are proba, then interval
# proba direct deputy is var (via Normal dist)
# quantiles direct deputies are interval, then proba
# interval direct deputy is quantiles
#
# so, conditions for defaulting for interval are:
# default to quantiles if any of the other three methods are implemented

# we default to _predict_quantiles if that is implemented or _predict_proba
# since _predict_quantiles will default to _predict_proba if it is not
alphas = []
Expand Down Expand Up @@ -412,11 +430,22 @@ def _predict_quantiles(self, X, alpha):
"""
implements_interval = self._has_implementation_of("_predict_interval")
implements_proba = self._has_implementation_of("_predict_proba")
can_do_proba = implements_interval or implements_proba
implements_var = self._has_implementation_of("_predict_var")
can_do_proba = implements_interval or implements_proba or implements_var

if not can_do_proba:
raise NotImplementedError

# defaulting logic is as follows:
# var direct deputies are proba, then interval
# proba direct deputy is var (via Normal dist)
# quantiles direct deputies are interval, then proba
# interval direct deputy is quantiles
#
# so, conditions for defaulting for quantiles are:
# 1. default to interval if interval implemented
# 2. default to proba if proba or var are implemented

if implements_interval:
pred_int = pd.DataFrame()
for a in alpha:
Expand Down Expand Up @@ -448,7 +477,7 @@ def _predict_quantiles(self, X, alpha):
int_idx = pd.MultiIndex.from_product([var_names, alpha])
pred_int.columns = int_idx

elif implements_proba:
elif implements_proba or implements_var:
pred_proba = self.predict_proba(X=X)
pred_int = pred_proba.quantile(alpha=alpha)

Expand Down Expand Up @@ -517,6 +546,16 @@ def _predict_var(self, X):
if not can_do_proba:
raise NotImplementedError

# defaulting logic is as follows:
# var direct deputies are proba, then interval
# proba direct deputy is var (via Normal dist)
# quantiles direct deputies are interval, then proba
# interval direct deputy is quantiles
#
# so, conditions for defaulting for var are:
# 1. default to proba if proba implemented
# 2. default to interval if interval or quantiles are implemented

if implements_proba:
pred_proba = self._predict_proba(X=X)
pred_var = pred_proba.var()
Expand Down

0 comments on commit a5bb1cf

Please sign in to comment.