diff --git a/arrow-buffer/src/buffer/boolean.rs b/arrow-buffer/src/buffer/boolean.rs index e2c892315b5f..92ed17ef5cfc 100644 --- a/arrow-buffer/src/buffer/boolean.rs +++ b/arrow-buffer/src/buffer/boolean.rs @@ -92,7 +92,7 @@ impl BooleanBuffer { /// Returns a `BitChunks` instance which can be used to iterate over /// this buffer's bits in `u64` chunks #[inline] - pub fn bit_chunks(&self) -> BitChunks { + pub fn bit_chunks(&self) -> BitChunks<'_> { BitChunks::new(self.values(), self.offset, self.len) } diff --git a/arrow-buffer/src/buffer/immutable.rs b/arrow-buffer/src/buffer/immutable.rs index 820ad04bf61a..6e7f713b46c4 100644 --- a/arrow-buffer/src/buffer/immutable.rs +++ b/arrow-buffer/src/buffer/immutable.rs @@ -309,7 +309,7 @@ impl Buffer { /// Returns a `BitChunks` instance which can be used to iterate over this buffers bits /// in larger chunks and starting at arbitrary bit offsets. /// Note that both `offset` and `length` are measured in bits. - pub fn bit_chunks(&self, offset: usize, len: usize) -> BitChunks { + pub fn bit_chunks(&self, offset: usize, len: usize) -> BitChunks<'_> { BitChunks::new(self.as_slice(), offset, len) } diff --git a/arrow-cast/src/display.rs b/arrow-cast/src/display.rs index 669b8a664c2b..3909e95f52d6 100644 --- a/arrow-cast/src/display.rs +++ b/arrow-cast/src/display.rs @@ -38,6 +38,17 @@ use lexical_core::FormattedSize; type TimeFormat<'a> = Option<&'a str>; +/// Format for displaying decimals +#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Hash)] +#[non_exhaustive] +pub enum DecimalFormat { + /// Render decimals as JSON numbers, e.g. 12.34 + #[default] + Number, + /// Render decimals as JSON strings, e.g. "12.34" + String, +} + /// Format for displaying durations #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] #[non_exhaustive] @@ -72,6 +83,8 @@ pub struct FormatOptions<'a> { time_format: TimeFormat<'a>, /// Duration format duration_format: DurationFormat, + /// Decimal rendering format + decimal_format: DecimalFormat, } impl Default for FormatOptions<'_> { @@ -92,6 +105,7 @@ impl<'a> FormatOptions<'a> { timestamp_tz_format: None, time_format: None, duration_format: DurationFormat::ISO8601, + decimal_format: DecimalFormat::Number, } } @@ -158,6 +172,14 @@ impl<'a> FormatOptions<'a> { ..self } } + + /// Set how decimal values should be formatted + pub const fn with_decimal_format(self, decimal_format: DecimalFormat) -> Self { + Self { + decimal_format, + ..self + } + } } /// Implements [`Display`] for a specific array value @@ -460,14 +482,24 @@ impl DisplayIndex for &PrimitiveArray { macro_rules! decimal_display { ($($t:ty),+) => { $(impl<'a> DisplayIndexState<'a> for &'a PrimitiveArray<$t> { - type State = (u8, i8); + type State = (u8, i8, DecimalFormat); - fn prepare(&self, _options: &FormatOptions<'a>) -> Result { - Ok((self.precision(), self.scale())) + fn prepare(&self, options: &FormatOptions<'a>) -> Result { + Ok((self.precision(), self.scale(), options.decimal_format)) } fn write(&self, s: &Self::State, idx: usize, f: &mut dyn Write) -> FormatResult { - write!(f, "{}", <$t>::format_decimal(self.values()[idx], s.0, s.1))?; + let formatted = <$t>::format_decimal(self.values()[idx], s.0, s.1); + match s.2 { + DecimalFormat::String => { + // Format as quoted string + write!(f, "\"{}\"", formatted)?; + } + DecimalFormat::Number => { + // Format as number + write!(f, "{}", formatted)?; + } + } Ok(()) } })+ diff --git a/arrow-json/src/writer/encoder.rs b/arrow-json/src/writer/encoder.rs index ed430fe6a1ec..03e2dde942fd 100644 --- a/arrow-json/src/writer/encoder.rs +++ b/arrow-json/src/writer/encoder.rs @@ -19,7 +19,7 @@ use arrow_array::cast::AsArray; use arrow_array::types::*; use arrow_array::*; use arrow_buffer::{ArrowNativeType, NullBuffer, OffsetBuffer, ScalarBuffer}; -use arrow_cast::display::{ArrayFormatter, FormatOptions}; +use arrow_cast::display::{ArrayFormatter, DecimalFormat, FormatOptions}; use arrow_schema::{ArrowError, DataType, FieldRef}; use half::f16; use lexical_core::FormattedSize; @@ -29,6 +29,7 @@ use std::io::Write; #[derive(Debug, Clone, Default)] pub struct EncoderOptions { pub explicit_nulls: bool, + pub decimal_format: DecimalFormat, } /// A trait to format array values as JSON values @@ -139,8 +140,10 @@ fn make_encoder_impl<'a>( (Box::new(encoder) as _, array.nulls().cloned()) } DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => { - let options = FormatOptions::new().with_display_error(true); - let formatter = ArrayFormatter::try_new(array, &options)?; + let format_options = FormatOptions::new() + .with_display_error(true) + .with_decimal_format(options.decimal_format); + let formatter = ArrayFormatter::try_new(array, &format_options)?; (Box::new(RawArrayFormatter(formatter)) as _, array.nulls().cloned()) } d => match d.is_temporal() { diff --git a/arrow-json/src/writer/mod.rs b/arrow-json/src/writer/mod.rs index a37aa5ff8c2c..675364a9490b 100644 --- a/arrow-json/src/writer/mod.rs +++ b/arrow-json/src/writer/mod.rs @@ -111,6 +111,7 @@ use std::{fmt::Debug, io::Write}; use arrow_array::*; use arrow_schema::*; +use arrow_cast::display::DecimalFormat; use encoder::{make_encoder, EncoderOptions}; /// This trait defines how to format a sequence of JSON objects to a @@ -227,6 +228,11 @@ impl WriterBuilder { self.0.explicit_nulls } + /// Returns the decimal format for this writer + pub fn decimal_format(&self) -> DecimalFormat { + self.0.decimal_format + } + /// Set whether to keep keys with null values, or to omit writing them. /// /// For example, with [`LineDelimited`] format: @@ -253,6 +259,12 @@ impl WriterBuilder { self } + /// Set how decimals should be formatted in JSON output. + pub fn with_decimal_format(mut self, decimal_format: DecimalFormat) -> Self { + self.0.decimal_format = decimal_format; + self + } + /// Create a new `Writer` with specified `JsonFormat` and builder options. pub fn build(self, writer: W) -> Writer where @@ -432,6 +444,29 @@ mod tests { assert_eq!(expected, actual); } + /// Helper to assert decimal output with `Number` and `String` decimal formats + fn assert_decimal_outputs( + batch: &RecordBatch, + expected_default: &str, + expected_decimal_as_string: &str, + ) { + let mut buf = Vec::new(); + { + let mut writer = LineDelimitedWriter::new(&mut buf); + writer.write_batches(&[batch]).unwrap(); + } + assert_json_eq(&buf, expected_default); + + let mut buf = Vec::new(); + { + let mut writer = WriterBuilder::new() + .with_decimal_format(DecimalFormat::String) + .build::<_, LineDelimited>(&mut buf); + writer.write_batches(&[batch]).unwrap(); + } + assert_json_eq(&buf, expected_decimal_as_string); + } + #[test] fn write_simple_rows() { let schema = Schema::new(vec![ @@ -1887,17 +1922,15 @@ mod tests { let schema = Schema::new(vec![field]); let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)]).unwrap(); - let mut buf = Vec::new(); - { - let mut writer = LineDelimitedWriter::new(&mut buf); - writer.write_batches(&[&batch]).unwrap(); - } - - assert_json_eq( - &buf, + assert_decimal_outputs( + &batch, r#"{"decimal":12.34} {"decimal":56.78} {"decimal":90.12} +"#, + r#"{"decimal":"12.34"} +{"decimal":"56.78"} +{"decimal":"90.12"} "#, ); } @@ -1914,18 +1947,15 @@ mod tests { let field = Arc::new(Field::new("decimal", array.data_type().clone(), true)); let schema = Schema::new(vec![field]); let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)]).unwrap(); - - let mut buf = Vec::new(); - { - let mut writer = LineDelimitedWriter::new(&mut buf); - writer.write_batches(&[&batch]).unwrap(); - } - - assert_json_eq( - &buf, + assert_decimal_outputs( + &batch, r#"{"decimal":12.3400} {"decimal":56.7800} {"decimal":90.1200} +"#, + r#"{"decimal":"12.3400"} +{"decimal":"56.7800"} +{"decimal":"90.1200"} "#, ); } @@ -1938,18 +1968,116 @@ mod tests { let field = Arc::new(Field::new("decimal", array.data_type().clone(), true)); let schema = Schema::new(vec![field]); let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)]).unwrap(); + assert_decimal_outputs( + &batch, + r#"{"decimal":12.34} +{} +{"decimal":56.78} +"#, + r#"{"decimal":"12.34"} +{} +{"decimal":"56.78"} +"#, + ); + } - let mut buf = Vec::new(); - { - let mut writer = LineDelimitedWriter::new(&mut buf); - writer.write_batches(&[&batch]).unwrap(); + #[test] + fn test_decimal128_list_encoder() { + let decimal_type = DataType::Decimal128(10, 2); + let item_field = FieldRef::new(Field::new("item", decimal_type.clone(), true)); + let schema = Schema::new(vec![Field::new("list", DataType::List(item_field), true)]); + + let values_builder = Decimal128Builder::new().with_data_type(decimal_type.clone()); + let mut list_builder = ListBuilder::new(values_builder); + let rows = [Some(vec![Some(1234), None]), Some(vec![Some(5678)])]; + + for row in rows { + match row { + Some(values) => { + for value in values { + match value { + Some(v) => list_builder.values().append_value(v), + None => list_builder.values().append_null(), + } + } + list_builder.append(true); + } + None => list_builder.append(false), + } } - assert_json_eq( - &buf, - r#"{"decimal":12.34} + let array = Arc::new(list_builder.finish()) as ArrayRef; + let batch = RecordBatch::try_new(Arc::new(schema), vec![array]).unwrap(); + + assert_decimal_outputs( + &batch, + r#"{"list":[12.34,null]} +{"list":[56.78]} +"#, + r#"{"list":["12.34",null]} +{"list":["56.78"]} +"#, + ); + } + + #[test] + fn test_decimal128_dictionary_encoder() { + let values = Arc::new( + Decimal128Array::from_iter_values([1234, 5678]) + .with_precision_and_scale(10, 2) + .unwrap(), + ); + let keys = Int8Array::from(vec![Some(0), None, Some(1)]); + let dict = DictionaryArray::new(keys, values.clone()); + + let schema = Schema::new(vec![Field::new( + "dict", + DataType::Dictionary(DataType::Int8.into(), DataType::Decimal128(10, 2).into()), + true, + )]); + + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dict)]).unwrap(); + + assert_decimal_outputs( + &batch, + r#"{"dict":12.34} {} -{"decimal":56.78} +{"dict":56.78} +"#, + r#"{"dict":"12.34"} +{} +{"dict":"56.78"} +"#, + ); + } + + #[test] + fn test_decimal256_dictionary_encoder() { + let values = Arc::new( + Decimal256Array::from_iter_values([i256::from(1234), i256::from(5678)]) + .with_precision_and_scale(10, 2) + .unwrap(), + ); + let keys = Int8Array::from(vec![Some(0), None, Some(1)]); + let dict = DictionaryArray::new(keys, values.clone()); + + let schema = Schema::new(vec![Field::new( + "dict", + DataType::Dictionary(DataType::Int8.into(), DataType::Decimal256(10, 2).into()), + true, + )]); + + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dict)]).unwrap(); + + assert_decimal_outputs( + &batch, + r#"{"dict":12.34} +{} +{"dict":56.78} +"#, + r#"{"dict":"12.34"} +{} +{"dict":"56.78"} "#, ); }