Skip to content

Commit

Permalink
fix boolean/numeric supertypes (#4252)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Aug 4, 2022
1 parent 912b6d5 commit c4bacff
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
7 changes: 7 additions & 0 deletions polars/polars-core/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,13 @@ fn _get_supertype(l: &DataType, r: &DataType) -> Option<DataType> {

(UInt32, UInt64) => Some(UInt64),

#[cfg(feature = "dtype-u8")]
(Boolean, UInt8) => Some(UInt8),
#[cfg(feature = "dtype-u16")]
(Boolean, UInt16) => Some(UInt16),
(Boolean, UInt32) => Some(UInt32),
(Boolean, UInt64) => Some(UInt64),

#[cfg(feature = "dtype-u8")]
(Float32, UInt8) => Some(Float32),
#[cfg(feature = "dtype-u16")]
Expand Down
18 changes: 18 additions & 0 deletions py-polars/tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,21 @@ def test_pow_dtype() -> None:
]
)
assert df.collect().dtypes == [pl.UInt32, pl.UInt32, pl.Float64]


def test_bool_numeric_supertype() -> None:
df = pl.DataFrame({"v": [1, 2, 3, 4, 5, 6]})
for dt in [
pl.UInt8,
pl.UInt16,
pl.UInt32,
pl.UInt64,
pl.Int8,
pl.Int16,
pl.Int32,
pl.Int64,
]:
assert (
df.select([(pl.col("v") < 3).sum().cast(dt) / pl.count()])[0, 0] - 0.3333333
<= 0.00001
)

0 comments on commit c4bacff

Please sign in to comment.