Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix getitem with rangeindex start != 0 and freq != 1 #1868

Merged
merged 3 commits into from
Jun 29, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
- Fixed an issue with `TorchForecastingModel.load_from_checkpoint()` not properly loading the loss function and metrics. [#1749](https://github.com/unit8co/darts/pull/1749) by [Antoine Madrona](https://github.com/madtoinou).
- Fixed a bug when loading the weights of a `TorchForecastingModel` trained with encoders or a Likelihood. [#1744](https://github.com/unit8co/darts/pull/1744) by [Antoine Madrona](https://github.com/madtoinou).
- Fixed a bug when using selected `target_components` with `ShapExplainer. [#1803](https://github.com/unit8co/darts/pull/#1803) by [Dennis Bader](https://github.com/dennisbader).
- Fixed `TimeSeries.__getitem__()` for series with a RangeIndex with start != 0 and freq != 1. [#1868](https://github.com/unit8co/darts/pull/#1868) by [Dennis Bader](https://github.com/dennisbader).

## [0.24.0](https://github.com/unit8co/darts/tree/0.24.0) (2023-04-12)
### For users of the library:
Expand Down
46 changes: 45 additions & 1 deletion darts/tests/test_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,7 @@ def test_ops(self):
# Cannot divide by 0.
self.series1 / 0

def test_getitem(self):
def test_getitem_datetime_index(self):
seriesA: TimeSeries = self.series1.drop_after(pd.Timestamp("20130105"))
self.assertEqual(self.series1[pd.date_range("20130101", " 20130104")], seriesA)
self.assertEqual(self.series1[:4], seriesA)
Expand All @@ -890,6 +890,50 @@ def test_getitem(self):
with self.assertRaises(IndexError):
self.series1[::-1]

def test_getitem_integer_index(self):
freq = 3
start = 1
end = start + (len(self.series1) - 1) * freq
idx_int = pd.RangeIndex(start=start, stop=end + freq, step=freq)
series = TimeSeries.from_times_and_values(
times=idx_int, values=self.series1.values()
)
assert series.freq == freq
assert series.start_time() == start
assert series.end_time() == end
assert series[idx_int] == series == series[0 : len(series)]

series_single = series.drop_after(start + 2 * freq)
assert (
series[pd.RangeIndex(start=start, stop=start + 2 * freq, step=freq)]
== series_single
)
assert series[:2] == series_single
assert series_single.freq == freq
assert series_single.start_time() == start
assert series_single.end_time() == start + freq

idx_single = pd.RangeIndex(start=start + freq, stop=start + 2 * freq, step=freq)
assert series[idx_single].time_index == idx_single
assert series[idx_single].pd_series().equals(series.pd_series()[1:2])
assert series[idx_single] == series[1:2] == series[1]

# cannot slice with two RangeIndex
with pytest.raises(IndexError):
_ = series[idx_single : idx_single + freq]

# RangeIndex not in time_index
with pytest.raises(KeyError):
_ = series[idx_single - 1]

# RangeIndex start is out of bounds
with pytest.raises(KeyError):
_ = series[pd.RangeIndex(start - freq, stop=end + freq, step=freq)]

# RangeIndex end is out of bounds
with pytest.raises(KeyError):
_ = series[pd.RangeIndex(start, stop=end + 2 * freq, step=freq)]

def test_fill_missing_dates(self):
with self.assertRaises(ValueError):
# Series cannot have date holes without automatic filling
Expand Down
4 changes: 3 additions & 1 deletion darts/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -4916,7 +4916,9 @@ def _set_freq_in_xa(xa_: xr.DataArray):
xa_ = xa_.assign_coords(
{
self._time_dim: pd.RangeIndex(
start=key, stop=key + self.freq, step=self.freq
start=time_idx[0],
stop=time_idx[0] + self.freq,
step=self.freq,
)
}
)
Expand Down