Skip to content

Commit

Permalink
ensure that schema is updated after drop
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jul 26, 2022
1 parent ad15e93 commit c82c177
Show file tree
Hide file tree
Showing 12 changed files with 120 additions and 85 deletions.
4 changes: 2 additions & 2 deletions polars/polars-lazy/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ impl LazyFrame {
/// Get a hold on the schema of the current LazyFrame computation.
pub fn schema(&self) -> SchemaRef {
let logical_plan = self.clone().get_plan_builder().build();
logical_plan.schema().clone()
logical_plan.schema().into_owned()
}

pub(crate) fn get_plan_builder(self) -> LogicalPlanBuilder {
Expand Down Expand Up @@ -556,7 +556,7 @@ impl LazyFrame {

// during debug we check if the optimizations have not modified the final schema
#[cfg(debug_assertions)]
let prev_schema = logical_plan.schema().clone();
let prev_schema = logical_plan.schema().into_owned();

let mut lp_top = to_alp(logical_plan, expr_arena, lp_arena)?;

Expand Down
47 changes: 26 additions & 21 deletions polars/polars-lazy/src/logical_plan/alp.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::borrow::Cow;
#[cfg(feature = "ipc")]
use crate::logical_plan::IpcScanOptionsInner;
#[cfg(feature = "parquet")]
Expand Down Expand Up @@ -132,7 +133,7 @@ pub enum ALogicalPlan {
input: Node,
function: Arc<dyn DataFrameUdf>,
options: LogicalPlanUdfOptions,
schema: Option<SchemaRef>,
schema: Option<Arc<dyn UdfSchema>>
},
Union {
inputs: Vec<Node>,
Expand Down Expand Up @@ -171,14 +172,14 @@ impl ALogicalPlan {
}

/// Get the schema of the logical plan node.
pub(crate) fn schema<'a>(&'a self, arena: &'a Arena<ALogicalPlan>) -> &'a SchemaRef {
pub(crate) fn schema<'a>(&'a self, arena: &'a Arena<ALogicalPlan>) -> Cow<'a, SchemaRef> {
use ALogicalPlan::*;
match self {
let schema = match self {
#[cfg(feature = "python")]
PythonScan { options } => &options.schema,
Union { inputs, .. } => arena.get(inputs[0]).schema(arena),
Cache { input } => arena.get(*input).schema(arena),
Sort { input, .. } => arena.get(*input).schema(arena),
Union { inputs, .. } => return arena.get(inputs[0]).schema(arena),
Cache { input } => return arena.get(*input).schema(arena),
Sort { input, .. } => return arena.get(*input).schema(arena),
Explode { schema, .. } => schema,
#[cfg(feature = "parquet")]
ParquetScan {
Expand All @@ -198,7 +199,7 @@ impl ALogicalPlan {
output_schema,
..
} => output_schema.as_ref().unwrap_or(schema),
Selection { input, .. } => arena.get(*input).schema(arena),
Selection { input, .. } => return arena.get(*input).schema(arena),
#[cfg(feature = "csv-file")]
CsvScan {
schema,
Expand All @@ -210,14 +211,18 @@ impl ALogicalPlan {
Aggregate { schema, .. } => schema,
Join { schema, .. } => schema,
HStack { schema, .. } => schema,
Distinct { input, .. } => arena.get(*input).schema(arena),
Slice { input, .. } => arena.get(*input).schema(arena),
Distinct { input, .. } => return arena.get(*input).schema(arena),
Slice { input, .. } => return arena.get(*input).schema(arena),
Melt { schema, .. } => schema,
Udf { input, schema, .. } => match schema {
Some(schema) => schema,
None => arena.get(*input).schema(arena),
Udf { input, schema, .. } => {
let input_schema = arena.get(*input).schema(arena);
return match schema {
Some(schema) => Cow::Owned(schema.get_schema(&input_schema).unwrap()),
None => input_schema,
}
},
}
};
Cow::Borrowed(schema)
}
}

Expand Down Expand Up @@ -622,7 +627,7 @@ impl<'a> ALogicalPlanBuilder<'a> {
}

pub fn melt(self, args: Arc<MeltArgs>) -> Self {
let schema = det_melt_schema(&args, self.schema());
let schema = det_melt_schema(&args, &self.schema());

let lp = ALogicalPlan::Melt {
input: self.root,
Expand All @@ -635,7 +640,7 @@ impl<'a> ALogicalPlanBuilder<'a> {

pub fn project_local(self, exprs: Vec<Node>) -> Self {
let input_schema = self.lp_arena.get(self.root).schema(self.lp_arena);
let schema = aexprs_to_schema(&exprs, input_schema, Context::Default, self.expr_arena);
let schema = aexprs_to_schema(&exprs, &input_schema, Context::Default, self.expr_arena);
let lp = ALogicalPlan::LocalProjection {
expr: exprs,
input: self.root,
Expand All @@ -647,7 +652,7 @@ impl<'a> ALogicalPlanBuilder<'a> {

pub fn project(self, exprs: Vec<Node>) -> Self {
let input_schema = self.lp_arena.get(self.root).schema(self.lp_arena);
let schema = aexprs_to_schema(&exprs, input_schema, Context::Default, self.expr_arena);
let schema = aexprs_to_schema(&exprs, &input_schema, Context::Default, self.expr_arena);

// if len == 0, no projection has to be done. This is a select all operation.
if !exprs.is_empty() {
Expand All @@ -671,19 +676,19 @@ impl<'a> ALogicalPlanBuilder<'a> {
}
}

pub(crate) fn schema(&self) -> &Schema {
pub(crate) fn schema(&'a self) -> Cow<'a, SchemaRef> {
self.lp_arena.get(self.root).schema(self.lp_arena)
}

pub(crate) fn with_columns(self, exprs: Vec<Node>) -> Self {
let schema = self.schema();
let mut new_schema = (*schema).clone();
let mut new_schema = (**schema).clone();

for e in &exprs {
let field = self
.expr_arena
.get(*e)
.to_field(schema, Context::Default, self.expr_arena)
.to_field(&schema, Context::Default, self.expr_arena)
.unwrap();

new_schema.with_column(field.name().clone(), field.data_type().clone());
Expand All @@ -710,8 +715,8 @@ impl<'a> ALogicalPlanBuilder<'a> {
// TODO! add this line if LogicalPlan is dropped in favor of ALogicalPlan
// let aggs = rewrite_projections(aggs, current_schema);

let mut schema = aexprs_to_schema(&keys, current_schema, Context::Default, self.expr_arena);
let other = aexprs_to_schema(&aggs, current_schema, Context::Aggregation, self.expr_arena);
let mut schema = aexprs_to_schema(&keys, &current_schema, Context::Default, self.expr_arena);
let other = aexprs_to_schema(&aggs, &current_schema, Context::Aggregation, self.expr_arena);
schema.merge(other);

let index_columns = &[
Expand Down
22 changes: 21 additions & 1 deletion polars/polars-lazy/src/logical_plan/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,26 @@ where

impl Debug for dyn DataFrameUdf {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "udf")
write!(f, "dyn DataFrameUdf")
}
}


pub trait UdfSchema: Send + Sync {
fn get_schema(&self, input_schema: &Schema) -> Result<SchemaRef>;
}

impl<F> UdfSchema for F
where
F: Fn(&Schema) -> Result<SchemaRef> + Send + Sync,
{
fn get_schema(&self, input_schema: &Schema) -> Result<SchemaRef> {
self(input_schema)
}
}

impl Debug for dyn UdfSchema {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "dyn UdfSchema")
}
}
21 changes: 12 additions & 9 deletions polars/polars-lazy/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ impl LogicalPlanBuilder {

pub fn project(self, exprs: Vec<Expr>) -> Self {
let (exprs, schema) =
try_delayed!(prepare_projection(exprs, self.0.schema()), &self.0, into);
try_delayed!(prepare_projection(exprs, &self.0.schema()), &self.0, into);

if exprs.is_empty() {
self.map(
Expand All @@ -251,7 +251,7 @@ impl LogicalPlanBuilder {

pub fn project_local(self, exprs: Vec<Expr>) -> Self {
let (exprs, schema) =
try_delayed!(prepare_projection(exprs, self.0.schema()), &self.0, into);
try_delayed!(prepare_projection(exprs, &self.0.schema()), &self.0, into);
LogicalPlan::LocalProjection {
expr: exprs,
input: Box::new(self.0),
Expand Down Expand Up @@ -293,10 +293,10 @@ impl LogicalPlanBuilder {
// current schema
let schema = self.0.schema();
let mut new_schema = (**schema).clone();
let (exprs, _) = try_delayed!(prepare_projection(exprs, schema), &self.0, into);
let (exprs, _) = try_delayed!(prepare_projection(exprs, &schema), &self.0, into);

for e in &exprs {
let field = e.to_field(schema, Context::Default).unwrap();
let field = e.to_field(&schema, Context::Default).unwrap();
new_schema.with_column(field.name().to_string(), field.data_type().clone());
}

Expand All @@ -315,7 +315,7 @@ impl LogicalPlanBuilder {
Expr::Wildcard | Expr::RenameAlias { .. } | Expr::Columns(_) => true,
_ => false,
}) {
let rewritten = rewrite_projections(vec![predicate], self.0.schema(), &[]);
let rewritten = rewrite_projections(vec![predicate], &self.0.schema(), &[]);
combine_predicates_expr(rewritten.into_iter())
} else {
predicate
Expand All @@ -337,6 +337,7 @@ impl LogicalPlanBuilder {
rolling_options: Option<RollingGroupOptions>,
) -> Self {
let current_schema = self.0.schema();
let current_schema = current_schema.as_ref();
let aggs = rewrite_projections(aggs.as_ref().to_vec(), current_schema, keys.as_ref());

let mut schema = try_delayed!(
Expand Down Expand Up @@ -402,7 +403,7 @@ impl LogicalPlanBuilder {
}

pub fn sort(self, by_column: Vec<Expr>, reverse: Vec<bool>, null_last: bool) -> Self {
let by_column = rewrite_projections(by_column, self.0.schema(), &[]);
let by_column = rewrite_projections(by_column, &self.0.schema(), &[]);
LogicalPlan::Sort {
input: Box::new(self.0),
by_column,
Expand All @@ -416,7 +417,7 @@ impl LogicalPlanBuilder {
}

pub fn explode(self, columns: Vec<Expr>) -> Self {
let columns = rewrite_projections(columns, self.0.schema(), &[]);
let columns = rewrite_projections(columns, &self.0.schema(), &[]);

let mut schema = (**self.0.schema()).clone();

Expand Down Expand Up @@ -445,7 +446,7 @@ impl LogicalPlanBuilder {
}

pub fn melt(self, args: Arc<MeltArgs>) -> Self {
let schema = det_melt_schema(&args, self.0.schema());
let schema = det_melt_schema(&args, &self.0.schema());
LogicalPlan::Melt {
input: Box::new(self.0),
args,
Expand Down Expand Up @@ -539,7 +540,9 @@ impl LogicalPlanBuilder {
input: Box::new(self.0),
function: Arc::new(function),
options,
schema,
schema: schema.map(|s| {
Arc::new(move |_: &Schema| Ok(s.clone())) as Arc<dyn UdfSchema>
})
}
.into()
}
Expand Down
40 changes: 22 additions & 18 deletions polars/polars-lazy/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use parking_lot::Mutex;
#[cfg(any(feature = "ipc", feature = "csv-file", feature = "parquet"))]
use std::path::PathBuf;
use std::{cell::Cell, fmt::Debug, sync::Arc};
use std::borrow::Cow;

use polars_core::prelude::*;

Expand Down Expand Up @@ -176,7 +177,7 @@ pub enum LogicalPlan {
input: Box<LogicalPlan>,
function: Arc<dyn DataFrameUdf>,
options: LogicalPlanUdfOptions,
schema: Option<SchemaRef>,
schema: Option<Arc<dyn UdfSchema>>,
},
Union {
inputs: Vec<LogicalPlan>,
Expand Down Expand Up @@ -214,35 +215,38 @@ impl LogicalPlan {
}

impl LogicalPlan {
pub(crate) fn schema(&self) -> &SchemaRef {
pub(crate) fn schema<'a>(&'a self) -> Cow<'a, SchemaRef> {
use LogicalPlan::*;
match self {
#[cfg(feature = "python")]
PythonScan { options } => &options.schema,
PythonScan { options } => Cow::Borrowed(&options.schema),
Union { inputs, .. } => inputs[0].schema(),
Cache { input } => input.schema(),
Sort { input, .. } => input.schema(),
Explode { schema, .. } => schema,
Explode { schema, .. } => Cow::Borrowed(schema),
#[cfg(feature = "parquet")]
ParquetScan { schema, .. } => schema,
ParquetScan { schema, .. } => Cow::Borrowed(schema),
#[cfg(feature = "ipc")]
IpcScan { schema, .. } => schema,
DataFrameScan { schema, .. } => schema,
AnonymousScan { schema, .. } => schema,
IpcScan { schema, .. } => Cow::Borrowed(schema),
DataFrameScan { schema, .. } => Cow::Borrowed(schema),
AnonymousScan { schema, .. } => Cow::Borrowed(schema),
Selection { input, .. } => input.schema(),
#[cfg(feature = "csv-file")]
CsvScan { schema, .. } => schema,
Projection { schema, .. } => schema,
LocalProjection { schema, .. } => schema,
Aggregate { schema, .. } => schema,
Join { schema, .. } => schema,
HStack { schema, .. } => schema,
CsvScan { schema, .. } => Cow::Borrowed(schema),
Projection { schema, .. } => Cow::Borrowed(schema),
LocalProjection { schema, .. } => Cow::Borrowed(schema),
Aggregate { schema, .. } => Cow::Borrowed(schema),
Join { schema, .. } => Cow::Borrowed(schema),
HStack { schema, .. } => Cow::Borrowed(schema),
Distinct { input, .. } => input.schema(),
Slice { input, .. } => input.schema(),
Melt { schema, .. } => schema,
Udf { input, schema, .. } => match schema {
Some(schema) => schema,
None => input.schema(),
Melt { schema, .. } => Cow::Borrowed(schema),
Udf { input, schema, .. } => {
let input_schema = input.schema();
match schema {
Some(schema) => Cow::Owned(schema.get_schema(&input_schema).unwrap()),
None => input_schema,
}
},
Error { input, .. } => input.schema(),
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ impl OptimizationRule for AggregatePushdown {
.map(|n| {
expr_arena
.get(*n)
.to_field(input_schema, Context::Default, expr_arena)
.to_field(&input_schema, Context::Default, expr_arena)
.unwrap()
})
.collect();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ fn impl_fast_projection(
let lp = ALogicalPlan::Udf {
input,
function: Arc::new(function),
schema,
schema: schema.map(|s| {
Arc::new(move |_: &Schema| Ok(s.clone())) as Arc<dyn UdfSchema>
}),
options,
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ impl PredicatePushDown {
optimizer::init_hashmap(Some(acc_predicates.len()));
for (name, &predicate) in acc_predicates.iter() {
// we can pushdown the predicate
if check_input_node(predicate, input_schema, expr_arena) {
if check_input_node(predicate, &input_schema, expr_arena) {
insert_and_combine_predicate(
&mut pushdown_predicates,
name.clone(),
Expand Down Expand Up @@ -235,10 +235,10 @@ impl PredicatePushDown {
// projection from a wildcard may be dropped if the schema changes due to the optimization
let expr: Vec<_> = expr
.into_iter()
.filter(|e| check_input_node(*e, schema, expr_arena))
.filter(|e| check_input_node(*e, &schema, expr_arena))
.collect();

let schema = aexprs_to_schema(&expr, schema, Context::Default, expr_arena);
let schema = aexprs_to_schema(&expr, &schema, Context::Default, expr_arena);
Ok(ALogicalPlan::LocalProjection {
expr,
input,
Expand Down Expand Up @@ -422,8 +422,8 @@ impl PredicatePushDown {
// be influenced by join
if !predicate_is_pushdown_boundary(predicate, expr_arena) {
// no else if. predicate can be in both tables.
if check_input_node(predicate, schema_left, expr_arena) {
let name = get_insertion_name(expr_arena, predicate, schema_left);
if check_input_node(predicate, &schema_left, expr_arena) {
let name = get_insertion_name(expr_arena, predicate, &schema_left);
insert_and_combine_predicate(
&mut pushdown_left,
name,
Expand All @@ -433,8 +433,8 @@ impl PredicatePushDown {
filter_left = true;
}

if check_input_node(predicate, schema_right, expr_arena) {
let name = get_insertion_name(expr_arena, predicate, schema_right);
if check_input_node(predicate, &schema_right, expr_arena) {
let name = get_insertion_name(expr_arena, predicate, &schema_right);
insert_and_combine_predicate(
&mut pushdown_right,
name,
Expand Down
Loading

0 comments on commit c82c177

Please sign in to comment.