Skip to content

Commit

Permalink
fix[rust]: keep list<cat> on rechunk (#4691)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Sep 2, 2022
1 parent ac4b9a9 commit 9496a92
Show file tree
Hide file tree
Showing 12 changed files with 125 additions and 74 deletions.
120 changes: 64 additions & 56 deletions polars/polars-core/src/chunked_array/logical/categorical/merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,71 +4,79 @@ use arrow::bitmap::MutableBitmap;

use super::*;

impl CategoricalChunked {
pub(crate) fn merge_categorical_map(&self, other: &Self) -> Result<Arc<RevMapping>> {
match (
&**self.get_rev_map(),
&**other.get_rev_map()
) {
(
RevMapping::Global(l_map, l_slots, l_id),
RevMapping::Global(r_map, r_slots, r_id),
) => {
if l_id != r_id {
return Err(PolarsError::ComputeError("The two categorical arrays are not created under the same global string cache. They cannot be merged".into()))
}
let mut new_map = (*l_map).clone();
pub(crate) fn merge_categorical_map(
left: &Arc<RevMapping>,
right: &Arc<RevMapping>,
) -> Result<Arc<RevMapping>> {
match (&**left, &**right) {
(RevMapping::Global(l_map, l_slots, l_id), RevMapping::Global(r_map, r_slots, r_id)) => {
if l_id != r_id {
return Err(PolarsError::ComputeError("The two categorical arrays are not created under the same global string cache. They cannot be merged".into()));
}
let mut new_map = (*l_map).clone();

let offset_buf = l_slots.offsets().as_slice().to_vec();
let values_buf = l_slots.values().as_slice().to_vec();
let offset_buf = l_slots.offsets().as_slice().to_vec();
let values_buf = l_slots.values().as_slice().to_vec();

let validity_buf = if let Some(validity) = l_slots.validity() {
let mut validity_buf = MutableBitmap::new();
let (b, offset, len) = validity.as_slice();
validity_buf.extend_from_slice(b, offset, len);
Some(validity_buf)
} else {
None
};
let validity_buf = if let Some(validity) = l_slots.validity() {
let mut validity_buf = MutableBitmap::new();
let (b, offset, len) = validity.as_slice();
validity_buf.extend_from_slice(b, offset, len);
Some(validity_buf)
} else {
None
};

// Safety
// all offsets are valid and the u8 data is valid utf8
let mut new_slots = unsafe {
MutableUtf8Array::from_data_unchecked(
DataType::Utf8.to_arrow(),
offset_buf,
values_buf,
validity_buf,
)
};
// Safety
// all offsets are valid and the u8 data is valid utf8
let mut new_slots = unsafe {
MutableUtf8Array::from_data_unchecked(
DataType::Utf8.to_arrow(),
offset_buf,
values_buf,
validity_buf,
)
};

for (cat, idx) in r_map.iter() {
new_map.entry(*cat).or_insert_with(|| {
// Safety
// within bounds
let str_val = unsafe { r_slots.value_unchecked(*idx as usize) };
let new_idx = new_slots.len() as u32;
new_slots.push(Some(str_val));
for (cat, idx) in r_map.iter() {
new_map.entry(*cat).or_insert_with(|| {
// Safety
// within bounds
let str_val = unsafe { r_slots.value_unchecked(*idx as usize) };
let new_idx = new_slots.len() as u32;
new_slots.push(Some(str_val));

new_idx
});
}
let new_rev = RevMapping::Global(new_map, new_slots.into(), *l_id);
Ok(Arc::new(new_rev))
new_idx
});
}
let new_rev = RevMapping::Global(new_map, new_slots.into(), *l_id);
Ok(Arc::new(new_rev))
}
(RevMapping::Local(arr_l), RevMapping::Local(arr_r)) => {
// they are from the same source, just clone
if std::ptr::eq(arr_l, arr_r) {
return Ok(left.clone());
}
(RevMapping::Local(arr_l), RevMapping::Local(arr_r)) => {
// they are from the same source, just clone
if std::ptr::eq(arr_l, arr_r) {
return Ok(self.get_rev_map().clone())
}

let arr = arrow::compute::concatenate::concatenate(&[arr_l, arr_r]).unwrap();
let arr = arr.as_any().downcast_ref::<Utf8Array<i64>>().unwrap().clone();
let arr = arrow::compute::concatenate::concatenate(&[arr_l, arr_r]).unwrap();
let arr = arr
.as_any()
.downcast_ref::<Utf8Array<i64>>()
.unwrap()
.clone();

Ok(Arc::new(RevMapping::Local(arr)))
}
_ => Err(PolarsError::ComputeError("cannot combine categorical under a global string cache with a non cached categorical".into()))
Ok(Arc::new(RevMapping::Local(arr)))
}
_ => Err(PolarsError::ComputeError(
"cannot combine categorical under a global string cache with a non cached categorical"
.into(),
)),
}
}

impl CategoricalChunked {
pub(crate) fn merge_categorical_map(&self, other: &Self) -> Result<Arc<RevMapping>> {
merge_categorical_map(self.get_rev_map(), other.get_rev_map())
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ mod ops;
pub mod stringcache;

pub use builder::*;
pub(crate) use merge::*;
pub(crate) use ops::{CategoricalTakeRandomGlobal, CategoricalTakeRandomLocal};

use super::*;
Expand Down
20 changes: 19 additions & 1 deletion polars/polars-core/src/chunked_array/ops/append.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,28 @@ impl Utf8Chunked {

#[doc(hidden)]
impl ListChunked {
pub fn append(&mut self, other: &Self) {
pub fn append(&mut self, other: &Self) -> Result<()> {
// todo! there should be a merge-dtype function, that goes all the way down.

#[cfg(feature = "dtype-categorical")]
use DataType::*;
#[cfg(feature = "dtype-categorical")]
if let (Categorical(Some(rev_map_l)), Categorical(Some(rev_map_r))) =
(self.inner_dtype(), other.inner_dtype())
{
if !rev_map_l.same_src(rev_map_r.as_ref()) {
let rev_map = merge_categorical_map(&rev_map_l, &rev_map_r)?;
self.field = Arc::new(Field::new(
self.name(),
DataType::List(Box::new(DataType::Categorical(Some(rev_map)))),
));
}
}

let len = self.len();
self.length += other.length;
new_chunks(&mut self.chunks, &other.chunks, len);
Ok(())
}
}
#[cfg(feature = "object")]
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-core/src/chunked_array/ops/chunkops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ impl<T: PolarsDataType> ChunkedArray<T> {
self.clone()
} else {
let chunks = inner_rechunk(&self.chunks);
ChunkedArray::from_chunks(self.name(), chunks)
self.copy_with_chunks(chunks, true)
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions polars/polars-core/src/chunked_array/ops/extend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,10 @@ impl BooleanChunked {

#[doc(hidden)]
impl ListChunked {
pub fn extend(&mut self, other: &Self) {
pub fn extend(&mut self, other: &Self) -> Result<()> {
// TODO! properly implement mutation
// this is harder because we don't know the inner type of the list
self.append(other);
self.append(other)
}
}

Expand Down
4 changes: 2 additions & 2 deletions polars/polars-core/src/chunked_array/ops/shift.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,10 @@ impl ChunkShiftFill<ListType, Option<&Series>> for ListChunked {
};

if periods < 0 {
slice.append(&fill);
slice.append(&fill).unwrap();
slice
} else {
fill.append(&slice);
fill.append(&slice).unwrap();
fill
}
}
Expand Down
6 changes: 2 additions & 4 deletions polars/polars-core/src/series/implementations/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ impl SeriesTrait for SeriesWrap<ListChunked> {

fn append(&mut self, other: &Series) -> Result<()> {
if self.0.dtype() == other.dtype() {
self.0.append(other.as_ref().as_ref());
Ok(())
self.0.append(other.as_ref().as_ref())
} else {
Err(PolarsError::SchemaMisMatch(
"cannot append Series; data types don't match".into(),
Expand All @@ -93,8 +92,7 @@ impl SeriesTrait for SeriesWrap<ListChunked> {
}
fn extend(&mut self, other: &Series) -> Result<()> {
if self.0.dtype() == other.dtype() {
self.0.extend(other.as_ref().as_ref());
Ok(())
self.0.extend(other.as_ref().as_ref())
} else {
Err(PolarsError::SchemaMisMatch(
"cannot extend Series; data types don't match".into(),
Expand Down
18 changes: 18 additions & 0 deletions polars/polars-lazy/src/tests/queries.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use polars_arrow::prelude::QuantileInterpolOptions;
use polars_core::frame::explode::MeltArgs;
use polars_core::series::ops::NullBehavior;
use polars_core::utils::{concat_df, concat_df_unchecked};
use polars_time::prelude::DateMethods;

use super::*;
Expand Down Expand Up @@ -2088,3 +2089,20 @@ fn test_partitioned_gb_ternary() -> Result<()> {

Ok(())
}

#[test]
fn test_foo() -> Result<()> {
let df = df![
"a" => ["a", "b"]
]?;

let df = df
.lazy()
.select([all().cast(DataType::Categorical(None)).list()])
.collect()?;

let mut out = concat_df(&[df.clone(), df.clone()])?;
dbg!(out.agg_chunks());

Ok(())
}
4 changes: 2 additions & 2 deletions py-polars/polars/internals/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -2290,7 +2290,7 @@ def columns(self, columns: Sequence[str]) -> None:
self._df.set_column_names(columns)

@property
def dtypes(self) -> list[type[DataType]]:
def dtypes(self) -> list[PolarsDataType]:
"""
Get dtypes of columns in DataFrame. Dtypes can also be found in column headers when printing the DataFrame.
Expand Down Expand Up @@ -2327,7 +2327,7 @@ def dtypes(self) -> list[type[DataType]]:
return self._df.dtypes()

@property
def schema(self) -> dict[str, type[DataType]]:
def schema(self) -> dict[str, PolarsDataType]:
"""
Get a dict[column name, DataType].
Expand Down
4 changes: 2 additions & 2 deletions py-polars/tests/test_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_init_dict() -> None:
# List of empty list/tuple
df = pl.DataFrame({"a": [[]], "b": [()]})
expected = {"a": pl.List(pl.Float64), "b": pl.List(pl.Float64)}
assert df.schema == expected # type: ignore[comparison-overlap]
assert df.schema == expected
assert df.rows() == [([], [])]

# Mixed dtypes
Expand Down Expand Up @@ -389,7 +389,7 @@ def test_init_only_columns() -> None:
assert df.shape == (0, 4)
assert df.frame_equal(truth, null_equal=True)
assert df.dtypes == [pl.Date, pl.UInt64, pl.Int8, pl.List]
assert df.schema["d"].inner == pl.UInt8 # type: ignore[attr-defined]
assert df.schema["d"].inner == pl.UInt8 # type: ignore[union-attr]

dfe = df.cleared()
assert (df.schema == dfe.schema) and (dfe.shape == df.shape)
Expand Down
4 changes: 2 additions & 2 deletions py-polars/tests/test_datelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def test_datetime_consistency() -> None:
pl.lit(dt).cast(pl.Datetime("ns")).alias("dt_ns"),
]
)
assert ddf.schema == { # type: ignore[comparison-overlap]
assert ddf.schema == {
"date": pl.Datetime("us"),
"dt": pl.Datetime("us"),
"dt_ms": pl.Datetime("ms"),
Expand Down Expand Up @@ -1109,7 +1109,7 @@ def test_datetime_instance_selection() -> None:
)
for tu in DTYPE_TEMPORAL_UNITS:
res = df.select(pl.col([pl.Datetime(tu)])).dtypes
assert res == [pl.Datetime(tu)] # type: ignore[comparison-overlap]
assert res == [pl.Datetime(tu)]
assert len(df.filter(pl.col(tu) == test_data[tu][0])) == 1


Expand Down
12 changes: 10 additions & 2 deletions py-polars/tests/test_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,13 @@ def test_dtype() -> None:
("dtm", pl.List(pl.Datetime)),
],
)
assert df.schema == { # type: ignore[comparison-overlap]
assert df.schema == {
"i": pl.List(pl.Int8),
"tm": pl.List(pl.Time),
"dt": pl.List(pl.Date),
"dtm": pl.List(pl.Datetime),
}
assert df.schema["i"].inner == pl.Int8 # type: ignore[attr-defined]
assert df.schema["i"].inner == pl.Int8 # type: ignore[union-attr]
assert df.rows() == [
(
[1, 2, 3],
Expand Down Expand Up @@ -447,3 +447,11 @@ def test_is_in_empty_list_4639() -> None:
assert df.with_columns(
[pl.lit(None).cast(pl.Int64).is_in(empty_list).alias("in_empty_list")]
).to_dict(False) == {"in_empty_list": [False]}


def test_inner_type_categorical_on_rechunk() -> None:
df = pl.DataFrame({"cats": ["foo", "bar"]}).select(
pl.col(pl.Utf8).cast(pl.Categorical).list()
)

assert pl.concat([df, df], rechunk=True).dtypes == [pl.List(pl.Categorical)]

0 comments on commit 9496a92

Please sign in to comment.