Skip to content

Commit

Permalink
feat(rust): binary search and rechunk in chunked gather (#11199)
Browse files Browse the repository at this point in the history
  • Loading branch information
orlp committed Sep 21, 2023
1 parent 431f85f commit c9ee1e2
Show file tree
Hide file tree
Showing 9 changed files with 289 additions and 165 deletions.
136 changes: 136 additions & 0 deletions crates/nano-arrow/src/bitmap/bitmask.rs
@@ -0,0 +1,136 @@
#[cfg(feature = "simd")]
use std::simd::ToBitMask;

#[cfg(feature = "simd")]
use num_traits::AsPrimitive;

use crate::bitmap::Bitmap;

// Loads a u64 from the given byteslice, as if it were padded with zeros.
fn load_padded_le_u64(bytes: &[u8]) -> u64 {
let len = bytes.len();
if len >= 8 {
return u64::from_le_bytes(bytes[0..8].try_into().unwrap());
}

if len >= 4 {
let lo = u32::from_le_bytes(bytes[0..4].try_into().unwrap());
let hi = u32::from_le_bytes(bytes[len - 4..len].try_into().unwrap());
return (lo as u64) | ((hi as u64) << (8 * (len - 4)));
}

if len == 0 {
return 0;
}

let lo = bytes[0] as u64;
let mid = (bytes[len / 2] as u64) << (8 * (len / 2));
let hi = (bytes[len - 1] as u64) << (8 * (len - 1));
lo | mid | hi
}

pub struct BitMask<'a> {
bytes: &'a [u8],
offset: usize,
len: usize,
}

impl<'a> BitMask<'a> {
pub fn from_bitmap(bitmap: &'a Bitmap) -> Self {
let (bytes, offset, len) = bitmap.as_slice();
// Check length so we can use unsafe access in our get.
assert!(bytes.len() * 8 >= len + offset);
Self { bytes, offset, len }
}

#[inline(always)]
pub fn len(&self) -> usize {
self.len
}

#[inline]
pub fn split_at(&self, idx: usize) -> (Self, Self) {
assert!(idx <= self.len);
unsafe { self.split_at_unchecked(idx) }
}

/// # Safety
/// The index must be in-bounds.
#[inline]
pub unsafe fn split_at_unchecked(&self, idx: usize) -> (Self, Self) {
debug_assert!(idx <= self.len);
let left = Self { len: idx, ..*self };
let right = Self {
len: self.len - idx,
offset: self.offset + idx,
..*self
};
(left, right)
}

#[cfg(feature = "simd")]
#[inline]
pub fn get_simd<T>(&self, idx: usize) -> T
where
T: ToBitMask,
<T as ToBitMask>::BitMask: Copy + 'static,
u64: AsPrimitive<<T as ToBitMask>::BitMask>,
{
// We don't support 64-lane masks because then we couldn't load our
// bitwise mask as a u64 and then do the byteshift on it.

let lanes = std::mem::size_of::<T::BitMask>() * 8;
assert!(lanes < 64);

let start_byte_idx = (self.offset + idx) / 8;
let byte_shift = (self.offset + idx) % 8;
if idx + lanes <= self.len {
// SAFETY: fast path, we know this is completely in-bounds.
let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });
T::from_bitmask((mask >> byte_shift).as_())
} else if idx < self.len {
// SAFETY: we know that at least the first byte is in-bounds.
// This is partially out of bounds, we have to do extra masking.
let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });
let num_out_of_bounds = idx + lanes - self.len;
let shifted = (mask << num_out_of_bounds) >> (num_out_of_bounds + byte_shift);
T::from_bitmask(shifted.as_())
} else {
T::from_bitmask((0u64).as_())
}
}

#[inline]
pub fn get_u32(&self, idx: usize) -> u32 {
let start_byte_idx = (self.offset + idx) / 8;
let byte_shift = (self.offset + idx) % 8;
if idx + 32 <= self.len {
// SAFETY: fast path, we know this is completely in-bounds.
let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });
(mask >> byte_shift) as u32
} else if idx < self.len {
// SAFETY: we know that at least the first byte is in-bounds.
// This is partially out of bounds, we have to do extra masking.
let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });
let num_out_of_bounds = idx + 32 - self.len;
let shifted = (mask << num_out_of_bounds) >> (num_out_of_bounds + byte_shift);
shifted as u32
} else {
0
}
}

#[inline]
pub fn get(&self, idx: usize) -> bool {
let byte_idx = (self.offset + idx) / 8;
let byte_shift = (self.offset + idx) % 8;

if idx < self.len {
// SAFETY: we know this is in-bounds.
let byte = unsafe { *self.bytes.get_unchecked(byte_idx) };
(byte >> byte_shift) & 1 == 1
} else {
false
}
}
}
2 changes: 2 additions & 0 deletions crates/nano-arrow/src/bitmap/mod.rs
Expand Up @@ -15,3 +15,5 @@ mod assign_ops;
pub use assign_ops::*;

pub mod utils;

pub mod bitmask;
110 changes: 5 additions & 105 deletions crates/polars-core/src/chunked_array/ops/aggregate/float_sum.rs
@@ -1,113 +1,13 @@
use std::ops::{Add, IndexMut};
#[cfg(feature = "simd")]
use std::simd::{Mask, Simd, SimdElement, ToBitMask};
use std::simd::{Mask, Simd, SimdElement};

use arrow::bitmap::bitmask::BitMask;
use arrow::bitmap::Bitmap;
#[cfg(feature = "simd")]
use num_traits::AsPrimitive;

const STRIPE: usize = 16;
const PAIRWISE_RECURSION_LIMIT: usize = 128;

// Load 8 bytes as little-endian into a u64, padding with zeros if it's too short.
#[cfg(feature = "simd")]
pub fn load_padded_le_u64(bytes: &[u8]) -> u64 {
let len = bytes.len();
if len >= 8 {
return u64::from_le_bytes(bytes[0..8].try_into().unwrap());
}

if len >= 4 {
let lo = u32::from_le_bytes(bytes[0..4].try_into().unwrap());
let hi = u32::from_le_bytes(bytes[len - 4..len].try_into().unwrap());
return (lo as u64) | ((hi as u64) << (8 * (len - 4)));
}

if len == 0 {
return 0;
}

let lo = bytes[0] as u64;
let mid = (bytes[len / 2] as u64) << (8 * (len / 2));
let hi = (bytes[len - 1] as u64) << (8 * (len - 1));
lo | mid | hi
}

struct BitMask<'a> {
bytes: &'a [u8],
offset: usize,
len: usize,
}

impl<'a> BitMask<'a> {
pub fn new(bitmap: &'a Bitmap) -> Self {
let (bytes, offset, len) = bitmap.as_slice();
// Check length so we can use unsafe access in our get.
assert!(bytes.len() * 8 >= len + offset);
Self { bytes, offset, len }
}

fn split_at(&self, idx: usize) -> (Self, Self) {
assert!(idx <= self.len);
unsafe { self.split_at_unchecked(idx) }
}

unsafe fn split_at_unchecked(&self, idx: usize) -> (Self, Self) {
debug_assert!(idx <= self.len);
let left = Self { len: idx, ..*self };
let right = Self {
len: self.len - idx,
offset: self.offset + idx,
..*self
};
(left, right)
}

#[cfg(feature = "simd")]
pub fn get_simd<T>(&self, idx: usize) -> T
where
T: ToBitMask,
<T as ToBitMask>::BitMask: Copy + 'static,
u64: AsPrimitive<<T as ToBitMask>::BitMask>,
{
// We don't support 64-lane masks because then we couldn't load our
// bitwise mask as a u64 and then do the byteshift on it.

let lanes = std::mem::size_of::<T::BitMask>() * 8;
assert!(lanes < 64);

let start_byte_idx = (self.offset + idx) / 8;
let byte_shift = (self.offset + idx) % 8;
if idx + lanes <= self.len {
// SAFETY: fast path, we know this is completely in-bounds.
let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });
T::from_bitmask((mask >> byte_shift).as_())
} else if idx < self.len {
// SAFETY: we know that at least the first byte is in-bounds.
// This is partially out of bounds, we have to do extra masking.
let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });
let num_out_of_bounds = idx + lanes - self.len;
let shifted = (mask << num_out_of_bounds) >> (num_out_of_bounds + byte_shift);
T::from_bitmask(shifted.as_())
} else {
T::from_bitmask((0u64).as_())
}
}

pub fn get(&self, idx: usize) -> bool {
let byte_idx = (self.offset + idx) / 8;
let byte_shift = (self.offset + idx) % 8;

if idx < self.len {
// SAFETY: we know this is in-bounds.
let byte = unsafe { *self.bytes.get_unchecked(byte_idx) };
(byte >> byte_shift) & 1 == 1
} else {
false
}
}
}

fn vector_horizontal_sum<V, T>(mut v: V) -> T
where
V: IndexMut<usize, Output = T>,
Expand Down Expand Up @@ -222,7 +122,7 @@ macro_rules! def_sum {
/// Also, f.len() == mask.len().
unsafe fn pairwise_sum_with_mask(f: &[$T], mask: BitMask<'_>) -> f64 {
debug_assert!(f.len() > 0 && f.len() % PAIRWISE_RECURSION_LIMIT == 0);
debug_assert!(f.len() == mask.len);
debug_assert!(f.len() == mask.len());

if let Ok(block) = f.try_into() {
return sum_block_vectorized_with_mask(block, mask) as f64;
Expand Down Expand Up @@ -253,8 +153,8 @@ macro_rules! def_sum {
}

pub fn sum_with_validity(f: &[$T], validity: &Bitmap) -> f64 {
let mask = BitMask::new(validity);
assert!(f.len() == mask.len);
let mask = BitMask::from_bitmap(validity);
assert!(f.len() == mask.len());

let remainder = f.len() % PAIRWISE_RECURSION_LIMIT;
let (rest, main) = f.split_at(remainder);
Expand Down

0 comments on commit c9ee1e2

Please sign in to comment.