Skip to content

Commit

Permalink
feat(rust, python): make cast recursive (#5596)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Nov 23, 2022
1 parent 6d74407 commit 082374e
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 59 deletions.
39 changes: 29 additions & 10 deletions polars/polars-core/src/chunked_array/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,17 +178,36 @@ impl ChunkCast for ListChunked {
match data_type {
DataType::List(child_type) => {
let phys_child = child_type.to_physical();
let mut ca = if child_type.to_physical() != self.inner_dtype().to_physical() {
let chunks = self
.downcast_iter()
.map(|list| cast_inner_list_type(list, &phys_child))
.collect::<PolarsResult<_>>()?;
ListChunked::from_chunks(self.name(), chunks)

if phys_child.is_primitive() {
let mut ca = if child_type.to_physical() != self.inner_dtype().to_physical() {
let chunks = self
.downcast_iter()
.map(|list| cast_inner_list_type(list, &phys_child))
.collect::<PolarsResult<_>>()?;
ListChunked::from_chunks(self.name(), chunks)
} else {
self.clone()
};
ca.set_inner_dtype(*child_type.clone());
Ok(ca.into_series())
} else {
self.clone()
};
ca.set_inner_dtype(*child_type.clone());
Ok(ca.into_series())
let ca = self.rechunk();
let arr = ca.downcast_iter().next().unwrap();
let s = Series::try_from(("", arr.values().clone())).unwrap();
let new_inner = s.cast(child_type)?;
let new_values = new_inner.array_ref(0).clone();

let data_type =
ListArray::<i64>::default_datatype(new_values.data_type().clone());
let new_arr = ListArray::<i64>::new(
data_type,
arr.offsets().clone(),
new_values,
arr.validity().cloned(),
);
Series::try_from((s.name(), Box::new(new_arr) as ArrayRef))
}
}
_ => Err(PolarsError::ComputeError("Cannot cast list type".into())),
}
Expand Down
27 changes: 21 additions & 6 deletions polars/polars-core/src/chunked_array/logical/struct_/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,11 +198,26 @@ impl LogicalType for StructChunked {

// in case of a struct, a cast will coerce the inner types
fn cast(&self, dtype: &DataType) -> PolarsResult<Series> {
let fields = self
.fields
.iter()
.map(|s| s.cast(dtype))
.collect::<PolarsResult<Vec<_>>>()?;
Ok(Self::new_unchecked(self.field.name(), &fields).into_series())
match dtype {
DataType::Struct(dtype_fields) => {
let mut new_fields = Vec::with_capacity(self.fields().len());
for (s_field, fld) in self.fields().iter().zip(dtype_fields) {
let mut new_s = s_field.cast(fld.data_type())?;
if new_s.name() != fld.name {
new_s.rename(&fld.name);
}
new_fields.push(new_s);
}
StructChunked::new(self.name(), &new_fields).map(|ca| ca.into_series())
}
_ => {
let fields = self
.fields
.iter()
.map(|s| s.cast(dtype))
.collect::<PolarsResult<Vec<_>>>()?;
Ok(Self::new_unchecked(self.field.name(), &fields).into_series())
}
}
}
}
15 changes: 15 additions & 0 deletions polars/polars-core/src/datatypes/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,21 @@ impl DataType {
matches!(self, Date | Datetime(_, _) | Duration(_) | Time)
}

/// Check if datatype is a primitive type. By that we mean that
/// it is not a container type.
pub fn is_primitive(&self) -> bool {
#[cfg(feature = "dtype-binary")]
{
self.is_numeric()
| matches!(self, DataType::Boolean | DataType::Utf8 | DataType::Binary)
}

#[cfg(not(feature = "dtype-binary"))]
{
self.is_numeric() | matches!(self, DataType::Boolean | DataType::Utf8)
}
}

/// Check if this [`DataType`] is a numeric type
pub fn is_numeric(&self) -> bool {
// allow because it cannot be replaced when object feature is activated
Expand Down
47 changes: 4 additions & 43 deletions polars/polars-core/src/series/any_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,48 +36,6 @@ fn any_values_to_bool(avs: &[AnyValue]) -> BooleanChunked {
.collect_trusted()
}

fn coerce_recursively(a: &Series, dtype: &DataType) -> Series {
match (a.dtype(), dtype) {
(lhs, rhs) if lhs == rhs => a.clone(),
#[cfg(feature = "dtype-struct")]
(DataType::Struct(_), DataType::Struct(dtype_fields)) => {
let a = a.struct_().unwrap();
let mut new_fields = Vec::with_capacity(a.fields().len());
for (s_field, fld) in a.fields().iter().zip(dtype_fields) {
let mut new_s = coerce_recursively(s_field, fld.data_type());
if new_s.name() != fld.name {
new_s.rename(&fld.name);
}
new_fields.push(new_s);
}
StructChunked::new(a.name(), &new_fields)
.unwrap()
.into_series()
}
(DataType::List(_), DataType::List(inner_type)) => {
let a = a.list().unwrap();
let a = a.rechunk();
let arr = a.downcast_iter().next().unwrap();
let s = Series::try_from(("", arr.values().clone())).unwrap();
let new_inner = coerce_recursively(&s, inner_type);
let new_values = new_inner.array_ref(0).clone();

let data_type = ListArray::<i64>::default_datatype(new_values.data_type().clone());
let new_arr = ListArray::<i64>::new(
data_type,
arr.offsets().clone(),
new_values,
arr.validity().cloned(),
);
Series::try_from((s.name(), Box::new(new_arr) as ArrayRef)).unwrap()
}
_ => match a.cast(dtype) {
Ok(s) => s,
_ => Series::full_null("", a.len(), dtype),
},
}
}

fn any_values_to_list(avs: &[AnyValue], inner_type: &DataType) -> ListChunked {
// this is handled downstream. The builder will choose the first non null type
if inner_type == &DataType::Null {
Expand All @@ -96,7 +54,10 @@ fn any_values_to_list(avs: &[AnyValue], inner_type: &DataType) -> ListChunked {
if b.dtype() == inner_type {
Some(b.clone())
} else {
Some(coerce_recursively(b, inner_type))
match b.cast(inner_type) {
Ok(out) => Some(out),
Err(_) => Some(Series::full_null(b.name(), b.len(), inner_type)),
}
}
}
_ => None,
Expand Down

0 comments on commit 082374e

Please sign in to comment.