Skip to content

Commit

Permalink
[ENH] add length option to _bottom_hier_datagen hierarchical data g…
Browse files Browse the repository at this point in the history
…enerator, speed up `ReconcilerForecaster` doctest (#4979)

This PR:
* adds a length option to `_bottom_hier_datagen` hierarchical data
generator, allowing to generate much smaller data sets
* speeds up the `ReconcilerForecaster` doctest by using this option, to
generate a much smaller data set in the doctest

Due to the test speed-up, related to
#2890
  • Loading branch information
fkiraly committed Aug 4, 2023
1 parent 0d9d1a5 commit 2268a9b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
1 change: 1 addition & 0 deletions sktime/forecasting/reconcile.py
Expand Up @@ -72,6 +72,7 @@ class ReconcilerForecaster(BaseForecaster):
... no_bottom_nodes=3,
... no_levels=1,
... random_seed=123,
... length=7,
... )
>>> y = agg.fit_transform(y)
>>> forecaster = NaiveForecaster(strategy="drift")
Expand Down
14 changes: 8 additions & 6 deletions sktime/utils/_testing/hierarchical.py
Expand Up @@ -105,19 +105,19 @@ def _bottom_hier_datagen(
coef_1_max=20,
coef_2_max=0.1,
random_seed=None,
length=144,
):
"""Hierarchical data generator using the flights dataset.
"""Hierarchical data generator using the airline dataset.
This function generates bottom level, i.e. not aggregated, time-series
from the flights dataset.
from the airline dataset (sktime.datasets.load_airline).
Each series is generated from the flights dataset using a linear model,
Each series is generated from the airline dataset using a linear model,
y = c0 + c1x + c2x^(c3), where the coefficients, intercept, and exponent
are randomly sampled for each series. The coefficients and intercept are
sampled between np.arange(0, *_max, 0.01) to keep the values positive. The
exponent is sampled from [0.5, 1, 1.5, 2].
Parameters
----------
no_levels : int, optional
Expand All @@ -128,7 +128,9 @@ def _bottom_hier_datagen(
Maximum possible value of the coefficient or intercept value.
random_seed : int, optional
Random seed for reproducability.
length : int between 1 and 144, optional, default = 144
length of base time series. If lowe than 144,
the airline dataet is truncated to the specified length, cutting from the end.
Returns
-------
Expand All @@ -139,7 +141,7 @@ def _bottom_hier_datagen(

rng = np.random.default_rng(random_seed)

base_ts = load_airline()
base_ts = load_airline()[:length]
df = pd.DataFrame(base_ts, index=base_ts.index)
df.index.rename(None, inplace=True)

Expand Down

0 comments on commit 2268a9b

Please sign in to comment.