Skip to content

Commit

Permalink
[Experimental] Create extension type in arrow that can hold any in-me…
Browse files Browse the repository at this point in the history
…mory data type (#1843)
  • Loading branch information
ritchie46 committed Nov 23, 2021
1 parent ee85fb9 commit 6e6f4d5
Show file tree
Hide file tree
Showing 25 changed files with 582 additions and 44 deletions.
2 changes: 2 additions & 0 deletions polars/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@ miri:
# some tests are also filtered, because miri cannot deal with the rayon threadpool
# Miri also reports UB in prettytable.rs, so we must toggle that feature off.
MIRIFLAGS="-Zmiri-disable-isolation" \
POLARS_ALLOW_EXTENSION=1 \
cargo miri test \
--no-default-features \
--features object \
-p polars-core \
-p polars-arrow \
--
Expand Down
5 changes: 3 additions & 2 deletions polars/polars-arrow/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ description = "Arrow interfaces for Polars DataFrame library"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
arrow = { package = "arrow2", git = "https://github.com/jorgecarleitao/arrow2", rev = "70562fac652a0dab08b4b7bf3d86d2d808ea98e6", default-features = false }
# arrow = { package = "arrow2", git = "https://github.com/ritchie46/arrow2", branch = "null", default-features = false }
# arrow = { package = "arrow2", git = "https://github.com/jorgecarleitao/arrow2", rev = "70562fac652a0dab08b4b7bf3d86d2d808ea98e6", default-features = false }
arrow = { package = "arrow2", git = "https://github.com/ritchie46/arrow2", default-features = false, features = ["compute"], branch = "fn_to" }
# arrow = { package = "arrow2", version = "0.7", default-features = false }
num = "^0.4"
thiserror = "^1.0"

[features]
strings = []
compute = ["arrow/compute"]
7 changes: 4 additions & 3 deletions polars/polars-arrow/src/array/mod.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
pub mod default_arrays;

use crate::utils::CustomIterTools;
use arrow::array::{Array, ArrayRef, BooleanArray, ListArray, PrimitiveArray, Utf8Array};
use arrow::bitmap::MutableBitmap;
use arrow::buffer::MutableBuffer;
use arrow::datatypes::DataType;
use arrow::types::{NativeType, NaturalDataType};
use std::sync::Arc;

use crate::utils::CustomIterTools;

pub mod default_arrays;

pub trait ValueSize {
/// Useful for a Utf8 or a List to get underlying value size.
/// During a rechunk this is handy
Expand Down
1 change: 1 addition & 0 deletions polars/polars-arrow/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pub mod array;
pub mod bit_util;
#[cfg(feature = "compute")]
pub mod compute;
pub mod error;
pub mod is_valid;
Expand Down
6 changes: 3 additions & 3 deletions polars/polars-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ docs-selection = [
[dependencies]
ahash = "0.7"
anyhow = "1.0"
arrow = { package = "arrow2", git = "https://github.com/jorgecarleitao/arrow2", rev = "70562fac652a0dab08b4b7bf3d86d2d808ea98e6", default-features = false, features = ["compute"] }
# arrow = { package = "arrow2", git = "https://github.com/ritchie46/arrow2", default-features = false, features = ["compute"], branch = "null" }
# arrow = { package = "arrow2", git = "https://github.com/jorgecarleitao/arrow2", rev = "70562fac652a0dab08b4b7bf3d86d2d808ea98e6", default-features = false, features = ["compute"] }
arrow = { package = "arrow2", git = "https://github.com/ritchie46/arrow2", default-features = false, features = ["compute"], branch = "fn_to" }
# arrow = { package = "arrow2", version = "0.7", default-features = false, features = ["compute"] }
chrono = { version = "0.4", optional = true }
comfy-table = { version = "4.0", optional = true }
Expand All @@ -143,7 +143,7 @@ lazy_static = "1.4"
ndarray = { version = "0.15", optional = true, default_features = false }
num = "^0.4"
num_cpus = "1.1"
polars-arrow = { version = "0.17.0", path = "../polars-arrow" }
polars-arrow = { version = "0.17.0", path = "../polars-arrow", features = ["compute"] }
prettytable-rs = { version = "0.8.0", optional = true }
rand = { version = "0.8", optional = true }
rand_distr = { version = "0.4", optional = true }
Expand Down
12 changes: 6 additions & 6 deletions polars/polars-core/src/chunked_array/categorical/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,19 +204,19 @@ mod test {
];
let ca = Utf8Chunked::new_from_opt_slice("a", slice);
let out = ca.cast(&DataType::Categorical)?;
let out = out.categorical().unwrap().clone();
assert_eq!(out.categorical_map.unwrap().len(), 2);
let mut out = out.categorical().unwrap().clone();
assert_eq!(out.categorical_map.take().unwrap().len(), 2);

// test the global branch
toggle_string_cache(true);
// empty global cache
let out = ca.cast(&DataType::Categorical)?;
let out = out.categorical().unwrap().clone();
assert_eq!(out.categorical_map.unwrap().len(), 2);
let mut out = out.categorical().unwrap().clone();
assert_eq!(out.categorical_map.take().unwrap().len(), 2);
// full global cache
let out = ca.cast(&DataType::Categorical)?;
let out = out.categorical().unwrap().clone();
assert_eq!(out.categorical_map.unwrap().len(), 2);
let mut out = out.categorical().unwrap().clone();
assert_eq!(out.categorical_map.take().unwrap().len(), 2);

// Check that we don't panic if we append two categorical arrays
// build under the same string cache
Expand Down
11 changes: 11 additions & 0 deletions polars/polars-core/src/chunked_array/drop.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
use crate::chunked_array::object::extension::drop::drop_list;
use crate::prelude::*;

impl<T> Drop for ChunkedArray<T> {
fn drop(&mut self) {
if matches!(self.dtype(), DataType::List(_)) {
// guarded by the type system
unsafe { drop_list(std::mem::transmute(self)) }
}
}
}
2 changes: 2 additions & 0 deletions polars/polars-core/src/chunked_array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ mod ndarray;
mod bitwise;
#[cfg(feature = "dtype-categorical")]
pub(crate) mod categorical;
#[cfg(feature = "object")]
mod drop;
pub(crate) mod list;
pub(crate) mod logical;
#[cfg(feature = "object")]
Expand Down
3 changes: 1 addition & 2 deletions polars/polars-core/src/chunked_array/object/builder.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use super::*;
use crate::prelude::*;
use crate::utils::get_iter_capacity;
use arrow::bitmap::{Bitmap, MutableBitmap};
use arrow::bitmap::MutableBitmap;
use std::marker::PhantomData;
use std::sync::Arc;

Expand Down
46 changes: 46 additions & 0 deletions polars/polars-core/src/chunked_array/object/extension/drop.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
use crate::chunked_array::object::extension::PolarsExtension;
use crate::prelude::*;

/// This will dereference a raw ptr when dropping the PolarsExtension, make sure that it's valid.
pub(crate) unsafe fn drop_list(ca: &ListChunked) {
let mut inner = ca.inner_dtype();
let mut nested_count = 0;

while let Some(a) = inner.inner_dtype() {
nested_count += 1;
inner = a.clone()
}

if matches!(inner, DataType::Object(_)) {
if nested_count != 0 {
panic!("multiple nested objects not yet supported")
}
// if empty the memory is leaked somewhere
assert!(!ca.chunks.is_empty());
for lst_arr in &ca.chunks {
// This list can be cloned, so we check the ref count before we drop
if let (ArrowDataType::LargeList(fld), 1) =
(lst_arr.data_type(), Arc::strong_count(lst_arr))
{
let dtype = fld.data_type();

assert!(matches!(dtype, ArrowDataType::Extension(_, _, _)));

// recreate the polars extension so that the content is dropped
let arr = lst_arr.as_any().downcast_ref::<LargeListArray>().unwrap();

let values = arr.values();

// The inner value also may be cloned, check the ref count
if Arc::strong_count(values) == 1 {
let arr = values
.as_any()
.downcast_ref::<FixedSizeBinaryArray>()
.unwrap()
.clone();
PolarsExtension::new(arr);
}
}
}
}
}
223 changes: 223 additions & 0 deletions polars/polars-core/src/chunked_array/object/extension/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
pub(crate) mod drop;
pub(crate) mod polars_extension;

use crate::{prelude::*, PROCESS_ID};
use arrow::array::{Array, FixedSizeBinaryArray};
use arrow::bitmap::MutableBitmap;
use arrow::buffer::{Buffer, MutableBuffer};
use polars_extension::PolarsExtension;
use std::mem;

/// Invariants
/// `ptr` must point to start a `T` allocation
/// `n_t_vals` must reprecent the correct number of `T` values in that allocation
unsafe fn create_drop<T: Sized>(mut ptr: *const u8, n_t_vals: usize) -> Box<dyn FnMut()> {
Box::new(move || {
let t_size = std::mem::size_of::<T>() as isize;
for _ in 0..n_t_vals {
let _ = std::ptr::read_unaligned(ptr as *const T);
ptr = ptr.offset(t_size as isize)
}
})
}

struct ExtensionSentinel {
drop_fn: Option<Box<dyn FnMut()>>,
pub(crate) to_series_fn: Option<Box<dyn Fn(&FixedSizeBinaryArray) -> Series>>,
}

impl Drop for ExtensionSentinel {
fn drop(&mut self) {
let mut drop_fn = self.drop_fn.take().unwrap();
drop_fn()
}
}

// https://stackoverflow.com/questions/28127165/how-to-convert-struct-to-u8d
// not entirely sure if padding bytes in T are intialized or not.
unsafe fn any_as_u8_slice<T: Sized>(p: &T) -> &[u8] {
std::slice::from_raw_parts((p as *const T) as *const u8, std::mem::size_of::<T>())
}

/// Create an extension Array that can be sent to arrow and (once wrapped in `[PolarsExtension]` will
/// also call drop on `T`, when the array is dropped.
pub(crate) fn create_extension<
I: IntoIterator<Item = Option<T>> + TrustedLen,
T: Sized + Default,
>(
iter: I,
) -> PolarsExtension {
let env = "POLARS_ALLOW_EXTENSION";
std::env::var(env).unwrap_or_else(|_| {
panic!(
"env var: {} must be set to allow extension types to be created",
env
)
});
let t_size = std::mem::size_of::<T>();
let t_alignment = std::mem::align_of::<T>();
let n_t_vals = iter.size_hint().1.unwrap();

let mut buf = MutableBuffer::with_capacity(n_t_vals * t_size);
let mut validity = MutableBitmap::with_capacity(n_t_vals);

// when we transmute from &[u8] to T, T must be aligned correctly,
// so we pad with bytes until the alignment matches
let n_padding = (buf.as_ptr() as usize) % t_alignment;
buf.extend_constant(n_padding, 0);

// transmute T as bytes and copy in buffer
for opt_t in iter.into_iter() {
match opt_t {
Some(t) => {
unsafe {
buf.extend_from_slice(any_as_u8_slice(&t));
// Safety: we allocated upfront
validity.push_unchecked(true)
}
mem::forget(t);
}
None => {
unsafe {
buf.extend_from_slice(any_as_u8_slice(&T::default()));
// Safety: we allocated upfront
validity.push_unchecked(false)
}
}
}
}

// we slice the buffer because we want to ignore the padding bytes from here
// they can be forgotten
let buf: Buffer<u8> = buf.into();
let len = buf.len() - n_padding;
let buf = buf.slice(n_padding, len);

// ptr to start of T, not to start of padding
let ptr = buf.as_slice().as_ptr();

// Safety:
// ptr and t are correct
let drop_fn = unsafe { create_drop::<T>(ptr, n_t_vals) };
let et = Box::new(ExtensionSentinel {
drop_fn: Some(drop_fn),
to_series_fn: None,
});
let et_ptr = &*et as *const ExtensionSentinel;
std::mem::forget(et);

let metadata = format!("{};{}", *PROCESS_ID, et_ptr as usize);

let physical_type = ArrowDataType::FixedSizeBinary(t_size);
let extension_type = ArrowDataType::Extension(
"POLARS_EXTENSION_TYPE".into(),
physical_type.into(),
Some(metadata),
);
let validity = if validity.null_count() > 0 {
Some(validity.into())
} else {
None
};

let array = FixedSizeBinaryArray::from_data(extension_type, buf, validity);

// Safety:
// we just heap allocated the ExtensionSentinel, so its alive.
unsafe { PolarsExtension::new(array) }
}

#[cfg(test)]
mod test {
use super::*;
use std::fmt::{Display, Formatter};

#[derive(Clone, Debug, Default, Eq, Hash, PartialEq)]
struct Foo {
pub a: i32,
pub b: u8,
pub other_heap: String,
}

impl Display for Foo {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}

impl PolarsObject for Foo {
fn type_name() -> &'static str {
"object"
}
}

#[test]
fn test_create_extension() {
std::env::set_var("POLARS_ALLOW_EXTENSION", "1");
// Run this under MIRI.
let foo = Foo {
a: 1,
b: 1,
other_heap: "foo".into(),
};
let foo2 = Foo {
a: 1,
b: 1,
other_heap: "bar".into(),
};

let vals = vec![Some(foo), Some(foo2)];
create_extension(vals.into_iter());
}

#[test]
fn test_extension_to_list() {
std::env::set_var("POLARS_ALLOW_EXTENSION", "1");
let foo1 = Foo {
a: 1,
b: 1,
other_heap: "foo".into(),
};
let foo2 = Foo {
a: 1,
b: 1,
other_heap: "bar".into(),
};

let values = &[Some(foo1), None, Some(foo2), None];
let ca = ObjectChunked::new_from_opt_slice("", values);

let groups = vec![(0u32, vec![0u32, 1]), (2, vec![2]), (3, vec![3])];
let out = ca.agg_list(&groups).unwrap();
assert!(matches!(out.dtype(), DataType::List(_)));
assert_eq!(out.len(), groups.len());
}

#[test]
fn test_extension_to_list_explode() {
std::env::set_var("POLARS_ALLOW_EXTENSION", "1");
let foo1 = Foo {
a: 1,
b: 1,
other_heap: "foo".into(),
};
let foo2 = Foo {
a: 1,
b: 1,
other_heap: "bar".into(),
};

let values = &[Some(foo1.clone()), None, Some(foo2.clone()), None];
let ca = ObjectChunked::new_from_opt_slice("", values);

let groups = vec![(0u32, vec![0u32, 1]), (2, vec![2]), (3, vec![3])];
let out = ca.agg_list(&groups).unwrap();
let a = out.explode().unwrap();

let ca_foo = a.as_any().downcast_ref::<ObjectChunked<Foo>>().unwrap();
assert_eq!(ca_foo.get(0).unwrap(), &foo1);
assert_eq!(ca_foo.get(1), None);
assert_eq!(ca_foo.get(2).unwrap(), &foo2);
assert_eq!(ca_foo.get(3), None);
}
}

0 comments on commit 6e6f4d5

Please sign in to comment.