Skip to content

Commit

Permalink
inplace arithmetic (#3709)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jun 16, 2022
1 parent 30969ae commit 74081ac
Show file tree
Hide file tree
Showing 13 changed files with 333 additions and 51 deletions.
4 changes: 2 additions & 2 deletions polars/polars-arrow/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ description = "Arrow interfaces for Polars DataFrame library"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
arrow = { package = "arrow2", git = "https://github.com/jorgecarleitao/arrow2", rev = "39db6fb7514364bfea08d594793b23e1ed5a7def", features = ["compute_concatenate"], default-features = false }
# arrow = { package = "arrow2", git = "https://github.com/jorgecarleitao/arrow2", rev = "39db6fb7514364bfea08d594793b23e1ed5a7def", features = ["compute_concatenate"], default-features = false }
# arrow = { package = "arrow2", path = "../../../arrow2", features = ["compute_concatenate"], default-features = false }
# arrow = { package = "arrow2", git = "https://github.com/ritchie46/arrow2", branch = "count_shared", features = ["compute_concatenate"], default-features = false }
arrow = { package = "arrow2", git = "https://github.com/ritchie46/arrow2", branch = "arity_assign", features = ["compute_concatenate"], default-features = false }
# arrow = { package = "arrow2", version = "0.12", default-features = false, features = ["compute_concatenate"] }
hashbrown = "0.12"
num = "^0.4"
Expand Down
10 changes: 5 additions & 5 deletions polars/polars-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ default = ["docs", "temporal", "private"]
lazy = ["sort_multiple"]

# ~40% faster collect, needed until trustedlength iter stabilizes
# more fast paths
# more fast paths, slower compilation
performant = []

# extra utilities for Utf8Chunked
Expand Down Expand Up @@ -171,11 +171,11 @@ thiserror = "^1.0"

[dependencies.arrow]
package = "arrow2"
git = "https://github.com/jorgecarleitao/arrow2"
# git = "https://github.com/ritchie46/arrow2"
rev = "39db6fb7514364bfea08d594793b23e1ed5a7def"
# git = "https://github.com/jorgecarleitao/arrow2"
git = "https://github.com/ritchie46/arrow2"
# rev = "39db6fb7514364bfea08d594793b23e1ed5a7def"
# path = "../../../arrow2"
# branch = "count_shared"
branch = "arity_assign"
# version = "0.12"
default-features = false
features = [
Expand Down
111 changes: 99 additions & 12 deletions polars/polars-core/src/chunked_array/arithmetic.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
//! Implementations of arithmetic operations on ChunkedArray's.
use crate::prelude::*;
use crate::utils::align_chunks_binary;
use crate::utils::{align_chunks_binary, align_chunks_binary_owned};
use arrow::array::PrimitiveArray;
use arrow::{compute, compute::arithmetics::basic};
use arrow::{
compute,
compute::{arithmetics::basic, arity_assign},
};
use num::{Num, NumCast, ToPrimitive};
use std::borrow::Cow;
use std::ops::{Add, Div, Mul, Rem, Sub};
Expand Down Expand Up @@ -93,6 +96,52 @@ where
ca
}

/// This assigns to the owned buffer if the ref count is 1
fn arithmetic_helper_owned<T, Kernel, F>(
mut lhs: ChunkedArray<T>,
mut rhs: ChunkedArray<T>,
kernel: Kernel,
operation: F,
) -> ChunkedArray<T>
where
T: PolarsNumericType,
Kernel: Fn(&mut PrimitiveArray<T::Native>, &mut PrimitiveArray<T::Native>),
F: Fn(T::Native, T::Native) -> T::Native,
{
let ca = match (lhs.len(), rhs.len()) {
(a, b) if a == b => {
let (mut lhs, mut rhs) = align_chunks_binary_owned(lhs, rhs);
lhs.downcast_iter_mut()
.zip(rhs.downcast_iter_mut())
.for_each(|(lhs, rhs)| kernel(lhs, rhs));
lhs
}
// broadcast right path
(_, 1) => {
let opt_rhs = rhs.get(0);
match opt_rhs {
None => ChunkedArray::full_null(lhs.name(), lhs.len()),
Some(rhs) => {
lhs.apply_mut(|lhs| operation(lhs, rhs));
lhs
}
}
}
(1, _) => {
let opt_lhs = lhs.get(0);
match opt_lhs {
None => ChunkedArray::full_null(lhs.name(), rhs.len()),
Some(lhs) => {
rhs.apply_mut(|rhs| operation(lhs, rhs));
rhs
}
}
}
_ => panic!("Cannot apply operation on arrays of different lengths"),
};
ca
}

// Operands on ChunkedArray & ChunkedArray

impl<T> Add for &ChunkedArray<T>
Expand Down Expand Up @@ -157,7 +206,12 @@ where
type Output = Self;

fn add(self, rhs: Self) -> Self::Output {
(&self).add(&rhs)
arithmetic_helper_owned(
self,
rhs,
|a, b| arity_assign::binary(a, b, |a, b| a + b),
|lhs, rhs| lhs + rhs,
)
}
}

Expand All @@ -168,7 +222,12 @@ where
type Output = Self;

fn div(self, rhs: Self) -> Self::Output {
(&self).div(&rhs)
arithmetic_helper_owned(
self,
rhs,
|a, b| arity_assign::binary(a, b, |a, b| a / b),
|lhs, rhs| lhs / rhs,
)
}
}

Expand All @@ -179,7 +238,12 @@ where
type Output = Self;

fn mul(self, rhs: Self) -> Self::Output {
(&self).mul(&rhs)
arithmetic_helper_owned(
self,
rhs,
|a, b| arity_assign::binary(a, b, |a, b| a * b),
|lhs, rhs| lhs * rhs,
)
}
}

Expand All @@ -190,7 +254,12 @@ where
type Output = Self;

fn sub(self, rhs: Self) -> Self::Output {
(&self).sub(&rhs)
arithmetic_helper_owned(
self,
rhs,
|a, b| arity_assign::binary(a, b, |a, b| a - b),
|lhs, rhs| lhs - rhs,
)
}
}

Expand Down Expand Up @@ -279,8 +348,14 @@ where
{
type Output = ChunkedArray<T>;

fn add(self, rhs: N) -> Self::Output {
(&self).add(rhs)
fn add(mut self, rhs: N) -> Self::Output {
if std::env::var("ASSIGN").is_ok() {
let adder: T::Native = NumCast::from(rhs).unwrap();
self.apply_mut(|val| val + adder);
self
} else {
(&self).add(rhs)
}
}
}

Expand All @@ -291,8 +366,14 @@ where
{
type Output = ChunkedArray<T>;

fn sub(self, rhs: N) -> Self::Output {
(&self).sub(rhs)
fn sub(mut self, rhs: N) -> Self::Output {
if std::env::var("ASSIGN").is_ok() {
let subber: T::Native = NumCast::from(rhs).unwrap();
self.apply_mut(|val| val - subber);
self
} else {
(&self).sub(rhs)
}
}
}

Expand All @@ -315,8 +396,14 @@ where
{
type Output = ChunkedArray<T>;

fn mul(self, rhs: N) -> Self::Output {
(&self).mul(rhs)
fn mul(mut self, rhs: N) -> Self::Output {
if std::env::var("ASSIGN").is_ok() {
let multiplier: T::Native = NumCast::from(rhs).unwrap();
self.apply_mut(|val| val * multiplier);
self
} else {
(&self).mul(rhs)
}
}
}

Expand Down
10 changes: 10 additions & 0 deletions polars/polars-core/src/chunked_array/ops/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,16 @@ impl<T: PolarsNumericType> ChunkedArray<T> {
}
}

impl<T: PolarsNumericType> ChunkedArray<T> {
pub(crate) fn apply_mut<F>(&mut self, f: F)
where
F: Fn(T::Native) -> T::Native + Copy,
{
self.downcast_iter_mut()
.for_each(|arr| arrow::compute::arity_assign::unary(arr, f));
}
}

impl<'a, T> ChunkApply<'a, T::Native, T::Native> for ChunkedArray<T>
where
T: PolarsNumericType,
Expand Down
12 changes: 12 additions & 0 deletions polars/polars-core/src/chunked_array/ops/downcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,18 @@ where
unsafe { &*(arr as *const dyn Array as *const PrimitiveArray<T::Native>) }
})
}

pub(crate) fn downcast_iter_mut(
&mut self,
) -> impl Iterator<Item = &mut PrimitiveArray<T::Native>> + DoubleEndedIterator {
self.chunks.iter_mut().map(|arr| {
// Safety:
// This should be the array type in PolarsNumericType
let arr = &mut **arr;
unsafe { &mut *(arr as *mut dyn Array as *mut PrimitiveArray<T::Native>) }
})
}

pub fn downcast_chunks(&self) -> Chunks<'_, PrimitiveArray<T::Native>> {
Chunks::new(&self.chunks)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
use crate::prelude::*;
use crate::utils::{get_supertype, get_time_units};
use num::{Num, NumCast};
use std::borrow::Cow;
use std::fmt::Debug;
use std::ops;
use super::*;

pub trait NumOpsDispatch: Debug {
fn subtract(&self, rhs: &Series) -> Result<Series> {
Expand Down
11 changes: 11 additions & 0 deletions polars/polars-core/src/series/arithmetic/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
mod borrowed;
mod owned;

use crate::prelude::*;
use crate::utils::{get_supertype, get_time_units};
use num::{Num, NumCast};
use std::borrow::Cow;
use std::fmt::Debug;
use std::ops::{self, Add, Div, Mul, Sub};

pub use borrowed::*;
83 changes: 83 additions & 0 deletions polars/polars-core/src/series/arithmetic/owned.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
use super::*;
#[cfg(feature = "performant")]
use crate::utils::align_chunks_binary_owned_series;

#[cfg(feature = "performant")]
pub fn coerce_lhs_rhs_owned(lhs: Series, rhs: Series) -> Result<(Series, Series)> {
let dtype = get_supertype(lhs.dtype(), rhs.dtype())?;
let left = if lhs.dtype() == &dtype {
lhs
} else {
lhs.cast(&dtype)?
};
let right = if rhs.dtype() == &dtype {
rhs
} else {
rhs.cast(&dtype)?
};
Ok((left, right))
}

#[cfg(feature = "performant")]
fn apply_operation_mut<T, F>(mut lhs: Series, mut rhs: Series, op: F) -> Series
where
T: PolarsNumericType,
F: Fn(ChunkedArray<T>, ChunkedArray<T>) -> ChunkedArray<T> + Copy,
ChunkedArray<T>: IntoSeries,
{
let lhs_ca: &mut ChunkedArray<T> = lhs._get_inner_mut().as_mut();
let rhs_ca: &mut ChunkedArray<T> = rhs._get_inner_mut().as_mut();

let lhs = std::mem::take(lhs_ca);
let rhs = std::mem::take(rhs_ca);

op(lhs, rhs).into_series()
}

macro_rules! impl_operation {
($operation:ident, $method:ident, $function:expr) => {
impl $operation for Series {
type Output = Series;

fn $method(self, rhs: Self) -> Self::Output {
#[cfg(feature = "performant")]
{
// only physical numeric values take the mutable path
if !self.is_logical() && self.is_numeric_physical() {
let (lhs, rhs) = coerce_lhs_rhs_owned(self, rhs).unwrap();
let (lhs, rhs) = align_chunks_binary_owned_series(lhs, rhs);
use DataType::*;
match lhs.dtype() {
#[cfg(feature = "dtype-i8")]
Int8 => apply_operation_mut::<Int8Type, _>(lhs, rhs, $function),
#[cfg(feature = "dtype-i16")]
Int16 => apply_operation_mut::<Int16Type, _>(lhs, rhs, $function),
Int32 => apply_operation_mut::<Int32Type, _>(lhs, rhs, $function),
Int64 => apply_operation_mut::<Int64Type, _>(lhs, rhs, $function),
#[cfg(feature = "dtype-u8")]
UInt8 => apply_operation_mut::<UInt8Type, _>(lhs, rhs, $function),
#[cfg(feature = "dtype-u16")]
UInt16 => apply_operation_mut::<UInt16Type, _>(lhs, rhs, $function),
UInt32 => apply_operation_mut::<UInt32Type, _>(lhs, rhs, $function),
UInt64 => apply_operation_mut::<UInt64Type, _>(lhs, rhs, $function),
Float32 => apply_operation_mut::<Float32Type, _>(lhs, rhs, $function),
Float64 => apply_operation_mut::<Float64Type, _>(lhs, rhs, $function),
_ => unreachable!(),
}
} else {
(&self).$method(&rhs)
}
}
#[cfg(not(feature = "performant"))]
{
(&self).$method(&rhs)
}
}
}
};
}

impl_operation!(Add, add, |a, b| a.add(b));
impl_operation!(Sub, sub, |a, b| a.sub(b));
impl_operation!(Mul, mul, |a, b| a.mul(b));
impl_operation!(Div, div, |a, b| a.div(b));
28 changes: 28 additions & 0 deletions polars/polars-core/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,34 @@ where
}
}

#[cfg(feature = "performant")]
pub(crate) fn align_chunks_binary_owned_series(left: Series, right: Series) -> (Series, Series) {
match (left.chunks().len(), right.chunks().len()) {
(1, 1) => (left, right),
(_, 1) => (left.rechunk(), right),
(1, _) => (left, right.rechunk()),
(_, _) => (left.rechunk(), right.rechunk()),
}
}

pub(crate) fn align_chunks_binary_owned<T, B>(
left: ChunkedArray<T>,
right: ChunkedArray<B>,
) -> (ChunkedArray<T>, ChunkedArray<B>)
where
ChunkedArray<B>: ChunkOps,
ChunkedArray<T>: ChunkOps,
B: PolarsDataType,
T: PolarsDataType,
{
match (left.chunks.len(), right.chunks.len()) {
(1, 1) => (left, right),
(_, 1) => (left.rechunk(), right),
(1, _) => (left, right.rechunk()),
(_, _) => (left.rechunk(), right.rechunk()),
}
}

#[allow(clippy::type_complexity)]
pub(crate) fn align_chunks_ternary<'a, A, B, C>(
a: &'a ChunkedArray<A>,
Expand Down
4 changes: 2 additions & 2 deletions polars/polars-io/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ private = ["polars-time/private"]
[dependencies]
ahash = "0.7"
anyhow = "1.0"
arrow = { package = "arrow2", git = "https://github.com/jorgecarleitao/arrow2", rev = "39db6fb7514364bfea08d594793b23e1ed5a7def", default-features = false }
# arrow = { package = "arrow2", git = "https://github.com/ritchie46/arrow2", branch = "count_shared", default-features = false }
# arrow = { package = "arrow2", git = "https://github.com/jorgecarleitao/arrow2", rev = "39db6fb7514364bfea08d594793b23e1ed5a7def", default-features = false }
arrow = { package = "arrow2", git = "https://github.com/ritchie46/arrow2", branch = "arity_assign", default-features = false }
# arrow = { package = "arrow2", version = "0.12", default-features = false }
# arrow = { package = "arrow2", path = "../../../arrow2", default-features = false }
csv-core = { version = "0.1.10", optional = true }
Expand Down

0 comments on commit 74081ac

Please sign in to comment.