Skip to content

Commit

Permalink
don't use arrow dtype for inner list dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Oct 8, 2021
1 parent 0592f15 commit 792ca8d
Show file tree
Hide file tree
Showing 15 changed files with 47 additions and 53 deletions.
6 changes: 3 additions & 3 deletions polars/polars-core/src/chunked_array/builder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ where
pub fn new(name: &str, capacity: usize, values_capacity: usize) -> Self {
let values = MutablePrimitiveArray::<T::Native>::with_capacity(values_capacity);
let builder = LargePrimitiveBuilder::<T::Native>::new_with_capacity(values, capacity);
let field = Field::new(name, DataType::List(T::get_dtype().to_arrow()));
let field = Field::new(name, DataType::List(Box::new(T::get_dtype())));

Self {
builder,
Expand Down Expand Up @@ -477,7 +477,7 @@ impl ListUtf8ChunkedBuilder {
pub fn new(name: &str, capacity: usize, values_capacity: usize) -> Self {
let values = MutableUtf8Array::<i64>::with_capacity(values_capacity);
let builder = LargeListUtf8Builder::new_with_capacity(values, capacity);
let field = Field::new(name, DataType::List(ArrowDataType::LargeUtf8));
let field = Field::new(name, DataType::List(Box::new(DataType::Utf8)));

ListUtf8ChunkedBuilder {
builder,
Expand Down Expand Up @@ -528,7 +528,7 @@ impl ListBooleanChunkedBuilder {
pub fn new(name: &str, capacity: usize, values_capacity: usize) -> Self {
let values = MutableBooleanArray::with_capacity(values_capacity);
let builder = LargeListBooleanBuilder::new_with_capacity(values, capacity);
let field = Field::new(name, DataType::List(ArrowDataType::Boolean));
let field = Field::new(name, DataType::List(Box::new(DataType::Boolean)));

Self {
builder,
Expand Down
15 changes: 6 additions & 9 deletions polars/polars-core/src/chunked_array/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,12 @@ impl ChunkCast for BooleanChunked {
}
}

fn cast_inner_list_type(
list: &ListArray<i64>,
child_type: &arrow::datatypes::DataType,
) -> Result<ArrayRef> {
fn cast_inner_list_type(list: &ListArray<i64>, child_type: &DataType) -> Result<ArrayRef> {
let child = list.values();
let offsets = list.offsets();
let child = cast::cast(child.as_ref(), child_type)?.into();
let child = cast::cast(child.as_ref(), &child_type.to_arrow())?.into();

let data_type = ListArray::<i64>::default_datatype(child_type.clone());
let data_type = ListArray::<i64>::default_datatype(child_type.to_arrow());
let list = ListArray::from_data(data_type, offsets.clone(), child, list.validity().cloned());
Ok(Arc::new(list) as ArrayRef)
}
Expand All @@ -137,7 +134,7 @@ impl ChunkCast for ListChunked {
DataType::List(child_type) => {
let chunks = self
.downcast_iter()
.map(|list| cast_inner_list_type(list, child_type))
.map(|list| cast_inner_list_type(list, &**child_type))
.collect::<Result<_>>()?;
let ca = ListChunked::new_from_chunks(self.name(), chunks);
Ok(ca.into_series())
Expand All @@ -158,9 +155,9 @@ mod test {
builder.append_slice(Some(&[1i32, 2, 3]));
let ca = builder.finish();

let new = ca.cast(&DataType::List(ArrowDataType::Float64))?;
let new = ca.cast(&DataType::List(DataType::Float64.into()))?;

assert_eq!(new.dtype(), &DataType::List(ArrowDataType::Float64));
assert_eq!(new.dtype(), &DataType::List(DataType::Float64.into()));
Ok(())
}

Expand Down
2 changes: 1 addition & 1 deletion polars/polars-core/src/chunked_array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,7 @@ impl ListChunked {
/// Get the inner data type of the list.
pub fn inner_dtype(&self) -> DataType {
match self.dtype() {
DataType::List(dt) => dt.into(),
DataType::List(dt) => *dt.clone(),
_ => unreachable!(),
}
}
Expand Down
8 changes: 4 additions & 4 deletions polars/polars-core/src/chunked_array/ops/is_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ where
// We check implicitly cast to supertype here
match other.dtype() {
DataType::List(dt) => {
let st = get_supertype(self.dtype(), &dt.into())?;
let st = get_supertype(self.dtype(), dt)?;
if &st != self.dtype() {
let left = self.cast(&st)?;
let right = other.cast(&DataType::List(st.to_arrow()))?;
let right = other.cast(&DataType::List(Box::new(st)))?;
return left.is_in(&right);
}

Expand Down Expand Up @@ -104,7 +104,7 @@ where
impl IsIn for Utf8Chunked {
fn is_in(&self, other: &Series) -> Result<BooleanChunked> {
match other.dtype() {
DataType::List(dt) if self.dtype() == dt => {
DataType::List(dt) if self.dtype() == &**dt => {
let ca: BooleanChunked = self
.into_iter()
.zip(other.list()?.into_iter())
Expand Down Expand Up @@ -149,7 +149,7 @@ impl IsIn for Utf8Chunked {
impl IsIn for BooleanChunked {
fn is_in(&self, other: &Series) -> Result<BooleanChunked> {
match other.dtype() {
DataType::List(dt) if self.dtype() == dt => {
DataType::List(dt) if self.dtype() == &**dt => {
let ca: BooleanChunked = self
.into_iter()
.zip(other.list()?.into_iter())
Expand Down
12 changes: 6 additions & 6 deletions polars/polars-core/src/datatypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ impl PolarsDataType for BooleanType {
impl PolarsDataType for ListType {
fn get_dtype() -> DataType {
// null as we cannot no anything without self.
DataType::List(ArrowDataType::Null)
DataType::List(Box::new(DataType::Null))
}
}

Expand Down Expand Up @@ -368,7 +368,7 @@ impl Display for DataType {
DataType::Utf8 => "str",
DataType::Date => "Date(days)",
DataType::Datetime => "datetime(ms)",
DataType::List(tp) => return write!(f, "list [{}]", DataType::from(tp)),
DataType::List(tp) => return write!(f, "list [{}]", tp),
#[cfg(feature = "object")]
DataType::Object(s) => s,
DataType::Categorical => "cat",
Expand Down Expand Up @@ -458,7 +458,7 @@ pub enum DataType {
/// A 64-bit date representing the elapsed time since UNIX epoch (1970-01-01)
/// in milliseconds (64 bits).
Datetime,
List(ArrowDataType),
List(Box<DataType>),
#[cfg(feature = "object")]
/// A generic type that can be used in a `Series`
/// &'static str can be used to determine/set inner type
Expand Down Expand Up @@ -498,7 +498,7 @@ impl DataType {
Datetime => ArrowDataType::Date64,
List(dt) => ArrowDataType::LargeList(Box::new(arrow::datatypes::Field::new(
"",
dt.clone(),
dt.to_arrow(),
true,
))),
Null => ArrowDataType::Null,
Expand Down Expand Up @@ -631,7 +631,7 @@ impl Schema {
f.name(),
ArrowDataType::LargeList(Box::new(ArrowField::new(
"item",
dt.clone(),
dt.to_arrow(),
true,
))),
true,
Expand Down Expand Up @@ -700,7 +700,7 @@ impl From<&ArrowDataType> for DataType {
ArrowDataType::Boolean => DataType::Boolean,
ArrowDataType::Float32 => DataType::Float32,
ArrowDataType::Float64 => DataType::Float64,
ArrowDataType::LargeList(f) => DataType::List(f.data_type().clone()),
ArrowDataType::LargeList(f) => DataType::List(Box::new(f.data_type().into())),
ArrowDataType::Date32 => DataType::Date,
ArrowDataType::Date64 => DataType::Datetime,
ArrowDataType::Utf8 => DataType::Utf8,
Expand Down
19 changes: 13 additions & 6 deletions polars/polars-core/src/series/implementations/dates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,10 @@ macro_rules! impl_dyn_series {

fn agg_list(&self, groups: &[(u32, Vec<u32>)]) -> Option<Series> {
// we cannot cast and dispatch as the inner type of the list would be incorrect
self.0
.agg_list(groups)
.map(|s| s.cast(&DataType::List(self.dtype().to_arrow())).unwrap())
self.0.agg_list(groups).map(|s| {
s.cast(&DataType::List(Box::new(self.dtype().clone())))
.unwrap()
})
}

fn agg_quantile(&self, groups: &[(u32, Vec<u32>)], quantile: f64) -> Option<Series> {
Expand Down Expand Up @@ -625,15 +626,15 @@ macro_rules! impl_dyn_series {
DataType::Date => self
.0
.repeat_by(by)
.cast(&DataType::List(ArrowDataType::Date32))
.cast(&DataType::List(Box::new(DataType::Date)))
.unwrap()
.list()
.unwrap()
.clone(),
DataType::Datetime => self
.0
.repeat_by(by)
.cast(&DataType::List(ArrowDataType::Date64))
.cast(&DataType::List(Box::new(DataType::Datetime)))
.unwrap()
.list()
.unwrap()
Expand Down Expand Up @@ -700,7 +701,13 @@ mod test {
let s = s.cast(&DataType::Datetime)?;

let l = s.agg_list(&[(0, vec![0, 1, 2])]).unwrap();
assert!(matches!(l.dtype(), DataType::List(ArrowDataType::Date64)));

match l.dtype() {
DataType::List(inner) => {
assert!(matches!(&**inner, DataType::Datetime))
}
_ => assert!(false),
}

Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-lazy/src/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1318,7 +1318,7 @@ impl Expr {
map_binary_lazy_field(self, by, function, |_schema, _ctxt, l, _r| {
Some(Field::new(
l.name(),
DataType::List(l.data_type().to_arrow()),
DataType::List(l.data_type().clone().into()),
))
})
}
Expand Down
5 changes: 1 addition & 4 deletions polars/polars-lazy/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,7 @@ pub fn arange(low: Expr, high: Expr, step: usize) -> Expr {
low,
high,
f,
Some(Field::new(
"arange",
DataType::List(DataType::Int64.to_arrow()),
)),
Some(Field::new("arange", DataType::List(DataType::Int64.into()))),
)
}
}
2 changes: 1 addition & 1 deletion polars/polars-lazy/src/logical_plan/aexpr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ impl AExpr {
AggGroups(expr) => {
let field = arena.get(*expr).to_field(schema, ctxt, arena)?;
let new_name = fmt_groupby_column(field.name(), GroupByMethod::Groups);
Field::new(&new_name, DataType::List(ArrowDataType::UInt32))
Field::new(&new_name, DataType::List(DataType::UInt32.into()))
}
Quantile { expr, quantile } => {
let mut field = field_by_context(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ impl PhysicalAggregation for AggregationExpr {
let new_name = fmt_groupby_column(ca.name(), self.agg_type);

let values_type = match ca.dtype() {
DataType::List(dt) => DataType::from(dt),
DataType::List(dt) => *dt.clone(),
_ => unreachable!(),
};

Expand Down
5 changes: 2 additions & 3 deletions py-polars/src/apply/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1049,8 +1049,7 @@ impl<'a> ApplyLambda<'a> for ListChunked {

match self.dtype() {
DataType::List(dt) => {
let mut builder =
get_list_builder(&dt.into(), self.len() * 5, self.len(), self.name());
let mut builder = get_list_builder(dt, self.len() * 5, self.len(), self.name());
if self.null_count() == 0 {
let mut it = self.into_no_null_iter();
// use first value to get dtype and replace default builder
Expand All @@ -1061,7 +1060,7 @@ impl<'a> ApplyLambda<'a> for ListChunked {
builder = get_list_builder(dt, self.len() * 5, self.len(), self.name());
builder.append_opt_series(Some(&out_series));
} else {
let mut builder = get_list_builder(&dt.into(), 0, 1, self.name());
let mut builder = get_list_builder(dt, 0, 1, self.name());
let ca = builder.finish();
return Ok(PySeries::new(ca.into_series()));
}
Expand Down
2 changes: 1 addition & 1 deletion py-polars/src/datatypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ impl From<&DataType> for PyDataType {
DataType::Utf8 => Utf8,
DataType::List(_) => List,
DataType::Date => Date,
DataType::Datetime=> Datetime,
DataType::Datetime => Datetime,
DataType::Object(_) => Object,
DataType::Categorical => Categorical,
dt => panic!("datatype: {:?} not supported", dt),
Expand Down
6 changes: 3 additions & 3 deletions py-polars/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -854,7 +854,7 @@ impl PyExpr {
|s| Ok(s.list()?.lst_max()),
GetOutput::map_field(|f| {
if let DataType::List(adt) = f.data_type() {
Field::new(f.name(), adt.into())
Field::new(f.name(), *adt.clone())
} else {
// inner type
f.clone()
Expand All @@ -871,7 +871,7 @@ impl PyExpr {
|s| Ok(s.list()?.lst_min()),
GetOutput::map_field(|f| {
if let DataType::List(adt) = f.data_type() {
Field::new(f.name(), adt.into())
Field::new(f.name(), *adt.clone())
} else {
// inner type
f.clone()
Expand All @@ -888,7 +888,7 @@ impl PyExpr {
|s| Ok(s.list()?.lst_sum()),
GetOutput::map_field(|f| {
if let DataType::List(adt) = f.data_type() {
Field::new(f.name(), adt.into())
Field::new(f.name(), *adt.clone())
} else {
// inner type
f.clone()
Expand Down
12 changes: 3 additions & 9 deletions py-polars/src/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ impl PySeries {
DataType::Float32 => PyList::new(python, series.f32().unwrap()),
DataType::Float64 => PyList::new(python, series.f64().unwrap()),
DataType::Date => PyList::new(python, &series.date().unwrap().0),
DataType::Datetime=> PyList::new(python, &series.datetime().unwrap().0),
DataType::Datetime => PyList::new(python, &series.datetime().unwrap().0),
DataType::Object(_) => {
let v = PyList::empty(python);
for i in 0..series.len() {
Expand Down Expand Up @@ -1132,14 +1132,8 @@ impl PySeries {
pub fn round_datetime(&self, rule: &str, n: u32) -> PyResult<Self> {
let rule = downsample_str_to_rule(rule, n)?;
match self.series.dtype() {
DataType::Date => Ok(self
.series
.date()
.unwrap()
.round(rule)
.into_series()
.into()),
DataType::Datetime=> Ok(self
DataType::Date => Ok(self.series.date().unwrap().round(rule).into_series().into()),
DataType::Datetime => Ok(self
.series
.datetime()
.unwrap()
Expand Down
2 changes: 1 addition & 1 deletion py-polars/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub fn str_to_polarstype(s: &str) -> DataType {
"<class 'polars.datatypes.Utf8'>" => DataType::Utf8,
"<class 'polars.datatypes.Date'>" => DataType::Date,
"<class 'polars.datatypes.Datetime'>" => DataType::Datetime,
"<class 'polars.datatypes.List'>" => DataType::List(ArrowDataType::Null),
"<class 'polars.datatypes.List'>" => DataType::List(DataType::Null.into()),
"<class 'polars.datatypes.Categorical'>" => DataType::Categorical,
"<class 'polars.datatypes.Object'>" => DataType::Object("object"),
tp => panic!("Type {} not implemented in str_to_polarstype", tp),
Expand Down

0 comments on commit 792ca8d

Please sign in to comment.