Skip to content

Commit

Permalink
Native implementation of the sign function (#4147)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Jul 28, 2022
1 parent 8f07335 commit 57cc243
Show file tree
Hide file tree
Showing 11 changed files with 130 additions and 31 deletions.
1 change: 1 addition & 0 deletions polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ string_justify = ["polars-lazy/string_justify", "polars-ops/string_justify"]
arg_where = ["polars-lazy/arg_where"]
date_offset = ["polars-lazy/date_offset"]
trigonometry = ["polars-lazy/trigonometry"]
sign = ["polars-lazy/sign"]

test = [
"lazy",
Expand Down
1 change: 1 addition & 0 deletions polars/polars-lazy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dtype-struct = ["polars-core/dtype-struct"]
object = ["polars-core/object"]
date_offset = []
trigonometry = []
sign = []

true_div = []

Expand Down
22 changes: 22 additions & 0 deletions polars/polars-lazy/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ mod is_in;
mod pow;
#[cfg(feature = "row_hash")]
mod row_hash;
#[cfg(feature = "sign")]
mod sign;
#[cfg(feature = "strings")]
mod strings;
#[cfg(any(feature = "temporal", feature = "date_offset"))]
Expand Down Expand Up @@ -42,6 +44,8 @@ pub enum FunctionExpr {
DateOffset(Duration),
#[cfg(feature = "trigonometry")]
Trigonometry(TrigonometricFunction),
#[cfg(feature = "sign")]
Sign,
FillNull {
super_type: DataType,
},
Expand Down Expand Up @@ -105,6 +109,8 @@ impl FunctionExpr {
DateOffset(_) => same_type(),
#[cfg(feature = "trigonometry")]
Trigonometry(_) => float_dtype(),
#[cfg(feature = "sign")]
Sign => with_dtype(DataType::Int64),
FillNull { super_type, .. } => with_dtype(super_type.clone()),
}
}
Expand All @@ -129,6 +135,18 @@ macro_rules! map_as_slice {
}};
}

// Fn(&Series)
macro_rules! map_without_args {
($func:path) => {{
let f = move |s: &mut [Series]| {
let s = &s[0];
$func(s)
};

SpecialEq::new(Arc::new(f))
}};
}

// Fn(&Series, args)
macro_rules! map_with_args {
($func:path, $($args:expr),*) => {{
Expand Down Expand Up @@ -199,6 +217,10 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
Trigonometry(trig_function) => {
map_with_args!(trigonometry::apply_trigonometric_function, trig_function)
}
#[cfg(feature = "sign")]
Sign => {
map_without_args!(sign::sign)
}
FillNull { super_type } => {
map_as_slice!(fill_null::fill_null, &super_type)
}
Expand Down
42 changes: 42 additions & 0 deletions polars/polars-lazy/src/dsl/function_expr/sign.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
use super::*;
use polars_core::export::num;
use DataType::*;

pub(super) fn sign(s: &Series) -> Result<Series> {
match s.dtype() {
Float32 => {
let ca = s.f32().unwrap();
sign_float(ca)
}
Float64 => {
let ca = s.f64().unwrap();
sign_float(ca)
}
dt if dt.is_numeric() => {
let s = s.cast(&Float64)?;
sign(&s)
}
dt => Err(PolarsError::ComputeError(
format!("cannot use 'sign' on Series of dtype: {:?}", dt).into(),
)),
}
}

fn sign_float<T>(ca: &ChunkedArray<T>) -> Result<Series>
where
T: PolarsFloatType,
T::Native: num::Float,
ChunkedArray<T>: IntoSeries,
{
ca.apply(signum_improved).into_series().cast(&Int64)
}

// Wrapper for the signum function that handles +/-0.0 inputs differently
// See discussion here: https://github.com/rust-lang/rust/issues/57543
fn signum_improved<F: num::Float>(v: F) -> F {
if v.is_zero() {
v
} else {
v.signum()
}
}
15 changes: 15 additions & 0 deletions polars/polars-lazy/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1381,6 +1381,21 @@ impl Expr {
}
}

/// Compute the sign of the given expression
#[cfg(feature = "sign")]
pub fn sign(self) -> Self {
Expr::Function {
input: vec![self],
function: FunctionExpr::Sign,
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyFlat,
input_wildcard_expansion: false,
auto_explode: false,
fmt_str: "sign",
},
}
}

/// Filter a single column
/// Should be used in aggregation context. If you want to filter on a DataFrame level, use
/// [LazyFrame::filter](LazyFrame::filter)
Expand Down
1 change: 1 addition & 0 deletions polars/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@
//! - `argwhere` Get indices where condition holds.
//! - `date_offset` Add an offset to dates that take months and leap years into account.
//! - `trigonometry` Trigonometric functions.
//! - `sign` Compute the element-wise sign of a Series.
//! * `DataFrame` pretty printing
//! - `fmt` - Activate DataFrame formatting
//!
Expand Down
2 changes: 2 additions & 0 deletions py-polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ ipc = ["polars/ipc"]
is_in = ["polars/is_in"]
json = ["polars/serde", "serde_json"]
trigonometry = ["polars/trigonometry"]
sign = ["polars/sign"]
asof_join = ["polars/asof_join"]
cross_join = ["polars/cross_join"]
pct_change = ["polars/pct_change"]
Expand All @@ -58,6 +59,7 @@ all = [
"json",
"repeat_by",
"trigonometry",
"sign",
"asof_join",
"cross_join",
"pct_change",
Expand Down
40 changes: 20 additions & 20 deletions py-polars/polars/internals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4288,31 +4288,31 @@ def upper_bound(self) -> Expr:

def sign(self) -> Expr:
"""
Return an element-wise indication of the sign of a number.
Compute the element-wise indication of the sign.
Examples
--------
>>> df = pl.DataFrame({"foo": [-9, -8, 0, 4]})
>>> df.select(pl.col("foo").sign())
shape: (4, 1)
┌─────┐
│ foo │
│ --- │
│ i64 │
╞═════╡
│ -1 │
├╌╌╌╌╌┤
│ -1 │
├╌╌╌╌╌┤
│ 0 │
├╌╌╌╌╌┤
│ 1 │
└─────┘
>>> df = pl.DataFrame({"a": [-9.0, -0.0, 0.0, 4.0, None]})
>>> df.select(pl.col("a").sign())
shape: (5, 1)
┌──────┐
│ a │
│ --- │
│ i64 │
╞══════╡
│ -1 │
├╌╌╌╌╌╌┤
│ 0 │
├╌╌╌╌╌╌┤
│ 0 │
├╌╌╌╌╌╌┤
│ 1 │
├╌╌╌╌╌╌┤
│ null │
└──────┘
"""
if not _NUMPY_AVAILABLE:
raise ImportError("'numpy' is required for this functionality.")
return np.sign(self) # type: ignore[call-overload]
return wrap_expr(self._pyexpr.sign())

def sin(self) -> Expr:
"""
Expand Down
17 changes: 8 additions & 9 deletions py-polars/polars/internals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2723,25 +2723,24 @@ def mode(self) -> Series:

def sign(self) -> Series:
"""
Return an element-wise indication of the sign of a number.
Compute the element-wise indication of the sign.
Examples
--------
>>> s = pl.Series("foo", [-9, -8, 0, 4])
>>> s.sign() #
shape: (4,)
Series: 'foo' [i64]
>>> s = pl.Series("a", [-9.0, -0.0, 0.0, 4.0, None])
>>> s.sign()
shape: (5,)
Series: 'a' [i64]
[
-1
-1
0
0
1
null
]
"""
if not _NUMPY_AVAILABLE:
raise ImportError("'numpy' is required for this functionality.")
return np.sign(self) # type: ignore[return-value]
return self.to_frame().select(pli.col(self.name).sign()).to_series()

def sin(self) -> Series:
"""
Expand Down
5 changes: 5 additions & 0 deletions py-polars/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,11 @@ impl PyExpr {
self.clone().inner.arctanh().into()
}

#[cfg(feature = "sign")]
pub fn sign(&self) -> PyExpr {
self.clone().inner.sign().into()
}

pub fn is_duplicated(&self) -> PyExpr {
self.clone().inner.is_duplicated().into()
}
Expand Down
15 changes: 13 additions & 2 deletions py-polars/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1702,10 +1702,21 @@ def test_str_split() -> None:


def test_sign() -> None:
a = pl.Series("a", [10, -20, None])
expected = pl.Series("a", [1, -1, None])
# Integers
a = pl.Series("a", [-9, -0, 0, 4, None])
expected = pl.Series("a", [-1, 0, 0, 1, None])
verify_series_and_expr_api(a, expected, "sign")

# Floats
a = pl.Series("a", [-9.0, -0.0, 0.0, 4.0, None])
expected = pl.Series("a", [-1, 0, 0, 1, None])
verify_series_and_expr_api(a, expected, "sign")

# Invalid input
a = pl.Series("a", [date(1950, 2, 1), date(1970, 1, 1), date(2022, 12, 12), None])
with pytest.raises(pl.ComputeError):
a.sign()


def test_exp() -> None:
a = pl.Series("a", [0.1, 0.01, None])
Expand Down

0 comments on commit 57cc243

Please sign in to comment.