Skip to content

Commit

Permalink
fix[rust]: determine supertype of structs (#4916)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Sep 21, 2022
1 parent 8640c14 commit cc688ea
Show file tree
Hide file tree
Showing 3 changed files with 289 additions and 264 deletions.
266 changes: 2 additions & 264 deletions polars/polars-core/src/utils/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub(crate) mod series;
mod supertype;

use std::borrow::Cow;
use std::ops::{Deref, DerefMut};
Expand All @@ -7,6 +8,7 @@ use arrow::bitmap::Bitmap;
pub use polars_arrow::utils::{TrustMyLength, *};
use rayon::prelude::*;
pub use series::*;
pub use supertype::*;
pub use {arrow, rayon};

#[cfg(feature = "private")]
Expand Down Expand Up @@ -466,270 +468,6 @@ pub fn get_time_units(tu_l: &TimeUnit, tu_r: &TimeUnit) -> TimeUnit {
}
}

/// Given two datatypes, determine the supertype that both types can safely be cast to
#[cfg(feature = "private")]
pub fn try_get_supertype(l: &DataType, r: &DataType) -> PolarsResult<DataType> {
match get_supertype(l, r) {
Some(dt) => Ok(dt),
None => Err(PolarsError::ComputeError(
format!("Failed to determine supertype of {:?} and {:?}", l, r).into(),
)),
}
}

/// Given two datatypes, determine the supertype that both types can safely be cast to
#[cfg(feature = "private")]
pub fn get_supertype(l: &DataType, r: &DataType) -> Option<DataType> {
fn inner(l: &DataType, r: &DataType) -> Option<DataType> {
use DataType::*;
if l == r {
return Some(l.clone());
}

match (l, r) {
#[cfg(feature = "dtype-i8")]
(Int8, Boolean) => Some(Int8),
//(Int8, Int8) => Some(Int8),
#[cfg(all(feature = "dtype-i8", feature = "dtype-i16"))]
(Int8, Int16) => Some(Int16),
#[cfg(feature = "dtype-i8")]
(Int8, Int32) => Some(Int32),
#[cfg(feature = "dtype-i8")]
(Int8, Int64) => Some(Int64),
#[cfg(all(feature = "dtype-i8", feature = "dtype-i16"))]
(Int8, UInt8) => Some(Int16),
#[cfg(all(feature = "dtype-i8", feature = "dtype-u16"))]
(Int8, UInt16) => Some(Int32),
#[cfg(feature = "dtype-i8")]
(Int8, UInt32) => Some(Int64),
#[cfg(feature = "dtype-i8")]
(Int8, UInt64) => Some(Float64), // Follow numpy
#[cfg(feature = "dtype-i8")]
(Int8, Float32) => Some(Float32),
#[cfg(feature = "dtype-i8")]
(Int8, Float64) => Some(Float64),

#[cfg(feature = "dtype-i16")]
(Int16, Boolean) => Some(Int16),
#[cfg(all(feature = "dtype-i16", feature = "dtype-i8"))]
(Int16, Int8) => Some(Int16),
//(Int16, Int16) => Some(Int16),
#[cfg(feature = "dtype-i16")]
(Int16, Int32) => Some(Int32),
#[cfg(feature = "dtype-i16")]
(Int16, Int64) => Some(Int64),
#[cfg(all(feature = "dtype-i16", feature = "dtype-u8"))]
(Int16, UInt8) => Some(Int16),
#[cfg(all(feature = "dtype-i16", feature = "dtype-u16"))]
(Int16, UInt16) => Some(Int32),
#[cfg(feature = "dtype-i16")]
(Int16, UInt32) => Some(Int64),
#[cfg(feature = "dtype-i16")]
(Int16, UInt64) => Some(Float64), // Follow numpy
#[cfg(feature = "dtype-i16")]
(Int16, Float32) => Some(Float32),
#[cfg(feature = "dtype-i16")]
(Int16, Float64) => Some(Float64),

(Int32, Boolean) => Some(Int32),
#[cfg(feature = "dtype-i8")]
(Int32, Int8) => Some(Int32),
#[cfg(feature = "dtype-i16")]
(Int32, Int16) => Some(Int32),
//(Int32, Int32) => Some(Int32),
(Int32, Int64) => Some(Int64),
#[cfg(feature = "dtype-u8")]
(Int32, UInt8) => Some(Int32),
#[cfg(feature = "dtype-u16")]
(Int32, UInt16) => Some(Int32),
(Int32, UInt32) => Some(Int64),
#[cfg(not(feature = "bigidx"))]
(Int32, UInt64) => Some(Float64), // Follow numpy
#[cfg(feature = "bigidx")]
(Int32, UInt64) => Some(Int64), // Needed for bigidx
(Int32, Float32) => Some(Float64), // Follow numpy
(Int32, Float64) => Some(Float64),

(Int64, Boolean) => Some(Int64),
#[cfg(feature = "dtype-i8")]
(Int64, Int8) => Some(Int64),
#[cfg(feature = "dtype-i16")]
(Int64, Int16) => Some(Int64),
(Int64, Int32) => Some(Int64),
//(Int64, Int64) => Some(Int64),
#[cfg(feature = "dtype-u8")]
(Int64, UInt8) => Some(Int64),
#[cfg(feature = "dtype-u16")]
(Int64, UInt16) => Some(Int64),
(Int64, UInt32) => Some(Int64),
#[cfg(not(feature = "bigidx"))]
(Int64, UInt64) => Some(Float64), // Follow numpy
#[cfg(feature = "bigidx")]
(Int64, UInt64) => Some(Int64), // Needed for bigidx
(Int64, Float32) => Some(Float64), // Follow numpy
(Int64, Float64) => Some(Float64),

#[cfg(all(feature = "dtype-u16", feature = "dtype-u8"))]
(UInt16, UInt8) => Some(UInt16),
#[cfg(feature = "dtype-u16")]
(UInt16, UInt32) => Some(UInt32),
#[cfg(feature = "dtype-u16")]
(UInt16, UInt64) => Some(UInt64),

#[cfg(feature = "dtype-u8")]
(UInt8, UInt32) => Some(UInt32),
#[cfg(feature = "dtype-u8")]
(UInt8, UInt64) => Some(UInt64),

(UInt32, UInt64) => Some(UInt64),

#[cfg(feature = "dtype-u8")]
(Boolean, UInt8) => Some(UInt8),
#[cfg(feature = "dtype-u16")]
(Boolean, UInt16) => Some(UInt16),
(Boolean, UInt32) => Some(UInt32),
(Boolean, UInt64) => Some(UInt64),

#[cfg(feature = "dtype-u8")]
(Float32, UInt8) => Some(Float32),
#[cfg(feature = "dtype-u16")]
(Float32, UInt16) => Some(Float32),
(Float32, UInt32) => Some(Float64),
(Float32, UInt64) => Some(Float64),

#[cfg(feature = "dtype-u8")]
(Float64, UInt8) => Some(Float64),
#[cfg(feature = "dtype-u16")]
(Float64, UInt16) => Some(Float64),
(Float64, UInt32) => Some(Float64),
(Float64, UInt64) => Some(Float64),

(Float64, Float32) => Some(Float64),

// Time related dtypes
#[cfg(feature = "dtype-date")]
(Date, UInt32) => Some(Int64),
#[cfg(feature = "dtype-date")]
(Date, UInt64) => Some(Int64),
#[cfg(feature = "dtype-date")]
(Date, Int32) => Some(Int32),
#[cfg(feature = "dtype-date")]
(Date, Int64) => Some(Int64),
#[cfg(feature = "dtype-date")]
(Date, Float32) => Some(Float32),
#[cfg(feature = "dtype-date")]
(Date, Float64) => Some(Float64),
#[cfg(all(feature = "dtype-date", feature = "dtype-datetime"))]
(Date, Datetime(tu, tz)) => Some(Datetime(*tu, tz.clone())),

#[cfg(feature = "dtype-datetime")]
(Datetime(_, _), UInt32) => Some(Int64),
#[cfg(feature = "dtype-datetime")]
(Datetime(_, _), UInt64) => Some(Int64),
#[cfg(feature = "dtype-datetime")]
(Datetime(_, _), Int32) => Some(Int64),
#[cfg(feature = "dtype-datetime")]
(Datetime(_, _), Int64) => Some(Int64),
#[cfg(feature = "dtype-datetime")]
(Datetime(_, _), Float32) => Some(Float64),
#[cfg(feature = "dtype-datetime")]
(Datetime(_, _), Float64) => Some(Float64),
#[cfg(all(feature = "dtype-datetime", feature = "dtype=date"))]
(Datetime(tu, tz), Date) => Some(Datetime(*tu, tz.clone())),

(Boolean, Float32) => Some(Float32),
(Boolean, Float64) => Some(Float64),

#[cfg(feature = "dtype-duration")]
(Duration(_), UInt32) => Some(Int64),
#[cfg(feature = "dtype-duration")]
(Duration(_), UInt64) => Some(Int64),
#[cfg(feature = "dtype-duration")]
(Duration(_), Int32) => Some(Int64),
#[cfg(feature = "dtype-duration")]
(Duration(_), Int64) => Some(Int64),
#[cfg(feature = "dtype-duration")]
(Duration(_), Float32) => Some(Float64),
#[cfg(feature = "dtype-duration")]
(Duration(_), Float64) => Some(Float64),

#[cfg(feature = "dtype-time")]
(Time, Int32) => Some(Int64),
#[cfg(feature = "dtype-time")]
(Time, Int64) => Some(Int64),
#[cfg(feature = "dtype-time")]
(Time, Float32) => Some(Float64),
#[cfg(feature = "dtype-time")]
(Time, Float64) => Some(Float64),

#[cfg(all(feature = "dtype-time", feature = "dtype-datetime"))]
(Time, Datetime(_, _)) => Some(Int64),
#[cfg(all(feature = "dtype-datetime", feature = "dtype-time"))]
(Datetime(_, _), Time) => Some(Int64),
#[cfg(all(feature = "dtype-time", feature = "dtype-date"))]
(Time, Date) => Some(Int64),
#[cfg(all(feature = "dtype-date", feature = "dtype-time"))]
(Date, Time) => Some(Int64),

// every known type can be casted to a string
(dt, Utf8) if dt != &DataType::Unknown => Some(Utf8),

(dt, Null) => Some(dt.clone()),

#[cfg(all(feature = "dtype-duration", feature = "dtype-datetime"))]
(Duration(lu), Datetime(ru, Some(tz))) | (Datetime(lu, Some(tz)), Duration(ru)) => {
if tz.is_empty() {
Some(Datetime(get_time_units(lu, ru), None))
} else {
Some(Datetime(get_time_units(lu, ru), Some(tz.clone())))
}
}
#[cfg(all(feature = "dtype-duration", feature = "dtype-datetime"))]
(Duration(lu), Datetime(ru, None)) | (Datetime(lu, None), Duration(ru)) => {
Some(Datetime(get_time_units(lu, ru), None))
}
#[cfg(all(feature = "dtype-duration", feature = "dtype-date"))]
(Duration(_), Date) | (Date, Duration(_)) => Some(Date),
#[cfg(feature = "dtype-duration")]
(Duration(lu), Duration(ru)) => Some(Duration(get_time_units(lu, ru))),

// None and Some("") timezones
// we cast from more precision to higher precision as that always fits with occasional loss of precision
#[cfg(feature = "dtype-datetime")]
(Datetime(tu_l, tz_l), Datetime(tu_r, tz_r))
if (tz_l.is_none() || tz_l.as_deref() == Some(""))
&& (tz_r.is_none() || tz_r.as_deref() == Some("")) =>
{
let tu = get_time_units(tu_l, tu_r);
Some(Datetime(tu, None))
}
// None and Some("<tz>") timezones
// we cast from more precision to higher precision as that always fits with occasional loss of precision
#[cfg(feature = "dtype-datetime")]
(Datetime(tu_l, tz_l), Datetime(tu_r, tz_r)) if tz_l.is_none() && tz_r.is_some() => {
let tu = get_time_units(tu_l, tu_r);
Some(Datetime(tu, tz_r.clone()))
}
(List(inner_left), List(inner_right)) => {
let st = get_supertype(inner_left, inner_right)?;
Some(DataType::List(Box::new(st)))
}
// todo! check if can be removed
(List(inner), other) | (other, List(inner)) => {
let st = get_supertype(inner, other)?;
Some(DataType::List(Box::new(st)))
}
(_, Unknown) => Some(Unknown),
_ => None,
}
}

match inner(l, r) {
Some(dt) => Some(dt),
None => inner(r, l),
}
}

/// This takes ownership of the DataFrame so that drop is called earlier.
/// Does not check if schema is correct
pub fn accumulate_dataframes_vertical_unchecked<I>(dfs: I) -> DataFrame
Expand Down

0 comments on commit cc688ea

Please sign in to comment.