Skip to content

Commit

Permalink
Fix decreasing interpolation of unsigned Series (#2552)
Browse files Browse the repository at this point in the history
  • Loading branch information
qiemem committed Feb 5, 2022
1 parent 483da28 commit 48ea1ed
Showing 1 changed file with 75 additions and 11 deletions.
86 changes: 75 additions & 11 deletions polars/polars-core/src/chunked_array/ops/interpolate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,47 @@ where
low + step * diff / steps_n
}

impl<T> Interpolate for ChunkedArray<T>
where
T: PolarsNumericType,
{
fn interpolate(&self) -> Self {
#[inline]
fn signed_interp<T: PolarsNumericType>(
low: T::Native,
high: T::Native,
steps: u32,
steps_n: T::Native,
av: &mut Vec<T::Native>,
) {
let diff = high - low;
for step_i in 1..steps {
let step_i = T::Native::from_u32(step_i).unwrap();
let v = linear_itp(low, step_i, diff, steps_n);
av.push(v)
}
}

#[inline]
fn unsigned_interp<T: PolarsNumericType>(
low: T::Native,
high: T::Native,
steps: u32,
steps_n: T::Native,
av: &mut Vec<T::Native>,
) {
if high >= low {
signed_interp::<T>(low, high, steps, steps_n, av)
} else {
let diff = low - high;
for step_i in (1..steps).rev() {
let step_i = T::Native::from_u32(step_i).unwrap();
let v = linear_itp(high, step_i, diff, steps_n);
av.push(v)
}
}
}

impl<T: PolarsNumericType> ChunkedArray<T> {
fn interpolate_impl<I>(&self, interpolation_branch: I) -> Self
where
I: Fn(T::Native, T::Native, u32, T::Native, &mut Vec<T::Native>),
{
// This implementation differs from pandas as that boundary None's are not removed
// this prevents a lot of errors due to expressions leading to different lengths
if !self.has_validity() || self.null_count() == self.len() {
Expand Down Expand Up @@ -71,13 +107,8 @@ where
// another null
Some(None) => {}
Some(Some(high)) => {
let diff = high - low;
let steps_n = T::Native::from_u32(steps).unwrap();
for step_i in 1..steps {
let step_i = T::Native::from_u32(step_i).unwrap();
let v = linear_itp(low, step_i, diff, steps_n);
av.push(v)
}
interpolation_branch(low, high, steps, steps_n, &mut av);
av.push(high);
low_val = Some(high);
break;
Expand Down Expand Up @@ -117,6 +148,32 @@ where
}
}

macro_rules! impl_interpolate {
($type:ident, $interpolation_branch:ident) => {
impl Interpolate for ChunkedArray<$type> {
fn interpolate(&self) -> Self {
self.interpolate_impl($interpolation_branch::<$type>)
}
}
};
}

#[cfg(feature = "dtype-u8")]
impl_interpolate!(UInt8Type, unsigned_interp);
#[cfg(feature = "dtype-u16")]
impl_interpolate!(UInt16Type, unsigned_interp);
impl_interpolate!(UInt32Type, unsigned_interp);
impl_interpolate!(UInt64Type, unsigned_interp);

#[cfg(feature = "dtype-i8")]
impl_interpolate!(Int8Type, signed_interp);
#[cfg(feature = "dtype-i16")]
impl_interpolate!(Int16Type, signed_interp);
impl_interpolate!(Int32Type, signed_interp);
impl_interpolate!(Int64Type, signed_interp);
impl_interpolate!(Float32Type, signed_interp);
impl_interpolate!(Float64Type, signed_interp);

#[cfg(test)]
mod test {
use super::*;
Expand Down Expand Up @@ -145,6 +202,13 @@ mod test {
);
}

#[test]
fn test_interpolate_decreasing_unsigned() {
let ca = UInt32Chunked::new("", &[Some(4), None, None, Some(1)]);
let out = ca.interpolate();
assert_eq!(Vec::from(&out), &[Some(4), Some(3), Some(2), Some(1)])
}

#[test]
fn test_interpolate2() {
let ca = Float32Chunked::new(
Expand Down

0 comments on commit 48ea1ed

Please sign in to comment.