Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Fix ARCH._check_predict_proba #5349 #5384

Merged
merged 5 commits into from Oct 11, 2023
Merged

Conversation

Vasudeva-bit
Copy link
Contributor

@Vasudeva-bit Vasudeva-bit commented Oct 8, 2023

Reference Issues/PRs

Fixes #5349

What does this implement/fix? Explain your changes.

Changes a if (check) condition in _check_predict_proba function of sktime/forecasting/tests/test_all_forecasters.py to preserve column name in univariate forecaster predictions.

Upon fixing this issue, ARCH doesn't require function _predict_proba as BaseForecaster's _predict_proba is sufficient. Hence, removed that function.

What should a reviewer concentrate their feedback on?

A few tests are expecting name to be preserved even if it is None, here. Thats why I didn't handle the None case separately. If the expected column name is 0 when original name is None, below code may be appropriate as recommended by @yarnabrina.

if isinstance(y_train, pd.Series) and y_train.name:
    assert (pred_cols == y_train.name).all()
elif isinstance(y_train, pd.Series):  # automatically implies empty name
    assert (pred_cols == pd.Index([0])).all()
else:
    assert (pred_cols == y_train.columns).all()

Additional context

Any changes in this test will affect existing univariate probabilistic forecasters only such as NaiveForecaster.

PR checklist

For all contributions
  • The PR title starts with either [ENH], [MNT], [DOC], or [BUG]. [BUG] - bugfix, [MNT] - CI, test framework, [ENH] - adding or improving code, [DOC] - writing or improving documentation or docstrings.

@fkiraly
Copy link
Collaborator

fkiraly commented Oct 8, 2023

I've updated this PR with code along the lines suggested. I will turn on tests for all forecasters so we can see what happens.

@fkiraly fkiraly added module:forecasting forecasting module: forecasting, incl probabilistic and hierarchical forecasting do not merge should not be merged - e.g., CI diagnostic, experimental diagnostics diagnostic PR to run CI with a modification, e.g., pre-release of dependencies labels Oct 8, 2023
@fkiraly
Copy link
Collaborator

fkiraly commented Oct 8, 2023

Ok, I'm confused now - nothing else seems to fail. Any explanations?

@Vasudeva-bit
Copy link
Contributor Author

After debugging for a while,

  1. Firstly, the function _check_predict_proba is only testing for y_train generated from _make_data(n_columns=n_columns). If the forecaster is univariate, then the y_train generated would always have column name None of type str. The _check_predict_proba checks whether the _predict_proba is changing None to 0 internally. Since every forecaster's _predict_proba doing it, all forecasters passed.
  2. Actually ARCH should fail this test, because _predict_proba function is deleted in this PR. But, the type(y_train.name) is str ('None'), not None. Therefore, None case is skipped and checks whether (pred_cols == y_train.name).all() which is None of str equals None of str so, ARCH also passed.

I think it's my bad, there is no bug in _check_predict_proba as it is not designed to check whether column names are preserved. Changes should only be in ARCH's _predict_proba like below:

  1. If the column name is None of type None/str, do pred_dist.name = pd.Index([0]).
  2. Otherwise, just return output of Super()._predict_proba() as it is, i.e., column names are preserved by default.

Hope it makes some sense now.

@fkiraly fkiraly added bugfix Fixes a known bug or removes unintended behavior and removed do not merge should not be merged - e.g., CI diagnostic, experimental diagnostics diagnostic PR to run CI with a modification, e.g., pre-release of dependencies labels Oct 11, 2023
Copy link
Collaborator

@fkiraly fkiraly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, so the issue was after all isolated to uarch - makes sense. Lets merge this as a fix then.

@fkiraly fkiraly merged commit c088c6f into sktime:main Oct 11, 2023
24 checks passed
@fkiraly fkiraly changed the title [BUG] Expected fix in _check_predict_proba #5349 [BUG] Fix ARCH._check_predict_proba #5349 Oct 11, 2023
@Vasudeva-bit Vasudeva-bit deleted the proba_test branch October 12, 2023 03:27
fkiraly pushed a commit that referenced this pull request Dec 13, 2023
Related to #5349 (issue), #5384 (PR).


> After debugging for a while,
> 
> 1. Firstly, the function `_check_predict_proba` is only testing for
`y_train` generated from `_make_data(n_columns=n_columns)`. If the
forecaster is univariate, then the `y_train` generated would always have
column name `None` of type `str`. The `_check_predict_proba` checks
whether the `_predict_proba` is changing `None` to `0` internally. Since
every forecaster's `_predict_proba` doing it, all forecasters passed.
> 2. Actually `ARCH` should fail this test, because `_predict_proba`
function is deleted in this PR. But, the `type(y_train.name)` is `str`
(`'None'`), not `None`. Therefore, `None` case is skipped and checks
whether `(pred_cols == y_train.name).all()` which is `None` of `str`
equals `None` of `str` so, `ARCH` also passed.
> 
> I think it's my bad, there is no bug in `_check_predict_proba` as it
is not designed to check whether column names are preserved. Changes
should only be in `ARCH`'s `_predict_proba` like below:
> 
> 1. If the column name is `None` of type `None`/`str`, do
`pred_dist.name = pd.Index([0])`.
> 2. Otherwise, just return output of `Super()._predict_proba()` as it
is, i.e., column names are preserved by default.

I think there was some ambiguity in my explanation for #5349 (issue) in
#5384 (PR), but what I meant was exactly the changes in this PR.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bugfix Fixes a known bug or removes unintended behavior module:forecasting forecasting module: forecasting, incl probabilistic and hierarchical forecasting
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG] test_predict_proba is expecting column name of pred_dist to be 0 always.
2 participants