Skip to content

Commit

Permalink
update schema in udfs (#4165)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jul 27, 2022
1 parent 78ec23b commit b8877b0
Show file tree
Hide file tree
Showing 16 changed files with 204 additions and 146 deletions.
90 changes: 51 additions & 39 deletions polars/polars-lazy/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ impl LazyFrame {
/// Get a hold on the schema of the current LazyFrame computation.
pub fn schema(&self) -> SchemaRef {
let logical_plan = self.clone().get_plan_builder().build();
logical_plan.schema().clone()
logical_plan.schema().into_owned()
}

pub(crate) fn get_plan_builder(self) -> LogicalPlanBuilder {
Expand Down Expand Up @@ -345,12 +345,16 @@ impl LazyFrame {
}
}

// schema after renaming
let mut new_schema = (*self.schema()).clone();

for (old, new) in existing.iter().zip(new.iter()) {
new_schema.rename(old, new.to_string()).unwrap();
}
let existing2 = existing.clone();
let new2 = new.clone();
let udf_schema = move |s: &Schema| {
// schema after renaming
let mut new_schema = s.clone();
for (old, new) in existing2.iter().zip(new2.iter()) {
new_schema.rename(old, new.to_string()).unwrap();
}
Ok(Arc::new(new_schema))
};

let prefix = "__POLARS_TEMP_";

Expand Down Expand Up @@ -393,17 +397,21 @@ impl LazyFrame {
DataFrame::new(cols)
},
None,
Some(new_schema),
Some(Arc::new(udf_schema)),
Some("RENAME_SWAPPING"),
)
}

fn rename_imp(self, existing: Vec<String>, new: Vec<String>) -> Self {
let mut schema = (*self.schema()).clone();

for (old, new) in existing.iter().zip(&new) {
let _ = schema.rename(old, new.clone());
}
fn rename_impl(self, existing: Vec<String>, new: Vec<String>) -> Self {
let existing2 = existing.clone();
let new2 = new.clone();
let udf_schema = move |s: &Schema| {
let mut new_schema = s.clone();
for (old, new) in existing2.iter().zip(&new2) {
let _ = new_schema.rename(old, new.clone());
}
Ok(Arc::new(new_schema))
};

self.with_columns(
existing
Expand All @@ -427,7 +435,7 @@ impl LazyFrame {
Ok(df)
},
None,
Some(schema),
Some(Arc::new(udf_schema)),
Some("RENAME"),
)
}
Expand Down Expand Up @@ -457,7 +465,7 @@ impl LazyFrame {
if new.iter().any(|name| schema.get(name).is_some()) {
self.rename_impl_swapping(existing, new)
} else {
self.rename_imp(existing, new)
self.rename_impl(existing, new)
}
}

Expand Down Expand Up @@ -556,7 +564,7 @@ impl LazyFrame {

// during debug we check if the optimizations have not modified the final schema
#[cfg(debug_assertions)]
let prev_schema = logical_plan.schema().clone();
let prev_schema = logical_plan.schema().into_owned();

let mut lp_top = to_alp(logical_plan, expr_arena, lp_arena)?;

Expand Down Expand Up @@ -1154,7 +1162,7 @@ impl LazyFrame {
self,
function: F,
optimizations: Option<AllowedOptimizations>,
schema: Option<Schema>,
schema: Option<Arc<dyn UdfSchema>>,
name: Option<&'static str>,
) -> LazyFrame
where
Expand All @@ -1166,7 +1174,7 @@ impl LazyFrame {
.map(
function,
optimizations.unwrap_or_default(),
schema.map(Arc::new),
schema,
name.unwrap_or("ANONYMOUS UDF"),
)
.build();
Expand Down Expand Up @@ -1208,10 +1216,12 @@ impl LazyFrame {
}
}

let new_schema = self
.schema()
.insert_index(0, name.to_string(), IDX_DTYPE)
.unwrap();
let name2 = name.to_string();
let udf_schema = move |s: &Schema| {
let new = s.insert_index(0, name2.clone(), IDX_DTYPE).unwrap();
Ok(Arc::new(new))
};

let name = name.to_owned();

// if we do the row count at scan we add a dummy map, to update the schema
Expand All @@ -1234,7 +1244,7 @@ impl LazyFrame {
}
},
Some(opt),
Some(new_schema),
Some(Arc::new(udf_schema)),
Some("WITH ROW COUNT"),
)
}
Expand All @@ -1250,27 +1260,29 @@ impl LazyFrame {

#[cfg(feature = "dtype-struct")]
fn unnest_impl(self, cols: PlHashSet<String>) -> Self {
let schema = self.schema();

let mut new_schema = Schema::with_capacity(schema.len() * 2);
for (name, dtype) in schema.iter() {
if cols.contains(name) {
if let DataType::Struct(flds) = dtype {
for fld in flds {
new_schema.with_column(fld.name().clone(), fld.data_type().clone())
let cols2 = cols.clone();
let udf_schema = move |schema: &Schema| {
let mut new_schema = Schema::with_capacity(schema.len() * 2);
for (name, dtype) in schema.iter() {
if cols.contains(name) {
if let DataType::Struct(flds) = dtype {
for fld in flds {
new_schema.with_column(fld.name().clone(), fld.data_type().clone())
}
} else {
// todo: return lazy error here.
panic!("expected struct dtype")
}
} else {
// todo: return lazy error here.
panic!("expected struct dtype")
new_schema.with_column(name.clone(), dtype.clone())
}
} else {
new_schema.with_column(name.clone(), dtype.clone())
}
}
Ok(Arc::new(new_schema))
};
self.map(
move |df| df.unnest(&cols),
move |df| df.unnest(&cols2),
Some(AllowedOptimizations::default()),
Some(new_schema),
Some(Arc::new(udf_schema)),
Some("unnest"),
)
}
Expand Down
55 changes: 33 additions & 22 deletions polars/polars-lazy/src/logical_plan/alp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::utils::{aexprs_to_schema, PushNode};
use polars_core::frame::explode::MeltArgs;
use polars_core::prelude::*;
use polars_utils::arena::{Arena, Node};
use std::borrow::Cow;
#[cfg(any(feature = "ipc", feature = "csv-file", feature = "parquet"))]
use std::path::PathBuf;
use std::sync::Arc;
Expand Down Expand Up @@ -132,7 +133,7 @@ pub enum ALogicalPlan {
input: Node,
function: Arc<dyn DataFrameUdf>,
options: LogicalPlanUdfOptions,
schema: Option<SchemaRef>,
schema: Option<Arc<dyn UdfSchema>>,
},
Union {
inputs: Vec<Node>,
Expand Down Expand Up @@ -171,14 +172,14 @@ impl ALogicalPlan {
}

/// Get the schema of the logical plan node.
pub(crate) fn schema<'a>(&'a self, arena: &'a Arena<ALogicalPlan>) -> &'a SchemaRef {
pub(crate) fn schema<'a>(&'a self, arena: &'a Arena<ALogicalPlan>) -> Cow<'a, SchemaRef> {
use ALogicalPlan::*;
match self {
let schema = match self {
#[cfg(feature = "python")]
PythonScan { options } => &options.schema,
Union { inputs, .. } => arena.get(inputs[0]).schema(arena),
Cache { input } => arena.get(*input).schema(arena),
Sort { input, .. } => arena.get(*input).schema(arena),
Union { inputs, .. } => return arena.get(inputs[0]).schema(arena),
Cache { input } => return arena.get(*input).schema(arena),
Sort { input, .. } => return arena.get(*input).schema(arena),
Explode { schema, .. } => schema,
#[cfg(feature = "parquet")]
ParquetScan {
Expand All @@ -198,7 +199,7 @@ impl ALogicalPlan {
output_schema,
..
} => output_schema.as_ref().unwrap_or(schema),
Selection { input, .. } => arena.get(*input).schema(arena),
Selection { input, .. } => return arena.get(*input).schema(arena),
#[cfg(feature = "csv-file")]
CsvScan {
schema,
Expand All @@ -210,14 +211,18 @@ impl ALogicalPlan {
Aggregate { schema, .. } => schema,
Join { schema, .. } => schema,
HStack { schema, .. } => schema,
Distinct { input, .. } => arena.get(*input).schema(arena),
Slice { input, .. } => arena.get(*input).schema(arena),
Distinct { input, .. } => return arena.get(*input).schema(arena),
Slice { input, .. } => return arena.get(*input).schema(arena),
Melt { schema, .. } => schema,
Udf { input, schema, .. } => match schema {
Some(schema) => schema,
None => arena.get(*input).schema(arena),
},
}
Udf { input, schema, .. } => {
let input_schema = arena.get(*input).schema(arena);
return match schema {
Some(schema) => Cow::Owned(schema.get_schema(&input_schema).unwrap()),
None => input_schema,
};
}
};
Cow::Borrowed(schema)
}
}

Expand Down Expand Up @@ -622,7 +627,7 @@ impl<'a> ALogicalPlanBuilder<'a> {
}

pub fn melt(self, args: Arc<MeltArgs>) -> Self {
let schema = det_melt_schema(&args, self.schema());
let schema = det_melt_schema(&args, &self.schema());

let lp = ALogicalPlan::Melt {
input: self.root,
Expand All @@ -635,7 +640,7 @@ impl<'a> ALogicalPlanBuilder<'a> {

pub fn project_local(self, exprs: Vec<Node>) -> Self {
let input_schema = self.lp_arena.get(self.root).schema(self.lp_arena);
let schema = aexprs_to_schema(&exprs, input_schema, Context::Default, self.expr_arena);
let schema = aexprs_to_schema(&exprs, &input_schema, Context::Default, self.expr_arena);
let lp = ALogicalPlan::LocalProjection {
expr: exprs,
input: self.root,
Expand All @@ -647,7 +652,7 @@ impl<'a> ALogicalPlanBuilder<'a> {

pub fn project(self, exprs: Vec<Node>) -> Self {
let input_schema = self.lp_arena.get(self.root).schema(self.lp_arena);
let schema = aexprs_to_schema(&exprs, input_schema, Context::Default, self.expr_arena);
let schema = aexprs_to_schema(&exprs, &input_schema, Context::Default, self.expr_arena);

// if len == 0, no projection has to be done. This is a select all operation.
if !exprs.is_empty() {
Expand All @@ -671,19 +676,19 @@ impl<'a> ALogicalPlanBuilder<'a> {
}
}

pub(crate) fn schema(&self) -> &Schema {
pub(crate) fn schema(&'a self) -> Cow<'a, SchemaRef> {
self.lp_arena.get(self.root).schema(self.lp_arena)
}

pub(crate) fn with_columns(self, exprs: Vec<Node>) -> Self {
let schema = self.schema();
let mut new_schema = (*schema).clone();
let mut new_schema = (**schema).clone();

for e in &exprs {
let field = self
.expr_arena
.get(*e)
.to_field(schema, Context::Default, self.expr_arena)
.to_field(&schema, Context::Default, self.expr_arena)
.unwrap();

new_schema.with_column(field.name().clone(), field.data_type().clone());
Expand All @@ -710,8 +715,14 @@ impl<'a> ALogicalPlanBuilder<'a> {
// TODO! add this line if LogicalPlan is dropped in favor of ALogicalPlan
// let aggs = rewrite_projections(aggs, current_schema);

let mut schema = aexprs_to_schema(&keys, current_schema, Context::Default, self.expr_arena);
let other = aexprs_to_schema(&aggs, current_schema, Context::Aggregation, self.expr_arena);
let mut schema =
aexprs_to_schema(&keys, &current_schema, Context::Default, self.expr_arena);
let other = aexprs_to_schema(
&aggs,
&current_schema,
Context::Aggregation,
self.expr_arena,
);
schema.merge(other);

let index_columns = &[
Expand Down
21 changes: 20 additions & 1 deletion polars/polars-lazy/src/logical_plan/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,25 @@ where

impl Debug for dyn DataFrameUdf {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "udf")
write!(f, "dyn DataFrameUdf")
}
}

pub trait UdfSchema: Send + Sync {
fn get_schema(&self, input_schema: &Schema) -> Result<SchemaRef>;
}

impl<F> UdfSchema for F
where
F: Fn(&Schema) -> Result<SchemaRef> + Send + Sync,
{
fn get_schema(&self, input_schema: &Schema) -> Result<SchemaRef> {
self(input_schema)
}
}

impl Debug for dyn UdfSchema {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "dyn UdfSchema")
}
}

0 comments on commit b8877b0

Please sign in to comment.