Skip to content

Commit

Permalink
recursively convert arrow (#3200)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Apr 20, 2022
1 parent e37a27e commit 1357724
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 30 deletions.
69 changes: 40 additions & 29 deletions polars/polars-core/src/series/from.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,8 @@ impl Series {
let chunks = cast_chunks(&chunks, &DataType::Utf8).unwrap();
Ok(Utf8Chunked::from_chunks(name, chunks).into_series())
}
ArrowDataType::List(fld) => {
let chunks = chunks
.iter()
.map(|arr| {
let arr: ArrayRef =
cast(arr.as_ref(), &ArrowDataType::LargeList(fld.clone()))
.unwrap()
.into();
convert_list_inner(&arr, fld)
})
.collect();
ArrowDataType::List(_) => {
let chunks = chunks.iter().map(convert_inner_types).collect();
Ok(ListChunked::from_chunks(name, chunks).into_series())
}
ArrowDataType::Boolean => Ok(BooleanChunked::from_chunks(name, chunks).into_series()),
Expand Down Expand Up @@ -174,11 +165,8 @@ impl Series {
ArrowTimeUnit::Nanosecond => s,
})
}
ArrowDataType::LargeList(fld) => {
let chunks = chunks
.iter()
.map(|arr| convert_list_inner(arr, fld))
.collect();
ArrowDataType::LargeList(_) => {
let chunks = chunks.iter().map(convert_inner_types).collect();
Ok(ListChunked::from_chunks(name, chunks).into_series())
}
ArrowDataType::Null => {
Expand Down Expand Up @@ -390,6 +378,7 @@ impl Series {
} else {
chunks[0].clone()
};
let arr = convert_inner_types(&arr);
let struct_arr = arr.as_any().downcast_ref::<StructArray>().unwrap();
assert!(
struct_arr.validity().is_none(),
Expand All @@ -416,22 +405,44 @@ impl Series {
}
}

fn convert_list_inner(arr: &ArrayRef, fld: &ArrowField) -> ArrayRef {
// if inner type is Utf8, we need to convert that to large utf8
match fld.data_type() {
fn convert_inner_types(arr: &ArrayRef) -> ArrayRef {
match arr.data_type() {
ArrowDataType::Utf8 => {
let arr = arr.as_any().downcast_ref::<Utf8Array<i32>>().unwrap();
Arc::from(utf8_to_large_utf8(arr))
}
ArrowDataType::List(field) => {
let out = cast(&**arr, &ArrowDataType::LargeList(field.clone())).unwrap();
convert_inner_types(&(Arc::from(out) as ArrayRef))
}
ArrowDataType::LargeList(_) => {
let arr = arr.as_any().downcast_ref::<ListArray<i64>>().unwrap();
let offsets = arr.offsets().iter().map(|x| *x as i64).collect();
let values = arr.values();
let values =
utf8_to_large_utf8(values.as_any().downcast_ref::<Utf8Array<i32>>().unwrap());
let values = convert_inner_types(arr.values());
let dtype = ListArray::<i64>::default_datatype(values.data_type().clone());
unsafe {
Arc::from(ListArray::<i64>::new_unchecked(
dtype,
arr.offsets().clone(),
values,
arr.validity().cloned(),
))
}
}
ArrowDataType::Struct(fields) => {
let arr = arr.as_any().downcast_ref::<StructArray>().unwrap();
let values = arr
.values()
.iter()
.map(convert_inner_types)
.collect::<Vec<_>>();

Arc::new(LargeListArray::from_data(
ArrowDataType::LargeList(
ArrowField::new(&fld.name, ArrowDataType::LargeUtf8, true).into(),
),
offsets,
Arc::new(values),
let fields = fields
.iter()
.map(|f| ArrowField::new(&f.name, DataType::from(&f.data_type).to_arrow(), true))
.collect();
Arc::new(StructArray::new(
ArrowDataType::Struct(fields),
values,
arr.validity().cloned(),
))
}
Expand Down
2 changes: 1 addition & 1 deletion py-polars/src/apply/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use polars::chunked_array::builder::get_list_builder;
use polars::prelude::*;
use polars_core::utils::CustomIterTools;
use polars_core::{export::rayon::prelude::*, POOL};
use pyo3::types::{PyDict, PyTuple};
use pyo3::types::PyDict;
use pyo3::{PyAny, PyResult};

pub trait PyArrowPrimitiveType: PolarsNumericType {}
Expand Down
9 changes: 9 additions & 0 deletions py-polars/tests/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,12 @@ def test_struct_logical_types_to_pandas() -> None:
timestamp = datetime(2022, 1, 1)
df = pd.DataFrame([{"struct": {"timestamp": timestamp}}])
assert pl.from_pandas(df).dtypes == [pl.Struct]


def test_recursive_arrow_conversion() -> None:
data = [{"list_of_struct": [{"a": "1"}, {"a": "2"}]}]
dfpd = pd.DataFrame(data)
df = pl.DataFrame(dfpd)
assert df.to_struct("struct").to_list() == [
{"list_of_struct": [{"a": "1"}, {"a": "2"}]}
]

0 comments on commit 1357724

Please sign in to comment.