Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Develop #156

Merged
merged 68 commits into from
Jul 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
01e1837
feat(4Theta): naive implementation of 4Theta model
Droxef Jul 3, 2020
f6831be
fix(theta): avoid NaN values in theta, and unnecessary season test
Droxef Jul 3, 2020
53ceeab
feat(gridsearch): add possibility to compare with model.fitted_values
Droxef Jul 3, 2020
7ae0803
feat(4theta): add a method to auto select best model
Droxef Jul 3, 2020
4001009
Merge branch 'develop' of https://github.com/unit8co/darts into feat/…
Droxef Jul 3, 2020
264ddf3
refactor(4Theta): Specify univariate model
Droxef Jul 3, 2020
e5bc790
style(4Theta): Fix linter
Droxef Jul 3, 2020
e951554
style(4Theta): fix docstring
Droxef Jul 3, 2020
f596476
style(4Theta): Fix docstring
Droxef Jul 3, 2020
b712c7d
style(4Theta): Change link
Droxef Jul 3, 2020
dee5b0d
style(4Theta): Correct docstring
Droxef Jul 3, 2020
ba64351
style(4theta): correct ticks in docstring
Droxef Jul 6, 2020
36d99fe
Merge branch 'develop' of https://github.com/unit8co/darts into feat/…
Droxef Jul 7, 2020
1dc5bb3
Merge branch 'develop' of https://github.com/unit8co/darts into feat/…
Droxef Jul 7, 2020
4d24f75
refactor(4Theta): change different modes verification and add Enum
Droxef Jul 7, 2020
240cea4
refactor(theta): replace all string modes by Enum
Droxef Jul 8, 2020
7d5dffa
test(backtesting): Add a test to verify if fitted_values exist
Droxef Jul 8, 2020
b19cf97
Fix(Theta): Correct all Enums
Droxef Jul 8, 2020
bbd58cc
fix(Theta): compare with enum members value instead. Correct some min…
Droxef Jul 8, 2020
71aa8d3
fix(4theta): move the creation of enums in init file
Droxef Jul 8, 2020
8286062
test(4theta): Add 4Theta to autoregressive test. Move Enums to top in…
Droxef Jul 8, 2020
af69e63
test(4theta): Add 4Theta specific test
Droxef Jul 8, 2020
97f6956
style(backtesting): fix lint
Droxef Jul 8, 2020
97cff2a
test(4theta): Add another exception to test
Droxef Jul 9, 2020
f6df487
ref(4Theta): mode.fitted_values is now a TimeSeries to be consistent
Droxef Jul 13, 2020
2bad4f7
style(Theta): rename mode to season_mode to be consistent w/ FourTheta
Droxef Jul 13, 2020
49fafda
Merge branch 'develop' of https://github.com/unit8co/darts into feat/…
Droxef Jul 13, 2020
fe1b320
docs(thetas): correct errors in the different docs
Droxef Jul 13, 2020
fc04e68
refactor(4Theta): Correct
Droxef Jul 13, 2020
cae49f6
refactor(backtesting): add a 'use_fitted_values' parameter
Droxef Jul 13, 2020
27b9143
fix(4theta): correct select_best_model
Droxef Jul 13, 2020
2fa7f4f
test(4Theta): add a test for zero mean and correct others
Droxef Jul 13, 2020
5ed92bd
Merge branch 'develop' of https://github.com/unit8co/darts into feat/…
Droxef Jul 14, 2020
a2b64e4
style(backtesting): linter formatting
Droxef Jul 14, 2020
98d1b7a
refactor(4Theta): change Enums names, correct theta and backtesting docs
Droxef Jul 16, 2020
f405e8b
refactor(4theta): move creation of fitted_values timeseries to backte…
Droxef Jul 16, 2020
9cad7cf
Merge branch 'develop' of https://github.com/unit8co/darts into feat/…
Droxef Jul 16, 2020
1e6be8e
refactor(statistics): include Enums in extract and remove functions
Droxef Jul 16, 2020
2f9bc5b
refactor(4Theta): check earlier if univariate
Droxef Jul 16, 2020
a405e47
test(4Theta): correct backtesting and test best_model
Droxef Jul 16, 2020
25f3de2
test(4Theta): add new modes in test models
Droxef Jul 16, 2020
e15f52b
docs(4theta): Add a disclaimer for 4theta performance
Droxef Jul 16, 2020
61c0f17
refactor(Theta): change theta to have the same behavior as FourTheta
Droxef Jul 16, 2020
e63e215
examples(darts-intro): modify notebook to give the same results
Droxef Jul 16, 2020
be1427d
style(4Theta): correct deprecation warning for logger.warn
Droxef Jul 16, 2020
34a8c41
Merge branch 'develop' into feat/FourTheta
hrzn Jul 17, 2020
d8e890f
style(4theta): move comment to backtesting
Droxef Jul 20, 2020
e02d65d
Merge branch 'feat/FourTheta' of https://github.com/unit8co/darts int…
Droxef Jul 20, 2020
8e92eba
Merge branch 'develop' of https://github.com/unit8co/darts into feat/…
Droxef Jul 20, 2020
0c04781
Merge branch 'develop' of https://github.com/unit8co/darts into feat/…
Droxef Jul 20, 2020
f2843c0
fix(residuals): correct a function argument
Droxef Jul 20, 2020
d7f90e3
fix(theta): correct bug where season period was kept between fits
Droxef Jul 20, 2020
47f16b8
cd(workflows): merge develop and PR workflows
Droxef Jul 20, 2020
c3da07b
test(4theta): bigger tolerance for macos
Droxef Jul 20, 2020
5813f3a
Merge branch 'develop' of https://github.com/unit8co/darts into feat/…
Droxef Jul 20, 2020
28b329b
Revert "cd(workflows): merge develop and PR workflows"
Droxef Jul 21, 2020
03f2ffd
cd(workflows): all tests trigger on PR
Droxef Jul 21, 2020
e77e94f
fix(workflow): remove duplicate tests
Droxef Jul 22, 2020
a074b35
fix(workflow): correct develop worklow not on master
Droxef Jul 22, 2020
9a27455
refactor(ExpSmoothing): add enum to exponential smoothing use
Droxef Jul 22, 2020
f92b1bf
Revert "fix(residuals): correct a function argument"
Droxef Jul 22, 2020
cfb0149
Revert "fix(workflow): correct develop worklow not on master"
Droxef Jul 22, 2020
7fb4f7a
Revert "fix(workflow): remove duplicate tests"
Droxef Jul 22, 2020
696abfa
fix(test/4theta): correct non-deterministic behaviour in test
Droxef Jul 22, 2020
412b995
Merge branch 'develop' into feat/FourTheta
Droxef Jul 22, 2020
bd8f1c1
Revert "cd(workflows): all tests trigger on PR"
Droxef Jul 23, 2020
8768490
Merge branch 'feat/FourTheta' of https://github.com/unit8co/darts int…
Droxef Jul 23, 2020
3b44282
Merge branch 'develop' into feat/FourTheta
TheMP Jul 27, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
23 changes: 16 additions & 7 deletions darts/models/exponential_smoothing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@
from .forecasting_model import UnivariateForecastingModel
from ..logging import get_logger
from ..timeseries import TimeSeries
from .. import ModelMode

logger = get_logger(__name__)


class ExponentialSmoothing(UnivariateForecastingModel):
def __init__(self,
trend: Optional[str] = 'additive',
trend: Optional[ModelMode] = ModelMode.ADDITIVE,
damped: Optional[bool] = False,
seasonal: Optional[str] = 'additive',
seasonal: Optional[ModelMode] = ModelMode.ADDITIVE,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the underlying string values of the enum still supported as argument values?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the raw strings are not suported anymore. We only accept the Enum members.
But to be backward-compatible, it can be interesting to accept both.

seasonal_periods: Optional[int] = 12,
**fit_kwargs):
""" Exponential Smoothing
Expand All @@ -28,14 +29,22 @@ def __init__(self,

We refer to this link for the original and more complete documentation of the parameters.

`model_mode` must be a ModelMode Enum member.
You can access the Enum with `from darts import ModelMode`.

`ExponentialSmoothing(trend=None, seasonal=None)` corresponds to a single exponential smoothing.
`ExponentialSmoothing(trend=ModelMode.ADDITIVE, seasonal=None)` corresponds to a Holt's exponential smoothing.

Parameters
----------
trend
Type of trend component
Type of trend component. Either ModelMode.ADDITIVE or ModelMode.MULTIPLICATIVE.
Defaults to `ModelMode.ADDITIVE`.
damped
Should the trend component be damped.
Should the trend component be damped. Defaults to False.
seasonal
Type of seasonal component
Type of seasonal component. Either ModelMode.ADDITIVE or ModelMode.MULTIPLICATIVE.
Defaults to `ModelMode.ADDITIVE`.
seasonal_periods
The number of periods in a complete seasonal cycle, e.g., 4 for quarterly data or 7 for daily
data with a weekly cycle.
Expand All @@ -60,9 +69,9 @@ def fit(self, series: TimeSeries, component_index: Optional[int] = None):
super().fit(series, component_index)
series = self.training_series
hw_model = hw.ExponentialSmoothing(series.values(),
trend=self.trend,
trend=self.trend.value,
damped=self.damped,
seasonal=self.seasonal,
seasonal=self.seasonal.value,
seasonal_periods=self.seasonal_periods)

hw_results = hw_model.fit(**self.fit_kwargs)
Expand Down
10 changes: 8 additions & 2 deletions darts/models/theta.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def __init__(self,
self.theta = theta
self.is_seasonal = False
self.seasonality = None
self.season_period = seasonality_period
self.seasonality_period = seasonality_period
self.season_period = None
self.season_mode = season_mode

raise_if_not(season_mode in SeasonalityMode,
Expand All @@ -79,6 +80,8 @@ def fit(self, series: TimeSeries, component_index: Optional[int] = None):
# or infers season_period from the TimeSeries itself.
if self.season_mode is SeasonalityMode.NONE:
self.season_period = 1
else:
self.season_period = self.seasonality_period
if self.season_period is None:
max_lag = len(ts) // 2
self.is_seasonal, self.season_period = check_seasonality(ts, self.season_period, max_lag=max_lag)
Expand Down Expand Up @@ -195,7 +198,8 @@ def __init__(self,
self.theta = theta
self.is_seasonal = False
self.seasonality = None
self.season_period = seasonality_period
self.seasonality_period = seasonality_period
self.season_period = None
self.model_mode = model_mode
self.season_mode = season_mode
self.trend_mode = trend_mode
Expand Down Expand Up @@ -230,6 +234,8 @@ def fit(self, ts, component_index: Optional[int] = None):
# or infers season_period from the TimeSeries itself.
if self.season_mode is SeasonalityMode.NONE:
self.season_period = 1
else:
self.season_period = self.seasonality_period
if self.season_period is None:
max_lag = len(ts) // 2
self.is_seasonal, self.season_period = check_seasonality(ts, self.season_period, max_lag=max_lag)
Expand Down
6 changes: 3 additions & 3 deletions darts/tests/test_4theta.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,16 @@ def test_zero_mean(self):
model.fit(sine_series)

def test_theta(self):
random.seed(1)
np.random.seed(1)
series = rt(length=50, mean=100)
theta_param = random.randrange(-5, 5)
theta_param = np.random.randint(1, 5)
theta = Theta(theta_param)
fourtheta = FourTheta(theta_param, normalization=False)
theta.fit(series)
fourtheta.fit(series)
forecast_theta = theta.predict(20)
forecast_fourtheta = fourtheta.predict(20)
self.assertTrue((forecast_theta - forecast_fourtheta <= 1e-12).all()[0])
self.assertTrue((forecast_theta - forecast_fourtheta <= 1e-11).all()[0])

def test_best_model(self):
random.seed(1)
Expand Down