Skip to content

Commit

Permalink
feat[rust, python]: clip_min,clip_max exprs (#4534)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Aug 22, 2022
1 parent 133f22d commit fbcbc90
Show file tree
Hide file tree
Showing 13 changed files with 451 additions and 44 deletions.
238 changes: 232 additions & 6 deletions polars/polars-core/src/datatypes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ use arrow::types::NativeType;
use num::{Bounded, FromPrimitive, Num, NumCast, Zero};
use polars_arrow::data_types::IsFloat;
#[cfg(feature = "serde")]
use serde::de::{EnumAccess, Error, Unexpected, VariantAccess, Visitor};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")]
use serde::{Deserializer, Serializer};

pub use crate::chunked_array::logical::*;
#[cfg(feature = "object")]
Expand Down Expand Up @@ -231,7 +235,6 @@ impl PolarsFloatType for Float32Type {}
impl PolarsFloatType for Float64Type {}

#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum AnyValue<'a> {
Null,
/// A binary true or false.
Expand Down Expand Up @@ -265,7 +268,6 @@ pub enum AnyValue<'a> {
/// A 64-bit date representing the elapsed time since UNIX epoch (1970-01-01)
/// in nanoseconds (64 bits).
#[cfg(feature = "dtype-datetime")]
#[cfg_attr(feature = "serde", serde(skip))]
Datetime(i64, TimeUnit, &'a Option<TimeZone>),
// A 64-bit integer representing difference between date-times in [`TimeUnit`]
#[cfg(feature = "dtype-duration")]
Expand All @@ -274,24 +276,248 @@ pub enum AnyValue<'a> {
#[cfg(feature = "dtype-time")]
Time(i64),
#[cfg(feature = "dtype-categorical")]
#[cfg_attr(feature = "serde", serde(skip))]
Categorical(u32, &'a RevMapping),
/// Nested type, contains arrays that are filled with one of the datetypes.
List(Series),
#[cfg(feature = "object")]
/// Can be used to fmt and implements Any, so can be downcasted to the proper value type.
#[cfg_attr(feature = "serde", serde(skip))]
Object(&'a dyn PolarsObjectSafe),
#[cfg(feature = "dtype-struct")]
#[cfg_attr(feature = "serde", serde(skip))]
Struct(Vec<AnyValue<'a>>, &'a [Field]),
#[cfg(feature = "dtype-struct")]
#[cfg_attr(feature = "serde", serde(skip))]
StructOwned(Box<(Vec<AnyValue<'a>>, Vec<Field>)>),
/// A UTF8 encoded string type.
Utf8Owned(String),
}

#[cfg(feature = "serde")]
impl Serialize for AnyValue<'_> {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
let name = "AnyValue";
match self {
AnyValue::Null => serializer.serialize_unit_variant(name, 0, "Null"),
AnyValue::Int8(v) => serializer.serialize_newtype_variant(name, 1, "Int8", v),
AnyValue::Int16(v) => serializer.serialize_newtype_variant(name, 2, "Int16", v),
AnyValue::Int32(v) => serializer.serialize_newtype_variant(name, 3, "Int32", v),
AnyValue::Int64(v) => serializer.serialize_newtype_variant(name, 4, "Int64", v),
AnyValue::UInt8(v) => serializer.serialize_newtype_variant(name, 5, "UInt8", v),
AnyValue::UInt16(v) => serializer.serialize_newtype_variant(name, 6, "UInt16", v),
AnyValue::UInt32(v) => serializer.serialize_newtype_variant(name, 7, "UInt32", v),
AnyValue::UInt64(v) => serializer.serialize_newtype_variant(name, 8, "UInt64", v),
AnyValue::Float32(v) => serializer.serialize_newtype_variant(name, 9, "Float32", v),
AnyValue::Float64(v) => serializer.serialize_newtype_variant(name, 10, "Float64", v),
AnyValue::List(v) => serializer.serialize_newtype_variant(name, 11, "List", v),
AnyValue::Boolean(v) => serializer.serialize_newtype_variant(name, 12, "Bool", v),
// both utf8 variants same number
AnyValue::Utf8(v) => serializer.serialize_newtype_variant(name, 13, "Utf8Owned", v),
AnyValue::Utf8Owned(v) => {
serializer.serialize_newtype_variant(name, 13, "Utf8Owned", v)
}
_ => todo!(),
}
}
}

#[cfg(feature = "serde")]
impl<'a> Deserialize<'a> for AnyValue<'static> {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: Deserializer<'a>,
{
#[repr(u8)]
enum AvField {
Null,
Int8,
Int16,
Int32,
Int64,
UInt8,
UInt16,
UInt32,
UInt64,
Float32,
Float64,
List,
Bool,
Utf8Owned,
}
const VARIANTS: &[&str] = &[
"Null",
"UInt8",
"UInt16",
"UInt32",
"UInt64",
"Int8",
"Int16",
"Int32",
"Int64",
"Float32",
"Float64",
"List",
"Boolean",
"Utf8Owned",
];
const LAST: u8 = unsafe { std::mem::transmute::<_, u8>(AvField::Utf8Owned) };

struct FieldVisitor;

impl Visitor<'_> for FieldVisitor {
type Value = AvField;

fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
write!(formatter, "an integer between 0-{}", LAST)
}

fn visit_i64<E>(self, v: i64) -> std::result::Result<Self::Value, E>
where
E: Error,
{
let field: u8 = NumCast::from(v).ok_or_else(|| {
serde::de::Error::invalid_value(
Unexpected::Signed(v),
&"expected value that fits into u8",
)
})?;

// safety:
// we are repr: u8 and check last value that we are in bounds
let field = unsafe {
if field <= LAST {
std::mem::transmute::<u8, AvField>(field)
} else {
return Err(serde::de::Error::invalid_value(
Unexpected::Signed(v),
&"expected value that fits into AnyValue's number of fields",
));
}
};
Ok(field)
}

fn visit_str<E>(self, v: &str) -> std::result::Result<Self::Value, E>
where
E: Error,
{
self.visit_bytes(v.as_bytes())
}

fn visit_bytes<E>(self, v: &[u8]) -> std::result::Result<Self::Value, E>
where
E: Error,
{
let field = match v {
b"Null" => AvField::Null,
b"Int8" => AvField::Int8,
b"Int16" => AvField::Int16,
b"Int32" => AvField::Int32,
b"Int64" => AvField::Int64,
b"UInt8" => AvField::UInt8,
b"UInt16" => AvField::UInt16,
b"UInt32" => AvField::UInt32,
b"UInt64" => AvField::UInt64,
b"Float32" => AvField::Float32,
b"Float64" => AvField::Float64,
b"List" => AvField::List,
b"Bool" => AvField::Bool,
b"Utf8Owned" | b"Utf8" => AvField::Utf8Owned,
_ => {
return Err(serde::de::Error::unknown_variant(
&String::from_utf8_lossy(v),
VARIANTS,
))
}
};
Ok(field)
}
}

impl<'a> Deserialize<'a> for AvField {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: Deserializer<'a>,
{
deserializer.deserialize_identifier(FieldVisitor)
}
}

struct OuterVisitor;

impl<'b> Visitor<'b> for OuterVisitor {
type Value = AnyValue<'static>;

fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
write!(formatter, "enum AnyValue")
}

fn visit_enum<A>(self, data: A) -> std::result::Result<Self::Value, A::Error>
where
A: EnumAccess<'b>,
{
let out = match data.variant()? {
(AvField::Null, _variant) => AnyValue::Null,
(AvField::Int8, variant) => {
let value = variant.newtype_variant()?;
AnyValue::Int8(value)
}
(AvField::Int16, variant) => {
let value = variant.newtype_variant()?;
AnyValue::Int16(value)
}
(AvField::Int32, variant) => {
let value = variant.newtype_variant()?;
AnyValue::Int32(value)
}
(AvField::Int64, variant) => {
let value = variant.newtype_variant()?;
AnyValue::Int64(value)
}
(AvField::UInt8, variant) => {
let value = variant.newtype_variant()?;
AnyValue::UInt8(value)
}
(AvField::UInt16, variant) => {
let value = variant.newtype_variant()?;
AnyValue::UInt16(value)
}
(AvField::UInt32, variant) => {
let value = variant.newtype_variant()?;
AnyValue::UInt32(value)
}
(AvField::UInt64, variant) => {
let value = variant.newtype_variant()?;
AnyValue::UInt64(value)
}
(AvField::Float32, variant) => {
let value = variant.newtype_variant()?;
AnyValue::Float32(value)
}
(AvField::Float64, variant) => {
let value = variant.newtype_variant()?;
AnyValue::Float64(value)
}
(AvField::Bool, variant) => {
let value = variant.newtype_variant()?;
AnyValue::Boolean(value)
}
(AvField::List, variant) => {
let value = variant.newtype_variant()?;
AnyValue::List(value)
}
(AvField::Utf8Owned, variant) => {
let value = variant.newtype_variant()?;
AnyValue::Utf8Owned(value)
}
};
Ok(out)
}
}
deserializer.deserialize_enum("AnyValue", VARIANTS, OuterVisitor)
}
}

impl<'a> AnyValue<'a> {
/// Extract a numerical value from the AnyValue
#[doc(hidden)]
Expand Down
14 changes: 1 addition & 13 deletions polars/polars-core/src/frame/asof_join/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,13 @@ use std::borrow::Cow;
use asof::*;
use num::Bounded;
#[cfg(feature = "serde")]
use serde::Deserializer;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

use crate::prelude::*;
use crate::utils::slice_slice;

#[derive(Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(Serialize))]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct AsOfOptions {
pub strategy: AsofStrategy,
/// A tolerance in the same unit as the asof column
Expand All @@ -29,16 +27,6 @@ pub struct AsOfOptions {
pub right_by: Option<Vec<String>>,
}

#[cfg(feature = "serde")]
impl<'a> Deserialize<'a> for AsOfOptions {
fn deserialize<D>(_deserializer: D) -> std::result::Result<Self, D::Error>
where
D: Deserializer<'a>,
{
todo!()
}
}

fn check_asof_columns(a: &Series, b: &Series) -> Result<()> {
if a.dtype() != b.dtype() {
return Err(PolarsError::ComputeError(
Expand Down
65 changes: 59 additions & 6 deletions polars/polars-core/src/series/ops/round.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,72 @@ impl Series {
}

#[cfg_attr(docsrs, doc(cfg(feature = "round_series")))]
/// Ceil underlying floating point array to the highest integers smaller or equal to the float value.
pub fn clip(mut self, min: f64, max: f64) -> Result<Self> {
/// Clamp underlying values to the `min` and `max` values.
pub fn clip(mut self, min: AnyValue<'_>, max: AnyValue<'_>) -> Result<Self> {
if self.dtype().is_numeric() {
macro_rules! apply_clip {
($pl_type:ty, $ca:expr, $min:expr, $max: expr) => {{
let min = min as <$pl_type as PolarsNumericType>::Native;
let max = max as <$pl_type as PolarsNumericType>::Native;
($pl_type:ty, $ca:expr) => {{
let min = min
.extract::<<$pl_type as PolarsNumericType>::Native>()
.unwrap();
let max = max
.extract::<<$pl_type as PolarsNumericType>::Native>()
.unwrap();

$ca.apply_mut(|val| val.clamp(min, max));
}};
}
let mutable = self._get_inner_mut();
downcast_as_macro_arg_physical_mut!(mutable, apply_clip, min, max);
downcast_as_macro_arg_physical_mut!(mutable, apply_clip);
Ok(self)
} else {
Err(PolarsError::SchemaMisMatch(
format!("Cannot use 'clip' on dtype {:?}, consider using a when -> then -> otherwise expression", self.dtype()).into(),
))
}
}

#[cfg_attr(docsrs, doc(cfg(feature = "round_series")))]
/// Clamp underlying values to the `max` value.
pub fn clip_max(mut self, max: AnyValue<'_>) -> Result<Self> {
use num::traits::clamp_max;
if self.dtype().is_numeric() {
macro_rules! apply_clip {
($pl_type:ty, $ca:expr) => {{
let max = max
.extract::<<$pl_type as PolarsNumericType>::Native>()
.unwrap();

$ca.apply_mut(|val| clamp_max(val, max));
}};
}
let mutable = self._get_inner_mut();
downcast_as_macro_arg_physical_mut!(mutable, apply_clip);
Ok(self)
} else {
Err(PolarsError::SchemaMisMatch(
format!("Cannot use 'clip' on dtype {:?}, consider using a when -> then -> otherwise expression", self.dtype()).into(),
))
}
}

#[cfg_attr(docsrs, doc(cfg(feature = "round_series")))]
/// Clamp underlying values to the `min` value.
pub fn clip_min(mut self, min: AnyValue<'_>) -> Result<Self> {
use num::traits::clamp_min;

if self.dtype().is_numeric() {
macro_rules! apply_clip {
($pl_type:ty, $ca:expr) => {{
let min = min
.extract::<<$pl_type as PolarsNumericType>::Native>()
.unwrap();

$ca.apply_mut(|val| clamp_min(val, min));
}};
}
let mutable = self._get_inner_mut();
downcast_as_macro_arg_physical_mut!(mutable, apply_clip);
Ok(self)
} else {
Err(PolarsError::SchemaMisMatch(
Expand Down

0 comments on commit fbcbc90

Please sign in to comment.