Skip to content

Commit

Permalink
fix[rust] schema inference use supertype instead of first non null (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Aug 19, 2022
1 parent a98e0ba commit 3c7302f
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 8 deletions.
7 changes: 7 additions & 0 deletions polars/polars-core/src/datatypes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,13 @@ impl<'a> AnyValue<'a> {
Datetime(v, _, _) => NumCast::from(*v),
#[cfg(feature = "dtype-duration")]
Duration(v, _) => NumCast::from(*v),
Boolean(v) => {
if *v {
NumCast::from(1)
} else {
NumCast::from(0)
}
}
_ => unimplemented!(),
}
}
Expand Down
66 changes: 63 additions & 3 deletions polars/polars-core/src/frame/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ impl DataFrame {
/// as this is a lot slower than creating the `Series` in a columnar fashion
#[cfg_attr(docsrs, doc(cfg(feature = "rows")))]
pub fn from_rows(rows: &[Row]) -> Result<Self> {
let schema = rows_to_schema(rows, Some(50));
let schema = rows_to_schema_first_non_null(rows, Some(50));
let has_nulls = schema
.iter_dtypes()
.any(|dtype| matches!(dtype, DataType::Null));
Expand Down Expand Up @@ -251,8 +251,56 @@ fn is_nested_null(av: &AnyValue) -> bool {
}
}

/// Infer schema from rows.
pub fn rows_to_schema(rows: &[Row], infer_schema_length: Option<usize>) -> Schema {
// nested dtypes that are all null, will be set as null leaf dtype
fn infer_dtype_dynamic(av: &AnyValue) -> DataType {
match av {
AnyValue::List(s) if s.null_count() == s.len() => DataType::List(Box::new(DataType::Null)),
#[cfg(feature = "dtype-struct")]
AnyValue::Struct(avs, _) => DataType::Struct(
avs.iter()
.map(|av| {
let dtype = infer_dtype_dynamic(av);
Field::new("", dtype)
})
.collect(),
),
av => av.into(),
}
}

/// Infer schema from rows and set the supertypes of the columns as column data type.
pub fn rows_to_schema_supertypes(
rows: &[Row],
infer_schema_length: Option<usize>,
) -> Result<Schema> {
// no of rows to use to infer dtype
let max_infer = infer_schema_length.unwrap_or(rows.len());

let mut dtypes: Vec<PlHashSet<DataType>> = vec![PlHashSet::with_capacity(4); rows[0].0.len()];

for row in rows.iter().take(max_infer) {
for (val, types_set) in row.0.iter().zip(dtypes.iter_mut()) {
let dtype = infer_dtype_dynamic(val);
types_set.insert(dtype);
}
}

dtypes
.into_iter()
.enumerate()
.map(|(i, types_set)| {
let dtype = types_set
.into_iter()
.map(Ok)
.fold_first_(|a, b| get_supertype(&a?, &b?))
.unwrap()?;
Ok(Field::new(format!("column_{}", i).as_ref(), dtype))
})
.collect::<Result<_>>()
}

/// Infer schema from rows and set the first no null type as column data type.
pub fn rows_to_schema_first_non_null(rows: &[Row], infer_schema_length: Option<usize>) -> Schema {
// no of rows to use to infer dtype
let max_infer = infer_schema_length.unwrap_or(rows.len());
let mut schema: Schema = (&rows[0]).into();
Expand Down Expand Up @@ -367,6 +415,18 @@ impl<'a> AnyValueBuffer<'a> {
(Utf8(builder), AnyValue::Null) => builder.append_null(),
// Struct and List can be recursive so use anyvalues for that
(All(_, vals), v) => vals.push(v),

// dynamic types
(Float64(builder), av) => builder.append_value(av.extract()?),
(Int64(builder), av) => builder.append_value(av.extract()?),
(Utf8(builder), av) => match av {
AnyValue::Utf8(v) => builder.append_value(v),
AnyValue::Int64(v) => builder.append_value(&format!("{}", v)),
AnyValue::Float64(v) => builder.append_value(&format!("{}", v)),
AnyValue::Boolean(true) => builder.append_value(&"true"),
AnyValue::Boolean(false) => builder.append_value(&"false"),
_ => return None,
},
_ => return None,
};
Some(())
Expand Down
5 changes: 5 additions & 0 deletions polars/polars-core/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,11 @@ fn _get_supertype(l: &DataType, r: &DataType) -> Option<DataType> {
let tu = get_time_units(tu_l, tu_r);
Some(Datetime(tu, tz_r.clone()))
}
(List(inner_left), List(inner_right)) => {
let st = _get_supertype(inner_left, inner_right)?;
Some(DataType::List(Box::new(st)))
}
// todo! check if can be removed
(List(inner), other) | (other, List(inner)) => {
let st = _get_supertype(inner, other)?;
Some(DataType::List(Box::new(st)))
Expand Down
3 changes: 2 additions & 1 deletion polars/polars-io/src/csv/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,8 @@ pub fn infer_file_schema(

let header_length = headers.len();
// keep track of inferred field types
let mut column_types: Vec<PlHashSet<DataType>> = vec![PlHashSet::new(); header_length];
let mut column_types: Vec<PlHashSet<DataType>> =
vec![PlHashSet::with_capacity(4); header_length];
// keep track of columns with nulls
let mut nulls: Vec<bool> = vec![false; header_length];

Expand Down
4 changes: 2 additions & 2 deletions py-polars/src/apply/dataframe.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use polars::prelude::*;
use polars_core::frame::row::{rows_to_schema, Row};
use polars_core::frame::row::{rows_to_schema_first_non_null, Row};
use pyo3::conversion::{FromPyObject, IntoPy};
use pyo3::prelude::*;
use pyo3::types::{PyBool, PyFloat, PyInt, PyList, PyString, PyTuple};
Expand Down Expand Up @@ -282,7 +282,7 @@ pub fn apply_lambda_with_rows_output<'a>(
let mut buf = Vec::with_capacity(inference_size);
buf.push(first_value);
buf.extend((&mut row_iter).take(inference_size).cloned());
let schema = rows_to_schema(&buf, Some(50));
let schema = rows_to_schema_first_non_null(&buf, Some(50));

if init_null_count > 0 {
// Safety: we know the iterators size
Expand Down
5 changes: 3 additions & 2 deletions py-polars/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::ops::Deref;

use numpy::IntoPyArray;
use polars::frame::groupby::GroupBy;
use polars::frame::row::{rows_to_schema, Row};
use polars::frame::row::{rows_to_schema_supertypes, Row};
#[cfg(feature = "avro")]
use polars::io::avro::AvroCompression;
#[cfg(feature = "ipc")]
Expand Down Expand Up @@ -52,8 +52,9 @@ impl PyDataFrame {
}

fn finish_from_rows(rows: Vec<Row>, infer_schema_length: Option<usize>) -> PyResult<Self> {
let schema =
rows_to_schema_supertypes(&rows, infer_schema_length).map_err(PyPolarsErr::from)?;
// replace inferred nulls with boolean
let schema = rows_to_schema(&rows, infer_schema_length);
let fields = schema.iter_fields().map(|mut fld| match fld.data_type() {
DataType::Null => {
fld.coerce(DataType::Boolean);
Expand Down
12 changes: 12 additions & 0 deletions py-polars/tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,15 @@ def test_schema_err() -> None:
df = pl.DataFrame({"foo": [None, 1, 2], "bar": [1, 2, 3]}).lazy()
with pytest.raises(pl.NotFoundError):
df.groupby("not-existent").agg(pl.col("bar").max().alias("max_bar")).schema


def test_schema_inference_from_rows() -> None:
# these have to upcast to float
assert pl.from_records([[1, 2.1, 3], [4, 5, 6.4]]).to_dict(False) == {
"column_0": [1.0, 2.1, 3.0],
"column_1": [4.0, 5.0, 6.4],
}
assert pl.from_dicts([{"a": 1, "b": 2}, {"a": 3.1, "b": 4.5}]).to_dict(False) == {
"a": [1.0, 3.1],
"b": [2.0, 4.5],
}

0 comments on commit 3c7302f

Please sign in to comment.