From d1dd90d79c26619ed452db374a86239c10b77331 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Tue, 12 Apr 2022 16:51:30 +0100 Subject: [PATCH] Support empty RecordBatch (#1536) --- arrow/src/record_batch.rs | 89 ++++++++++++++++++++++++++------------ arrow/src/util/data_gen.rs | 1 + 2 files changed, 62 insertions(+), 28 deletions(-) diff --git a/arrow/src/record_batch.rs b/arrow/src/record_batch.rs index 10bd4c5c561..4557998acc4 100644 --- a/arrow/src/record_batch.rs +++ b/arrow/src/record_batch.rs @@ -41,6 +41,7 @@ use crate::error::{ArrowError, Result}; pub struct RecordBatch { schema: SchemaRef, columns: Vec>, + row_count: usize, } impl RecordBatch { @@ -77,8 +78,7 @@ impl RecordBatch { /// ``` pub fn try_new(schema: SchemaRef, columns: Vec) -> Result { let options = RecordBatchOptions::default(); - Self::validate_new_batch(&schema, columns.as_slice(), &options)?; - Ok(RecordBatch { schema, columns }) + Self::try_new_impl(schema, columns, &options) } /// Creates a `RecordBatch` from a schema and columns, with additional options, @@ -90,8 +90,7 @@ impl RecordBatch { columns: Vec, options: &RecordBatchOptions, ) -> Result { - Self::validate_new_batch(&schema, columns.as_slice(), options)?; - Ok(RecordBatch { schema, columns }) + Self::try_new_impl(schema, columns, options) } /// Creates a new empty [`RecordBatch`]. @@ -101,23 +100,21 @@ impl RecordBatch { .iter() .map(|field| new_empty_array(field.data_type())) .collect(); - RecordBatch { schema, columns } + + RecordBatch { + schema, + columns, + row_count: 0, + } } /// Validate the schema and columns using [`RecordBatchOptions`]. Returns an error - /// if any validation check fails. - fn validate_new_batch( - schema: &SchemaRef, - columns: &[ArrayRef], + /// if any validation check fails, otherwise returns the created [`RecordBatch`] + fn try_new_impl( + schema: SchemaRef, + columns: Vec, options: &RecordBatchOptions, - ) -> Result<()> { - // check that there are some columns - if columns.is_empty() { - return Err(ArrowError::InvalidArgumentError( - "at least one column must be defined to create a record batch" - .to_string(), - )); - } + ) -> Result { // check that number of fields in schema match column length if schema.fields().len() != columns.len() { return Err(ArrowError::InvalidArgumentError(format!( @@ -128,7 +125,13 @@ impl RecordBatch { } // check that all columns have the same row count - let row_count = columns[0].data().len(); + let row_count = options + .row_count + .or(columns.first().map(|col| col.len())) + .ok_or(ArrowError::InvalidArgumentError( + "must either specify a row count or at least one column".to_string(), + ))?; + if columns.iter().any(|c| c.len() != row_count) { return Err(ArrowError::InvalidArgumentError( "all columns in a record batch must have the same length".to_string(), @@ -163,7 +166,11 @@ impl RecordBatch { i))); } - Ok(()) + Ok(RecordBatch { + schema, + columns, + row_count, + }) } /// Returns the [`Schema`](crate::datatypes::Schema) of the record batch. @@ -218,10 +225,6 @@ impl RecordBatch { /// Returns the number of rows in each column. /// - /// # Panics - /// - /// Panics if the `RecordBatch` contains no columns. - /// /// # Example /// /// ``` @@ -243,7 +246,7 @@ impl RecordBatch { /// # } /// ``` pub fn num_rows(&self) -> usize { - self.columns[0].data().len() + self.row_count } /// Get a reference to a column's array by index. @@ -267,10 +270,6 @@ impl RecordBatch { /// /// Panics if `offset` with `length` is greater than column length. pub fn slice(&self, offset: usize, length: usize) -> RecordBatch { - if self.schema.fields().is_empty() { - assert!((offset + length) == 0); - return RecordBatch::new_empty(self.schema.clone()); - } assert!((offset + length) <= self.num_rows()); let columns = self @@ -282,6 +281,7 @@ impl RecordBatch { Self { schema: self.schema.clone(), columns, + row_count: length, } } @@ -402,15 +402,20 @@ impl RecordBatch { /// Options that control the behaviour used when creating a [`RecordBatch`]. #[derive(Debug)] +#[non_exhaustive] pub struct RecordBatchOptions { /// Match field names of structs and lists. If set to `true`, the names must match. pub match_field_names: bool, + + /// Optional row count, useful for specifying a row count for a RecordBatch with no columns + pub row_count: Option, } impl Default for RecordBatchOptions { fn default() -> Self { Self { match_field_names: true, + row_count: None, } } } @@ -426,6 +431,7 @@ impl From<&StructArray> for RecordBatch { let columns = struct_array.boxed_fields.clone(); RecordBatch { schema: Arc::new(schema), + row_count: struct_array.len(), columns, } } else { @@ -644,6 +650,7 @@ mod tests { // creating the batch without field name validation should pass let options = RecordBatchOptions { match_field_names: false, + row_count: None, }; let batch = RecordBatch::try_new_with_options(schema, vec![a], &options); assert!(batch.is_ok()); @@ -934,4 +941,30 @@ mod tests { assert_eq!(expected, record_batch.project(&[0, 2]).unwrap()); } + + #[test] + fn test_no_column_record_batch() { + let schema = Arc::new(Schema::new(vec![])); + + let err = RecordBatch::try_new(schema.clone(), vec![]).unwrap_err(); + assert!(err + .to_string() + .contains("must either specify a row count or at least one column")); + + let mut options = RecordBatchOptions::default(); + options.row_count = Some(10); + + let ok = + RecordBatch::try_new_with_options(schema.clone(), vec![], &options).unwrap(); + assert_eq!(ok.num_rows(), 10); + + let a = ok.slice(2, 5); + assert_eq!(a.num_rows(), 5); + + let b = ok.slice(5, 0); + assert_eq!(b.num_rows(), 0); + + assert_ne!(a, b); + assert_eq!(b, RecordBatch::new_empty(schema)) + } } diff --git a/arrow/src/util/data_gen.rs b/arrow/src/util/data_gen.rs index 35b65ef303d..21b8ee8c9fd 100644 --- a/arrow/src/util/data_gen.rs +++ b/arrow/src/util/data_gen.rs @@ -49,6 +49,7 @@ pub fn create_random_batch( columns, &RecordBatchOptions { match_field_names: false, + row_count: None, }, ) }