Skip to content

Commit

Permalink
rename cast_with_datatype and allow List inner type cast
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed May 7, 2021
1 parent 973e58f commit 47539ad
Show file tree
Hide file tree
Showing 12 changed files with 185 additions and 135 deletions.
166 changes: 151 additions & 15 deletions polars/polars-core/src/chunked_array/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
use crate::chunked_array::builder::CategoricalChunkedBuilder;
use crate::chunked_array::kernels::{cast_numeric_from_dtype, transmute_array_from_dtype};
use crate::prelude::*;
use arrow::array::{make_array, Array, ArrayDataBuilder};
use arrow::compute::cast;
use num::{NumCast, ToPrimitive};

Expand Down Expand Up @@ -38,6 +39,52 @@ macro_rules! cast_from_dtype {
}};
}

macro_rules! cast_with_dtype {
($self:expr, $data_type:expr) => {{
use DataType::*;
match $data_type {
Boolean => ChunkCast::cast::<BooleanType>($self).map(|ca| ca.into_series()),
Utf8 => ChunkCast::cast::<Utf8Type>($self).map(|ca| ca.into_series()),
#[cfg(feature = "dtype-u8")]
UInt8 => ChunkCast::cast::<UInt8Type>($self).map(|ca| ca.into_series()),
#[cfg(feature = "dtype-u16")]
UInt16 => ChunkCast::cast::<UInt16Type>($self).map(|ca| ca.into_series()),
UInt32 => ChunkCast::cast::<UInt32Type>($self).map(|ca| ca.into_series()),
#[cfg(feature = "dtype-u64")]
UInt64 => ChunkCast::cast::<UInt64Type>($self).map(|ca| ca.into_series()),
#[cfg(feature = "dtype-i8")]
Int8 => ChunkCast::cast::<Int8Type>($self).map(|ca| ca.into_series()),
#[cfg(feature = "dtype-i16")]
Int16 => ChunkCast::cast::<Int16Type>($self).map(|ca| ca.into_series()),
Int32 => ChunkCast::cast::<Int32Type>($self).map(|ca| ca.into_series()),
Int64 => ChunkCast::cast::<Int64Type>($self).map(|ca| ca.into_series()),
Float32 => ChunkCast::cast::<Float32Type>($self).map(|ca| ca.into_series()),
Float64 => ChunkCast::cast::<Float64Type>($self).map(|ca| ca.into_series()),
#[cfg(feature = "dtype-date32")]
Date32 => ChunkCast::cast::<Date32Type>($self).map(|ca| ca.into_series()),
#[cfg(feature = "dtype-date64")]
Date64 => ChunkCast::cast::<Date64Type>($self).map(|ca| ca.into_series()),
#[cfg(feature = "dtype-time64-ns")]
Time64(TimeUnit::Nanosecond) => {
ChunkCast::cast::<Time64NanosecondType>($self).map(|ca| ca.into_series())
}
#[cfg(feature = "dtype-duration-ns")]
Duration(TimeUnit::Nanosecond) => {
ChunkCast::cast::<DurationNanosecondType>($self).map(|ca| ca.into_series())
}
#[cfg(feature = "dtype-duration-ms")]
Duration(TimeUnit::Millisecond) => {
ChunkCast::cast::<DurationMillisecondType>($self).map(|ca| ca.into_series())
}
List(_) => ChunkCast::cast::<ListType>($self).map(|ca| ca.into_series()),
Categorical => ChunkCast::cast::<CategoricalType>($self).map(|ca| ca.into_series()),
dt => Err(PolarsError::Other(
format!("Casting to {:?} is not supported", dt).into(),
)),
}
}};
}

impl ChunkCast for CategoricalChunked {
fn cast<N>(&self) -> Result<ChunkedArray<N>>
where
Expand Down Expand Up @@ -72,6 +119,9 @@ impl ChunkCast for CategoricalChunked {
_ => cast_ca(self),
}
}
fn cast_with_dtype(&self, data_type: &DataType) -> Result<Series> {
cast_with_dtype!(self, data_type)
}
}

impl<T> ChunkCast for ChunkedArray<T>
Expand All @@ -98,7 +148,7 @@ where
// to float32
(Duration(_), Float32) | (Date32, Float32) | (Date64, Float32)
// to float64
| (Duration(_), Float64) | (Date32, Float64) | (Date64, Float64)
| (Duration(_), Float64) | (Date32, Float64) | (Date64, Float64)
// underlying type: i64
| (Duration(_), UInt64)
=> {
Expand All @@ -121,19 +171,10 @@ where
ca
})
}
}

macro_rules! impl_chunkcast {
($ca_type:ident) => {
impl ChunkCast for $ca_type {
fn cast<N>(&self) -> Result<ChunkedArray<N>>
where
N: PolarsDataType,
{
cast_ca(self)
}
}
};
fn cast_with_dtype(&self, data_type: &DataType) -> Result<Series> {
cast_with_dtype!(self, data_type)
}
}

impl ChunkCast for Utf8Chunked {
Expand All @@ -153,7 +194,102 @@ impl ChunkCast for Utf8Chunked {
_ => cast_ca(self),
}
}
fn cast_with_dtype(&self, data_type: &DataType) -> Result<Series> {
cast_with_dtype!(self, data_type)
}
}

impl ChunkCast for BooleanChunked {
fn cast<N>(&self) -> Result<ChunkedArray<N>>
where
N: PolarsDataType,
{
cast_ca(self)
}
fn cast_with_dtype(&self, data_type: &DataType) -> Result<Series> {
cast_with_dtype!(self, data_type)
}
}

impl_chunkcast!(BooleanChunked);
impl_chunkcast!(ListChunked);
/// We cannot cast anything to or from List/LargeList
/// So this implementation casts the inner tyupe
impl ChunkCast for ListChunked {
fn cast<N>(&self) -> Result<ChunkedArray<N>>
where
N: PolarsDataType,
{
match N::get_dtype() {
DataType::List(child_type) => {
let chunks = self
.downcast_iter()
.map(|list| {
let ad = list.data().clone();
let child = ad.child_data()[0].clone();
let child = make_array(child);
let child = cast(&child, &child_type)?;

let new = ArrayDataBuilder::new(ArrowDataType::LargeList(Box::new(
ArrowField::new("", child.data_type().clone(), true),
)))
.len(list.len())
.buffers(ad.buffers().to_vec())
.add_child_data(child.data().clone())
.build();
Ok(make_array(new))
})
.collect::<Result<_>>()?;
let ca = ListChunked::new_from_chunks(self.name(), chunks);
let ca = unsafe { std::mem::transmute(ca) };
Ok(ca)
}
_ => Err(PolarsError::Other("Cannot cast list type".into())),
}
}
fn cast_with_dtype(&self, data_type: &DataType) -> Result<Series> {
match data_type {
DataType::List(child_type) => {
let chunks = self
.downcast_iter()
.map(|list| {
let ad = list.data().clone();
let child = ad.child_data()[0].clone();
let child = make_array(child);
let child = cast(&child, &child_type)?;

let new = ArrayDataBuilder::new(ArrowDataType::LargeList(Box::new(
ArrowField::new("", child.data_type().clone(), true),
)))
.len(list.len())
.buffers(ad.buffers().to_vec())
.add_child_data(child.data().clone())
.build();
Ok(make_array(new))
})
.collect::<Result<_>>()?;
let ca = ListChunked::new_from_chunks(self.name(), chunks);
Ok(ca.into_series())
}
_ => Err(PolarsError::Other("Cannot cast list type".into())),
}
}
}

#[cfg(test)]
mod test {
use crate::prelude::*;
use polars_arrow::builder::PrimitiveArrayBuilder;

#[test]
fn test_cast_list() -> Result<()> {
let mut builder =
ListPrimitiveChunkedBuilder::<Int32Type>::new("a", PrimitiveArrayBuilder::new(10), 10);
builder.append_slice(Some(&[1i32, 2, 3]));
builder.append_slice(Some(&[1i32, 2, 3]));
let ca = builder.finish();

let new = ca.cast_with_dtype(&DataType::List(ArrowDataType::Float64))?;

assert_eq!(new.dtype(), &DataType::List(ArrowDataType::Float64));
Ok(())
}
}
2 changes: 2 additions & 0 deletions polars/polars-core/src/chunked_array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,8 @@ pub trait ChunkCast {
fn cast<N>(&self) -> Result<ChunkedArray<N>>
where
N: PolarsDataType;

fn cast_with_dtype(&self, data_type: &DataType) -> Result<Series>;
}

/// Fastest way to do elementwise operations on a ChunkedArray<T> when the operation is cheaper than
Expand Down
4 changes: 2 additions & 2 deletions polars/polars-core/src/frame/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ fn get_supertype_all(df: &DataFrame, rhs: &Series) -> Result<DataType> {
macro_rules! impl_arithmetic {
($self:expr, $rhs:expr, $operand: tt) => {{
let st = get_supertype_all($self, $rhs)?;
let rhs = $rhs.cast_with_datatype(&st)?;
let rhs = $rhs.cast_with_dtype(&st)?;
let cols = $self.columns.par_iter().map(|s| {
Ok(&s.cast_with_datatype(&st)? $operand &rhs)
Ok(&s.cast_with_dtype(&st)? $operand &rhs)
}).collect::<Result<_>>()?;
Ok(DataFrame::new_no_checks(cols))
}}
Expand Down
4 changes: 2 additions & 2 deletions polars/polars-core/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -796,12 +796,12 @@ impl DataFrame {
.into_iter()
.map(|s| match s.dtype() {
DataType::Utf8 => s,
_ => s.cast_with_datatype(&dtype).expect("supertype is known"),
_ => s.cast_with_dtype(&dtype).expect("supertype is known"),
})
.collect::<Vec<_>>();

if !matches!(first.dtype(), DataType::Utf8) {
first = first.cast_with_datatype(&dtype)?;
first = first.cast_with_dtype(&dtype)?;
}
}

Expand Down
4 changes: 2 additions & 2 deletions polars/polars-core/src/series/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,12 @@ pub(crate) fn coerce_lhs_rhs<'a>(
let left = if lhs.dtype() == &dtype {
Cow::Borrowed(lhs)
} else {
Cow::Owned(lhs.cast_with_datatype(&dtype)?)
Cow::Owned(lhs.cast_with_dtype(&dtype)?)
};
let right = if rhs.dtype() == &dtype {
Cow::Borrowed(rhs)
} else {
Cow::Owned(rhs.cast_with_datatype(&dtype)?)
Cow::Owned(rhs.cast_with_dtype(&dtype)?)
};
Ok((left, right))
}
Expand Down

0 comments on commit 47539ad

Please sign in to comment.