Skip to content

Commit

Permalink
fix_pivot (#3199)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Apr 20, 2022
1 parent b4e2e56 commit e37a27e
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 45 deletions.
188 changes: 143 additions & 45 deletions polars/polars-core/src/frame/groupby/pivot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,128 @@ impl DataFrame {
self.pivot_impl(&values, &index, &columns, agg_fn, sort_columns, true)
}

fn compute_col_idx(
&self,
column: &str,
groups: &GroupsProxy,
) -> Result<(Vec<IdxSize>, Series)> {
let column_s = self.column(column)?;
let column_agg = column_s.agg_first(groups);
let column_agg_physical = column_agg.to_physical_repr();

let mut col_to_idx = PlHashMap::with_capacity(HASHMAP_INIT_SIZE);
let mut idx = 0 as IdxSize;
let col_locations = column_agg_physical
.iter()
.map(|v| {
let idx = *col_to_idx.entry(v).or_insert_with(|| {
let old_idx = idx;
idx += 1;
old_idx
});
idx
})
.collect();

drop(col_to_idx);
Ok((col_locations, column_agg))
}

fn compute_row_idx(
&self,
index: &[String],
groups: &GroupsProxy,
count: usize,
) -> Result<(Vec<IdxSize>, usize, Option<Vec<Series>>)> {
let (row_locations, n_rows, row_index) = if index.len() == 1 {
let index_s = self.column(&index[0])?;
let index_agg = index_s.agg_first(groups);
let index_agg_physical = index_agg.to_physical_repr();

let mut row_to_idx =
PlIndexMap::with_capacity_and_hasher(HASHMAP_INIT_SIZE, Default::default());
let mut idx = 0 as IdxSize;
let row_locations = index_agg_physical
.iter()
.map(|v| {
let idx = *row_to_idx.entry(v).or_insert_with(|| {
let old_idx = idx;
idx += 1;
old_idx
});
idx
})
.collect::<Vec<_>>();

let row_index = match count {
0 => Some(vec![Series::new(
&index[0],
row_to_idx.into_iter().map(|(k, _)| k).collect::<Vec<_>>(),
)]),
_ => None,
};

(row_locations, idx as usize, row_index)
} else {
let index_s = self.columns(index)?;
let index_agg_physical = index_s
.iter()
.map(|s| s.agg_first(groups).to_physical_repr().into_owned())
.collect::<Vec<_>>();
let mut iters = index_agg_physical
.iter()
.map(|s| s.iter())
.collect::<Vec<_>>();
let mut row_to_idx =
PlIndexMap::with_capacity_and_hasher(HASHMAP_INIT_SIZE, Default::default());
let mut idx = 0 as IdxSize;

let mut row_locations = Vec::with_capacity(groups.len());
loop {
match iters
.iter_mut()
.map(|it| it.next())
.collect::<Option<Vec<_>>>()
{
None => break,
Some(items) => {
let idx = *row_to_idx.entry(items).or_insert_with(|| {
let old_idx = idx;
idx += 1;
old_idx
});
row_locations.push(idx)
}
}
}
let row_index = match count {
0 => Some(
index
.iter()
.enumerate()
.map(|(i, name)| {
Series::new(
name,
row_to_idx
.iter()
.map(|(k, _)| {
debug_assert!(i < k.len());
unsafe { k.get_unchecked(i).clone() }
})
.collect::<Vec<_>>(),
)
})
.collect::<Vec<_>>(),
),
_ => None,
};

(row_locations, idx as usize, row_index)
};

Ok((row_locations, n_rows, row_index))
}

fn pivot_impl(
&self,
// these columns will be aggregated in the nested groupby
Expand All @@ -100,8 +222,6 @@ impl DataFrame {
sort_columns: bool,
stable: bool,
) -> Result<DataFrame> {
let keys = self.select_series(index)?;

let mut final_cols = vec![];

let mut count = 0;
Expand All @@ -112,37 +232,17 @@ impl DataFrame {

let groups = self.groupby_stable(groupby)?.groups;

let local_keys = keys
.par_iter()
.map(|k| k.agg_first(&groups))
.collect::<Vec<_>>();

// this are the row locations
let local_keys = DataFrame::new_no_checks(local_keys);
let local_keys_gb = local_keys.groupby_stable(index)?;
// these are the row locations
if !stable {
println!("unstable pivot not yet supported, using stable pivot");
};
let local_index_groups = &local_keys_gb.groups;

let column_s = self.column(column)?;
let column_agg = column_s.agg_first(&groups);
let column_agg_physical = column_agg.to_physical_repr();

let mut col_to_idx = PlHashMap::with_capacity(HASHMAP_INIT_SIZE);

let mut idx = 0 as IdxSize;
let col_locations = column_agg_physical
.iter()
.map(|v| {
let idx = *col_to_idx.entry(v).or_insert_with(|| {
let old_idx = idx;
idx += 1;
old_idx
});
idx
})
.collect::<Vec<_>>();
let (col, row) = POOL.join(
|| self.compute_col_idx(column, &groups),
|| self.compute_row_idx(index, &groups, count),
);
let (col_locations, column_agg) = col?;
let (row_locations, n_rows, mut row_index) = row?;

for value_col in values {
let value_col = self.column(value_col)?;
Expand All @@ -161,28 +261,26 @@ impl DataFrame {

let headers = column_agg.unique_stable()?.cast(&DataType::Utf8)?;
let headers = headers.utf8().unwrap();
let n_rows = local_index_groups.len();
let n_cols = headers.len();

let mut buf = vec![AnyValue::Null; n_rows * n_cols];

let mut col_idx_iter = col_locations.iter();
let value_agg_phys = value_agg.to_physical_repr();
let mut value_iter = value_agg_phys.iter();
for (row_idx, g) in local_index_groups.idx_ref().iter().enumerate() {
for _ in g.1 {
let val = value_iter.next().unwrap();
let col_idx = col_idx_iter.next().unwrap();

// Safety:
// in bounds
unsafe {
let idx = row_idx as usize + *col_idx as usize * n_rows;
debug_assert!(idx < buf.len());
*buf.get_unchecked_mut(idx) = val;
}

for ((row_idx, col_idx), val) in row_locations
.iter()
.zip(&col_locations)
.zip(value_agg_phys.iter())
{
// Safety:
// in bounds
unsafe {
let idx = *row_idx as usize + *col_idx as usize * n_rows;
debug_assert!(idx < buf.len());
*buf.get_unchecked_mut(idx) = val;
}
}

let headers_iter = headers.par_iter_indexed();

let mut cols = (0..n_cols)
Expand All @@ -203,7 +301,7 @@ impl DataFrame {
}

let cols = if count == 0 {
let mut final_cols = local_keys_gb.keys();
let mut final_cols = row_index.take().unwrap();
final_cols.extend(cols);
final_cols
} else {
Expand Down
20 changes: 20 additions & 0 deletions polars/tests/it/core/pivot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,23 @@ fn test_pivot_new() -> Result<()> {

Ok(())
}

#[test]
fn test_pivot_2() -> Result<()> {
let df = df![
"name"=> ["avg", "avg", "act", "test", "test"],
"err" => [Some("name1"), Some("name2"), None, Some("name1"), Some("name2")],
"wght"=> [0.0, 0.1, 1.0, 0.4, 0.2]
]?;

let out = df.pivot_stable(["wght"], ["err"], ["name"], PivotAgg::First, false)?;
let expected = df![
"err" => [Some("name1"), Some("name2"), None],
"avg" => [Some(0.0), Some(0.1), None],
"act" => [None, None, Some(1.)],
"test" => [Some(0.4), Some(0.2), None],
]?;
assert!(out.frame_equal_missing(&expected));

Ok(())
}

0 comments on commit e37a27e

Please sign in to comment.