Skip to content

Commit

Permalink
Add tests series (#1783)
Browse files Browse the repository at this point in the history
  • Loading branch information
zundertj committed Nov 16, 2021
1 parent f0c2aef commit 12ff82f
Show file tree
Hide file tree
Showing 2 changed files with 218 additions and 18 deletions.
51 changes: 41 additions & 10 deletions py-polars/polars/eager/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1141,6 +1141,19 @@ def filter(self, predicate: "Series") -> "Series":
----------
predicate
Boolean mask.
Examples
--------
>>> s = pl.Series("a", [1, 2, 3])
>>> mask = pl.Series("", [True, False, True])
>>> s.filter(mask)
shape: (2,)
Series: 'a' [i64]
[
1
3
]
"""
if isinstance(predicate, list):
predicate = Series("", predicate)
Expand Down Expand Up @@ -1199,7 +1212,7 @@ def take_every(self, n: int) -> "Series":
Examples
--------
>>> s = pl.Series("a", [1, 2, 3, 4])
>>> s.take_every(2))
>>> s.take_every(2)
shape: (2,)
Series: '' [i64]
[
Expand Down Expand Up @@ -1258,6 +1271,20 @@ def argsort(self, reverse: bool = False) -> "Series":
-------
indexes
Indexes that can be used to sort this array.
Examples
--------
>>> s = pl.Series("a", [5, 3, 4, 1, 2])
>>> s.argsort()
shape: (4,)
Series: 'a' [i64]
[
3
4
1
2
0
]
"""
return wrap_s(self._s.argsort(reverse))

Expand Down Expand Up @@ -1407,14 +1434,15 @@ def is_finite(self) -> "Series":
Examples
--------
>>> s = pl.Series("a", [1.0, 2.0, 3.0])
>>> import numpy as np
>>> s = pl.Series("a", [1.0, 2.0, np.inf])
>>> s.is_finite()
shape: (3,)
Series: 'a' [bool]
[
true
true
true
false
]
"""
Expand All @@ -1430,14 +1458,15 @@ def is_infinite(self) -> "Series":
Examples
--------
>>> s = pl.Series("a", [1.0, 2.0, 3.0])
>>> import numpy as np
>>> s = pl.Series("a", [1.0, 2.0, np.inf])
>>> s.is_infinite()
shape: (3,)
Series: 'a' [bool]
[
false
false
false
true
]
"""
Expand All @@ -1455,12 +1484,14 @@ def is_nan(self) -> "Series":
--------
>>> import numpy as np
>>> s = pl.Series("a", [1.0, 2.0, 3.0, np.NaN])
>>> s.take([1, 3])
shape: (2,)
Series: 'a' [i64]
>>> s.is_nan()
shape: (4,)
Series: 'a' [bool]
[
2
4
false
false
false
true
]
"""
Expand Down
185 changes: 177 additions & 8 deletions py-polars/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def test_various():
assert not a.is_numeric()


def test_filter():
def test_filter_ops():
a = pl.Series("a", range(20))
assert a[a > 1].len() == 18
assert a[a < 1].len() == 1
Expand All @@ -181,6 +181,10 @@ def test_to_python():
assert isinstance(b, list)
assert len(b) == 20

b = a.to_list(use_pyarrow=True)
assert isinstance(b, list)
assert len(b) == 20

a = pl.Series("a", [1, None, 2])
assert a.null_count() == 1
assert a.to_list() == [1, None, 2]
Expand Down Expand Up @@ -304,6 +308,7 @@ def test_rolling():
assert a.rolling_min(2).to_list() == [None, 1, 2, 2, 1]
assert a.rolling_max(2).to_list() == [None, 2, 3, 3, 2]
assert a.rolling_sum(2).to_list() == [None, 3, 5, 5, 3]
assert a.rolling_mean(2).to_list() == [None, 1.5, 2.5, 2.5, 1.5]
assert np.isclose(a.rolling_std(2).to_list()[1], 0.7071067811865476)
assert np.isclose(a.rolling_var(2).to_list()[1], 0.5)
assert a.rolling_median(4).to_list() == [None, None, None, 2, 2]
Expand Down Expand Up @@ -429,13 +434,6 @@ def test_arange_expr():
assert out[0].to_list() == [0, 2]


def test_strftime():
a = pl.Series("a", [10000, 20000, 30000], dtype=pl.Date)
assert a.dtype == pl.Date
a = a.dt.strftime("%F")
assert a[2] == "2052-02-20"


def test_round():
a = pl.Series("f", [1.003, 2.003])
b = a.round(2)
Expand Down Expand Up @@ -746,6 +744,177 @@ def test_abs():
assert np.abs(s).to_list() == [1, 2, 3, 4]


def test_to_dummies():
s = pl.Series("a", [1, 2, 3])
result = s.to_dummies()
expected = pl.DataFrame({"a_1": [1, 0, 0], "a_2": [0, 1, 0], "a_3": [0, 0, 1]})
assert result.frame_equal(expected)


def test_value_counts():
s = pl.Series("a", [1, 2, 2, 3])
result = s.value_counts()
expected = pl.DataFrame({"a": [1, 2, 3], "counts": [1, 2, 1]})
assert result.sort("a").frame_equal(expected)


def test_chunk_lengths():
s = pl.Series("a", [1, 2, 2, 3])
# this is a Series with one chunk, of length 4
assert s.n_chunks() == 1
assert s.chunk_lengths() == [4]


def test_limit():
s = pl.Series("a", [1, 2, 3])
assert s.limit(2).series_equal(pl.Series("a", [1, 2]))


def test_filter():
s = pl.Series("a", [1, 2, 3])
mask = pl.Series("", [True, False, True])
assert s.filter(mask).series_equal(pl.Series("a", [1, 3]))


def test_take_every():
s = pl.Series("a", [1, 2, 3, 4])
assert s.take_every(2).series_equal(pl.Series([1, 3]))


def test_argsort():
s = pl.Series("a", [5, 3, 4, 1, 2])
result = s.argsort()
expected = pl.Series([3, 4, 1, 2, 0])
assert result.series_equal(expected)

result_reverse = s.argsort(True)
expected_reverse = pl.Series([0, 2, 1, 4, 3])
assert result_reverse.series_equal(expected_reverse)


def test_arg_min_and_arg_max():
s = pl.Series("a", [5, 3, 4, 1, 2])
assert s.arg_min() == 3
assert s.arg_max() == 0


def test_is_null_is_not_null():
s = pl.Series("a", [1.0, 2.0, 3.0, None])
assert s.is_null().series_equal(pl.Series([False, False, False, True]))
assert s.is_not_null().series_equal(pl.Series([True, True, True, False]))


def test_is_finite_is_infinite():
s = pl.Series("a", [1.0, 2.0, np.inf])

s.is_finite().series_equal(pl.Series([True, True, False]))
s.is_infinite().series_equal(pl.Series([False, False, True]))


def test_is_nan_is_not_nan():
s = pl.Series("a", [1.0, 2.0, 3.0, np.NaN])
assert s.is_nan().series_equal(pl.Series([False, False, False, True]))
assert s.is_not_nan().series_equal(pl.Series([True, True, True, False]))


def test_is_unique():
s = pl.Series("a", [1, 2, 2, 3])
assert s.is_unique().series_equal(pl.Series([True, False, False, True]))


def test_is_duplicated():
s = pl.Series("a", [1, 2, 2, 3])
assert s.is_duplicated().series_equal(pl.Series([False, True, True, False]))


def test_dot():
s = pl.Series("a", [1, 2, 3])
s2 = pl.Series("b", [4.0, 5.0, 6.0])
assert s.dot(s2) == 32


def test_sample():
s = pl.Series("a", [1, 2, 3, 4, 5])
assert len(s.sample(n=2)) == 2
assert len(s.sample(frac=0.4)) == 2

assert len(s.sample(n=2, with_replacement=True)) == 2

# on a series of length 5, you cannot sample more than 5 items
with pytest.raises(Exception):
s.sample(n=10, with_replacement=False)
# unless you use with_replacement=True
assert len(s.sample(n=10, with_replacement=True)) == 10


def test_peak_max_peak_min():
s = pl.Series("a", [4, 1, 3, 2, 5])
assert s.peak_min().series_equal(pl.Series([False, True, False, True, False]))
assert s.peak_max().series_equal(pl.Series([True, False, True, False, True]))


def test_shrink_to_fit():
s = pl.Series("a", [4, 1, 3, 2, 5])
assert s.shrink_to_fit(in_place=True) is None

s = pl.Series("a", [4, 1, 3, 2, 5])
assert isinstance(s.shrink_to_fit(in_place=False), pl.Series)


def test_str_concat():
s = pl.Series(["1", None, "2"])
assert s.str_concat()[0] == "1-null-2"


def test_str_lengths():
s = pl.Series(["messi", "ronaldo", None])
assert s.str.lengths().to_list() == [5, 7, None]


def test_str_contains():
s = pl.Series(["messi", "ronaldo", "ibrahimovic"])
assert s.str.contains("mes").to_list() == [True, False, False]


def test_str_replace_str_replace_all():
s = pl.Series(["hello", "world", "test"])
assert s.str.replace("o", "0").to_list() == ["hell0", "w0rld", "test"]

s = pl.Series(["hello", "world", "test"])
assert s.str.replace_all("o", "0").to_list() == ["hell0", "w0rld", "test"]


def test_str_to_lowercase():
s = pl.Series(["Hello", "WORLD"])
assert s.str.to_lowercase().to_list() == ["hello", "world"]


def test_str_to_uppercase():
s = pl.Series(["Hello", "WORLD"])
assert s.str.to_uppercase().to_list() == ["HELLO", "WORLD"]


def test_str_rstrip():
s = pl.Series([" hello ", "world\t "])
assert s.str.rstrip().to_list() == [" hello", "world"]


def test_str_lstrip():
s = pl.Series([" hello ", "\t world"])
assert s.str.lstrip().to_list() == ["hello ", "world"]


def test_dt_strftime():
a = pl.Series("a", [10000, 20000, 30000], dtype=pl.Date)
assert a.dtype == pl.Date
a = a.dt.strftime("%F")
assert a[2] == "2052-02-20"


def test_dt_year_month_week_day_ordinal_day():
a = pl.Series("a", [10000, 20000, 30000], dtype=pl.Date)
assert a.dt.year().to_list() == [1997, 2024, 2052]
assert a.dt.month().to_list() == [5, 10, 2]
assert a.dt.weekday().to_list() == [0, 4, 1]
assert a.dt.day().to_list() == [19, 4, 20]
assert a.dt.ordinal_day().to_list() == [139, 278, 51]

0 comments on commit 12ff82f

Please sign in to comment.