Skip to content

Commit

Permalink
from_rows improve schema correctness (#4097)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jul 20, 2022
1 parent 46c40a8 commit 753fd12
Show file tree
Hide file tree
Showing 7 changed files with 537 additions and 476 deletions.
2 changes: 2 additions & 0 deletions polars/polars-core/src/datatypes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,8 @@ impl PartialEq for DataType {
(Duration(tu_l), Duration(tu_r)) => tu_l == tu_r,
#[cfg(feature = "object")]
(Object(lhs), Object(rhs)) => lhs == rhs,
#[cfg(feature = "dtype-struct")]
(Struct(lhs), Struct(rhs)) => lhs == rhs,
_ => std::mem::discriminant(self) == std::mem::discriminant(other),
}
}
Expand Down
8 changes: 4 additions & 4 deletions polars/polars-core/src/frame/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ pub(crate) enum AnyValueBuffer<'a> {
Utf8(Utf8ChunkedBuilder),
#[cfg(feature = "dtype-categorical")]
Categorical(CategoricalChunkedBuilder),
All(Vec<AnyValue<'a>>),
All(DataType, Vec<AnyValue<'a>>),
}

impl<'a> AnyValueBuffer<'a> {
Expand Down Expand Up @@ -385,7 +385,7 @@ impl<'a> AnyValueBuffer<'a> {
(Utf8(builder), AnyValue::Utf8(v)) => builder.append_value(v),
(Utf8(builder), AnyValue::Null) => builder.append_null(),
// Struct and List can be recursive so use anyvalues for that
(All(vals), v) => vals.push(v),
(All(_, vals), v) => vals.push(v),
_ => return None,
};
Some(())
Expand Down Expand Up @@ -417,7 +417,7 @@ impl<'a> AnyValueBuffer<'a> {
Utf8(b) => b.finish().into_series(),
#[cfg(feature = "dtype-categorical")]
Categorical(b) => b.finish().into_series(),
All(vals) => Series::new("", vals),
All(dtype, vals) => Series::from_any_values_and_dtype("", &vals, &dtype).unwrap(),
}
}
}
Expand Down Expand Up @@ -447,7 +447,7 @@ impl From<(&DataType, usize)> for AnyValueBuffer<'_> {
#[cfg(feature = "dtype-categorical")]
Categorical(_) => AnyValueBuffer::Categorical(CategoricalChunkedBuilder::new("", len)),
// Struct and List can be recursive so use anyvalues for that
_ => AnyValueBuffer::All(Vec::with_capacity(len)),
dt => AnyValueBuffer::All(dt.clone(), Vec::with_capacity(len)),
}
}
}
Expand Down
182 changes: 100 additions & 82 deletions polars/polars-core/src/series/any_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,31 @@ fn any_values_to_bool(avs: &[AnyValue]) -> BooleanChunked {
.collect_trusted()
}

fn any_values_to_list(avs: &[AnyValue]) -> ListChunked {
avs.iter()
.map(|av| match av {
AnyValue::List(b) => Some(b.clone()),
_ => None,
})
.collect_trusted()
fn any_values_to_list(avs: &[AnyValue], inner_type: &DataType) -> ListChunked {
// this is handled downstream. The builder will choose the first non null type
if inner_type == &DataType::Null {
avs.iter()
.map(|av| match av {
AnyValue::List(b) => Some(b.clone()),
_ => None,
})
.collect_trusted()
}
// make sure that wrongly inferred anyvalues don't deviate from the datatype
else {
avs.iter()
.map(|av| match av {
AnyValue::List(b) => {
if b.dtype() == inner_type {
Some(b.clone())
} else {
Some(Series::full_null("", b.len(), inner_type))
}
}
_ => None,
})
.collect_trusted()
}
}

impl<'a, T: AsRef<[AnyValue<'a>]>> NamedFrom<T, [AnyValue<'a>]> for Series {
Expand All @@ -42,87 +60,87 @@ impl<'a, T: AsRef<[AnyValue<'a>]>> NamedFrom<T, [AnyValue<'a>]> for Series {
}

impl Series {
pub fn from_any_values<'a>(name: &str, av: &[AnyValue<'a>]) -> Result<Series> {
match av.iter().find(|av| !matches!(av, AnyValue::Null)) {
None => Ok(Series::full_null(name, av.len(), &DataType::Int32)),
Some(av_) => {
let mut s = match av_ {
#[cfg(feature = "dtype-i8")]
AnyValue::Int8(_) => any_values_to_primitive::<Int8Type>(av).into_series(),
#[cfg(feature = "dtype-i16")]
AnyValue::Int16(_) => any_values_to_primitive::<Int16Type>(av).into_series(),
AnyValue::Int32(_) => any_values_to_primitive::<Int32Type>(av).into_series(),
AnyValue::Int64(_) => any_values_to_primitive::<Int64Type>(av).into_series(),
#[cfg(feature = "dtype-u8")]
AnyValue::UInt8(_) => any_values_to_primitive::<UInt8Type>(av).into_series(),
#[cfg(feature = "dtype-u16")]
AnyValue::UInt16(_) => any_values_to_primitive::<UInt16Type>(av).into_series(),
AnyValue::UInt32(_) => any_values_to_primitive::<UInt32Type>(av).into_series(),
AnyValue::UInt64(_) => any_values_to_primitive::<UInt64Type>(av).into_series(),
AnyValue::Float32(_) => {
any_values_to_primitive::<Float32Type>(av).into_series()
}
AnyValue::Float64(_) => {
any_values_to_primitive::<Float64Type>(av).into_series()
}
AnyValue::Utf8(_) | AnyValue::Utf8Owned(_) => {
any_values_to_utf8(av).into_series()
}
AnyValue::Boolean(_) => any_values_to_bool(av).into_series(),
AnyValue::List(_) => any_values_to_list(av).into_series(),
#[cfg(feature = "dtype-date")]
AnyValue::Date(_) => any_values_to_primitive::<Int32Type>(av)
.into_date()
.into_series(),
#[cfg(feature = "dtype-datetime")]
AnyValue::Datetime(_, tu, tz) => any_values_to_primitive::<Int64Type>(av)
.into_datetime(*tu, (*tz).clone())
.into_series(),
#[cfg(feature = "dtype-time")]
AnyValue::Time(_) => any_values_to_primitive::<Int64Type>(av)
.into_time()
.into_series(),
#[cfg(feature = "dtype-duration")]
AnyValue::Duration(_, tu) => any_values_to_primitive::<Int64Type>(av)
.into_duration(*tu)
.into_series(),
#[cfg(feature = "dtype-struct")]
AnyValue::StructOwned(payload) => {
let vals = &payload.0;
let fields = &payload.1;

// the fields of the struct
let mut series_fields = Vec::with_capacity(vals.len());
for (i, field) in fields.iter().enumerate() {
let mut field_avs = Vec::with_capacity(av.len());
pub fn from_any_values_and_dtype<'a>(
name: &str,
av: &[AnyValue<'a>],
dtype: &DataType,
) -> Result<Series> {
let mut s = match dtype {
#[cfg(feature = "dtype-i8")]
DataType::Int8 => any_values_to_primitive::<Int8Type>(av).into_series(),
#[cfg(feature = "dtype-i16")]
DataType::Int16 => any_values_to_primitive::<Int16Type>(av).into_series(),
DataType::Int32 => any_values_to_primitive::<Int32Type>(av).into_series(),
DataType::Int64 => any_values_to_primitive::<Int64Type>(av).into_series(),
#[cfg(feature = "dtype-u8")]
DataType::UInt8 => any_values_to_primitive::<UInt8Type>(av).into_series(),
#[cfg(feature = "dtype-u16")]
DataType::UInt16 => any_values_to_primitive::<UInt16Type>(av).into_series(),
DataType::UInt32 => any_values_to_primitive::<UInt32Type>(av).into_series(),
DataType::UInt64 => any_values_to_primitive::<UInt64Type>(av).into_series(),
DataType::Float32 => any_values_to_primitive::<Float32Type>(av).into_series(),
DataType::Float64 => any_values_to_primitive::<Float64Type>(av).into_series(),
DataType::Utf8 => any_values_to_utf8(av).into_series(),
DataType::Boolean => any_values_to_bool(av).into_series(),
#[cfg(feature = "dtype-date")]
DataType::Date => any_values_to_primitive::<Int32Type>(av)
.into_date()
.into_series(),
#[cfg(feature = "dtype-datetime")]
DataType::Datetime(tu, tz) => any_values_to_primitive::<Int64Type>(av)
.into_datetime(*tu, (*tz).clone())
.into_series(),
#[cfg(feature = "dtype-time")]
DataType::Time => any_values_to_primitive::<Int64Type>(av)
.into_time()
.into_series(),
#[cfg(feature = "dtype-duration")]
DataType::Duration(tu) => any_values_to_primitive::<Int64Type>(av)
.into_duration(*tu)
.into_series(),
DataType::List(inner) => any_values_to_list(av, inner).into_series(),
#[cfg(feature = "dtype-struct")]
DataType::Struct(fields) => {
// the fields of the struct
let mut series_fields = Vec::with_capacity(fields.len());
for (i, field) in fields.iter().enumerate() {
let mut field_avs = Vec::with_capacity(av.len());

for av in av.iter() {
match av {
AnyValue::StructOwned(pl) => {
for (l, r) in fields.iter().zip(pl.1.iter()) {
if l.name() != r.name() {
return Err(PolarsError::ComputeError(
"struct orders must remain the same".into(),
));
}
}

let av_val = pl.0[i].clone();
field_avs.push(av_val)
for av in av.iter() {
match av {
AnyValue::StructOwned(payload) => {
for (l, r) in fields.iter().zip(payload.1.iter()) {
if l.name() != r.name() {
return Err(PolarsError::ComputeError(
"struct orders must remain the same".into(),
));
}
_ => field_avs.push(AnyValue::Null),
}

let av_val = payload.0[i].clone();
field_avs.push(av_val)
}
series_fields.push(Series::new(field.name(), &field_avs))
_ => field_avs.push(AnyValue::Null),
}
return Ok(StructChunked::new(name, &series_fields)
.unwrap()
.into_series());
}
av => panic!("av {:?} not implemented", av),
};
s.rename(name);
Ok(s)
series_fields.push(Series::new(field.name(), &field_avs))
}
return Ok(StructChunked::new(name, &series_fields)
.unwrap()
.into_series());
}
dtype => panic!("dtype {:?} not implemented", dtype),
};
s.rename(name);
Ok(s)
}

pub fn from_any_values<'a>(name: &str, av: &[AnyValue<'a>]) -> Result<Series> {
match av.iter().find(|av| !matches!(av, AnyValue::Null)) {
None => Ok(Series::full_null(name, av.len(), &DataType::Int32)),
Some(av_) => {
let dtype: DataType = av_.into();
Series::from_any_values_and_dtype(name, av, &dtype)
}
}
}
Expand Down
8 changes: 7 additions & 1 deletion polars/polars-core/src/series/arithmetic/borrowed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,14 @@ pub(crate) fn coerce_lhs_rhs<'a>(
if let Ok(result) = coerce_time_units(lhs, rhs) {
return Ok(result);
}
let dtype = match (lhs.dtype(), rhs.dtype()) {
#[cfg(feature = "dtype-struct")]
(DataType::Struct(_), DataType::Struct(_)) => {
return Ok((Cow::Borrowed(lhs), Cow::Borrowed(rhs)))
}
_ => get_supertype(lhs.dtype(), rhs.dtype())?,
};

let dtype = get_supertype(lhs.dtype(), rhs.dtype())?;
let left = if lhs.dtype() == &dtype {
Cow::Borrowed(lhs)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,9 @@ impl OptimizationRule for TypeCoercionRule {
(DataType::Time, DataType::Utf8, op) if op.is_comparison() => {
print_date_str_comparison_warning()
}
// structs can be arbitrarily nested, leave the complexity to the caller for now.
#[cfg(feature = "dtype-struct")]
(DataType::Struct(_), DataType::Struct(_), _op) => return None,
_ => {}
}

Expand Down

0 comments on commit 753fd12

Please sign in to comment.