Skip to content

Commit

Permalink
Support writing arbitrarily nested arrow arrays (apache#1744)
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed May 26, 2022
1 parent 722fcfc commit b50ba82
Show file tree
Hide file tree
Showing 2 changed files with 460 additions and 1,409 deletions.
91 changes: 45 additions & 46 deletions parquet/src/arrow/arrow_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ use super::schema::{
decimal_length_from_precision,
};

use crate::arrow::levels::calculate_array_levels;
use crate::column::writer::ColumnWriter;
use crate::errors::{ParquetError, Result};
use crate::file::properties::WriterProperties;
Expand Down Expand Up @@ -173,16 +174,15 @@ impl<W: Write> ArrowWriter<W> {
}
}

let mut levels: Vec<_> = arrays
let mut levels = arrays
.iter()
.map(|array| {
let batch_level = LevelInfo::new(0, array.len());
let mut levels = batch_level.calculate_array_levels(array, field);
let mut levels = calculate_array_levels(array, field)?;
// Reverse levels as we pop() them when writing arrays
levels.reverse();
levels
Ok(levels)
})
.collect();
.collect::<Result<Vec<_>>>()?;

write_leaves(&mut row_group_writer, &arrays, &mut levels)?;
}
Expand Down Expand Up @@ -341,26 +341,24 @@ fn write_leaf(
column: &ArrayRef,
levels: LevelInfo,
) -> Result<i64> {
let indices = levels.filter_array_indices();
// Slice array according to computed offset and length
let column = column.slice(levels.offset, levels.length);
// TODO: Avoid filtering if no need
let indices = levels.non_null_indices();
let written = match writer {
ColumnWriter::Int32ColumnWriter(ref mut typed) => {
let values = match column.data_type() {
ArrowDataType::Date64 => {
// If the column is a Date64, we cast it to a Date32, and then interpret that as Int32
let array = if let ArrowDataType::Date64 = column.data_type() {
let array =
arrow::compute::cast(&column, &ArrowDataType::Date32)?;
let array = arrow::compute::cast(column, &ArrowDataType::Date32)?;
arrow::compute::cast(&array, &ArrowDataType::Int32)?
} else {
arrow::compute::cast(&column, &ArrowDataType::Int32)?
arrow::compute::cast(column, &ArrowDataType::Int32)?
};
let array = array
.as_any()
.downcast_ref::<arrow_array::Int32Array>()
.expect("Unable to get int32 array");
get_numeric_array_slice::<Int32Type, _>(array, &indices)
get_numeric_array_slice::<Int32Type, _>(array, indices)
}
ArrowDataType::UInt32 => {
// follow C++ implementation and use overflow/reinterpret cast from u32 to i32 which will map
Expand All @@ -373,21 +371,21 @@ fn write_leaf(
array,
|x| x as i32,
);
get_numeric_array_slice::<Int32Type, _>(&array, &indices)
get_numeric_array_slice::<Int32Type, _>(&array, indices)
}
_ => {
let array = arrow::compute::cast(&column, &ArrowDataType::Int32)?;
let array = arrow::compute::cast(column, &ArrowDataType::Int32)?;
let array = array
.as_any()
.downcast_ref::<arrow_array::Int32Array>()
.expect("Unable to get i32 array");
get_numeric_array_slice::<Int32Type, _>(array, &indices)
get_numeric_array_slice::<Int32Type, _>(array, indices)
}
};
typed.write_batch(
values.as_slice(),
Some(levels.definition.as_slice()),
levels.repetition.as_deref(),
levels.def_levels(),
levels.rep_levels(),
)?
}
ColumnWriter::BoolColumnWriter(ref mut typed) => {
Expand All @@ -396,9 +394,9 @@ fn write_leaf(
.downcast_ref::<arrow_array::BooleanArray>()
.expect("Unable to get boolean array");
typed.write_batch(
get_bool_array_slice(array, &indices).as_slice(),
Some(levels.definition.as_slice()),
levels.repetition.as_deref(),
get_bool_array_slice(array, indices).as_slice(),
levels.def_levels(),
levels.rep_levels(),
)?
}
ColumnWriter::Int64ColumnWriter(ref mut typed) => {
Expand All @@ -408,7 +406,7 @@ fn write_leaf(
.as_any()
.downcast_ref::<arrow_array::Int64Array>()
.expect("Unable to get i64 array");
get_numeric_array_slice::<Int64Type, _>(array, &indices)
get_numeric_array_slice::<Int64Type, _>(array, indices)
}
ArrowDataType::UInt64 => {
// follow C++ implementation and use overflow/reinterpret cast from u64 to i64 which will map
Expand All @@ -421,21 +419,21 @@ fn write_leaf(
array,
|x| x as i64,
);
get_numeric_array_slice::<Int64Type, _>(&array, &indices)
get_numeric_array_slice::<Int64Type, _>(&array, indices)
}
_ => {
let array = arrow::compute::cast(&column, &ArrowDataType::Int64)?;
let array = arrow::compute::cast(column, &ArrowDataType::Int64)?;
let array = array
.as_any()
.downcast_ref::<arrow_array::Int64Array>()
.expect("Unable to get i64 array");
get_numeric_array_slice::<Int64Type, _>(array, &indices)
get_numeric_array_slice::<Int64Type, _>(array, indices)
}
};
typed.write_batch(
values.as_slice(),
Some(levels.definition.as_slice()),
levels.repetition.as_deref(),
levels.def_levels(),
levels.rep_levels(),
)?
}
ColumnWriter::Int96ColumnWriter(ref mut _typed) => {
Expand All @@ -447,9 +445,9 @@ fn write_leaf(
.downcast_ref::<arrow_array::Float32Array>()
.expect("Unable to get Float32 array");
typed.write_batch(
get_numeric_array_slice::<FloatType, _>(array, &indices).as_slice(),
Some(levels.definition.as_slice()),
levels.repetition.as_deref(),
get_numeric_array_slice::<FloatType, _>(array, indices).as_slice(),
levels.def_levels(),
levels.rep_levels(),
)?
}
ColumnWriter::DoubleColumnWriter(ref mut typed) => {
Expand All @@ -458,9 +456,9 @@ fn write_leaf(
.downcast_ref::<arrow_array::Float64Array>()
.expect("Unable to get Float64 array");
typed.write_batch(
get_numeric_array_slice::<DoubleType, _>(array, &indices).as_slice(),
Some(levels.definition.as_slice()),
levels.repetition.as_deref(),
get_numeric_array_slice::<DoubleType, _>(array, indices).as_slice(),
levels.def_levels(),
levels.rep_levels(),
)?
}
ColumnWriter::ByteArrayColumnWriter(ref mut typed) => match column.data_type() {
Expand All @@ -471,8 +469,8 @@ fn write_leaf(
.expect("Unable to get BinaryArray array");
typed.write_batch(
get_binary_array(array).as_slice(),
Some(levels.definition.as_slice()),
levels.repetition.as_deref(),
levels.def_levels(),
levels.rep_levels(),
)?
}
ArrowDataType::Utf8 => {
Expand All @@ -482,8 +480,8 @@ fn write_leaf(
.expect("Unable to get LargeBinaryArray array");
typed.write_batch(
get_string_array(array).as_slice(),
Some(levels.definition.as_slice()),
levels.repetition.as_deref(),
levels.def_levels(),
levels.rep_levels(),
)?
}
ArrowDataType::LargeBinary => {
Expand All @@ -493,8 +491,8 @@ fn write_leaf(
.expect("Unable to get LargeBinaryArray array");
typed.write_batch(
get_large_binary_array(array).as_slice(),
Some(levels.definition.as_slice()),
levels.repetition.as_deref(),
levels.def_levels(),
levels.rep_levels(),
)?
}
ArrowDataType::LargeUtf8 => {
Expand All @@ -504,8 +502,8 @@ fn write_leaf(
.expect("Unable to get LargeUtf8 array");
typed.write_batch(
get_large_string_array(array).as_slice(),
Some(levels.definition.as_slice()),
levels.repetition.as_deref(),
levels.def_levels(),
levels.rep_levels(),
)?
}
_ => unreachable!("Currently unreachable because data type not supported"),
Expand All @@ -518,14 +516,14 @@ fn write_leaf(
.as_any()
.downcast_ref::<arrow_array::IntervalYearMonthArray>()
.unwrap();
get_interval_ym_array_slice(array, &indices)
get_interval_ym_array_slice(array, indices)
}
IntervalUnit::DayTime => {
let array = column
.as_any()
.downcast_ref::<arrow_array::IntervalDayTimeArray>()
.unwrap();
get_interval_dt_array_slice(array, &indices)
get_interval_dt_array_slice(array, indices)
}
_ => {
return Err(ParquetError::NYI(
Expand All @@ -541,14 +539,14 @@ fn write_leaf(
.as_any()
.downcast_ref::<arrow_array::FixedSizeBinaryArray>()
.unwrap();
get_fsb_array_slice(array, &indices)
get_fsb_array_slice(array, indices)
}
ArrowDataType::Decimal(_, _) => {
let array = column
.as_any()
.downcast_ref::<arrow_array::DecimalArray>()
.unwrap();
get_decimal_array_slice(array, &indices)
get_decimal_array_slice(array, indices)
}
_ => {
return Err(ParquetError::NYI(
Expand All @@ -559,8 +557,8 @@ fn write_leaf(
};
typed.write_batch(
bytes.as_slice(),
Some(levels.definition.as_slice()),
levels.repetition.as_deref(),
levels.def_levels(),
levels.rep_levels(),
)?
}
};
Expand Down Expand Up @@ -593,6 +591,7 @@ macro_rules! def_get_binary_array_fn {
};
}

// TODO: These methods don't handle non null indices correctly
def_get_binary_array_fn!(get_binary_array, arrow_array::BinaryArray);
def_get_binary_array_fn!(get_string_array, arrow_array::StringArray);
def_get_binary_array_fn!(get_large_binary_array, arrow_array::LargeBinaryArray);
Expand Down

0 comments on commit b50ba82

Please sign in to comment.