Skip to content

Commit

Permalink
fix[rust]: test parquet statistics and fix encountered issues (#4656)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Sep 1, 2022
1 parent 96e897c commit 33a1c27
Show file tree
Hide file tree
Showing 11 changed files with 313 additions and 178 deletions.
11 changes: 9 additions & 2 deletions polars/polars-io/src/parquet/predicates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,18 @@ impl BatchStats {
pub(crate) fn collect_statistics(
md: &[RowGroupMetaData],
arrow_schema: &ArrowSchema,
rg: Option<usize>,
) -> ArrowResult<Option<BatchStats>> {
let mut schema = Schema::with_capacity(arrow_schema.fields.len());
let mut stats = vec![];

for fld in &arrow_schema.fields {
let st = deserialize(fld, md)?;
// note that we only select a single row group.
let st = match rg {
None => deserialize(fld, md)?,
// we select a single row group and collect only those stats
Some(rg) => deserialize(fld, &md[rg..rg + 1])?,
};
schema.with_column(fld.name.to_string(), (&fld.data_type).into());
stats.push(ColumnStats(st, Field::from(fld)));
}
Expand All @@ -93,10 +99,11 @@ pub(super) fn read_this_row_group(
predicate: Option<&Arc<dyn PhysicalIoExpr>>,
file_metadata: &arrow::io::parquet::read::FileMetaData,
schema: &ArrowSchema,
rg: usize,
) -> Result<bool> {
if let Some(pred) = &predicate {
if let Some(pred) = pred.as_stats_evaluator() {
if let Some(stats) = collect_statistics(&file_metadata.row_groups, schema)? {
if let Some(stats) = collect_statistics(&file_metadata.row_groups, schema, Some(rg))? {
let should_read = pred.should_read(&stats);
// a parquet file may not have statistics of all columns
if matches!(should_read, Ok(false)) {
Expand Down
12 changes: 7 additions & 5 deletions polars/polars-io/src/parquet/read_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ fn rg_to_dfs(
let md = &file_metadata.row_groups[rg];
let current_row_count = md.num_rows() as IdxSize;

if !read_this_row_group(predicate.as_ref(), file_metadata, schema)? {
if !read_this_row_group(predicate.as_ref(), file_metadata, schema, rg)? {
previous_row_count += current_row_count;
continue;
}
Expand Down Expand Up @@ -170,21 +170,23 @@ fn rg_to_dfs_par(
let row_groups = file_metadata
.row_groups
.iter()
.map(|rg_md| {
.enumerate()
.map(|(rg_idx, rg_md)| {
let row_count_start = previous_row_count;
let num_rows = rg_md.num_rows();
previous_row_count += num_rows;
let local_limit = remaining_rows;
remaining_rows = remaining_rows.saturating_sub(num_rows);

(rg_md, local_limit, row_count_start)
(rg_idx, rg_md, local_limit, row_count_start)
})
.collect::<Vec<_>>();

let dfs = row_groups
.into_par_iter()
.map(|(md, local_limit, row_count_start)| {
if local_limit == 0 || !read_this_row_group(predicate.as_ref(), file_metadata, schema)?
.map(|(rg_idx, md, local_limit, row_count_start)| {
if local_limit == 0
|| !read_this_row_group(predicate.as_ref(), file_metadata, schema, rg_idx)?
{
return Ok(None);
}
Expand Down
159 changes: 1 addition & 158 deletions polars/polars-lazy/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ mod pow;
mod rolling;
#[cfg(feature = "row_hash")]
mod row_hash;
mod schema;
#[cfg(feature = "search_sorted")]
mod search_sorted;
mod shift_and_fill;
Expand Down Expand Up @@ -111,164 +112,6 @@ pub enum TrigonometricFunction {
ArcTanh,
}

impl FunctionExpr {
pub(crate) fn get_field(
&self,
_input_schema: &Schema,
_cntxt: Context,
fields: &[Field],
) -> Result<Field> {
// set a dtype
let with_dtype = |dtype: DataType| Ok(Field::new(fields[0].name(), dtype));

// map a single dtype
let map_dtype = |func: &dyn Fn(&DataType) -> DataType| {
let dtype = func(fields[0].data_type());
Ok(Field::new(fields[0].name(), dtype))
};

// map all dtypes
#[cfg(feature = "list")]
let map_dtypes = |func: &dyn Fn(&[&DataType]) -> DataType| {
let mut fld = fields[0].clone();
let dtypes = fields.iter().map(|fld| fld.data_type()).collect::<Vec<_>>();
let new_type = func(&dtypes);
fld.coerce(new_type);
Ok(fld)
};

#[cfg(any(feature = "rolling_window", feature = "trigonometry"))]
// set float supertype
let float_dtype = || {
map_dtype(&|dtype| match dtype {
DataType::Float32 => DataType::Float32,
_ => DataType::Float64,
})
};

// map to same type
let same_type = || map_dtype(&|dtype| dtype.clone());

// get supertype of all types
let super_type = || {
let mut first = fields[0].clone();
let mut st = first.data_type().clone();
for field in &fields[1..] {
st = get_supertype(&st, field.data_type())?
}
first.coerce(st);
Ok(first)
};

// inner super type of lists
#[cfg(feature = "list")]
let inner_super_type_list = || {
map_dtypes(&|dts| {
let mut super_type_inner = None;

for dt in dts {
match dt {
DataType::List(inner) => match super_type_inner {
None => super_type_inner = Some(*inner.clone()),
Some(st_inner) => {
super_type_inner = get_supertype(&st_inner, inner).ok()
}
},
dt => match super_type_inner {
None => super_type_inner = Some((*dt).clone()),
Some(st_inner) => super_type_inner = get_supertype(&st_inner, dt).ok(),
},
}
}
DataType::List(Box::new(super_type_inner.unwrap()))
})
};

use FunctionExpr::*;
match self {
NullCount => with_dtype(IDX_DTYPE),
Pow => super_type(),
#[cfg(feature = "row_hash")]
Hash(..) => with_dtype(DataType::UInt64),
#[cfg(feature = "is_in")]
IsIn => with_dtype(DataType::Boolean),
#[cfg(feature = "arg_where")]
ArgWhere => with_dtype(IDX_DTYPE),
#[cfg(feature = "search_sorted")]
SearchSorted => with_dtype(IDX_DTYPE),
#[cfg(feature = "strings")]
StringExpr(s) => {
use StringFunction::*;
match s {
Contains { .. } | EndsWith(_) | StartsWith(_) => with_dtype(DataType::Boolean),
Extract { .. } => same_type(),
ExtractAll(_) => with_dtype(DataType::List(Box::new(DataType::Utf8))),
CountMatch(_) => with_dtype(DataType::UInt32),
#[cfg(feature = "string_justify")]
Zfill { .. } | LJust { .. } | RJust { .. } => same_type(),
#[cfg(feature = "temporal")]
Strptime(options) => with_dtype(options.date_dtype.clone()),
#[cfg(feature = "concat_str")]
Concat(_) => with_dtype(DataType::Utf8),
#[cfg(feature = "regex")]
Replace { .. } => with_dtype(DataType::Utf8),
Uppercase | Lowercase => with_dtype(DataType::Utf8),
}
}

#[cfg(feature = "date_offset")]
DateOffset(_) => same_type(),
#[cfg(feature = "trigonometry")]
Trigonometry(_) => float_dtype(),
#[cfg(feature = "sign")]
Sign => with_dtype(DataType::Int64),
FillNull { super_type, .. } => with_dtype(super_type.clone()),
#[cfg(feature = "is_in")]
ListContains => with_dtype(DataType::Boolean),
#[cfg(all(feature = "rolling_window", feature = "moment"))]
RollingSkew { .. } => float_dtype(),
ShiftAndFill { .. } => same_type(),
Nan(n) => n.get_field(fields),
#[cfg(feature = "round_series")]
Clip { .. } => same_type(),
#[cfg(feature = "list")]
ListExpr(l) => {
use ListFunction::*;
match l {
Concat => inner_super_type_list(),
}
}
#[cfg(feature = "dtype-struct")]
StructExpr(s) => {
use StructFunction::*;
match s {
FieldByIndex(index) => {
let (index, _) = slice_offsets(*index, 0, fields.len());
fields.get(index).cloned().ok_or_else(|| {
PolarsError::ComputeError(
"index out of bounds in 'struct.field'".into(),
)
})
}
FieldByName(name) => {
if let DataType::Struct(flds) = &fields[0].dtype {
let fld = flds
.iter()
.find(|fld| fld.name() == name.as_ref())
.ok_or_else(|| PolarsError::NotFound(name.as_ref().to_string()))?;
Ok(fld.clone())
} else {
Err(PolarsError::NotFound(name.as_ref().to_string()))
}
}
}
}
#[cfg(feature = "top_k")]
TopK { .. } => same_type(),
}
}
}

macro_rules! wrap {
($e:expr) => {
SpecialEq::new(Arc::new($e))
Expand Down

0 comments on commit 33a1c27

Please sign in to comment.