-
-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Experimental] Create extension type in arrow that can hold any in-me…
…mory data type (#1843)
- Loading branch information
Showing
25 changed files
with
582 additions
and
44 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
pub mod array; | ||
pub mod bit_util; | ||
#[cfg(feature = "compute")] | ||
pub mod compute; | ||
pub mod error; | ||
pub mod is_valid; | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
use crate::chunked_array::object::extension::drop::drop_list; | ||
use crate::prelude::*; | ||
|
||
impl<T> Drop for ChunkedArray<T> { | ||
fn drop(&mut self) { | ||
if matches!(self.dtype(), DataType::List(_)) { | ||
// guarded by the type system | ||
unsafe { drop_list(std::mem::transmute(self)) } | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
46 changes: 46 additions & 0 deletions
46
polars/polars-core/src/chunked_array/object/extension/drop.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
use crate::chunked_array::object::extension::PolarsExtension; | ||
use crate::prelude::*; | ||
|
||
/// This will dereference a raw ptr when dropping the PolarsExtension, make sure that it's valid. | ||
pub(crate) unsafe fn drop_list(ca: &ListChunked) { | ||
let mut inner = ca.inner_dtype(); | ||
let mut nested_count = 0; | ||
|
||
while let Some(a) = inner.inner_dtype() { | ||
nested_count += 1; | ||
inner = a.clone() | ||
} | ||
|
||
if matches!(inner, DataType::Object(_)) { | ||
if nested_count != 0 { | ||
panic!("multiple nested objects not yet supported") | ||
} | ||
// if empty the memory is leaked somewhere | ||
assert!(!ca.chunks.is_empty()); | ||
for lst_arr in &ca.chunks { | ||
// This list can be cloned, so we check the ref count before we drop | ||
if let (ArrowDataType::LargeList(fld), 1) = | ||
(lst_arr.data_type(), Arc::strong_count(lst_arr)) | ||
{ | ||
let dtype = fld.data_type(); | ||
|
||
assert!(matches!(dtype, ArrowDataType::Extension(_, _, _))); | ||
|
||
// recreate the polars extension so that the content is dropped | ||
let arr = lst_arr.as_any().downcast_ref::<LargeListArray>().unwrap(); | ||
|
||
let values = arr.values(); | ||
|
||
// The inner value also may be cloned, check the ref count | ||
if Arc::strong_count(values) == 1 { | ||
let arr = values | ||
.as_any() | ||
.downcast_ref::<FixedSizeBinaryArray>() | ||
.unwrap() | ||
.clone(); | ||
PolarsExtension::new(arr); | ||
} | ||
} | ||
} | ||
} | ||
} |
223 changes: 223 additions & 0 deletions
223
polars/polars-core/src/chunked_array/object/extension/mod.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,223 @@ | ||
pub(crate) mod drop; | ||
pub(crate) mod polars_extension; | ||
|
||
use crate::{prelude::*, PROCESS_ID}; | ||
use arrow::array::{Array, FixedSizeBinaryArray}; | ||
use arrow::bitmap::MutableBitmap; | ||
use arrow::buffer::{Buffer, MutableBuffer}; | ||
use polars_extension::PolarsExtension; | ||
use std::mem; | ||
|
||
/// Invariants | ||
/// `ptr` must point to start a `T` allocation | ||
/// `n_t_vals` must reprecent the correct number of `T` values in that allocation | ||
unsafe fn create_drop<T: Sized>(mut ptr: *const u8, n_t_vals: usize) -> Box<dyn FnMut()> { | ||
Box::new(move || { | ||
let t_size = std::mem::size_of::<T>() as isize; | ||
for _ in 0..n_t_vals { | ||
let _ = std::ptr::read_unaligned(ptr as *const T); | ||
ptr = ptr.offset(t_size as isize) | ||
} | ||
}) | ||
} | ||
|
||
struct ExtensionSentinel { | ||
drop_fn: Option<Box<dyn FnMut()>>, | ||
pub(crate) to_series_fn: Option<Box<dyn Fn(&FixedSizeBinaryArray) -> Series>>, | ||
} | ||
|
||
impl Drop for ExtensionSentinel { | ||
fn drop(&mut self) { | ||
let mut drop_fn = self.drop_fn.take().unwrap(); | ||
drop_fn() | ||
} | ||
} | ||
|
||
// https://stackoverflow.com/questions/28127165/how-to-convert-struct-to-u8d | ||
// not entirely sure if padding bytes in T are intialized or not. | ||
unsafe fn any_as_u8_slice<T: Sized>(p: &T) -> &[u8] { | ||
std::slice::from_raw_parts((p as *const T) as *const u8, std::mem::size_of::<T>()) | ||
} | ||
|
||
/// Create an extension Array that can be sent to arrow and (once wrapped in `[PolarsExtension]` will | ||
/// also call drop on `T`, when the array is dropped. | ||
pub(crate) fn create_extension< | ||
I: IntoIterator<Item = Option<T>> + TrustedLen, | ||
T: Sized + Default, | ||
>( | ||
iter: I, | ||
) -> PolarsExtension { | ||
let env = "POLARS_ALLOW_EXTENSION"; | ||
std::env::var(env).unwrap_or_else(|_| { | ||
panic!( | ||
"env var: {} must be set to allow extension types to be created", | ||
env | ||
) | ||
}); | ||
let t_size = std::mem::size_of::<T>(); | ||
let t_alignment = std::mem::align_of::<T>(); | ||
let n_t_vals = iter.size_hint().1.unwrap(); | ||
|
||
let mut buf = MutableBuffer::with_capacity(n_t_vals * t_size); | ||
let mut validity = MutableBitmap::with_capacity(n_t_vals); | ||
|
||
// when we transmute from &[u8] to T, T must be aligned correctly, | ||
// so we pad with bytes until the alignment matches | ||
let n_padding = (buf.as_ptr() as usize) % t_alignment; | ||
buf.extend_constant(n_padding, 0); | ||
|
||
// transmute T as bytes and copy in buffer | ||
for opt_t in iter.into_iter() { | ||
match opt_t { | ||
Some(t) => { | ||
unsafe { | ||
buf.extend_from_slice(any_as_u8_slice(&t)); | ||
// Safety: we allocated upfront | ||
validity.push_unchecked(true) | ||
} | ||
mem::forget(t); | ||
} | ||
None => { | ||
unsafe { | ||
buf.extend_from_slice(any_as_u8_slice(&T::default())); | ||
// Safety: we allocated upfront | ||
validity.push_unchecked(false) | ||
} | ||
} | ||
} | ||
} | ||
|
||
// we slice the buffer because we want to ignore the padding bytes from here | ||
// they can be forgotten | ||
let buf: Buffer<u8> = buf.into(); | ||
let len = buf.len() - n_padding; | ||
let buf = buf.slice(n_padding, len); | ||
|
||
// ptr to start of T, not to start of padding | ||
let ptr = buf.as_slice().as_ptr(); | ||
|
||
// Safety: | ||
// ptr and t are correct | ||
let drop_fn = unsafe { create_drop::<T>(ptr, n_t_vals) }; | ||
let et = Box::new(ExtensionSentinel { | ||
drop_fn: Some(drop_fn), | ||
to_series_fn: None, | ||
}); | ||
let et_ptr = &*et as *const ExtensionSentinel; | ||
std::mem::forget(et); | ||
|
||
let metadata = format!("{};{}", *PROCESS_ID, et_ptr as usize); | ||
|
||
let physical_type = ArrowDataType::FixedSizeBinary(t_size); | ||
let extension_type = ArrowDataType::Extension( | ||
"POLARS_EXTENSION_TYPE".into(), | ||
physical_type.into(), | ||
Some(metadata), | ||
); | ||
let validity = if validity.null_count() > 0 { | ||
Some(validity.into()) | ||
} else { | ||
None | ||
}; | ||
|
||
let array = FixedSizeBinaryArray::from_data(extension_type, buf, validity); | ||
|
||
// Safety: | ||
// we just heap allocated the ExtensionSentinel, so its alive. | ||
unsafe { PolarsExtension::new(array) } | ||
} | ||
|
||
#[cfg(test)] | ||
mod test { | ||
use super::*; | ||
use std::fmt::{Display, Formatter}; | ||
|
||
#[derive(Clone, Debug, Default, Eq, Hash, PartialEq)] | ||
struct Foo { | ||
pub a: i32, | ||
pub b: u8, | ||
pub other_heap: String, | ||
} | ||
|
||
impl Display for Foo { | ||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { | ||
write!(f, "{:?}", self) | ||
} | ||
} | ||
|
||
impl PolarsObject for Foo { | ||
fn type_name() -> &'static str { | ||
"object" | ||
} | ||
} | ||
|
||
#[test] | ||
fn test_create_extension() { | ||
std::env::set_var("POLARS_ALLOW_EXTENSION", "1"); | ||
// Run this under MIRI. | ||
let foo = Foo { | ||
a: 1, | ||
b: 1, | ||
other_heap: "foo".into(), | ||
}; | ||
let foo2 = Foo { | ||
a: 1, | ||
b: 1, | ||
other_heap: "bar".into(), | ||
}; | ||
|
||
let vals = vec![Some(foo), Some(foo2)]; | ||
create_extension(vals.into_iter()); | ||
} | ||
|
||
#[test] | ||
fn test_extension_to_list() { | ||
std::env::set_var("POLARS_ALLOW_EXTENSION", "1"); | ||
let foo1 = Foo { | ||
a: 1, | ||
b: 1, | ||
other_heap: "foo".into(), | ||
}; | ||
let foo2 = Foo { | ||
a: 1, | ||
b: 1, | ||
other_heap: "bar".into(), | ||
}; | ||
|
||
let values = &[Some(foo1), None, Some(foo2), None]; | ||
let ca = ObjectChunked::new_from_opt_slice("", values); | ||
|
||
let groups = vec![(0u32, vec![0u32, 1]), (2, vec![2]), (3, vec![3])]; | ||
let out = ca.agg_list(&groups).unwrap(); | ||
assert!(matches!(out.dtype(), DataType::List(_))); | ||
assert_eq!(out.len(), groups.len()); | ||
} | ||
|
||
#[test] | ||
fn test_extension_to_list_explode() { | ||
std::env::set_var("POLARS_ALLOW_EXTENSION", "1"); | ||
let foo1 = Foo { | ||
a: 1, | ||
b: 1, | ||
other_heap: "foo".into(), | ||
}; | ||
let foo2 = Foo { | ||
a: 1, | ||
b: 1, | ||
other_heap: "bar".into(), | ||
}; | ||
|
||
let values = &[Some(foo1.clone()), None, Some(foo2.clone()), None]; | ||
let ca = ObjectChunked::new_from_opt_slice("", values); | ||
|
||
let groups = vec![(0u32, vec![0u32, 1]), (2, vec![2]), (3, vec![3])]; | ||
let out = ca.agg_list(&groups).unwrap(); | ||
let a = out.explode().unwrap(); | ||
|
||
let ca_foo = a.as_any().downcast_ref::<ObjectChunked<Foo>>().unwrap(); | ||
assert_eq!(ca_foo.get(0).unwrap(), &foo1); | ||
assert_eq!(ca_foo.get(1), None); | ||
assert_eq!(ca_foo.get(2).unwrap(), &foo2); | ||
assert_eq!(ca_foo.get(3), None); | ||
} | ||
} |
Oops, something went wrong.