-
-
Notifications
You must be signed in to change notification settings - Fork 1.7k
/
test_random.py
125 lines (91 loc) · 4.04 KB
/
test_random.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from __future__ import annotations
import pytest
import polars as pl
from polars.testing import assert_frame_equal, assert_series_equal
def test_shuffle_group_by_reseed() -> None:
def unique_shuffle_groups(n: int, seed: int | None) -> int:
ls = [1, 2, 3] * n # 1, 2, 3, 1, 2, 3...
groups = sorted(list(range(n)) * 3) # 0, 0, 0, 1, 1, 1, ...
df = pl.DataFrame({"l": ls, "group": groups})
shuffled = df.group_by("group", maintain_order=True).agg(
pl.col("l").shuffle(seed)
)
num_unique = shuffled.group_by("l").agg(pl.lit(0)).select(pl.count())
return int(num_unique[0, 0])
assert unique_shuffle_groups(50, None) > 1 # Astronomically unlikely.
assert (
unique_shuffle_groups(50, 0xDEADBEEF) == 1
) # Fixed seed should be always the same.
def test_sample_expr() -> None:
a = pl.Series("a", range(0, 20))
out = pl.select(
pl.lit(a).sample(fraction=0.5, with_replacement=False, seed=1)
).to_series()
assert out.shape == (10,)
assert out.to_list() != out.sort().to_list()
assert out.unique().shape == (10,)
assert set(out).issubset(set(a))
out = pl.select(pl.lit(a).sample(n=10, with_replacement=False, seed=1)).to_series()
assert out.shape == (10,)
assert out.to_list() != out.sort().to_list()
assert out.unique().shape == (10,)
# pl.set_random_seed should lead to reproducible results.
pl.set_random_seed(1)
result1 = pl.select(pl.lit(a).sample(n=10)).to_series()
pl.set_random_seed(1)
result2 = pl.select(pl.lit(a).sample(n=10)).to_series()
assert_series_equal(result1, result2)
def test_sample_df() -> None:
df = pl.DataFrame({"foo": [1, 2, 3], "bar": [6, 7, 8], "ham": ["a", "b", "c"]})
assert df.sample(n=2, seed=0).shape == (2, 3)
assert df.sample(fraction=0.4, seed=0).shape == (1, 3)
def test_sample_empty_df() -> None:
df = pl.DataFrame({"foo": []})
# // If with replacement, then expect empty df
assert df.sample(n=3, with_replacement=True).shape == (0, 1)
assert df.sample(fraction=0.4, with_replacement=True).shape == (0, 1)
# // If without replacement, then expect shape mismatch on sample_n not sample_frac
with pytest.raises(pl.ShapeError):
df.sample(n=3, with_replacement=False)
assert df.sample(fraction=0.4, with_replacement=False).shape == (0, 1)
def test_sample_series() -> None:
s = pl.Series("a", [1, 2, 3, 4, 5])
assert len(s.sample(n=2, seed=0)) == 2
assert len(s.sample(fraction=0.4, seed=0)) == 2
assert len(s.sample(n=2, with_replacement=True, seed=0)) == 2
# on a series of length 5, you cannot sample more than 5 items
with pytest.raises(pl.ShapeError):
s.sample(n=10, with_replacement=False, seed=0)
# unless you use with_replacement=True
assert len(s.sample(n=10, with_replacement=True, seed=0)) == 10
def test_rank_random_expr() -> None:
df = pl.from_dict(
{"a": [1] * 5, "b": [1, 2, 3, 4, 5], "c": [200, 100, 100, 50, 100]}
)
df_ranks1 = df.with_columns(
pl.col("c").rank(method="random", seed=1).over("a").alias("rank")
)
df_ranks2 = df.with_columns(
pl.col("c").rank(method="random", seed=1).over("a").alias("rank")
)
assert_frame_equal(df_ranks1, df_ranks2)
def test_rank_random_series() -> None:
s = pl.Series("a", [1, 2, 3, 2, 2, 3, 0])
assert_series_equal(
s.rank("random", seed=1), pl.Series("a", [2, 4, 7, 3, 5, 6, 1], dtype=pl.UInt32)
)
def test_shuffle_expr() -> None:
# pl.set_random_seed should lead to reproducible results.
s = pl.Series("a", range(20))
pl.set_random_seed(1)
result1 = pl.select(pl.lit(s).shuffle()).to_series()
pl.set_random_seed(1)
result2 = pl.select(pl.lit(s).shuffle()).to_series()
assert_series_equal(result1, result2)
def test_shuffle_series() -> None:
a = pl.Series("a", [1, 2, 3])
out = a.shuffle(2)
expected = pl.Series("a", [2, 1, 3])
assert_series_equal(out, expected)
out = pl.select(pl.lit(a).shuffle(2)).to_series()
assert_series_equal(out, expected)