Skip to content

Commit

Permalink
feat(rust, python): shrink_type expression (#5351)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Oct 27, 2022
1 parent c2a029d commit 0f207f0
Show file tree
Hide file tree
Showing 10 changed files with 156 additions and 18 deletions.
4 changes: 4 additions & 0 deletions polars/polars-lazy/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ mod schema;
#[cfg(feature = "search_sorted")]
mod search_sorted;
mod shift_and_fill;
mod shrink_type;
#[cfg(feature = "sign")]
mod sign;
#[cfg(feature = "strings")]
Expand Down Expand Up @@ -105,6 +106,7 @@ pub enum FunctionExpr {
IsUnique,
IsDuplicated,
Coalesce,
ShrinkType,
}

impl Display for FunctionExpr {
Expand Down Expand Up @@ -157,6 +159,7 @@ impl Display for FunctionExpr {
IsUnique => "is_unique",
IsDuplicated => "is_duplicated",
Coalesce => "coalesce",
ShrinkType => "shrink_dtype",
};
write!(f, "{}", s)
}
Expand Down Expand Up @@ -327,6 +330,7 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
IsUnique => map!(dispatch::is_unique),
IsDuplicated => map!(dispatch::is_duplicated),
Coalesce => map_as_slice!(fill_null::coalesce),
ShrinkType => map_owned!(shrink_type::shrink),
}
}
}
Expand Down
23 changes: 23 additions & 0 deletions polars/polars-lazy/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,29 @@ impl FunctionExpr {
TopK { .. } => same_type(),
Shift(..) | Reverse => same_type(),
IsNotNull | IsNull | Not | IsUnique | IsDuplicated => with_dtype(DataType::Boolean),
ShrinkType => {
// we return the smallest type this can return
// this might not be correct once the actual data
// comes in, but if we set the smallest datatype
// we have the least chance that the smaller dtypes
// get cast to larger types in type-coercion
// this will lead to an incorrect schema in polars
// but we because only the numeric types deviate in
// bit size this will likely not lead to issues
map_dtype(&|dt| {
if dt.is_numeric() {
if dt.is_float() {
DataType::Float32
} else if dt.is_unsigned() {
DataType::Int8
} else {
DataType::UInt8
}
} else {
dt.clone()
}
})
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
use super::*;

pub(super) fn shrink(s: Series) -> PolarsResult<Series> {
if s.dtype().is_numeric() {
if s.dtype().is_float() {
s.cast(&DataType::Float32)
} else if s.dtype().is_unsigned() {
let max = s.max_as_series().get(0).extract::<u64>().unwrap();
if max <= u8::MAX as u64 {
s.cast(&DataType::UInt8)
} else if max <= u16::MAX as u64 {
s.cast(&DataType::UInt16)
} else if max <= u32::MAX as u64 {
s.cast(&DataType::UInt32)
} else {
Ok(s)
}
} else {
let min = s.min_as_series().get(0).extract::<i64>().unwrap();
let max = s.max_as_series().get(0).extract::<i64>().unwrap();

if min >= i8::MIN as i64 && max <= i8::MAX as i64 {
s.cast(&DataType::Int8)
} else if min >= i16::MIN as i64 && max <= i16::MAX as i64 {
s.cast(&DataType::Int16)
} else if min >= i32::MIN as i64 && max <= i32::MAX as i64 {
s.cast(&DataType::Int32)
} else {
Ok(s)
}
}
} else {
Ok(s)
}
}
25 changes: 7 additions & 18 deletions polars/polars-lazy/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2086,6 +2086,13 @@ impl Expr {
})
}

/// Shrink numeric columns to the minimal required datatype
/// needed to fit the extrema of this [`Series`].
/// This can be used to reduce memory pressure.
pub fn shrink_dtype(self) -> Self {
self.map_private(FunctionExpr::ShrinkType)
}

/// Check if all boolean values are `true`
pub fn all(self) -> Self {
self.apply(
Expand All @@ -2106,24 +2113,6 @@ impl Expr {
})
}

/// This is useful if an `apply` function needs a floating point type.
/// Because this cast is done on a `map` level, it will be faster.
pub fn to_float(self) -> Self {
self.map(
|s| match s.dtype() {
DataType::Float32 | DataType::Float64 => Ok(s),
_ => s.cast(&DataType::Float64),
},
GetOutput::map_dtype(|dt| {
if matches!(dt, DataType::Float32) {
DataType::Float32
} else {
DataType::Float64
}
}),
)
}

#[cfg(feature = "dtype-struct")]
#[cfg_attr(docsrs, doc(cfg(feature = "dtype-struct")))]
/// Count all unique values and create a struct mapping value to count
Expand Down
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/expression.rst
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ Manipulation/ selection
Expr.sample
Expr.shift
Expr.shift_and_fill
Expr.shrink_dtype
Expr.shuffle
Expr.slice
Expr.sort
Expand Down
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/series.rst
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ Manipulation/ selection
Series.shift
Series.shift_and_fill
Series.shrink_to_fit
Series.shrink_dtype
Series.shuffle
Series.slice
Series.sort
Expand Down
37 changes: 37 additions & 0 deletions py-polars/polars/internals/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5993,6 +5993,43 @@ def list(self) -> Expr:
"""
return wrap_expr(self._pyexpr.list())

def shrink_dtype(self) -> Expr:
"""
Shrink numeric columns to the minimal required datatype.
Shrink to the dtype needed to fit the extrema of this [`Series`].
This can be used to reduce memory pressure.
Examples
--------
>>> pl.DataFrame(
... {
... "a": [1, 2, 3],
... "b": [1, 2, 2 << 32],
... "c": [-1, 2, 1 << 30],
... "d": [-112, 2, 112],
... "e": [-112, 2, 129],
... "f": ["a", "b", "c"],
... "g": [0.1, 1.32, 0.12],
... "h": [True, None, False],
... }
... ).select(pl.all().shrink_dtype())
shape: (3, 8)
┌─────┬────────────┬────────────┬──────┬──────┬─────┬──────┬───────┐
│ a ┆ b ┆ c ┆ d ┆ e ┆ f ┆ g ┆ h │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ i8 ┆ i64 ┆ i32 ┆ i8 ┆ i16 ┆ str ┆ f32 ┆ bool │
╞═════╪════════════╪════════════╪══════╪══════╪═════╪══════╪═══════╡
│ 1 ┆ 1 ┆ -1 ┆ -112 ┆ -112 ┆ a ┆ 0.1 ┆ true │
├╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 2 ┆ 2 ┆ 2 ┆ 2 ┆ 2 ┆ b ┆ 1.32 ┆ null │
├╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 3 ┆ 8589934592 ┆ 1073741824 ┆ 112 ┆ 129 ┆ c ┆ 0.12 ┆ false │
└─────┴────────────┴────────────┴──────┴──────┴─────┴──────┴───────┘
"""
return wrap_expr(self._pyexpr.shrink_dtype())

@property
def str(self) -> ExprStringNameSpace:
"""
Expand Down
8 changes: 8 additions & 0 deletions py-polars/polars/internals/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -4640,6 +4640,14 @@ def new_from_index(self, index: int, length: int) -> pli.Series:
"""Create a new Series filled with values from the given index."""
return wrap_s(self._s.new_from_index(index, length))

def shrink_dtype(self) -> Series:
"""
Shrink numeric columns to the minimal required datatype.
Shrink to the dtype needed to fit the extrema of this [`Series`].
This can be used to reduce memory pressure.
"""

# Below are the namespaces defined. Do not move these up in the definition of
# Series, as it confuses mypy between the type annotation `str` and the
# namespace `str`
Expand Down
4 changes: 4 additions & 0 deletions py-polars/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,10 @@ impl PyExpr {
self.clone().inner.product().into()
}

pub fn shrink_dtype(&self) -> PyExpr {
self.inner.clone().shrink_dtype().into()
}

pub fn str_parse_date(&self, fmt: Option<String>, strict: bool, exact: bool) -> PyExpr {
self.inner
.clone()
Expand Down
36 changes: 36 additions & 0 deletions py-polars/tests/unit/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,39 @@ def test_fill_null_static_schema_4843() -> None:
df2 = df1.select([pl.col(pl.Int64).fill_null(0)])
df3 = df2.select(pl.col(pl.Int64))
assert df3.schema == {"a": pl.Int64, "b": pl.Int64}


def test_shrink_dtype() -> None:
out = pl.DataFrame(
{
"a": [1, 2, 3],
"b": [1, 2, 2 << 32],
"c": [-1, 2, 1 << 30],
"d": [-112, 2, 112],
"e": [-112, 2, 129],
"f": ["a", "b", "c"],
"g": [0.1, 1.32, 0.12],
"h": [True, None, False],
}
).select(pl.all().shrink_dtype())
assert out.dtypes == [
pl.Int8,
pl.Int64,
pl.Int32,
pl.Int8,
pl.Int16,
pl.Utf8,
pl.Float32,
pl.Boolean,
]

assert out.to_dict(False) == {
"a": [1, 2, 3],
"b": [1, 2, 8589934592],
"c": [-1, 2, 1073741824],
"d": [-112, 2, 112],
"e": [-112, 2, 129],
"f": ["a", "b", "c"],
"g": [0.10000000149011612, 1.3200000524520874, 0.11999999731779099],
"h": [True, None, False],
}

0 comments on commit 0f207f0

Please sign in to comment.