Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/fit args #161

Merged
merged 43 commits into from
Sep 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
425a479
add support for columns to the TimeSeries object
grll Jul 15, 2020
beb6432
add colum support indexing to timeseries
grll Jul 16, 2020
20064aa
fix wrong docstring
grll Jul 16, 2020
8c9f224
refactor indexing, fix docstring, columns as last arg
grll Jul 16, 2020
a8021ed
clean indexing method
grll Jul 16, 2020
4bde24b
refactor indexing only based on loc and iloc
grll Jul 16, 2020
f9d89a8
Update darts/timeseries.py
grll Jul 16, 2020
9c0c9e6
use underlying columns by default
grll Jul 20, 2020
cd6df5d
fix column added on intern _df and use self.freq_str
grll Jul 20, 2020
6bef192
fix parameter position in from_times_and_values
grll Jul 20, 2020
fb8b78d
fix the tests to use str columns
grll Jul 20, 2020
9ad5c46
fix docstring timeseries
grll Jul 20, 2020
1cde216
remove None check on df that should exists
grll Jul 20, 2020
66105c7
Merge branch 'develop' into features/indexing
grll Jul 21, 2020
8ddc228
Merge branch 'develop' into features/indexing
TheMP Jul 28, 2020
cade004
add comment for clarifying that _df is a copy
grll Jul 28, 2020
38635ab
add separate function to process columns
grll Jul 28, 2020
f96f169
Merge branch 'features/indexing' of github.com:unit8co/darts into fea…
grll Jul 28, 2020
28173d1
Merge branch 'develop' into features/indexing
grll Jul 29, 2020
b515be0
adapt map with str col indexing
grll Jul 30, 2020
0556c97
univariate fcast model only support univariate ts
grll Jul 21, 2020
6b61b3f
MultivariateFcasModel fits on the whole training ts
grll Jul 21, 2020
3fb1bb2
refactor torch forcasting model to use covariate_series
grll Jul 21, 2020
8166e27
fix unused imports
grll Jul 21, 2020
6b73a2a
allow to specify only covaraite_series
grll Jul 21, 2020
4d0304e
enforce covariate_series and target_series inputs for multivariate model
grll Jul 21, 2020
dad9607
adapt torch datasets to use covariate / target series
grll Jul 21, 2020
0de6892
adapt validation series provided as a Tuple
grll Jul 21, 2020
71d6c46
fix typo
grll Jul 21, 2020
1ebb94b
adapt create_dataset on tcn model
grll Jul 21, 2020
bca4ac0
remove component index from fit function
grll Jul 21, 2020
530af67
adapt tests to new syntax
grll Jul 21, 2020
b671085
refacotr metaclasses
grll Jul 23, 2020
c04fd3b
abstract a new method make fitable series
grll Jul 23, 2020
1d914e1
adapt torchforcastingmodel to parent class changes
grll Jul 23, 2020
b4111e2
keep covariate/target seires for Multivariate models only
grll Jul 24, 2020
2ae9a1c
fix typos with new implementation
grll Jul 24, 2020
7544010
move series length check in forcasting model
grll Jul 24, 2020
8a9aa8b
rename covariate into training series
grll Jul 30, 2020
d909155
adapt old backtesting to support the new fit args syntax
grll Jul 30, 2020
f598f80
Merge branch 'develop' into refactor/fit-args
pennfranc Aug 7, 2020
c75c9c4
Refactor/backtesting (#125)
guillaumeraille Sep 17, 2020
29abe9a
Merge branch 'develop' into refactor/fit-args
pennfranc Sep 17, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 0 additions & 12 deletions darts/backtesting/__init__.py

This file was deleted.

585 changes: 0 additions & 585 deletions darts/backtesting/backtesting.py

This file was deleted.

9 changes: 4 additions & 5 deletions darts/models/arima.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
.. [1] https://wikipedia.org/wiki/Autoregressive_integrated_moving_average
"""

from typing import Optional
from statsmodels.tsa.arima_model import ARMA as staARMA
from statsmodels.tsa.arima_model import ARIMA as staARIMA
from pmdarima import AutoARIMA as PmdAutoARIMA
Expand Down Expand Up @@ -45,8 +44,8 @@ def __init__(self, p: int = 12, d: int = 1, q: int = 0):
def __str__(self):
return 'ARIMA({},{},{})'.format(self.p, self.d, self.q)

def fit(self, series: TimeSeries, component_index: Optional[int] = None):
super().fit(series, component_index)
def fit(self, series: TimeSeries):
super().fit(series)
series = self.training_series
m = staARIMA(series.values(),
order=(self.p, self.d, self.q)) if self.d > 0 else staARMA(series.values(), order=(self.p, self.q))
Expand Down Expand Up @@ -90,8 +89,8 @@ def __init__(self, *autoarima_args, **autoarima_kwargs):
def __str__(self):
return 'Auto-ARIMA'

def fit(self, series: TimeSeries, component_index: Optional[int] = None):
super().fit(series, component_index)
def fit(self, series: TimeSeries):
super().fit(series)
series = self.training_series
self.model.fit(series.values())

Expand Down
13 changes: 6 additions & 7 deletions darts/models/baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
A collection of simple benchmark models.
"""

from typing import Optional
import numpy as np

from .forecasting_model import UnivariateForecastingModel
Expand All @@ -28,8 +27,8 @@ def __init__(self):
def __str__(self):
return 'Naive mean predictor model'

def fit(self, series: TimeSeries, component_index: Optional[int] = None):
super().fit(series, component_index)
def fit(self, series: TimeSeries):
super().fit(series)
self.mean_val = np.mean(series.univariate_values())

def predict(self, n: int):
Expand Down Expand Up @@ -62,8 +61,8 @@ def min_train_series_length(self):
def __str__(self):
return 'Naive seasonal model, with K={}'.format(self.K)

def fit(self, series: TimeSeries, component_index: Optional[int] = None):
super().fit(series, component_index)
def fit(self, series: TimeSeries):
super().fit(series)
raise_if_not(len(series) >= self.K, 'The time series requires at least K={} points'.format(self.K), logger)
self.last_k_vals = series.univariate_values()[-self.K:]

Expand All @@ -87,8 +86,8 @@ def __init__(self):
def __str__(self):
return 'Naive drift model'

def fit(self, series: TimeSeries, component_index: Optional[int] = None):
super().fit(series, component_index)
def fit(self, series: TimeSeries):
super().fit(series)
series = self.training_series

def predict(self, n: int):
Expand Down
4 changes: 2 additions & 2 deletions darts/models/exponential_smoothing.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def __init__(self,
def __str__(self):
return 'Exponential smoothing'

def fit(self, series: TimeSeries, component_index: Optional[int] = None):
super().fit(series, component_index)
def fit(self, series: TimeSeries):
super().fit(series)
series = self.training_series
hw_model = hw.ExponentialSmoothing(series.values(),
trend=self.trend.value,
Expand Down
4 changes: 2 additions & 2 deletions darts/models/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,8 @@ def __init__(self,
def __str__(self):
return 'FFT(nr_freqs_to_keep=' + str(self.nr_freqs_to_keep) + ', trend=' + str(self.trend) + ')'

def fit(self, series: TimeSeries, component_index: Optional[int] = None):
super().fit(series, component_index)
def fit(self, series: TimeSeries):
super().fit(series)
series = self.training_series

# determine trend
Expand Down