Skip to content

Commit

Permalink
fix(rust, python): ndjson struct inference (#6049)
Browse files Browse the repository at this point in the history
  • Loading branch information
universalmind303 committed Jan 5, 2023
1 parent f6ee650 commit 42bfd87
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 97 deletions.
121 changes: 28 additions & 93 deletions polars/polars-io/src/ndjson_core/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::hash::{Hash, Hasher};

use arrow::types::NativeType;
use num::traits::NumCast;
use polars_core::frame::row::AnyValueBuffer;
use polars_core::prelude::*;
use polars_time::prelude::utf8::infer::{infer_pattern_single, DatetimeInfer};
use polars_time::prelude::utf8::Pattern;
Expand All @@ -17,102 +18,19 @@ impl<'a> Hash for BufferKey<'a> {
}
}

pub(crate) fn init_buffers(
schema: &Schema,
capacity: usize,
) -> PolarsResult<PlIndexMap<BufferKey, Buffer>> {
schema
.iter()
.map(|(name, dtype)| {
let builder = match dtype {
&DataType::Boolean => Buffer::Boolean(BooleanChunkedBuilder::new(name, capacity)),
&DataType::Int32 => Buffer::Int32(PrimitiveChunkedBuilder::new(name, capacity)),
&DataType::Int64 => Buffer::Int64(PrimitiveChunkedBuilder::new(name, capacity)),
&DataType::UInt32 => Buffer::UInt32(PrimitiveChunkedBuilder::new(name, capacity)),
&DataType::UInt64 => Buffer::UInt64(PrimitiveChunkedBuilder::new(name, capacity)),
&DataType::Float32 => Buffer::Float32(PrimitiveChunkedBuilder::new(name, capacity)),
&DataType::Float64 => Buffer::Float64(PrimitiveChunkedBuilder::new(name, capacity)),
&DataType::Utf8 => {
Buffer::Utf8(Utf8ChunkedBuilder::new(name, capacity, capacity * 25))
}
#[cfg(feature = "dtype-datetime")]
&DataType::Datetime(_, _) => {
Buffer::Datetime(PrimitiveChunkedBuilder::new(name, capacity))
}
#[cfg(feature = "dtype-date")]
&DataType::Date => Buffer::Date(PrimitiveChunkedBuilder::new(name, capacity)),
_ => Buffer::All((Vec::with_capacity(capacity), name)),
};
let key = KnownKey::from(name);

Ok((BufferKey(key), builder))
})
.collect()
}

#[allow(clippy::large_enum_variant)]
pub(crate) enum Buffer<'a> {
Boolean(BooleanChunkedBuilder),
Int32(PrimitiveChunkedBuilder<Int32Type>),
Int64(PrimitiveChunkedBuilder<Int64Type>),
UInt32(PrimitiveChunkedBuilder<UInt32Type>),
UInt64(PrimitiveChunkedBuilder<UInt64Type>),
Float32(PrimitiveChunkedBuilder<Float32Type>),
Float64(PrimitiveChunkedBuilder<Float64Type>),
Utf8(Utf8ChunkedBuilder),
#[cfg(feature = "dtype-datetime")]
Datetime(PrimitiveChunkedBuilder<Int64Type>),
#[cfg(feature = "dtype-date")]
Date(PrimitiveChunkedBuilder<Int32Type>),
All((Vec<AnyValue<'a>>, &'a str)),
}

impl<'a> Buffer<'a> {
pub(crate) fn into_series(self) -> PolarsResult<Series> {
let s = match self {
Buffer::Boolean(v) => v.finish().into_series(),
Buffer::Int32(v) => v.finish().into_series(),
Buffer::Int64(v) => v.finish().into_series(),
Buffer::UInt32(v) => v.finish().into_series(),
Buffer::UInt64(v) => v.finish().into_series(),
Buffer::Float32(v) => v.finish().into_series(),
Buffer::Float64(v) => v.finish().into_series(),
#[cfg(feature = "dtype-datetime")]
Buffer::Datetime(v) => v
.finish()
.into_series()
.cast(&DataType::Datetime(TimeUnit::Microseconds, None))
.unwrap(),
#[cfg(feature = "dtype-date")]
Buffer::Date(v) => v.finish().into_series().cast(&DataType::Date).unwrap(),
Buffer::Utf8(v) => v.finish().into_series(),
Buffer::All((vals, name)) => Series::new(name, vals),
};
Ok(s)
}
pub(crate) struct Buffer<'a>(&'a str, AnyValueBuffer<'a>);

pub(crate) fn add_null(&mut self) {
match self {
Buffer::Boolean(v) => v.append_null(),
Buffer::Int32(v) => v.append_null(),
Buffer::Int64(v) => v.append_null(),
Buffer::UInt32(v) => v.append_null(),
Buffer::UInt64(v) => v.append_null(),
Buffer::Float32(v) => v.append_null(),
Buffer::Float64(v) => v.append_null(),
Buffer::Utf8(v) => v.append_null(),
#[cfg(feature = "dtype-datetime")]
Buffer::Datetime(v) => v.append_null(),
#[cfg(feature = "dtype-date")]
Buffer::Date(v) => v.append_null(),
Buffer::All((v, _)) => v.push(AnyValue::Null),
};
impl Buffer<'_> {
pub fn into_series(self) -> Series {
let mut s = self.1.into_series();
s.rename(self.0);
s
}

#[inline]
pub(crate) fn add(&mut self, value: &Value) -> PolarsResult<()> {
use Buffer::*;
match self {
use AnyValueBuffer::*;
match &mut self.1 {
Boolean(buf) => {
match value {
Value::Static(StaticNode::Bool(b)) => buf.append_value(*b),
Expand Down Expand Up @@ -177,7 +95,7 @@ impl<'a> Buffer<'a> {
Ok(())
}
#[cfg(feature = "dtype-datetime")]
Datetime(buf) => {
Datetime(buf, _, _) => {
let v = deserialize_datetime::<Int64Type>(value);
buf.append_option(v);
Ok(())
Expand All @@ -188,13 +106,30 @@ impl<'a> Buffer<'a> {
buf.append_option(v);
Ok(())
}
All((buf, _)) => {
All(_, buf) => {
let av = deserialize_all(value);
buf.push(av);
Ok(())
}
_ => panic!("unexpected dtype when deserializing ndjson"),
}
}
pub fn add_null(&mut self) {
self.1.add(AnyValue::Null).expect("should not fail");
}
}
pub(crate) fn init_buffers(
schema: &Schema,
capacity: usize,
) -> PolarsResult<PlIndexMap<BufferKey, Buffer>> {
schema
.iter()
.map(|(name, dtype)| {
let av_buf = (dtype, capacity).into();
let key = KnownKey::from(name);
Ok((BufferKey(key), Buffer(name, av_buf)))
})
.collect()
}

fn deserialize_number<T: NativeType + NumCast>(value: &Value) -> Option<T> {
Expand Down
6 changes: 2 additions & 4 deletions polars/polars-io/src/ndjson_core/ndjson.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ impl<'a> CoreJsonReader<'a> {
}
}

if total_rows == 128 {
if total_rows <= 128 {
n_threads = 1;
}

Expand All @@ -208,20 +208,18 @@ impl<'a> CoreJsonReader<'a> {
} else {
std::cmp::min(rows_per_thread, max_proxy)
};

let file_chunks = get_file_chunks_json(bytes, n_threads);
let dfs = POOL.install(|| {
file_chunks
.into_par_iter()
.map(|(start_pos, stop_at_nbytes)| {
let mut buffers = init_buffers(&self.schema, capacity)?;
let _ = parse_lines(&bytes[start_pos..stop_at_nbytes], &mut buffers);

DataFrame::new(
buffers
.into_values()
.map(|buf| buf.into_series())
.collect::<PolarsResult<_>>()?,
.collect::<_>(),
)
})
.collect::<PolarsResult<Vec<_>>>()
Expand Down
71 changes: 71 additions & 0 deletions polars/tests/it/io/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,74 @@ fn read_ndjson_with_trailing_newline() {
.unwrap();
assert!(expected.frame_equal(&df));
}
#[test]
fn test_read_ndjson_iss_5875() {
let jsonlines = r#"
{"struct": {"int_inner": [1, 2, 3], "float_inner": 5.0, "str_inner": ["a", "b", "c"]}}
{"struct": {"int_inner": [4, 5, 6]}, "float": 4.0}
"#;
let cursor = Cursor::new(jsonlines);

let df = JsonLineReader::new(cursor).finish();
assert!(df.is_ok());

let field_int_inner = Field::new("int_inner", DataType::List(Box::new(DataType::Int64)));
let field_float_inner = Field::new("float_inner", DataType::Float64);
let field_str_inner = Field::new("str_inner", DataType::List(Box::new(DataType::Utf8)));

let mut schema = Schema::new();
schema.with_column(
"struct".to_owned(),
DataType::Struct(vec![
field_int_inner.clone(),
field_float_inner.clone(),
field_str_inner.clone(),
]),
);
schema.with_column("float".to_owned(), DataType::Float64);

assert_eq!(schema, df.unwrap().schema());
}

#[test]
fn test_read_ndjson_iss_5875_part2() {
let jsonlines = r#"
{"struct": {"int_list_inner": [4, 5, 6]}}
{"struct": {"int_list_inner": [1, 2, 3], "float_inner": 5.0, "str_list_inner": ["a", "b", "c"]}, "int_opt": null, "float_list_outer": [1.1, 2.2]}
"#;
let cursor = Cursor::new(jsonlines);

let df = JsonLineReader::new(cursor).finish();
assert!(df.is_ok());
let field_int_list_inner =
Field::new("int_list_inner", DataType::List(Box::new(DataType::Int64)));
let field_float = Field::new("float_inner", DataType::Float64);
let field_str_list = Field::new("str_list_inner", DataType::List(Box::new(DataType::Utf8)));
let field_float_list = Field::new(
"float_list_outer",
DataType::List(Box::new(DataType::Float64)),
);
let mut schema = Schema::new();
schema.with_column(
"struct".to_owned(),
DataType::Struct(vec![field_int_list_inner, field_float, field_str_list]),
);
schema.with_column(
"float_list_outer".to_owned(),
field_float_list.data_type().clone(),
);

assert_eq!(schema, df.unwrap().schema());
}
#[test]
fn test_read_ndjson_iss_5875_part3() {
let jsonlines = r#"
{"key1":"value1", "key2": "value2", "key3": {"k1": 2, "k3": "value5", "k10": 5}}
{"key1":"value5", "key2": "value4", "key3": {"k1": 2, "k5": "value5", "k10": 4}}
{"key1":"value6", "key3": {"k1": 5, "k3": "value5"}}"#;

let cursor = Cursor::new(jsonlines);

let df = JsonLineReader::new(cursor).finish();
assert!(df.is_ok());
}

0 comments on commit 42bfd87

Please sign in to comment.