Skip to content

Commit

Permalink
n_unique aggregate method; #44
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Oct 3, 2020
1 parent 3decf57 commit 6d885a3
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 11 deletions.
2 changes: 1 addition & 1 deletion polars/src/chunked_array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ impl<T> ChunkedArray<T> {
/// Get the index of the chunk and the index of the value in that chunk
#[inline]
pub(crate) fn index_to_chunked_index(&self, index: usize) -> (usize, usize) {
if self.chunk_id().len() == 1 {
if self.chunks.len() == 1 {
return (0, index);
}
let mut index_remainder = index;
Expand Down
105 changes: 104 additions & 1 deletion polars/src/frame/group_by.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ use super::hash_join::prepare_hashed_relation;
use crate::chunked_array::builder::PrimitiveChunkedBuilder;
use crate::frame::select::Selection;
use crate::prelude::*;
use crate::utils::Xob;
use arrow::array::{PrimitiveBuilder, StringBuilder};
use enum_dispatch::enum_dispatch;
use fnv::FnvBuildHasher;
use itertools::Itertools;
use num::{Num, NumCast, ToPrimitive, Zero};
use rayon::prelude::*;
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::hash::Hash;
use std::{
fmt::{Debug, Formatter},
Expand Down Expand Up @@ -477,6 +478,64 @@ impl AggLast for LargeListChunked {
}
}

#[enum_dispatch(Series)]
trait AggNUnique {
fn agg_n_unique(&self, _groups: &Vec<(usize, Vec<usize>)>) -> UInt32Chunked {
unimplemented!()
}
}

macro_rules! impl_agg_n_unique {
($self:ident, $groups:ident, $ca_type:ty) => {{
$groups
.into_iter()
.map(|(_first, idx)| {
if $self.null_count() == 0 {
let mut set = HashSet::with_hasher(FnvBuildHasher::default());
for i in idx {
let v = unsafe { $self.get_unchecked(*i) };
set.insert(v);
}
set.len() as u32
} else {
let mut set = HashSet::with_hasher(FnvBuildHasher::default());
for i in idx {
let opt_v = $self.get(*i);
set.insert(opt_v);
}
set.len() as u32
}
})
.collect::<$ca_type>()
.into_inner()
}};
}

impl<T> AggNUnique for ChunkedArray<T>
where
T: PolarsIntegerType + Send,
T::Native: Hash + Eq,
{
fn agg_n_unique(&self, groups: &Vec<(usize, Vec<usize>)>) -> UInt32Chunked {
impl_agg_n_unique!(self, groups, Xob<UInt32Chunked>)
}
}

impl AggNUnique for Float32Chunked {}
impl AggNUnique for Float64Chunked {}
impl AggNUnique for LargeListChunked {}
impl AggNUnique for BooleanChunked {
fn agg_n_unique(&self, groups: &Vec<(usize, Vec<usize>)>) -> UInt32Chunked {
impl_agg_n_unique!(self, groups, Xob<UInt32Chunked>)
}
}

impl AggNUnique for Utf8Chunked {
fn agg_n_unique(&self, groups: &Vec<(usize, Vec<usize>)>) -> UInt32Chunked {
impl_agg_n_unique!(self, groups, Xob<UInt32Chunked>)
}
}

impl<'df, 'selection_str> GroupBy<'df, 'selection_str> {
/// Select the column by which the determine the groups.
/// You can select a single column or a slice of columns.
Expand Down Expand Up @@ -738,6 +797,42 @@ impl<'df, 'selection_str> GroupBy<'df, 'selection_str> {
DataFrame::new(cols)
}

/// Aggregate grouped `Series` by counting the number of unique values.
///
/// # Example
///
/// ```rust
/// # use polars::prelude::*;
/// fn example(df: DataFrame) -> Result<DataFrame> {
/// df.groupby("date")?.select("temp").n_unique()
/// }
/// ```
/// Returns:
///
/// ```text
/// +------------+---------------+
/// | date | temp_n_unique |
/// | --- | --- |
/// | date32 | u32 |
/// +============+===============+
/// | 2020-08-23 | 1 |
/// +------------+---------------+
/// | 2020-08-22 | 2 |
/// +------------+---------------+
/// | 2020-08-21 | 2 |
/// +------------+---------------+
/// ```
pub fn n_unique(&self) -> Result<DataFrame> {
let (mut cols, agg_cols) = self.prepare_agg()?;
for agg_col in agg_cols {
let new_name = format!["{}_n_unique", agg_col.name()];
let mut agg = agg_col.agg_n_unique(&self.groups).into_series();
agg.rename(&new_name);
cols.push(agg);
}
DataFrame::new(cols)
}

/// Aggregate grouped series and compute the number of values per group.
///
/// # Example
Expand Down Expand Up @@ -1267,6 +1362,14 @@ mod test {
"{:?}",
df.groupby("date").unwrap().select("temp").last().unwrap()
);
println!(
"{:?}",
df.groupby("date")
.unwrap()
.select("temp")
.n_unique()
.unwrap()
);
}

#[test]
Expand Down
27 changes: 18 additions & 9 deletions py-polars/pypolars/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,41 +516,50 @@ def __init__(self, df: DataFrame, by: List[str], selection: List[str]):
self.by = by
self.selection = selection

def first(self):
def first(self) -> DataFrame:
"""
Aggregate the first value in the group.
Aggregate the first values in the group.
"""
return wrap_df(self._df.groupby(self.by, self.selection, "first"))

def last(self):
def last(self) -> DataFrame:
"""
Aggregate the first value in the group.
Aggregate the last values in the group.
"""
return wrap_df(self._df.groupby(self.by, self.selection, "last"))

def sum(self):
def sum(self) -> DataFrame:
"""
Reduce the groups to the sum.
"""
return wrap_df(self._df.groupby(self.by, self.selection, "sum"))

def min(self):
def min(self) -> DataFrame:
"""
Reduce the groups to the minimal value.
"""
return wrap_df(self._df.groupby(self.by, self.selection, "min"))

def max(self):
def max(self) -> DataFrame:
"""
Reduce the groups to the maximal value.
"""
return wrap_df(self._df.groupby(self.by, self.selection, "max"))

def count(self):
def count(self) -> DataFrame:
"""
Count the number of values in each group.
"""
return wrap_df(self._df.groupby(self.by, self.selection, "count"))

def mean(self):
def mean(self) -> DataFrame:
"""
Reduce the groups to the mean values.
"""
return wrap_df(self._df.groupby(self.by, self.selection, "mean"))

def n_unique(self) -> DataFrame:
"""
Count the unique values per group.
"""
return wrap_df(self._df.groupby(self.by, self.selection, "n_unique"))
1 change: 1 addition & 0 deletions py-polars/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ impl PyDataFrame {
"last" => selection.last(),
"sum" => selection.sum(),
"count" => selection.count(),
"n_unique" => selection.n_unique(),
a => Err(PolarsError::Other(format!("agg fn {} does not exists", a))),
};
let df = df.map_err(PyPolarsEr::from)?;
Expand Down

0 comments on commit 6d885a3

Please sign in to comment.