Skip to content

Commit

Permalink
Fix invalid inputs for trigonometric functions (#4164)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Jul 27, 2022
1 parent ad15e93 commit 78ec23b
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 14 deletions.
9 changes: 8 additions & 1 deletion polars/polars-lazy/src/dsl/function_expr/trigonometry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,17 @@ pub(super) fn apply_trigonometric_function(
let ca = s.f64().unwrap();
apply_trigonometric_function_to_float(ca, trig_function)
}
_ => {
dt if dt.is_numeric() => {
let s = s.cast(&DataType::Float64)?;
apply_trigonometric_function(&s, trig_function)
}
dt => Err(PolarsError::ComputeError(
format!(
"cannot use trigonometric function on Series of dtype: {:?}",
dt
)
.into(),
)),
}
}

Expand Down
27 changes: 14 additions & 13 deletions py-polars/tests/test_series.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import math
from datetime import date, datetime
from typing import Any
from unittest.mock import patch
Expand Down Expand Up @@ -1127,18 +1128,6 @@ def test_comparisons_bool_series_to_int() -> None:
srs_bool > 2 # noqa: B015


def test_trigonometry_functions() -> None:
srs_float = pl.Series("t", [0.0, np.pi])
assert np.allclose(srs_float.sin(), np.array([0.0, 0.0]))
assert np.allclose(srs_float.cos(), np.array([1.0, -1.0]))
assert np.allclose(srs_float.tan(), np.array([0.0, -0.0]))

srs_float = pl.Series("t", [1.0, 0.0, -1])
assert np.allclose(srs_float.arcsin(), np.array([1.571, 0.0, -1.571]), atol=0.01)
assert np.allclose(srs_float.arccos(), np.array([0.0, 1.571, 3.142]), atol=0.01)
assert np.allclose(srs_float.arctan(), np.array([0.785, 0.0, -0.785]), atol=0.01)


def test_abs() -> None:
# ints
s = pl.Series([1, -2, 3, -4])
Expand Down Expand Up @@ -1530,11 +1519,23 @@ def test_is_between_datetime() -> None:
)
@pytest.mark.filterwarnings("ignore:invalid value encountered:RuntimeWarning")
def test_trigonometric(f: str) -> None:
s = pl.Series("a", [0.0])
s = pl.Series("a", [0.0, math.pi])
expected = pl.Series("a", getattr(np, f)(s.to_numpy()))
verify_series_and_expr_api(s, expected, f)


def test_trigonometric_invalid_input() -> None:
# String
s = pl.Series("a", ["1", "2", "3"])
with pytest.raises(pl.ComputeError):
s.sin()

# Date
s = pl.Series("a", [date(1990, 2, 28), date(2022, 7, 26)])
with pytest.raises(pl.ComputeError):
s.cosh()


def test_ewm_mean() -> None:
a = pl.Series("a", [2, 5, 3])
expected = pl.Series(
Expand Down

0 comments on commit 78ec23b

Please sign in to comment.