diff --git a/skpro/regression/base/_base.py b/skpro/regression/base/_base.py index fd8d0c53..00e62a28 100644 --- a/skpro/regression/base/_base.py +++ b/skpro/regression/base/_base.py @@ -228,8 +228,16 @@ def _predict_proba(self, X): if not can_do_proba: raise NotImplementedError - # if any of the above are implemented, predict_var will have a default - # we use predict_var to get scale, and predict to get location + # defaulting logic is as follows: + # var direct deputies are proba, then interval + # proba direct deputy is var (via Normal dist) + # quantiles direct deputies are interval, then proba + # interval direct deputy is quantiles + # + # so, conditions for defaulting for proba is: + # default to var if any of the other three are implemented + + # we use predict_var to get scale, and predict to get location pred_var = self.predict_var(X=X) pred_std = np.sqrt(pred_var) pred_mean = self.predict(X=X) @@ -316,11 +324,21 @@ def _predict_interval(self, X, coverage): """ implements_quantiles = self._has_implementation_of("_predict_quantiles") implements_proba = self._has_implementation_of("_predict_proba") - can_do_proba = implements_quantiles or implements_proba + implements_var = self._has_implementation_of("_predict_var") + can_do_proba = implements_quantiles or implements_proba or implements_var if not can_do_proba: raise NotImplementedError + # defaulting logic is as follows: + # var direct deputies are proba, then interval + # proba direct deputy is var (via Normal dist) + # quantiles direct deputies are interval, then proba + # interval direct deputy is quantiles + # + # so, conditions for defaulting for interval are: + # default to quantiles if any of the other three methods are implemented + # we default to _predict_quantiles if that is implemented or _predict_proba # since _predict_quantiles will default to _predict_proba if it is not alphas = [] @@ -412,11 +430,22 @@ def _predict_quantiles(self, X, alpha): """ implements_interval = self._has_implementation_of("_predict_interval") implements_proba = self._has_implementation_of("_predict_proba") - can_do_proba = implements_interval or implements_proba + implements_var = self._has_implementation_of("_predict_var") + can_do_proba = implements_interval or implements_proba or implements_var if not can_do_proba: raise NotImplementedError + # defaulting logic is as follows: + # var direct deputies are proba, then interval + # proba direct deputy is var (via Normal dist) + # quantiles direct deputies are interval, then proba + # interval direct deputy is quantiles + # + # so, conditions for defaulting for quantiles are: + # 1. default to interval if interval implemented + # 2. default to proba if proba or var are implemented + if implements_interval: pred_int = pd.DataFrame() for a in alpha: @@ -448,7 +477,7 @@ def _predict_quantiles(self, X, alpha): int_idx = pd.MultiIndex.from_product([var_names, alpha]) pred_int.columns = int_idx - elif implements_proba: + elif implements_proba or implements_var: pred_proba = self.predict_proba(X=X) pred_int = pred_proba.quantile(alpha=alpha) @@ -517,6 +546,16 @@ def _predict_var(self, X): if not can_do_proba: raise NotImplementedError + # defaulting logic is as follows: + # var direct deputies are proba, then interval + # proba direct deputy is var (via Normal dist) + # quantiles direct deputies are interval, then proba + # interval direct deputy is quantiles + # + # so, conditions for defaulting for var are: + # 1. default to proba if proba implemented + # 2. default to interval if interval or quantiles are implemented + if implements_proba: pred_proba = self._predict_proba(X=X) pred_var = pred_proba.var()