Skip to content

Commit

Permalink
#1101 Implemented min_train_series_length for Theta and FourTheta (#1111
Browse files Browse the repository at this point in the history
)

* #1101 Implemented min_train_series_length for Theta and FourTheta

* #1103 Placed changelog entry under the Unreleased section

Co-authored-by: Rijk van der Meulen <rijk.vandermeulen@eyeon.nl>
Co-authored-by: Julien Herzen <julien@unit8.co>
  • Loading branch information
3 people committed Aug 9, 2022
1 parent e9f2128 commit 5ef13ba
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
Darts is still in an early development phase and we cannot always guarantee backwards compatibility. Changes that may **break code which uses a previous release of Darts** are marked with a "&#x1F534;".

## [Unreleased](https://github.com/unit8co/darts/tree/master)
- Implemented the min_train_series_length method for the FourTheta and Theta models that overwrites the minimum default of 3 training samples by 2*seasonal_period when appropriate [#1101](https://github.com/unit8co/darts/pull/1101) by [Rijk van der Meulen](https://github.com/rijkvandermeulen)

[Full Changelog](https://github.com/unit8co/darts/compare/0.20.0...master)


Expand Down
22 changes: 22 additions & 0 deletions darts/models/forecasting/theta.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,17 @@ def predict(self, n: int, num_samples: int = 1) -> "TimeSeries":
def __str__(self):
return f"Theta({self.theta})"

@property
def min_train_series_length(self) -> int:
if (
self.season_mode != SeasonalityMode.NONE
and self.seasonality_period
and self.seasonality_period > 1
):
return 2 * self.seasonality_period
else:
return 3


class FourTheta(ForecastingModel):
def __init__(
Expand Down Expand Up @@ -457,3 +468,14 @@ def __str__(self):
return "4Theta(theta:{}, curve:{}, model:{}, seasonality:{})".format(
self.theta, self.trend_mode, self.model_mode, self.season_mode
)

@property
def min_train_series_length(self) -> int:
if (
self.season_mode != SeasonalityMode.NONE
and self.seasonality_period
and self.seasonality_period > 1
):
return 2 * self.seasonality_period
else:
return 3
48 changes: 48 additions & 0 deletions darts/tests/models/forecasting/test_4theta.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,51 @@ def test_best_model(self):
self.assertTrue(
mape(val_series, forecast_best) <= mape(val_series, forecast_random)
)

def test_min_train_series_length_with_seasonality(self):
seasonality_period = 12
fourtheta = FourTheta(
model_mode=ModelMode.MULTIPLICATIVE,
trend_mode=TrendMode.EXPONENTIAL,
season_mode=SeasonalityMode.ADDITIVE,
seasonality_period=seasonality_period,
normalization=False,
)
theta = Theta(
season_mode=SeasonalityMode.ADDITIVE,
seasonality_period=seasonality_period,
)
self.assertEqual(fourtheta.min_train_series_length, 2 * seasonality_period)
self.assertEqual(theta.min_train_series_length, 2 * seasonality_period)

def test_min_train_series_length_without_seasonality(self):
fourtheta = FourTheta(
model_mode=ModelMode.MULTIPLICATIVE,
trend_mode=TrendMode.EXPONENTIAL,
season_mode=SeasonalityMode.ADDITIVE,
seasonality_period=None,
normalization=False,
)
theta = Theta(
season_mode=SeasonalityMode.ADDITIVE,
seasonality_period=None,
)
self.assertEqual(fourtheta.min_train_series_length, 3)
self.assertEqual(theta.min_train_series_length, 3)

def test_fit_insufficient_train_series_length(self):
sine_series = st(length=21, freq="MS")
with self.assertRaises(ValueError):
fourtheta = FourTheta(
model_mode=ModelMode.MULTIPLICATIVE,
trend_mode=TrendMode.EXPONENTIAL,
season_mode=SeasonalityMode.ADDITIVE,
seasonality_period=12,
)
fourtheta.fit(sine_series)
with self.assertRaises(ValueError):
theta = Theta(
season_mode=SeasonalityMode.ADDITIVE,
seasonality_period=12,
)
theta.fit(sine_series)

0 comments on commit 5ef13ba

Please sign in to comment.