Skip to content

Commit

Permalink
handle numpy integers as return values (#3063)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Apr 5, 2022
1 parent bb3c2f4 commit 3780793
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 11 deletions.
2 changes: 1 addition & 1 deletion py-polars/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 13 additions & 10 deletions py-polars/src/apply/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,6 @@ fn infer_and_finish<'a, A: ApplyLambda<'a>>(
Some(first_value),
)
.map(|ca| ca.into_series().into())
} else if out.is_instance_of::<PyInt>().unwrap() {
let first_value = out.extract::<i64>().unwrap();
applyer
.apply_lambda_with_primitive_out_type::<Int64Type>(
py,
lambda,
null_count,
Some(first_value),
)
.map(|ca| ca.into_series().into())
} else if out.is_instance_of::<PyString>().unwrap() {
let first_value = out.extract::<&str>().unwrap();
applyer
Expand Down Expand Up @@ -80,6 +70,19 @@ fn infer_and_finish<'a, A: ApplyLambda<'a>>(
} else if out.is_instance_of::<PyTuple>().unwrap() {
let first = out.extract::<Wrap<AnyValue<'_>>>()?;
applyer.apply_to_struct(py, lambda, null_count, first.0)
}
// this succeeds for numpy ints as well, where checking if it is pyint fails
// we do this later in the chain so that we don't extract integers from string chars.
else if out.extract::<i64>().is_ok() {
let first_value = out.extract::<i64>().unwrap();
applyer
.apply_lambda_with_primitive_out_type::<Int64Type>(
py,
lambda,
null_count,
Some(first_value),
)
.map(|ca| ca.into_series().into())
} else {
applyer
.apply_lambda_with_object_out_type(
Expand Down
16 changes: 16 additions & 0 deletions py-polars/tests/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,19 @@ def test_apply_numpy_out_3057() -> None:
)
.frame_equal(pl.DataFrame({"id": [0, 1], "result": [1.955, 13.0]}))
)


def test_apply_numpy_int_out() -> None:
df = pl.DataFrame({"col1": [2, 4, 8, 16]})
assert df.with_column(
pl.col("col1").apply(lambda x: np.left_shift(x, 8)).alias("result")
).frame_equal(
pl.DataFrame({"col1": [2, 4, 8, 16], "result": [512, 1024, 2048, 4096]})
)
df = pl.DataFrame({"col1": [2, 4, 8, 16], "shift": [1, 1, 2, 2]})

assert df.select(
pl.struct(["col1", "shift"])
.apply(lambda cols: np.left_shift(cols["col1"], cols["shift"]))
.alias("result")
).frame_equal(pl.DataFrame({"result": [4, 8, 32, 64]}))

0 comments on commit 3780793

Please sign in to comment.