Skip to content

Commit

Permalink
cast (u)int8 (u)int16 -> i64 in sum/cumsum/cumprod operations
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Dec 1, 2021
1 parent aef8aa9 commit b7123a4
Show file tree
Hide file tree
Showing 13 changed files with 57 additions and 14 deletions.
2 changes: 1 addition & 1 deletion polars/polars-core/src/series/implementations/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ impl SeriesTrait for SeriesWrap<BooleanChunked> {
ChunkFillNull::fill_null(&self.0, strategy).map(|ca| ca.into_series())
}

fn sum_as_series(&self) -> Series {
fn _sum_as_series(&self) -> Series {
ChunkAggSeries::sum_as_series(&self.0)
}
fn max_as_series(&self) -> Series {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ impl SeriesTrait for SeriesWrap<CategoricalChunked> {
ChunkFillNull::fill_null(&self.0, strategy).map(|ca| ca.into_series())
}

fn sum_as_series(&self) -> Series {
fn _sum_as_series(&self) -> Series {
CategoricalChunked::full_null(self.name(), 1).into_series()
}
fn max_as_series(&self) -> Series {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ macro_rules! impl_dyn_series {
.map(|ca| ca.$into_logical().into_series())
}

fn sum_as_series(&self) -> Series {
fn _sum_as_series(&self) -> Series {
Int32Chunked::full_null(self.name(), 1)
.cast(self.dtype())
.unwrap()
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-core/src/series/implementations/floats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ macro_rules! impl_dyn_series {
ChunkFillNull::fill_null(&self.0, strategy).map(|ca| ca.into_series())
}

fn sum_as_series(&self) -> Series {
fn _sum_as_series(&self) -> Series {
ChunkAggSeries::sum_as_series(&self.0)
}
fn max_as_series(&self) -> Series {
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-core/src/series/implementations/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ impl SeriesTrait for SeriesWrap<ListChunked> {
ChunkFillNull::fill_null(&self.0, strategy).map(|ca| ca.into_series())
}

fn sum_as_series(&self) -> Series {
fn _sum_as_series(&self) -> Series {
ChunkAggSeries::sum_as_series(&self.0)
}
fn max_as_series(&self) -> Series {
Expand Down
8 changes: 6 additions & 2 deletions polars/polars-core/src/series/implementations/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,11 @@ macro_rules! impl_dyn_series {
}

fn agg_sum(&self, groups: &[(u32, Vec<u32>)]) -> Option<Series> {
self.0.agg_sum(groups)
use DataType::*;
match self.dtype() {
Int8 | UInt8 | Int16 | UInt16 => self.cast(&Int64).unwrap().agg_sum(groups),
_ => self.0.agg_sum(groups),
}
}

fn agg_first(&self, groups: &[(u32, Vec<u32>)]) -> Series {
Expand Down Expand Up @@ -684,7 +688,7 @@ macro_rules! impl_dyn_series {
ChunkFillNull::fill_null(&self.0, strategy).map(|ca| ca.into_series())
}

fn sum_as_series(&self) -> Series {
fn _sum_as_series(&self) -> Series {
ChunkAggSeries::sum_as_series(&self.0)
}
fn max_as_series(&self) -> Series {
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-core/src/series/implementations/object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ where
&self.0
}

fn sum_as_series(&self) -> Series {
fn _sum_as_series(&self) -> Series {
ObjectChunked::<T>::full_null(self.name(), 1).into_series()
}
fn max_as_series(&self) -> Series {
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-core/src/series/implementations/utf8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ impl SeriesTrait for SeriesWrap<Utf8Chunked> {
ChunkFillNull::fill_null(&self.0, strategy).map(|ca| ca.into_series())
}

fn sum_as_series(&self) -> Series {
fn _sum_as_series(&self) -> Series {
ChunkAggSeries::sum_as_series(&self.0)
}
fn max_as_series(&self) -> Series {
Expand Down
34 changes: 31 additions & 3 deletions polars/polars-core/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,16 @@ impl Series {
pub fn cast(&self, dtype: &DataType) -> Result<Self> {
self.0.cast(dtype)
}

/// Compute the sum of all values in this Series.
/// Returns `None` if the array is empty or only contains null values.
///
/// If the [`DataType`] is one of `{Int8, UInt8, Int16, UInt16}` the `Series` is
/// first cast to `Int64` to prevent overflow issues.
///
/// ```
/// # use polars_core::prelude::*;
/// let s = Series::new("days", [1, 2, 3].as_ref());
/// let s = Series::new("days", &[1, 2, 3]);
/// assert_eq!(s.sum(), Some(6));
/// ```
pub fn sum<T>(&self) -> Option<T>
Expand Down Expand Up @@ -482,6 +488,18 @@ impl Series {
UInt64Chunked::new_from_aligned_vec(self.name(), self.0.vec_hash(build_hasher))
}

/// Get the sum of the Series as a new Series of length 1.
///
/// If the [`DataType`] is one of `{Int8, UInt8, Int16, UInt16}` the `Series` is
/// first cast to `Int64` to prevent overflow issues.
pub fn sum_as_series(&self) -> Series {
use DataType::*;
match self.dtype() {
Int8 | UInt8 | Int16 | UInt16 => self.cast(&Int64).unwrap().sum_as_series(),
_ => self._sum_as_series(),
}
}

/// Get an array with the cumulative max computed at every element
#[cfg_attr(docsrs, doc(cfg(feature = "cum_agg")))]
pub fn cummax(&self, _reverse: bool) -> Series {
Expand Down Expand Up @@ -509,12 +527,17 @@ impl Series {
}

/// Get an array with the cumulative sum computed at every element
///
/// If the [`DataType`] is one of `{Int8, UInt8, Int16, UInt16}` the `Series` is
/// first cast to `Int64` to prevent overflow issues.
#[cfg_attr(docsrs, doc(cfg(feature = "cum_agg")))]
pub fn cumsum(&self, _reverse: bool) -> Series {
#[cfg(feature = "cum_agg")]
{
use DataType::*;
match self.dtype() {
DataType::Boolean => self.cast(&DataType::UInt32).unwrap()._cumsum(_reverse),
Boolean => self.cast(&DataType::UInt32).unwrap()._cumsum(_reverse),
Int8 | UInt8 | Int16 | UInt16 => self.cast(&Int64).unwrap()._cumsum(_reverse),
_ => self._cumsum(_reverse),
}
}
Expand All @@ -525,12 +548,17 @@ impl Series {
}

/// Get an array with the cumulative product computed at every element
///
/// If the [`DataType`] is one of `{Int8, UInt8, Int16, UInt16}` the `Series` is
/// first cast to `Int64` to prevent overflow issues.
#[cfg_attr(docsrs, doc(cfg(feature = "cum_agg")))]
pub fn cumprod(&self, _reverse: bool) -> Series {
#[cfg(feature = "cum_agg")]
{
use DataType::*;
match self.dtype() {
DataType::Boolean => self.cast(&DataType::UInt32).unwrap()._cumprod(_reverse),
Boolean => self.cast(&UInt32).unwrap()._cumprod(_reverse),
Int8 | UInt8 | Int16 | UInt16 => self.cast(&Int64).unwrap()._cumprod(_reverse),
_ => self._cumprod(_reverse),
}
}
Expand Down
7 changes: 6 additions & 1 deletion polars/polars-core/src/series/series_trait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ pub(crate) mod private {
fn agg_max(&self, _groups: &[(u32, Vec<u32>)]) -> Option<Series> {
None
}
/// If the [`DataType`] is one of `{Int8, UInt8, Int16, UInt16}` the `Series` is
/// first cast to `Int64` to prevent overflow issues.
fn agg_sum(&self, _groups: &[(u32, Vec<u32>)]) -> Option<Series> {
None
}
Expand Down Expand Up @@ -818,7 +820,10 @@ pub trait SeriesTrait:
}

/// Get the sum of the Series as a new Series of length 1.
fn sum_as_series(&self) -> Series {
///
/// If the [`DataType`] is one of `{Int8, UInt8, Int16, UInt16}` the `Series` is
/// first cast to `Int64` to prevent overflow issues.
fn _sum_as_series(&self) -> Series {
invalid_operation_panic!(self)
}
/// Get the max of the Series as a new Series of length 1.
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-lazy/src/tests/queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1779,7 +1779,7 @@ fn test_groupby_small_ints() -> Result<()> {
.sort("foo", true)
.collect()?;

assert_eq!(Vec::from(out.column("foo")?.i16()?), &[Some(2), Some(1)]);
assert_eq!(Vec::from(out.column("foo")?.i64()?), &[Some(2), Some(1)]);
Ok(())
}

Expand Down
3 changes: 3 additions & 0 deletions py-polars/polars/internals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,9 @@ def min(self) -> "Expr":
def sum(self) -> "Expr":
"""
Get sum value.
Note that dtypes in {Int8, UInt8, Int16, UInt16} are cast to
Int64 before summing to prevent overflow issues.
"""
return wrap_expr(self._pyexpr.sum())

Expand Down
3 changes: 3 additions & 0 deletions py-polars/polars/internals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,9 @@ def sum(self) -> Union[int, float]:
"""
Reduce this Series to the sum value.
Note that dtypes in {Int8, UInt8, Int16, UInt16} are cast to
Int64 before summing to prevent overflow issues.
Examples
--------
>>> s = pl.Series("a", [1, 2, 3])
Expand Down

0 comments on commit b7123a4

Please sign in to comment.