Skip to content

Commit

Permalink
fix agg_std/agg_var for Float32; closes #1446
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Sep 26, 2021
1 parent cd8cca9 commit bed84db
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 2 deletions.
20 changes: 18 additions & 2 deletions polars/polars-core/src/series/implementations/floats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,27 @@ macro_rules! impl_dyn_series {
}

fn agg_std(&self, groups: &[(u32, Vec<u32>)]) -> Option<Series> {
self.0.agg_std(groups)
match self.0.dtype() {
DataType::Float32 => self
.0
.cast::<Float64Type>()
.unwrap()
.agg_std(groups)
.map(|s| s.cast::<Float32Type>().unwrap()),
_ => self.0.agg_std(groups),
}
}

fn agg_var(&self, groups: &[(u32, Vec<u32>)]) -> Option<Series> {
self.0.agg_var(groups)
match self.0.dtype() {
DataType::Float32 => self
.0
.cast::<Float64Type>()
.unwrap()
.agg_var(groups)
.map(|s| s.cast::<Float32Type>().unwrap()),
_ => self.0.agg_var(groups),
}
}

fn agg_n_unique(&self, groups: &[(u32, Vec<u32>)]) -> Option<UInt32Chunked> {
Expand Down
19 changes: 19 additions & 0 deletions py-polars/tests/test_window.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import numpy as np

import polars as pl


def test_std():
for dtype in [pl.Float32, pl.Float64]:
df = pl.DataFrame(
[
pl.Series("groups", ["a", "a", "b", "b"]),
pl.Series("values", [1.0, 2.0, 3.0, 4.0], dtype=dtype),
]
)

out = df.select(pl.col("values").std().over("groups"))
assert np.isclose(out["values"][0], 0.7071067690849304)

out = df.select(pl.col("values").var().over("groups"))
assert np.isclose(out["values"][0], 0.5)

0 comments on commit bed84db

Please sign in to comment.