Skip to content

Commit

Permalink
Fix/head tail (#942)
Browse files Browse the repository at this point in the history
* fix head & tail

* added unit tests
  • Loading branch information
hrzn committed May 18, 2022
1 parent 8d9bd19 commit 699bf17
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
10 changes: 10 additions & 0 deletions darts/tests/test_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -1358,6 +1358,11 @@ def test_head_overshot_sample_axis(self):
result = self.ts.head(20, axis="sample")
self.assertEqual(10, result.n_samples)

def test_head_numeric_time_index(self):
s = TimeSeries.from_values(self.ts.values())
# taking the head should not crash
s.head()

def test_tail_overshot_time_axis(self):
result = self.ts.tail(20)
self.assertEqual(10, result.n_timesteps)
Expand All @@ -1371,6 +1376,11 @@ def test_tail_overshot_sample_axis(self):
result = self.ts.tail(20, axis="sample")
self.assertEqual(10, result.n_samples)

def test_tail_numeric_time_index(self):
s = TimeSeries.from_values(self.ts.values())
# taking the tail should not crash
s.tail()


class TimeSeriesFromDataFrameTestCase(DartsBaseTestClass):
def test_from_dataframe_sunny_day(self):
Expand Down
16 changes: 12 additions & 4 deletions darts/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,8 +1261,12 @@ def head(
"""

axis_str = self._get_dim_name(axis)
display_n = range(min(size, self._xa.sizes[axis_str]))
return self.__class__(self._xa[{axis_str: display_n}])
display_n = min(size, self._xa.sizes[axis_str])

if axis_str == self._time_dim:
return self[:display_n]
else:
return self.__class__(self._xa[{axis_str: range(display_n)}])

def tail(
self, size: Optional[int] = 5, axis: Optional[Union[int, str]] = 0
Expand All @@ -1284,8 +1288,12 @@ def tail(
"""

axis_str = self._get_dim_name(axis)
display_n = range(-min(size, self._xa.sizes[axis_str]), 0)
return self.__class__(self._xa[{axis_str: display_n}])
display_n = min(size, self._xa.sizes[axis_str])

if axis_str == self._time_dim:
return self[-display_n:]
else:
return self.__class__(self._xa[{axis_str: range(-display_n, 0)}])

def concatenate(
self,
Expand Down

0 comments on commit 699bf17

Please sign in to comment.