Skip to content

Commit

Permalink
melt extra arguments (#3133)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Apr 13, 2022
1 parent 977c46e commit 2b7e463
Show file tree
Hide file tree
Showing 17 changed files with 178 additions and 95 deletions.
50 changes: 44 additions & 6 deletions polars/polars-core/src/frame/explode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@ fn get_exploded(series: &Series) -> Result<(Series, Buffer<i64>)> {
}
}

/// Arguments for `[DataFrame::melt]` function
#[derive(Clone, Default, Debug)]
pub struct MeltArgs {
pub id_vars: Vec<String>,
pub value_vars: Vec<String>,
pub variable_name: Option<String>,
pub value_name: Option<String>,
}

impl DataFrame {
pub fn explode_impl(&self, mut columns: Vec<Series>) -> Result<DataFrame> {
if self.height() == 0 {
Expand Down Expand Up @@ -241,12 +250,22 @@ impl DataFrame {
{
let id_vars = id_vars.into_vec();
let value_vars = value_vars.into_vec();
self.melt2(id_vars, value_vars)
self.melt2(MeltArgs {
id_vars,
value_vars,
..Default::default()
})
}

/// Similar to melt, but without generics. This may be easier if you want to pass
/// an empty `id_vars` or empty `value_vars`.
pub fn melt2(&self, id_vars: Vec<String>, mut value_vars: Vec<String>) -> Result<Self> {
pub fn melt2(&self, args: MeltArgs) -> Result<Self> {
let id_vars = args.id_vars;
let mut value_vars = args.value_vars;

let value_name = args.value_name.as_deref().unwrap_or("value");
let variable_name = args.variable_name.as_deref().unwrap_or("variable");

let len = self.height();

// if value vars is empty we take all columns that are not in id_vars.
Expand Down Expand Up @@ -306,13 +325,17 @@ impl DataFrame {
// Safety
// The give dtype is correct
let values =
unsafe { Series::from_chunks_and_dtype_unchecked("value", vec![values_arr], &st) };
unsafe { Series::from_chunks_and_dtype_unchecked(value_name, vec![values_arr], &st) };

let variable_col = variable_col.into_arc();
// Safety
// The give dtype is correct
let variables = unsafe {
Series::from_chunks_and_dtype_unchecked("variable", vec![variable_col], &DataType::Utf8)
Series::from_chunks_and_dtype_unchecked(
variable_name,
vec![variable_col],
&DataType::Utf8,
)
};

ids.hstack_mut(&[variables, values])?;
Expand All @@ -323,6 +346,7 @@ impl DataFrame {

#[cfg(test)]
mod test {
use crate::frame::explode::MeltArgs;
use crate::prelude::*;

#[test]
Expand Down Expand Up @@ -413,7 +437,14 @@ mod test {
&[Some(10), Some(11), Some(12), Some(2), Some(4), Some(6)]
);

let melted = df.melt2(vec![], vec![]).unwrap();
let args = MeltArgs {
id_vars: vec![],
value_vars: vec![],
variable_name: None,
value_name: None,
};

let melted = df.melt2(args).unwrap();
let value = melted.column("value")?;
// utf8 because of supertype
let value = value.utf8()?;
Expand All @@ -423,7 +454,14 @@ mod test {
&["a", "b", "a", "1", "3", "5", "10", "11", "12", "2", "4", "6"]
);

let melted = df.melt2(vec!["A".into()], vec![]).unwrap();
let args = MeltArgs {
id_vars: vec!["A".into()],
value_vars: vec![],
variable_name: None,
value_name: None,
};

let melted = df.melt2(args).unwrap();
let value = melted.column("value")?;
let value = value.i32()?;
let value = value.into_no_null_iter().collect::<Vec<_>>();
Expand Down
1 change: 1 addition & 0 deletions polars/polars-core/src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub use crate::{
df,
error::{PolarsError, Result},
frame::{
explode::MeltArgs,
groupby::{GroupsIdx, GroupsProxy, GroupsSlice},
hash_join::JoinType,
*,
Expand Down
8 changes: 3 additions & 5 deletions polars/polars-lazy/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ use crate::prelude::{
use crate::logical_plan::FETCH_ROWS;
use crate::utils::{combine_predicates_expr, expr_to_root_column_names};
use polars_arrow::prelude::QuantileInterpolOptions;
use polars_core::frame::explode::MeltArgs;
use polars_io::RowCount;

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -1012,12 +1013,9 @@ impl LazyFrame {
}

/// Melt the DataFrame from wide to long format
pub fn melt(self, id_vars: Vec<String>, value_vars: Vec<String>) -> LazyFrame {
pub fn melt(self, args: MeltArgs) -> LazyFrame {
let opt_state = self.get_opt_state();
let lp = self
.get_plan_builder()
.melt(Arc::new(id_vars), Arc::new(value_vars))
.build();
let lp = self.get_plan_builder().melt(Arc::new(args)).build();
Self::from_logical_plan(lp, opt_state)
}

Expand Down
21 changes: 7 additions & 14 deletions polars/polars-lazy/src/logical_plan/alp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::logical_plan::ParquetOptions;
use crate::logical_plan::{det_melt_schema, Context, CsvParserOptions};
use crate::prelude::*;
use crate::utils::{aexprs_to_schema, PushNode};
use polars_core::frame::explode::MeltArgs;
use polars_core::prelude::*;
use polars_utils::arena::{Arena, Node};
#[cfg(any(feature = "csv-file", feature = "parquet"))]
Expand All @@ -16,8 +17,7 @@ use std::sync::Arc;
pub enum ALogicalPlan {
Melt {
input: Node,
id_vars: Arc<Vec<String>>,
value_vars: Arc<Vec<String>>,
args: Arc<MeltArgs>,
schema: SchemaRef,
},
Slice {
Expand Down Expand Up @@ -193,15 +193,9 @@ impl ALogicalPlan {
inputs,
options: *options,
},
Melt {
id_vars,
value_vars,
schema,
..
} => Melt {
Melt { args, schema, .. } => Melt {
input: inputs[0],
id_vars: id_vars.clone(),
value_vars: value_vars.clone(),
args: args.clone(),
schema: schema.clone(),
},
Slice { offset, len, .. } => Slice {
Expand Down Expand Up @@ -548,13 +542,12 @@ impl<'a> ALogicalPlanBuilder<'a> {
}
}

pub fn melt(self, id_vars: Arc<Vec<String>>, value_vars: Arc<Vec<String>>) -> Self {
let schema = det_melt_schema(&id_vars, &value_vars, self.schema());
pub fn melt(self, args: Arc<MeltArgs>) -> Self {
let schema = det_melt_schema(&args, self.schema());

let lp = ALogicalPlan::Melt {
input: self.root,
id_vars,
value_vars,
args,
schema,
};
let node = self.lp_arena.add(lp);
Expand Down
37 changes: 22 additions & 15 deletions polars/polars-lazy/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::prelude::*;
use crate::utils;
use crate::utils::{combine_predicates_expr, has_expr};
use parking_lot::Mutex;
use polars_core::frame::explode::MeltArgs;
use polars_core::prelude::*;
use polars_core::utils::get_supertype;
use polars_io::csv::CsvEncoding;
Expand Down Expand Up @@ -402,12 +403,11 @@ impl LogicalPlanBuilder {
.into()
}

pub fn melt(self, id_vars: Arc<Vec<String>>, value_vars: Arc<Vec<String>>) -> Self {
let schema = det_melt_schema(&id_vars, &value_vars, self.0.schema());
pub fn melt(self, args: Arc<MeltArgs>) -> Self {
let schema = det_melt_schema(&args, self.0.schema());
LogicalPlan::Melt {
input: Box::new(self.0),
id_vars,
value_vars,
args,
schema,
}
.into()
Expand Down Expand Up @@ -504,24 +504,31 @@ impl LogicalPlanBuilder {
}
}

pub(crate) fn det_melt_schema(
id_vars: &[String],
value_vars: &[String],
input_schema: &Schema,
) -> SchemaRef {
pub(crate) fn det_melt_schema(args: &MeltArgs, input_schema: &Schema) -> SchemaRef {
let mut new_schema = Schema::from(
id_vars
args.id_vars
.iter()
.map(|id| Field::new(id, input_schema.get(id).unwrap().clone())),
);
new_schema.with_column("variable".to_string(), DataType::Utf8);
let variable_name = args
.variable_name
.as_ref()
.cloned()
.unwrap_or_else(|| "variable".to_string());
let value_name = args
.value_name
.as_ref()
.cloned()
.unwrap_or_else(|| "value".to_string());

new_schema.with_column(variable_name, DataType::Utf8);

// We need to determine the supertype of all value columns.
let mut st = None;

// take all columns that are not in `id_vars` as `value_var`
if value_vars.is_empty() {
let id_vars = PlHashSet::from_iter(id_vars);
if args.value_vars.is_empty() {
let id_vars = PlHashSet::from_iter(&args.id_vars);
for (name, dtype) in input_schema.iter() {
if !id_vars.contains(name) {
match &st {
Expand All @@ -531,14 +538,14 @@ pub(crate) fn det_melt_schema(
}
}
} else {
for name in value_vars {
for name in &args.value_vars {
let dtype = input_schema.get(name).unwrap();
match &st {
None => st = Some(dtype.clone()),
Some(st_) => st = Some(get_supertype(st_, dtype).unwrap()),
}
}
}
new_schema.with_column("value".to_string(), st.unwrap());
new_schema.with_column(value_name, st.unwrap());
Arc::new(new_schema)
}
12 changes: 4 additions & 8 deletions polars/polars-lazy/src/logical_plan/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,15 +169,13 @@ pub(crate) fn to_alp(
}
LogicalPlan::Melt {
input,
id_vars,
value_vars,
args,
schema,
} => {
let input = to_alp(*input, expr_arena, lp_arena)?;
ALogicalPlan::Melt {
input,
id_vars,
value_vars,
args,
schema,
}
}
Expand Down Expand Up @@ -826,15 +824,13 @@ pub(crate) fn node_to_lp(
}
ALogicalPlan::Melt {
input,
id_vars,
value_vars,
args,
schema,
} => {
let input = node_to_lp(input, expr_arena, lp_arena);
LogicalPlan::Melt {
input: Box::new(input),
id_vars,
value_vars,
args,
schema,
}
}
Expand Down
4 changes: 2 additions & 2 deletions polars/polars-lazy/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ mod projection;
pub(crate) use apply::*;
pub(crate) use builder::*;
pub use lit::*;
use polars_core::frame::explode::MeltArgs;

// Will be set/ unset in the fetch operation to communicate overwriting the number of rows to scan.
thread_local! {pub(crate) static FETCH_ROWS: Cell<Option<usize>> = Cell::new(None)}
Expand Down Expand Up @@ -148,8 +149,7 @@ pub enum LogicalPlan {
/// A Melt operation
Melt {
input: Box<LogicalPlan>,
id_vars: Arc<Vec<String>>,
value_vars: Arc<Vec<String>>,
args: Arc<MeltArgs>,
schema: SchemaRef,
},
/// A User Defined Function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,16 +185,18 @@ impl PredicatePushDown {

Melt {
input,
id_vars,
value_vars,
args,
schema,
} => {
let variable_name = args.variable_name.as_deref().unwrap_or("variable");
let value_name = args.value_name.as_deref().unwrap_or("value_name");

// predicates that will be done at this level
let condition = |name: Arc<str>| {
let name = &*name;
name == "variable"
|| name == "value"
|| value_vars.iter().any(|s| s.as_str() == name)
name == variable_name
|| name == value_name
|| args.value_vars.iter().any(|s| s.as_str() == name)
};
let local_predicates =
transfer_to_local(expr_arena, &mut acc_predicates, condition);
Expand All @@ -203,8 +205,7 @@ impl PredicatePushDown {

let lp = ALogicalPlan::Melt {
input,
id_vars,
value_vars,
args,
schema,
};
Ok(self.optional_apply_predicate(lp, local_predicates, lp_arena, expr_arena))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -564,12 +564,7 @@ impl ProjectionPushDown {
)?;
Ok(Selection { predicate, input })
}
Melt {
input,
id_vars,
value_vars,
..
} => {
Melt { input, args, .. } => {
let (mut acc_projections, mut local_projections, names) = split_acc_projections(
acc_projections,
lp_arena.get(input).schema(lp_arena),
Expand All @@ -581,15 +576,15 @@ impl ProjectionPushDown {
}

// make sure that the requested columns are projected
id_vars.iter().for_each(|name| {
args.id_vars.iter().for_each(|name| {
add_str_to_accumulated(
name,
&mut acc_projections,
&mut projected_names,
expr_arena,
)
});
value_vars.iter().for_each(|name| {
args.value_vars.iter().for_each(|name| {
add_str_to_accumulated(
name,
&mut acc_projections,
Expand All @@ -607,8 +602,7 @@ impl ProjectionPushDown {
expr_arena,
)?;

let builder =
ALogicalPlanBuilder::new(input, expr_arena, lp_arena).melt(id_vars, value_vars);
let builder = ALogicalPlanBuilder::new(input, expr_arena, lp_arena).melt(args);
Ok(self.finish_node(local_projections, builder))
}
Aggregate {
Expand Down

0 comments on commit 2b7e463

Please sign in to comment.