Skip to content

Commit

Permalink
feat(rust, python): repeat_by should also support broadcasting of LHS (
Browse files Browse the repository at this point in the history
  • Loading branch information
reswqa committed Aug 26, 2023
1 parent 3883e7a commit 2b200f0
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 61 deletions.
136 changes: 84 additions & 52 deletions crates/polars-core/src/chunked_array/ops/repeat_by.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ type LargeListArray = ListArray<i64>;

fn check_lengths(length_srs: usize, length_by: usize) -> PolarsResult<()> {
polars_ensure!(
(length_srs == length_by) | (length_by == 1),
ComputeError: "Length of repeat_by argument needs to be 1 or equal to the length of the Series. Series length {}, by length {}",
(length_srs == length_by) | (length_by == 1) | (length_srs == 1),
ComputeError: "repeat_by argument and the Series should have equal length, or at least one of them should have length 1. Series length {}, by length {}",
length_srs, length_by
);
Ok(())
Expand All @@ -22,95 +22,127 @@ where
fn repeat_by(&self, by: &IdxCa) -> PolarsResult<ListChunked> {
check_lengths(self.len(), by.len())?;

if (self.len() != by.len()) & (by.len() == 1) {
return self.repeat_by(&IdxCa::new(
match (self.len(), by.len()) {
(left_len, right_len) if left_len == right_len => {
Ok(arity::binary(self, by, |arr, by| {
let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| {
opt_by.map(|by| std::iter::repeat(opt_v.copied()).take(*by as usize))
});

// SAFETY: length of iter is trusted.
unsafe {
LargeListArray::from_iter_primitive_trusted_len(
iter,
T::get_dtype().to_arrow(),
)
}
}))
},
(_, 1) => self.repeat_by(&IdxCa::new(
self.name(),
std::iter::repeat(by.get(0).unwrap())
.take(self.len())
.collect::<Vec<IdxSize>>(),
));
)),
(1, _) => {
let new_array = self.new_from_index(0, by.len());
new_array.repeat_by(by)
},
// we have already checked the length
_ => unreachable!(),
}

Ok(arity::binary(self, by, |arr, by| {
let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| {
opt_by.map(|by| std::iter::repeat(opt_v.copied()).take(*by as usize))
});

// SAFETY: length of iter is trusted.
unsafe {
LargeListArray::from_iter_primitive_trusted_len(iter, T::get_dtype().to_arrow())
}
}))
}
}

impl RepeatBy for BooleanChunked {
fn repeat_by(&self, by: &IdxCa) -> PolarsResult<ListChunked> {
check_lengths(self.len(), by.len())?;

if (self.len() != by.len()) & (by.len() == 1) {
return self.repeat_by(&IdxCa::new(
match (self.len(), by.len()) {
(left_len, right_len) if left_len == right_len => {
Ok(arity::binary(self, by, |arr, by| {
let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| {
opt_by.map(|by| std::iter::repeat(opt_v).take(*by as usize))
});

// SAFETY: length of iter is trusted.
unsafe { LargeListArray::from_iter_bool_trusted_len(iter) }
}))
},
(_, 1) => self.repeat_by(&IdxCa::new(
self.name(),
std::iter::repeat(by.get(0).unwrap())
.take(self.len())
.collect::<Vec<IdxSize>>(),
));
)),
(1, _) => {
let new_array = self.new_from_index(0, by.len());
new_array.repeat_by(by)
},
// we have already checked the length
_ => unreachable!(),
}

Ok(arity::binary(self, by, |arr, by| {
let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| {
opt_by.map(|by| std::iter::repeat(opt_v).take(*by as usize))
});

// SAFETY: length of iter is trusted.
unsafe { LargeListArray::from_iter_bool_trusted_len(iter) }
}))
}
}
impl RepeatBy for Utf8Chunked {
fn repeat_by(&self, by: &IdxCa) -> PolarsResult<ListChunked> {
// TODO! dispatch via binary.
check_lengths(self.len(), by.len())?;

if (self.len() != by.len()) & (by.len() == 1) {
return self.repeat_by(&IdxCa::new(
match (self.len(), by.len()) {
(left_len, right_len) if left_len == right_len => {
Ok(arity::binary(self, by, |arr, by| {
let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| {
opt_by.map(|by| std::iter::repeat(opt_v).take(*by as usize))
});

// SAFETY: length of iter is trusted.
unsafe { LargeListArray::from_iter_utf8_trusted_len(iter, self.len()) }
}))
},
(_, 1) => self.repeat_by(&IdxCa::new(
self.name(),
std::iter::repeat(by.get(0).unwrap())
.take(self.len())
.collect::<Vec<IdxSize>>(),
));
)),
(1, _) => {
let new_array = self.new_from_index(0, by.len());
new_array.repeat_by(by)
},
// we have already checked the length
_ => unreachable!(),
}

Ok(arity::binary(self, by, |arr, by| {
let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| {
opt_by.map(|by| std::iter::repeat(opt_v).take(*by as usize))
});

// SAFETY: length of iter is trusted.
unsafe { LargeListArray::from_iter_utf8_trusted_len(iter, self.len()) }
}))
}
}

impl RepeatBy for BinaryChunked {
fn repeat_by(&self, by: &IdxCa) -> PolarsResult<ListChunked> {
check_lengths(self.len(), by.len())?;

if (self.len() != by.len()) & (by.len() == 1) {
return self.repeat_by(&IdxCa::new(
match (self.len(), by.len()) {
(left_len, right_len) if left_len == right_len => {
Ok(arity::binary(self, by, |arr, by| {
let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| {
opt_by.map(|by| std::iter::repeat(opt_v).take(*by as usize))
});

// SAFETY: length of iter is trusted.
unsafe { LargeListArray::from_iter_binary_trusted_len(iter, self.len()) }
}))
},
(_, 1) => self.repeat_by(&IdxCa::new(
self.name(),
std::iter::repeat(by.get(0).unwrap())
.take(self.len())
.collect::<Vec<IdxSize>>(),
));
)),
(1, _) => {
let new_array = self.new_from_index(0, by.len());
new_array.repeat_by(by)
},
// we have already checked the length
_ => unreachable!(),
}

Ok(arity::binary(self, by, |arr, by| {
let iter = arr.into_iter().zip(by).map(|(opt_v, opt_by)| {
opt_by.map(|by| std::iter::repeat(opt_v).take(*by as usize))
});

// SAFETY: length of iter is trusted.
unsafe { LargeListArray::from_iter_binary_trusted_len(iter, self.len()) }
}))
}
}
57 changes: 48 additions & 9 deletions py-polars/tests/unit/dataframe/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -1692,20 +1692,47 @@ def test_repeat_by_unequal_lengths_panic() -> None:
)
with pytest.raises(
pl.ComputeError,
match="""Length of repeat_by argument needs to be 1 or equal to the length of the Series.""",
match="repeat_by argument and the Series should have equal length, "
"or at least one of them should have length 1",
):
df.select(pl.col("a").repeat_by(pl.Series([2, 2])))


@pytest.mark.parametrize(
("value", "values_expect"),
[
(1.2, [[1.2], [1.2, 1.2], [1.2, 1.2, 1.2]]),
(True, [[True], [True, True], [True, True, True]]),
("x", [["x"], ["x", "x"], ["x", "x", "x"]]),
(b"a", [[b"a"], [b"a", b"a"], [b"a", b"a", b"a"]]),
],
)
def test_repeat_by_broadcast_left(
value: float | bool | str, values_expect: list[list[float | bool | str]]
) -> None:
df = pl.DataFrame(
{
"n": [1, 2, 3],
}
)
expected = pl.DataFrame({"values": values_expect})
result = df.select(pl.lit(value).repeat_by(pl.col("n")).alias("values"))
assert_frame_equal(result, expected)


@pytest.mark.parametrize(
("a", "a_expected"),
[
([1.2, 2.2, 3.3], [[1.2, 1.2, 1.2], [2.2, 2.2, 2.2], [3.3, 3.3, 3.3]]),
([True, False], [[True, True, True], [False, False, False]]),
(["x", "y", "z"], [["x", "x", "x"], ["y", "y", "y"], ["z", "z", "z"]]),
(
[b"a", b"b", b"c"],
[[b"a", b"a", b"a"], [b"b", b"b", b"b"], [b"c", b"c", b"c"]],
),
],
)
def test_repeat_by_parameterized(
def test_repeat_by_broadcast_right(
a: list[float | bool | str], a_expected: list[list[float | bool | str]]
) -> None:
df = pl.DataFrame(
Expand All @@ -1720,13 +1747,25 @@ def test_repeat_by_parameterized(
assert_frame_equal(result, expected)


def test_repeat_by() -> None:
df = pl.DataFrame({"name": ["foo", "bar"], "n": [2, 3]})
out = df.select(pl.col("n").repeat_by("n"))
s = out["n"]

assert s[0].to_list() == [2, 2]
assert s[1].to_list() == [3, 3, 3]
@pytest.mark.parametrize(
("a", "a_expected"),
[
(["foo", "bar"], [["foo", "foo"], ["bar", "bar", "bar"]]),
([1, 2], [[1, 1], [2, 2, 2]]),
([True, False], [[True, True], [False, False, False]]),
(
[b"a", b"b"],
[[b"a", b"a"], [b"b", b"b", b"b"]],
),
],
)
def test_repeat_by(
a: list[float | bool | str], a_expected: list[list[float | bool | str]]
) -> None:
df = pl.DataFrame({"a": a, "n": [2, 3]})
expected = pl.DataFrame({"a": a_expected})
result = df.select(pl.col("a").repeat_by("n"))
assert_frame_equal(result, expected)


def test_join_dates() -> None:
Expand Down

0 comments on commit 2b200f0

Please sign in to comment.