Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] add length option to _bottom_hier_datagen hierarchical data generator, speed up ReconcilerForecaster doctest #4979

Merged
merged 1 commit into from Aug 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions sktime/forecasting/reconcile.py
Expand Up @@ -72,6 +72,7 @@ class ReconcilerForecaster(BaseForecaster):
... no_bottom_nodes=3,
... no_levels=1,
... random_seed=123,
... length=7,
... )
>>> y = agg.fit_transform(y)
>>> forecaster = NaiveForecaster(strategy="drift")
Expand Down
14 changes: 8 additions & 6 deletions sktime/utils/_testing/hierarchical.py
Expand Up @@ -105,19 +105,19 @@ def _bottom_hier_datagen(
coef_1_max=20,
coef_2_max=0.1,
random_seed=None,
length=144,
):
"""Hierarchical data generator using the flights dataset.
"""Hierarchical data generator using the airline dataset.

This function generates bottom level, i.e. not aggregated, time-series
from the flights dataset.
from the airline dataset (sktime.datasets.load_airline).

Each series is generated from the flights dataset using a linear model,
Each series is generated from the airline dataset using a linear model,
y = c0 + c1x + c2x^(c3), where the coefficients, intercept, and exponent
are randomly sampled for each series. The coefficients and intercept are
sampled between np.arange(0, *_max, 0.01) to keep the values positive. The
exponent is sampled from [0.5, 1, 1.5, 2].


Parameters
----------
no_levels : int, optional
Expand All @@ -128,7 +128,9 @@ def _bottom_hier_datagen(
Maximum possible value of the coefficient or intercept value.
random_seed : int, optional
Random seed for reproducability.

length : int between 1 and 144, optional, default = 144
length of base time series. If lowe than 144,
the airline dataet is truncated to the specified length, cutting from the end.

Returns
-------
Expand All @@ -139,7 +141,7 @@ def _bottom_hier_datagen(

rng = np.random.default_rng(random_seed)

base_ts = load_airline()
base_ts = load_airline()[:length]
df = pd.DataFrame(base_ts, index=base_ts.index)
df.index.rename(None, inplace=True)

Expand Down