Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Enhances `NeuralForecastRNN` to interpret `freq` from `ForecastingHorizon` when passed as `"auto"` <!-- Welcome to sktime, and thanks for contributing! Please have a look at our contribution guide: https://www.sktime.net/en/latest/get_involved/contributing.html --> #### Reference Issues/PRs <!-- Example: Fixes #1234. See also #3456. Please use keywords (e.g., Fixes) to create link to the issues or pull requests you resolved, so that they will automatically be closed when your pull request is merged. See https://github.com/blog/1506-closing-issues-via-pull-requests. If no issue exists, you can open one here: https://github.com/sktime/sktime/issues --> Fixes #6003. #### What does this implement/fix? Explain your changes. <!-- A clear and concise description of what you have implemented. --> The `NeuralForecastRNN` constructor previously required a `freq` argument, which is now proposed to default to `"auto"` in which case it interprets `freq` from `ForecastingHorizon`, leveraging `fh.freq` in the `fit` method. #### What should a reviewer concentrate their feedback on? <!-- This section is particularly useful if you have a pull request that is still in development. You can guide the reviews to focus on the parts that are ready for their comments. We suggest using bullets (indicated by * or -) and filled checkboxes [x] here --> I have run the tests with the updated estimator ```py results = check_estimator(NeuralForecastRNN) # All tests PASSED! ``` `freq` can now be passed like this: ```py y, X = load_longley() y_train, y_test, X_train, X_test = temporal_train_test_split(y, X, test_size=4) model = NeuralForecastRNN( "auto", # interprets to be "A-DEC" futr_exog_list=["ARMED", "POP"], max_steps=5) model.fit(y_train, X=X_train, fh=[1, 2, 3, 4]) model.predict(X=X_test) # Seed set to 1 # 1959 66241.984375 # 1960 66700.132812 # 1961 66550.195312 # 1962 67310.007812 # Freq: A-DEC, Name: TOTEMP, dtype: float64 ```
- Loading branch information
Showing
4 changed files
with
138 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters