Skip to content

Commit

Permalink
feat: supoprt check equality of schema and arrow schema (#10903)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZENOTME committed Jul 17, 2023
1 parent 5833a20 commit 9587646
Showing 1 changed file with 63 additions and 1 deletion.
64 changes: 63 additions & 1 deletion src/common/src/catalog/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::HashMap;
use std::ops::Index;

use arrow_schema::{DataType as ArrowDataType, Schema as ArrowSchema};
use itertools::Itertools;
use risingwave_pb::plan_common::{PbColumnDesc, PbField};

use super::ColumnDesc;
use crate::array::ArrayBuilderImpl;
use crate::types::{DataType, StructType};
use crate::util::iter_util::ZipEqFast;

/// The field in the schema of the executor's return data
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct Field {
Expand Down Expand Up @@ -197,6 +198,32 @@ impl Schema {
true
}
}

/// Check if the schema can convert to arrow schema.
pub fn same_as_arrow_schema(&self, arrow_schema: &ArrowSchema) -> bool {
if self.fields.len() != arrow_schema.fields().len() {
return false;
}
let mut schema_fields = HashMap::new();
self.fields.iter().for_each(|field| {
let res = schema_fields.insert(&field.name, &field.data_type);
// This assert is to make sure there is no duplicate field name in the schema.
assert!(res.is_none())
});

arrow_schema.fields().iter().all(|arrow_field| {
schema_fields
.get(arrow_field.name())
.and_then(|data_type| {
if let Ok(data_type) = TryInto::<ArrowDataType>::try_into(*data_type) && data_type == *arrow_field.data_type() {
Some(())
} else {
None
}
})
.is_some()
})
}
}

impl Field {
Expand Down Expand Up @@ -328,3 +355,38 @@ pub mod test_utils {
decimal_n::<3>()
}
}

#[cfg(test)]
mod test {
#[test]
fn test_same_as_arrow_schema() {
use arrow_schema::{DataType as ArrowDataType, Field as ArrowField};

use super::*;
let risingwave_schema = Schema::new(vec![
Field::with_name(DataType::Int32, "a"),
Field::with_name(DataType::Int32, "b"),
Field::with_name(DataType::Int32, "c"),
]);
let arrow_schema = ArrowSchema::new(vec![
ArrowField::new("a", ArrowDataType::Int32, false),
ArrowField::new("b", ArrowDataType::Int32, false),
ArrowField::new("c", ArrowDataType::Int32, false),
]);
assert!(risingwave_schema.same_as_arrow_schema(&arrow_schema));

let risingwave_schema = Schema::new(vec![
Field::with_name(DataType::Int32, "d"),
Field::with_name(DataType::Int32, "c"),
Field::with_name(DataType::Int32, "a"),
Field::with_name(DataType::Int32, "b"),
]);
let arrow_schema = ArrowSchema::new(vec![
ArrowField::new("a", ArrowDataType::Int32, false),
ArrowField::new("b", ArrowDataType::Int32, false),
ArrowField::new("d", ArrowDataType::Int32, false),
ArrowField::new("c", ArrowDataType::Int32, false),
]);
assert!(risingwave_schema.same_as_arrow_schema(&arrow_schema));
}
}

0 comments on commit 9587646

Please sign in to comment.