Skip to content

Commit

Permalink
groupby select multiple columns; closes #65
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Sep 2, 2020
1 parent da80534 commit 1680dd2
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 105 deletions.
1 change: 1 addition & 0 deletions polars/src/doc/changelog/v0_5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
//! * `DataFrame.column` returns `Result<_>` **breaking change**.
//! * Define idiomatic way to do inplace operations on a `DataFrame` with `apply`, `may_apply` and `ChunkSet`
//! * `ChunkSet` Trait.
//! * Groupby can be done on a selection of multiple columns.
//!
183 changes: 110 additions & 73 deletions polars/src/frame/group_by.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use super::hash_join::prepare_hashed_relation;
use crate::chunked_array::builder::PrimitiveChunkedBuilder;
use crate::frame::select::Selection;
use crate::prelude::*;
use arrow::array::{PrimitiveBuilder, StringBuilder};
use enum_dispatch::enum_dispatch;
Expand Down Expand Up @@ -142,13 +143,13 @@ impl DataFrame {
/// ```
///
#[derive(Debug, Clone)]
pub struct GroupBy<'a> {
pub struct GroupBy<'a, 'b> {
df: &'a DataFrame,
/// By which column should the grouping operation be performed.
pub by: String,
// [first idx, [other idx]]
groups: Vec<(usize, Vec<usize>)>,
selection: Option<String>,
selection: Option<Vec<&'b str>>,
}

#[enum_dispatch(Series)]
Expand Down Expand Up @@ -283,10 +284,14 @@ where
}
}

impl<'a> GroupBy<'a> {
impl<'a, 'b> GroupBy<'a, 'b> {
/// Select the column by which the determine the groups.
pub fn select(mut self, name: &str) -> Self {
self.selection = Some(name.to_string());
/// You can select a single column or a slice of columns.
pub fn select<S>(mut self, selection: S) -> Self
where
S: Selection<'b>,
{
self.selection = Some(selection.to_selection_vec());
self
}

Expand All @@ -299,15 +304,15 @@ impl<'a> GroupBy<'a> {
}
}

fn prepare_agg(&self) -> Result<(&String, Series, &Series)> {
let name = match &self.selection {
Some(name) => name,
fn prepare_agg(&self) -> Result<(Series, Vec<Series>)> {
let selection = match &self.selection {
Some(selection) => selection,
None => return Err(PolarsError::NoSelection),
};

let keys = self.keys();
let agg_col = self.df.column(name)?;
Ok((name, keys, agg_col))
let agg_col = self.df.select_series(selection)?;
Ok((keys, agg_col))
}

/// Aggregate grouped series and compute the mean per group.
Expand All @@ -317,31 +322,36 @@ impl<'a> GroupBy<'a> {
/// ```rust
/// # use polars::prelude::*;
/// fn example(df: DataFrame) -> Result<DataFrame> {
/// df.groupby("date")?.select("temp").mean()
/// df.groupby("date")?.select(&["temp", "rain"]).mean()
/// }
/// ```
/// Returns:
///
/// ```text
/// +------------+-----------+
/// | date | temp_mean |
/// | --- | --- |
/// | date32 | f64 |
/// +============+===========+
/// | 2020-08-23 | 9 |
/// +------------+-----------+
/// | 2020-08-22 | 4 |
/// +------------+-----------+
/// | 2020-08-21 | 15 |
/// +------------+-----------+
/// +------------+-----------+-----------+
/// | date | temp_mean | rain_mean |
/// | --- | --- | --- |
/// | date32 | f64 | f64 |
/// +============+===========+===========+
/// | 2020-08-23 | 9 | 0.1 |
/// +------------+-----------+-----------+
/// | 2020-08-22 | 4 | 0.155 |
/// +------------+-----------+-----------+
/// | 2020-08-21 | 15 | 0.15 |
/// +------------+-----------+-----------+
/// ```
pub fn mean(&self) -> Result<DataFrame> {
let (name, keys, agg_col) = self.prepare_agg()?;
let new_name = format!["{}_mean", name];

let mut agg = agg_col.agg_mean(&self.groups);
agg.rename(&new_name);
DataFrame::new(vec![keys, agg])
let (keys, agg_cols) = self.prepare_agg()?;

let mut cols = Vec::with_capacity(agg_cols.len() + 1);
cols.push(keys);
for agg_col in agg_cols {
let new_name = format!["{}_mean", agg_col.name()];
let mut agg = agg_col.agg_mean(&self.groups);
agg.rename(&new_name);
cols.push(agg);
}
DataFrame::new(cols)
}

/// Aggregate grouped series and compute the sum per group.
Expand Down Expand Up @@ -370,11 +380,16 @@ impl<'a> GroupBy<'a> {
/// +------------+----------+
/// ```
pub fn sum(&self) -> Result<DataFrame> {
let (name, keys, agg_col) = self.prepare_agg()?;
let new_name = format!["{}_sum", name];
let mut agg = agg_col.agg_sum(&self.groups);
agg.rename(&new_name);
DataFrame::new(vec![keys, agg])
let (keys, agg_cols) = self.prepare_agg()?;
let mut cols = Vec::with_capacity(agg_cols.len() + 1);
cols.push(keys);
for agg_col in agg_cols {
let new_name = format!["{}_sum", agg_col.name()];
let mut agg = agg_col.agg_sum(&self.groups);
agg.rename(&new_name);
cols.push(agg);
}
DataFrame::new(cols)
}

/// Aggregate grouped series and compute the minimal value per group.
Expand Down Expand Up @@ -403,11 +418,16 @@ impl<'a> GroupBy<'a> {
/// +------------+----------+
/// ```
pub fn min(&self) -> Result<DataFrame> {
let (name, keys, agg_col) = self.prepare_agg()?;
let new_name = format!["{}_min", name];
let mut agg = apply_method_numeric_series!(agg_col, agg_min, &self.groups);
agg.rename(&new_name);
DataFrame::new(vec![keys, agg])
let (keys, agg_cols) = self.prepare_agg()?;
let mut cols = Vec::with_capacity(agg_cols.len() + 1);
cols.push(keys);
for agg_col in agg_cols {
let new_name = format!["{}_min", agg_col.name()];
let mut agg = agg_col.agg_min(&self.groups);
agg.rename(&new_name);
cols.push(agg);
}
DataFrame::new(cols)
}

/// Aggregate grouped series and compute the maximum value per group.
Expand Down Expand Up @@ -436,11 +456,16 @@ impl<'a> GroupBy<'a> {
/// +------------+----------+
/// ```
pub fn max(&self) -> Result<DataFrame> {
let (name, keys, agg_col) = self.prepare_agg()?;
let new_name = format!["{}_max", name];
let mut agg = agg_col.agg_max(&self.groups);
agg.rename(&new_name);
DataFrame::new(vec![keys, agg])
let (keys, agg_cols) = self.prepare_agg()?;
let mut cols = Vec::with_capacity(agg_cols.len() + 1);
cols.push(keys);
for agg_col in agg_cols {
let new_name = format!["{}_max", agg_col.name()];
let mut agg = agg_col.agg_max(&self.groups);
agg.rename(&new_name);
cols.push(agg);
}
DataFrame::new(cols)
}

/// Aggregate grouped series and compute the number of values per group.
Expand Down Expand Up @@ -469,18 +494,23 @@ impl<'a> GroupBy<'a> {
/// +------------+------------+
/// ```
pub fn count(&self) -> Result<DataFrame> {
let (name, keys, agg_col) = self.prepare_agg()?;
let new_name = format!["{}_count", name];

let mut builder = PrimitiveChunkedBuilder::new(&new_name, self.groups.len());
for (_first, idx) in &self.groups {
let s =
unsafe { agg_col.take_iter_unchecked(idx.into_iter().copied(), Some(idx.len())) };
builder.append_value(s.len() as u32);
let (keys, agg_cols) = self.prepare_agg()?;
let mut cols = Vec::with_capacity(agg_cols.len() + 1);
cols.push(keys);
for agg_col in agg_cols {
let new_name = format!["{}_count", agg_col.name()];
let mut builder = PrimitiveChunkedBuilder::new(&new_name, self.groups.len());
for (_first, idx) in &self.groups {
let s = unsafe {
agg_col.take_iter_unchecked(idx.into_iter().copied(), Some(idx.len()))
};
builder.append_value(s.len() as u32);
}
let ca = builder.finish();
let agg = Series::UInt32(ca);
cols.push(agg);
}
let ca = builder.finish();
let agg = Series::UInt32(ca);
DataFrame::new(vec![keys, agg])
DataFrame::new(cols)
}

/// Aggregate the groups of the groupby operation into lists.
Expand Down Expand Up @@ -510,44 +540,47 @@ impl<'a> GroupBy<'a> {
/// +------------+---------------+
/// ```
pub fn agg_list(&self) -> Result<DataFrame> {
let (name, keys, agg_col) = self.prepare_agg()?;
let new_name = format!["{}_agg_list", name];

macro_rules! impl_gb {
($type:ty) => {{
($type:ty, $agg_col:expr) => {{
let values_builder = PrimitiveBuilder::<$type>::new(self.groups.len());
let mut builder = LargeListPrimitiveChunkedBuilder::new(
&new_name,
values_builder,
self.groups.len(),
);
let mut builder =
LargeListPrimitiveChunkedBuilder::new("", values_builder, self.groups.len());
for (_first, idx) in &self.groups {
let s = unsafe {
agg_col.take_iter_unchecked(idx.into_iter().copied(), Some(idx.len()))
$agg_col.take_iter_unchecked(idx.into_iter().copied(), Some(idx.len()))
};
builder.append_opt_series(Some(&s))
}
let list = builder.finish().into_series();
DataFrame::new(vec![keys, list])
builder.finish().into_series()
}};
}

macro_rules! impl_gb_utf8 {
() => {{
($agg_col:expr) => {{
let values_builder = StringBuilder::new(self.groups.len());
let mut builder =
LargeListUtf8ChunkedBuilder::new(&new_name, values_builder, self.groups.len());
LargeListUtf8ChunkedBuilder::new("", values_builder, self.groups.len());
for (_first, idx) in &self.groups {
let s = unsafe {
agg_col.take_iter_unchecked(idx.into_iter().copied(), Some(idx.len()))
$agg_col.take_iter_unchecked(idx.into_iter().copied(), Some(idx.len()))
};
builder.append_opt_series(Some(&s))
}
let list = builder.finish().into_series();
DataFrame::new(vec![keys, list])
builder.finish().into_series()
}};
}
match_arrow_data_type_apply_macro!(agg_col.dtype(), impl_gb, impl_gb_utf8)

let (keys, agg_cols) = self.prepare_agg()?;
let mut cols = Vec::with_capacity(agg_cols.len() + 1);
cols.push(keys);
for agg_col in agg_cols {
let new_name = format!["{}_agg_list", agg_col.name()];
let mut agg =
match_arrow_data_type_apply_macro!(agg_col.dtype(), impl_gb, impl_gb_utf8, agg_col);
agg.rename(&new_name);
cols.push(agg);
}
DataFrame::new(cols)
}
}

Expand Down Expand Up @@ -580,7 +613,11 @@ mod test {
);
println!(
"{:?}",
df.groupby("date").unwrap().select("temp").mean().unwrap()
df.groupby("date")
.unwrap()
.select(&["temp", "rain"])
.mean()
.unwrap()
);
println!(
"{:?}",
Expand Down
15 changes: 12 additions & 3 deletions polars/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ impl DataFrame {
Ok(self.select_at_idx(idx).unwrap())
}

/// Select column(s) from this DataFrame.
/// Select column(s) from this DataFrame and return a new DataFrame.
///
/// # Examples
///
Expand All @@ -359,6 +359,16 @@ impl DataFrame {
/// }
/// ```
pub fn select<'a, S>(&self, selection: S) -> Result<Self>
where
S: Selection<'a>,
{
let selected = self.select_series(selection)?;
let df = DataFrame::new(selected)?;
Ok(df)
}

/// Select column(s) from this DataFrame and return them into a Vector.
pub fn select_series<'a, S>(&self, selection: S) -> Result<Vec<Series>>
where
S: Selection<'a>,
{
Expand All @@ -367,8 +377,7 @@ impl DataFrame {
.iter()
.map(|c| self.column(c).map(|s| s.clone()))
.collect::<Result<Vec<_>>>()?;
let df = DataFrame::new(selected)?;
Ok(df)
Ok(selected)
}

/// Select a mutable series by name.
Expand Down

0 comments on commit 1680dd2

Please sign in to comment.