Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/historical_forecasts accept negative integer as start value #1866

Merged
merged 27 commits into from
Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
9dbe81f
feat: historical_foreacst accept negative integer as start value
madtoinou Jun 29, 2023
344e929
fix: improved the negative start unit test
madtoinou Jun 29, 2023
50a2b1b
fix: simplified the logic around exception raising
madtoinou Jun 29, 2023
a58c89f
Merge branch 'master' into feat/historical_forecast_neg_int_start
madtoinou Jun 29, 2023
df176eb
Merge branch 'master' into feat/historical_forecast_neg_int_start
dennisbader Jul 4, 2023
b3ee491
merging master and using dict type to convey index from the end of th…
madtoinou Aug 9, 2023
db17f94
Merge branch 'master' into feat/historical_forecast_neg_int_start
madtoinou Aug 9, 2023
eee099e
fix: instead of adding capabilities to get_index_at_point, use a new …
madtoinou Aug 11, 2023
c4fcd58
test: udpated tests accordingly
madtoinou Aug 11, 2023
5fda99e
Merge branch 'master' into feat/historical_forecast_neg_int_start
madtoinou Aug 11, 2023
eef4089
doc: updated changelog
madtoinou Aug 11, 2023
a727152
test: added test for historical forecast on ts using a rangeindex sta…
madtoinou Aug 11, 2023
c1cccc3
Apply suggestions from code review
madtoinou Aug 11, 2023
7beb2a6
fix: changed the literal to 'positional_index' and 'value_index'
madtoinou Aug 11, 2023
4e40275
feat: making the error messages more informative, adapted the tests a…
madtoinou Aug 11, 2023
90b2e62
feat: extending the new argument to backtest and gridsearch
madtoinou Aug 11, 2023
ce4b669
fix: import of Literal for python 3.8
madtoinou Aug 14, 2023
292af54
doc: updated changelog
madtoinou Aug 14, 2023
86c3b84
fix: shortened the literal for start_format, updated tests accordingly
madtoinou Aug 14, 2023
b3d5929
doc: updated start docstring
madtoinou Aug 14, 2023
94842b0
test: limited the dependency on unittest in anticipation of the refac…
madtoinou Aug 14, 2023
f6f95bd
doc: updated changelog
madtoinou Aug 14, 2023
35bf096
fix: fixed typo
madtoinou Aug 14, 2023
ba13934
fix: fixed typo
madtoinou Aug 14, 2023
6a91897
doc: copy start and start_format docstring from hist_fct to backtest …
madtoinou Aug 15, 2023
462c51a
Apply suggestions from code review
madtoinou Aug 15, 2023
a8ac1f8
Merge branch 'master' into feat/historical_forecast_neg_int_start
madtoinou Aug 15, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 18 additions & 4 deletions darts/tests/models/forecasting/test_historical_forecasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,16 @@ def test_historical_forecasts_local_models(self):
"LocalForecastingModel does not support historical forecasting with `retrain` set to `False`"
)

def test_historical_forecasts_negative_start(self):
series = tg.sine_timeseries(length=10)

model = LinearRegressionModel(lags=2)
model.fit(series[:8])

forecasts = model.historical_forecasts(series=series, start=-2, retrain=False)
self.assertEqual(len(forecasts), 2)
self.assertEqual(series.time_index[-2], forecasts.time_index[0])

def test_historical_forecasts(self):
train_length = 10
forecast_horizon = 8
Expand Down Expand Up @@ -550,14 +560,18 @@ def test_sanity_check_invalid_start(self):
)
with pytest.raises(ValueError) as msg:
LinearRegressionModel(lags=1).historical_forecasts(
rangeidx_step1, start=rangeidx_step1.start_time() - rangeidx_step1.freq
rangeidx_step1, start=-11
)
assert str(msg.value).startswith("if `start` is an integer, must be `>= 0`")
assert str(msg.value).startswith(
"`start` index `-11` is out of bounds for series of length 10"
)
with pytest.raises(ValueError) as msg:
LinearRegressionModel(lags=1).historical_forecasts(
rangeidx_step2, start=rangeidx_step2.start_time() - rangeidx_step2.freq
rangeidx_step2, start=-11
)
assert str(msg.value).startswith("if `start` is an integer, must be `>= 0`")
assert str(msg.value).startswith(
"`start` index `-11` is out of bounds for series of length 10"
)

# value too high
with pytest.raises(ValueError) as msg:
Expand Down
18 changes: 17 additions & 1 deletion darts/tests/test_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ def test_integer_range_indexing(self):

# getting index for idx should return i s.t., series[i].time == idx
self.assertEqual(series.get_index_at_point(101), 91)
# getting index for negative idx return idx + len(ts)
self.assertEqual(series.get_index_at_point(-3), 97)

# slicing outside of the index range should return an empty ts
self.assertEqual(len(series[120:125]), 0)
Expand All @@ -130,6 +132,8 @@ def test_integer_range_indexing(self):

# getting index for idx should return i s.t., series[i].time == idx
self.assertEqual(series.get_index_at_point(100), 50)
# getting index for negative idx return idx + len(ts)
self.assertEqual(series.get_index_at_point(-1), 99)

# getting index outside of the index range should raise an exception
with self.assertRaises(IndexError):
Expand Down Expand Up @@ -158,6 +162,8 @@ def test_integer_range_indexing(self):

# getting index for idx should return i s.t., series[i].time == idx
self.assertEqual(series.get_index_at_point(16), 3)
# getting index for negative idx return idx + len(ts)
self.assertEqual(series.get_index_at_point(-2), 8)

def test_integer_indexing(self):
n = 10
Expand Down Expand Up @@ -493,15 +499,25 @@ def helper_test_split(test_case, test_series: TimeSeries):
test_case.assertEqual(len(seriesK), 5)
test_case.assertEqual(len(seriesL), len(test_series) - 5)

seriesM, seriesN = test_series.split_after(-2)
test_case.assertEqual(len(seriesM), len(test_series) - len(seriesN))
test_case.assertEqual(len(seriesN), 1)

seriesO, seriesP = test_series.split_before(-2)
test_case.assertEqual(len(seriesO), len(test_series) - len(seriesP))
test_case.assertEqual(len(seriesP), 2)

test_case.assertEqual(test_series.freq_str, seriesA.freq_str)
test_case.assertEqual(test_series.freq_str, seriesC.freq_str)
test_case.assertEqual(test_series.freq_str, seriesE.freq_str)
test_case.assertEqual(test_series.freq_str, seriesG.freq_str)
test_case.assertEqual(test_series.freq_str, seriesI.freq_str)
test_case.assertEqual(test_series.freq_str, seriesK.freq_str)
test_case.assertEqual(test_series.freq_str, seriesM.freq_str)
test_case.assertEqual(test_series.freq_str, seriesO.freq_str)

# Test split points outside of range
for value in [-5, 1.1, pd.Timestamp("21300104")]:
for value in [1.1, pd.Timestamp("21300104")]:
with test_case.assertRaises(ValueError):
test_series.split_before(value)

Expand Down
6 changes: 4 additions & 2 deletions darts/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def __init__(self, xa: xr.DataArray):
logger,
)
else:
self._freq = self._time_index.step
self._freq: int = self._time_index.step
self._freq_str = None

# check static covariates
Expand Down Expand Up @@ -2085,7 +2085,9 @@ def get_index_at_point(
)
point_index = int((len(self) - 1) * point)
elif isinstance(point, (int, np.int64)):
if self.has_datetime_index or (self.start_time() == 0 and self.freq == 1):
if point < 0:
madtoinou marked this conversation as resolved.
Show resolved Hide resolved
point_index = point + len(self)
elif self.has_datetime_index or (self.start_time() == 0 and self.freq == 1):
point_index = point
else:
point_index_float = (point - self.start_time()) / self.freq
Expand Down
60 changes: 28 additions & 32 deletions darts/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from tqdm.notebook import tqdm as tqdm_notebook

from darts import TimeSeries
from darts.logging import get_logger, raise_if_not, raise_log
from darts.logging import get_logger, raise_if, raise_if_not, raise_log
from darts.utils.timeseries_generation import generate_index

try:
Expand Down Expand Up @@ -230,9 +230,7 @@ def _historical_forecasts_general_checks(series, kwargs):
0.0 <= n.start <= 1.0, "`start` should be between 0.0 and 1.0.", logger
)
elif isinstance(n.start, (int, np.int64)):
raise_if_not(
n.start >= 0, "if `start` is an integer, must be `>= 0`.", logger
)
pass

# verbose error messages
if not isinstance(n.start, pd.Timestamp):
Expand All @@ -259,37 +257,35 @@ def _historical_forecasts_general_checks(series, kwargs):
logger,
)
elif isinstance(n.start, (int, np.int64)):
if (
series_.has_datetime_index
or (series_.has_range_index and series_.freq == 1)
) and n.start >= len(series_):
raise_log(
ValueError(
f"`start` index `{n.start}` is out of bounds for series of length {len(series_)} "
f"at index: {idx}."
),
logger,
)
elif (
series_.has_range_index and series_.freq > 1
) and n.start > series_.time_index[-1]:
raise_log(
ValueError(
f"`start` index `{n.start}` is larger than the last index `{series_.time_index[-1]}` "
f"for series at index: {idx}."
),
logger,
)

start = series_.get_timestamp_at_point(n.start)
if n.retrain is not False and start == series_.start_time():
raise_log(
ValueError(
f"{start_value_msg} `{start}` is the first timestamp of the series {idx}, resulting in an "
f"empty training set."
raise_if(
(n.start < 0 and np.abs(n.start) > len(series_))
or (
(
series_.has_datetime_index
or (series_.has_range_index and series_.freq == 1)
)
and n.start >= len(series_)
),
f"`start` index `{n.start}` is out of bounds for series of length {len(series_)} "
f"at index: {idx}.",
logger,
)
raise_if(
series_.has_range_index
and series_.freq > 1
and n.start > series_.time_index[-1],
f"`start` index `{n.start}` is larger than the last index `{series_.time_index[-1]}` "
f"for series at index: {idx}.",
logger,
)

start = series_.get_timestamp_at_point(n.start)
raise_if(
n.retrain is not False and start == series_.start_time(),
f"{start_value_msg} `{start}` is the first timestamp of the series {idx}, resulting in an "
f"empty training set.",
logger,
)

# check that overlap_end and start together form a valid combination
overlap_end = n.overlap_end
Expand Down