Skip to content

Commit

Permalink
python: improve arithmetic consistency (#3001)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Mar 29, 2022
1 parent fe10d4a commit 1a2bf9b
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 26 deletions.
5 changes: 5 additions & 0 deletions py-polars/build.requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,8 @@ numpydoc==1.1.0

# Stub files
pandas-stubs==1.2.0.39


# pinned third rate deps
# to be removed later
click==8.0.4
18 changes: 14 additions & 4 deletions py-polars/polars/internals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,10 +346,19 @@ def __le__(self, other: Any) -> "Series":
def _arithmetic(self, other: Any, op_s: str, op_ffi: str) -> "Series":
if isinstance(other, Series):
return wrap_s(getattr(self._s, op_s)(other._s))
other = maybe_cast(other, self.dtype, self.time_unit)
f = get_ffi_func(op_ffi, self.dtype, self._s)
if isinstance(other, float) and not self.is_float():
_s = sequence_to_pyseries("", [other])
if "rhs" in op_ffi:
return wrap_s(getattr(_s, op_s)(self._s))
else:
return wrap_s(getattr(self._s, op_s)(_s))
else:
other = maybe_cast(other, self.dtype, self.time_unit)
f = get_ffi_func(op_ffi, self.dtype, self._s)
if f is None:
return NotImplemented
raise ValueError(
f"cannot do arithmetic with series of dtype: {self.dtype} and argument of type: {type(other)}"
)
return wrap_s(f(other))

def __add__(self, other: Any) -> "Series":
Expand All @@ -374,7 +383,8 @@ def __floordiv__(self, other: Any) -> "Series":
if self.is_datelike():
raise ValueError("first cast to integer before dividing datelike dtypes")
result = self._arithmetic(other, "div", "div_<>")
if self.is_float():
# todo! in place, saves allocation
if self.is_float() or isinstance(other, float):
result = result.floor()
return result

Expand Down
10 changes: 0 additions & 10 deletions py-polars/tests/db-benchmark/various.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# may contain many things that seemed to go wrong at scale

import os
import time

import numpy as np
Expand Down Expand Up @@ -53,12 +52,3 @@
)
assert computed[0, "min"] == minimum
assert computed[0, "max"] == maximum

# test home directory support
# https://github.com/pola-rs/polars/pull/2940
filename = "~/test.parquet"

df.to_parquet(filename)
df = pl.read_parquet(filename)

os.remove(filename)
7 changes: 7 additions & 0 deletions py-polars/tests/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,10 @@ def test_apply_infer_list() -> None:
}
)
assert df.select([pl.all().apply(lambda x: [x])]).dtypes == [pl.List] * 3


def test_apply_arithmetic_consistency() -> None:
df = pl.DataFrame({"A": ["a", "a"], "B": [2, 3]})
assert df.groupby("A").agg(pl.col("B").apply(lambda x: x + 1.0))["B"].to_list() == [
[3.0, 4.0]
]
25 changes: 13 additions & 12 deletions py-polars/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def test_arithmetic(s: pl.Series) -> None:
# negate
assert (-a == [-1, -2]).sum() == 2
# wrong dtypes in rhs operands
assert ((1.0 - a) == [0, -1]).sum() == 2
assert ((1.0 - a) == [0.0, -1.0]).sum() == 2
assert ((1.0 / a) == [1.0, 0.5]).sum() == 2
assert ((1.0 * a) == [1, 2]).sum() == 2
assert ((1.0 + a) == [2, 3]).sum() == 2
Expand Down Expand Up @@ -902,15 +902,14 @@ def test_from_sequences() -> None:

def test_comparisons_int_series_to_float() -> None:
srs_int = pl.Series([1, 2, 3, 4])
testing.assert_series_equal(srs_int - 1.0, pl.Series([0, 1, 2, 3]))
testing.assert_series_equal(srs_int + 1.0, pl.Series([2, 3, 4, 5]))
testing.assert_series_equal(srs_int * 2.0, pl.Series([2, 4, 6, 8]))
# todo: this is inconsistent
testing.assert_series_equal(srs_int - 1.0, pl.Series([0.0, 1.0, 2.0, 3.0]))
testing.assert_series_equal(srs_int + 1.0, pl.Series([2.0, 3.0, 4.0, 5.0]))
testing.assert_series_equal(srs_int * 2.0, pl.Series([2.0, 4.0, 6.0, 8.0]))
testing.assert_series_equal(srs_int / 2.0, pl.Series([0.5, 1.0, 1.5, 2.0]))
testing.assert_series_equal(srs_int % 2.0, pl.Series([1, 0, 1, 0]))
testing.assert_series_equal(4.0 % srs_int, pl.Series([0, 0, 1, 0]))
testing.assert_series_equal(srs_int % 2.0, pl.Series([1.0, 0.0, 1.0, 0.0]))
testing.assert_series_equal(4.0 % srs_int, pl.Series([0.0, 0.0, 1.0, 0.0]))

testing.assert_series_equal(srs_int // 2.0, pl.Series([0, 1, 1, 2]))
testing.assert_series_equal(srs_int // 2.0, pl.Series([0.0, 1.0, 1.0, 2.0]))
testing.assert_series_equal(srs_int < 3.0, pl.Series([True, True, False, False]))
testing.assert_series_equal(srs_int <= 3.0, pl.Series([True, True, True, False]))
testing.assert_series_equal(srs_int > 3.0, pl.Series([False, False, False, True]))
Expand Down Expand Up @@ -941,13 +940,15 @@ def test_comparisons_bool_series_to_int() -> None:
srs_bool = pl.Series([True, False])
# todo: do we want this to work?
testing.assert_series_equal(srs_bool / 1, pl.Series([True, False], dtype=Float64))
with pytest.raises(TypeError, match=r"\-: 'Series' and 'int'"):
match = r"cannot do arithmetic with series of dtype: <class 'polars.datatypes.Boolean'> and argument of type: <class 'bool'>"
with pytest.raises(ValueError, match=match):
srs_bool - 1
with pytest.raises(TypeError, match=r"\+: 'Series' and 'int'"):
with pytest.raises(ValueError, match=match):
srs_bool + 1
with pytest.raises(TypeError, match=r"\%: 'Series' and 'int'"):
match = r"cannot do arithmetic with series of dtype: <class 'polars.datatypes.Boolean'> and argument of type: <class 'bool'>"
with pytest.raises(ValueError, match=match):
srs_bool % 2
with pytest.raises(TypeError, match=r"\*: 'Series' and 'int'"):
with pytest.raises(ValueError, match=match):
srs_bool * 1
with pytest.raises(
TypeError, match=r"'<' not supported between instances of 'Series' and 'int'"
Expand Down

0 comments on commit 1a2bf9b

Please sign in to comment.