Skip to content

Commit

Permalink
feat(python): r-associative support for commutative DataFrame opera…
Browse files Browse the repository at this point in the history
…tors (#5394)
  • Loading branch information
alexander-beedie committed Nov 2, 2022
1 parent 445c550 commit 446050b
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 20 deletions.
13 changes: 11 additions & 2 deletions py-polars/polars/internals/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,6 +1071,9 @@ def __mul__(self: DF, other: DataFrame | pli.Series | int | float | bool) -> DF:
other = _prepare_other_arg(other)
return self._from_pydf(self._df.mul(other._s))

def __rmul__(self: DF, other: DataFrame | pli.Series | int | float | bool) -> DF:
return self * other

def __truediv__(self: DF, other: DataFrame | pli.Series | int | float | bool) -> DF:
if isinstance(other, DataFrame):
return self._from_pydf(self._df.div_df(other._df))
Expand All @@ -1079,14 +1082,20 @@ def __truediv__(self: DF, other: DataFrame | pli.Series | int | float | bool) ->
return self._from_pydf(self._df.div(other._s))

def __add__(
self: DF,
other: DataFrame | pli.Series | int | float | bool | str,
self: DF, other: DataFrame | pli.Series | int | float | bool | str
) -> DF:
if isinstance(other, DataFrame):
return self._from_pydf(self._df.add_df(other._df))
other = _prepare_other_arg(other)
return self._from_pydf(self._df.add(other._s))

def __radd__(
self: DF, other: DataFrame | pli.Series | int | float | bool | str
) -> DF:
if isinstance(other, str):
return self.select((pli.lit(other) + pli.col("*")).keep_name())
return self + other

def __sub__(self: DF, other: DataFrame | pli.Series | int | float | bool) -> DF:
if isinstance(other, DataFrame):
return self._from_pydf(self._df.sub_df(other._df))
Expand Down
36 changes: 29 additions & 7 deletions py-polars/polars/internals/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,9 +437,19 @@ def _arithmetic(self, other: Any, op_s: str, op_ffi: str) -> Series:
)
return wrap_s(f(other))

@overload
def __add__(self, other: pli.DataFrame) -> pli.DataFrame: # type: ignore[misc]
...

@overload
def __add__(self, other: Any) -> Series:
...

def __add__(self, other: Any) -> Series | pli.DataFrame:
if isinstance(other, str):
other = Series("", [other])
elif isinstance(other, pli.DataFrame):
return other + self
return self._arithmetic(other, "add", "add_<>")

def __sub__(self, other: Any) -> Series:
Expand All @@ -464,10 +474,26 @@ def __floordiv__(self, other: Any) -> Series:
result = result.floor()
return result

def __invert__(self) -> Series:
if self.dtype == Boolean:
return wrap_s(self._s._not())
return NotImplemented

@overload
def __mul__(self, other: pli.DataFrame) -> pli.DataFrame: # type: ignore[misc]
...

@overload
def __mul__(self, other: Any) -> Series:
...

def __mul__(self, other: Any) -> Series | pli.DataFrame:
if self.is_datelike():
raise ValueError("first cast to integer before multiplying datelike dtypes")
return self._arithmetic(other, "mul", "mul_<>")
elif isinstance(other, pli.DataFrame):
return other * self
else:
return self._arithmetic(other, "mul", "mul_<>")

def __mod__(self, other: Any) -> Series:
if self.is_datelike():
Expand All @@ -484,16 +510,13 @@ def __rmod__(self, other: Any) -> Series:
return self._arithmetic(other, "rem", "rem_<>_rhs")

def __radd__(self, other: Any) -> Series:
if isinstance(other, str):
return (other + self.to_frame()).to_series()
return self._arithmetic(other, "add", "add_<>_rhs")

def __rsub__(self, other: Any) -> Series:
return self._arithmetic(other, "sub", "sub_<>_rhs")

def __invert__(self) -> Series:
if self.dtype == Boolean:
return wrap_s(self._s._not())
return NotImplemented

def __rtruediv__(self, other: Any) -> Series:
if self.is_datelike():
raise ValueError("first cast to integer before dividing datelike dtypes")
Expand All @@ -502,7 +525,6 @@ def __rtruediv__(self, other: Any) -> Series:

if isinstance(other, int):
other = float(other)

return self.cast(Float64).__rfloordiv__(other)

def __rfloordiv__(self, other: Any) -> Series:
Expand Down
22 changes: 13 additions & 9 deletions py-polars/tests/unit/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -1786,18 +1786,18 @@ def test_shrink_to_fit() -> None:
def test_arithmetic() -> None:
df = pl.DataFrame({"a": [1.0, 2.0], "b": [3.0, 4.0]})

df_mul = df * 2
expected = pl.DataFrame({"a": [2, 4], "b": [6, 8]})
assert df_mul.frame_equal(expected)
for df_mul in (df * 2, 2 * df):
expected = pl.DataFrame({"a": [2, 4], "b": [6, 8]})
assert df_mul.frame_equal(expected)

for df_plus in (df + 2, 2 + df):
expected = pl.DataFrame({"a": [3, 4], "b": [5, 6]})
assert df_plus.frame_equal(expected)

df_div = df / 2
expected = pl.DataFrame({"a": [0.5, 1.0], "b": [1.5, 2.0]})
assert df_div.frame_equal(expected)

df_plus = df + 2
expected = pl.DataFrame({"a": [3, 4], "b": [5, 6]})
assert df_plus.frame_equal(expected)

df_minus = df - 2
expected = pl.DataFrame({"a": [-1, 0], "b": [1, 2]})
assert df_minus.frame_equal(expected)
Expand Down Expand Up @@ -1845,11 +1845,15 @@ def test_arithmetic() -> None:

def test_add_string() -> None:
df = pl.DataFrame({"a": ["hi", "there"], "b": ["hello", "world"]})
result = df + " hello"
expected = pl.DataFrame(
{"a": ["hi hello", "there hello"], "b": ["hello hello", "world hello"]}
)
assert result.frame_equal(expected)
assert (df + " hello").frame_equal(expected)

expected = pl.DataFrame(
{"a": ["hello hi", "hello there"], "b": ["hello hello", "hello world"]}
)
assert ("hello " + df).frame_equal(expected)


def test_get_item() -> None:
Expand Down
8 changes: 6 additions & 2 deletions py-polars/tests/unit/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,9 +269,10 @@ def test_arithmetic(s: pl.Series) -> None:
assert ((a / 1) == [1.0, 2.0]).sum() == 2
assert ((a // 2) == [0, 1]).sum() == 2
assert ((a * 2) == [2, 4]).sum() == 2
assert ((1 + a) == [2, 3]).sum() == 2
assert ((2 + a) == [3, 4]).sum() == 2
assert ((1 - a) == [0, -1]).sum() == 2
assert ((1 * a) == [1, 2]).sum() == 2
assert ((2 * a) == [2, 4]).sum() == 2

# integer division
assert_series_equal(1 / a, pl.Series([1.0, 0.5]))
if s.dtype == Int64:
Expand Down Expand Up @@ -344,6 +345,9 @@ def test_add_string() -> None:
result = s + " world"
assert_series_equal(result, pl.Series(["hello world", "weird world"]))

result = "pfx:" + s
assert_series_equal(result, pl.Series(["pfx:hello", "pfx:weird"]))


def test_append_extend() -> None:
a = pl.Series("a", [1, 2])
Expand Down

0 comments on commit 446050b

Please sign in to comment.