Skip to content

Commit

Permalink
Fixed failing to accept dictionary full of nulls (jorgecarleitao#1312)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Dec 12, 2022
1 parent 1fcfd7c commit ecc1497
Show file tree
Hide file tree
Showing 3 changed files with 191 additions and 0 deletions.
35 changes: 35 additions & 0 deletions src/array/dictionary/mod.rs
Expand Up @@ -17,6 +17,9 @@ pub(super) mod fmt;
mod iterator;
mod mutable;
use crate::array::specification::check_indexes_unchecked;
mod typed_iterator;

use crate::array::dictionary::typed_iterator::{DictValue, DictionaryValuesIterTyped};
pub use iterator::*;
pub use mutable::*;

Expand Down Expand Up @@ -237,6 +240,38 @@ impl<K: DictionaryKey> DictionaryArray<K> {
DictionaryValuesIter::new(self)
}

/// Returns an iterator over the the values [`V::IterValue`].
///
/// # Panics
///
/// Panics if the keys of this [`DictionaryArray`] have any null types.
/// If they do [`DictionaryArray::iter_typed`] should be called
pub fn values_iter_typed<V: DictValue>(
&self,
) -> Result<DictionaryValuesIterTyped<K, V>, Error> {
let keys = &self.keys;
assert_eq!(keys.null_count(), 0);
let values = self.values.as_ref();
let values = V::downcast_values(values)?;
Ok(unsafe { DictionaryValuesIterTyped::new(keys, values) })
}

/// Returns an iterator over the the optional values of [`Option<V::IterValue>`].
///
/// # Panics
///
/// This function panics if the `values` array
pub fn iter_typed<V: DictValue>(
&self,
) -> Result<ZipValidity<V::IterValue<'_>, DictionaryValuesIterTyped<K, V>, BitmapIter>, Error>
{
let keys = &self.keys;
let values = self.values.as_ref();
let values = V::downcast_values(values)?;
let values_iter = unsafe { DictionaryValuesIterTyped::new(keys, values) };
Ok(ZipValidity::new_with_validity(values_iter, self.validity()))
}

/// Returns the [`DataType`] of this [`DictionaryArray`]
#[inline]
pub fn data_type(&self) -> &DataType {
Expand Down
111 changes: 111 additions & 0 deletions src/array/dictionary/typed_iterator.rs
@@ -0,0 +1,111 @@
use crate::array::{Array, PrimitiveArray, Utf8Array};
use crate::error::{Error, Result};
use crate::trusted_len::TrustedLen;
use crate::types::Offset;

use super::DictionaryKey;

pub trait DictValue {
type IterValue<'this>
where
Self: 'this;

/// # Safety
/// Will not do any bound checks but must check validity.
unsafe fn get_unchecked(&self, item: usize) -> Self::IterValue<'_>;

/// Take a [`dyn Array`] an try to downcast it to the type of `DictValue`.
fn downcast_values(array: &dyn Array) -> Result<&Self>
where
Self: Sized;
}

impl<O: Offset> DictValue for Utf8Array<O> {
type IterValue<'a> = &'a str;

unsafe fn get_unchecked(&self, item: usize) -> Self::IterValue<'_> {
self.value_unchecked(item)
}

fn downcast_values(array: &dyn Array) -> Result<&Self>
where
Self: Sized,
{
array
.as_any()
.downcast_ref::<Self>()
.ok_or(Error::InvalidArgumentError(
"could not convert array to dictionary value".into(),
))
.map(|arr| {
assert_eq!(
arr.null_count(),
0,
"null values in values not supported in iteration"
);
arr
})
}
}

/// Iterator of values of an `ListArray`.
pub struct DictionaryValuesIterTyped<'a, K: DictionaryKey, V: DictValue> {
keys: &'a PrimitiveArray<K>,
values: &'a V,
index: usize,
end: usize,
}

impl<'a, K: DictionaryKey, V: DictValue> DictionaryValuesIterTyped<'a, K, V> {
pub(super) unsafe fn new(keys: &'a PrimitiveArray<K>, values: &'a V) -> Self {
Self {
keys,
values,
index: 0,
end: keys.len(),
}
}
}

impl<'a, K: DictionaryKey, V: DictValue> Iterator for DictionaryValuesIterTyped<'a, K, V> {
type Item = V::IterValue<'a>;

#[inline]
fn next(&mut self) -> Option<Self::Item> {
if self.index == self.end {
return None;
}
let old = self.index;
self.index += 1;
unsafe {
let key = self.keys.value_unchecked(old);
let idx = key.as_usize();
Some(self.values.get_unchecked(idx))
}
}

#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
(self.end - self.index, Some(self.end - self.index))
}
}

unsafe impl<'a, K: DictionaryKey, V: DictValue> TrustedLen for DictionaryValuesIterTyped<'a, K, V> {}

impl<'a, K: DictionaryKey, V: DictValue> DoubleEndedIterator
for DictionaryValuesIterTyped<'a, K, V>
{
#[inline]
fn next_back(&mut self) -> Option<Self::Item> {
if self.index == self.end {
None
} else {
self.end -= 1;
unsafe {
let key = self.keys.value_unchecked(self.end);
let idx = key.as_usize();
Some(self.values.get_unchecked(idx))
}
}
}
}
45 changes: 45 additions & 0 deletions tests/it/array/dictionary/mod.rs
Expand Up @@ -165,3 +165,48 @@ fn keys_values_iter() {

assert_eq!(array.keys_values_iter().collect::<Vec<_>>(), vec![1, 0]);
}

#[test]
fn iter_values_typed() {
let values = Utf8Array::<i32>::from_slice(["a", "aa"]);
let array =
DictionaryArray::try_from_keys(PrimitiveArray::from_vec(vec![1, 0, 0]), values.boxed())
.unwrap();

let mut iter = array.values_iter_typed::<Utf8Array<i32>>().unwrap();
assert_eq!(iter.size_hint(), (3, Some(3)));
assert_eq!(iter.collect::<Vec<_>>(), vec!["aa", "a", "a"]);

let mut iter = array.iter_typed::<Utf8Array<i32>>().unwrap();
assert_eq!(iter.size_hint(), (3, Some(3)));
assert_eq!(
iter.collect::<Vec<_>>(),
vec![Some("aa"), Some("a"), Some("a")]
);
}

#[test]
#[should_panic]
fn iter_values_typed_panic() {
let values = Utf8Array::<i32>::from_iter([Some("a"), Some("aa"), None]);
let array =
DictionaryArray::try_from_keys(PrimitiveArray::from_vec(vec![1, 0, 0]), values.boxed())
.unwrap();

// should not be iterating values
let mut iter = array.values_iter_typed::<Utf8Array<i32>>().unwrap();
let _ = iter.collect::<Vec<_>>();
}

#[test]
#[should_panic]
fn iter_values_typed_panic_2() {
let values = Utf8Array::<i32>::from_iter([Some("a"), Some("aa"), None]);
let array =
DictionaryArray::try_from_keys(PrimitiveArray::from_vec(vec![1, 0, 0]), values.boxed())
.unwrap();

// should not be iterating values
let mut iter = array.iter_typed::<Utf8Array<i32>>().unwrap();
let _ = iter.collect::<Vec<_>>();
}

0 comments on commit ecc1497

Please sign in to comment.