Skip to content

Commit

Permalink
fix(rust, python): fix panic in hmean (#5808)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Dec 14, 2022
1 parent c0c3a08 commit 58ce98d
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 15 deletions.
18 changes: 15 additions & 3 deletions polars/polars-core/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2917,10 +2917,22 @@ impl DataFrame {
0 => Ok(None),
1 => Ok(Some(self.columns[0].clone())),
_ => {
let sum = || self.hsum(none_strategy);
let columns = self
.columns
.iter()
.cloned()
.filter(|s| {
let dtype = s.dtype();
dtype.is_numeric() || matches!(dtype, DataType::Boolean)
})
.collect();
let numeric_df = DataFrame::new_no_checks(columns);

let sum = || numeric_df.hsum(none_strategy);

let null_count = || {
self.columns
numeric_df
.columns
.par_iter()
.map(|s| s.is_null().cast(&DataType::UInt32).unwrap())
.reduce_with(|l, r| &l + &r)
Expand All @@ -2934,7 +2946,7 @@ impl DataFrame {

// value lengths: len - null_count
let value_length: UInt32Chunked =
(self.width().sub(&null_count)).u32().unwrap().clone();
(numeric_df.width().sub(&null_count)).u32().unwrap().clone();

// make sure that we do not divide by zero
// by replacing with None
Expand Down
15 changes: 3 additions & 12 deletions py-polars/polars/internals/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -5962,22 +5962,13 @@ def mean(
╞═════╪═════╪══════╪══════╡
│ 2.0 ┆ 7.0 ┆ null ┆ 0.5 │
└─────┴─────┴──────┴──────┘
Note: a PanicException is raised with axis = 1 and a string column.
>>> df = pl.DataFrame(
... {
... "foo": [1, 2, 3],
... "bar": [6, 7, 8],
... }
... )
>>> df.mean(axis=1)
shape: (3,)
Series: 'foo' [f64]
[
3.5
4.5
5.5
2.666667
3.0
5.5
]
"""
Expand Down
6 changes: 6 additions & 0 deletions py-polars/tests/unit/test_aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,9 @@ def test_duration_aggs() -> None:
"literal": [1],
"time_difference": [timedelta(days=31)],
}


def test_hmean_with_str_column() -> None:
assert pl.DataFrame(
{"int": [1, 2, 3], "bool": [True, True, None], "str": ["a", "b", "c"]}
).mean(axis=1).to_list() == [1.0, 1.5, 3.0]

0 comments on commit 58ce98d

Please sign in to comment.