Skip to content

Commit

Permalink
feat[rust, python]: struct access ergonomics (#4570)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Aug 25, 2022
1 parent 30d6270 commit 4d7640d
Show file tree
Hide file tree
Showing 14 changed files with 277 additions and 214 deletions.
40 changes: 40 additions & 0 deletions polars/polars-lazy/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ mod shift_and_fill;
mod sign;
#[cfg(feature = "strings")]
mod strings;
#[cfg(feature = "dtype-struct")]
mod struct_;
#[cfg(any(feature = "temporal", feature = "date_offset"))]
mod temporal;
#[cfg(feature = "trigonometry")]
Expand All @@ -28,12 +30,15 @@ mod trigonometry;
#[cfg(feature = "list")]
pub(super) use list::ListFunction;
use polars_core::prelude::*;
use polars_core::utils::slice_offsets;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

pub(super) use self::nan::NanFunction;
#[cfg(feature = "strings")]
pub(super) use self::strings::StringFunction;
#[cfg(feature = "dtype-struct")]
pub(super) use self::struct_::StructFunction;
use super::*;

#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
Expand Down Expand Up @@ -79,6 +84,8 @@ pub enum FunctionExpr {
},
#[cfg(feature = "list")]
ListExpr(ListFunction),
#[cfg(feature = "dtype-struct")]
StructExpr(StructFunction),
}

#[cfg(feature = "trigonometry")]
Expand Down Expand Up @@ -226,6 +233,31 @@ impl FunctionExpr {
Concat => inner_super_type_list(),
}
}
#[cfg(feature = "dtype-struct")]
StructExpr(s) => {
use StructFunction::*;
match s {
FieldByIndex(index) => {
let (index, _) = slice_offsets(*index, 0, fields.len());
fields.get(index).cloned().ok_or_else(|| {
PolarsError::ComputeError(
"index out of bounds in 'struct.field'".into(),
)
})
}
FieldByName(name) => {
if let DataType::Struct(flds) = &fields[0].dtype {
let fld = flds
.iter()
.find(|fld| fld.name() == name.as_ref())
.ok_or_else(|| PolarsError::NotFound(name.as_ref().to_string()))?;
Ok(fld.clone())
} else {
Err(PolarsError::NotFound(name.as_ref().to_string()))
}
}
}
}
}
}
}
Expand Down Expand Up @@ -369,6 +401,14 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
Concat => wrap!(list::concat),
}
}
#[cfg(feature = "dtype-struct")]
StructExpr(sf) => {
use StructFunction::*;
match sf {
FieldByIndex(index) => map_with_args!(struct_::get_by_index, index),
FieldByName(name) => map_with_args!(struct_::get_by_name, name.clone()),
}
}
}
}
}
Expand Down
23 changes: 23 additions & 0 deletions polars/polars-lazy/src/dsl/function_expr/struct_.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
use polars_core::utils::slice_offsets;

use super::*;

#[derive(Clone, Eq, PartialEq, Hash, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum StructFunction {
FieldByIndex(i64),
FieldByName(Arc<str>),
}

pub(super) fn get_by_index(s: &Series, index: i64) -> Result<Series> {
let s = s.struct_()?;
let (index, _) = slice_offsets(index, 0, s.fields().len());
s.fields()
.get(index)
.cloned()
.ok_or_else(|| PolarsError::ComputeError("index out of bounds in 'struct.field'".into()))
}
pub(super) fn get_by_name(s: &Series, name: Arc<str>) -> Result<Series> {
let ca = s.struct_()?;
ca.field_by_name(name.as_ref())
}
20 changes: 7 additions & 13 deletions polars/polars-lazy/src/dsl/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,8 @@ pub fn argsort_by<E: AsRef<[Expr]>>(by: E, reverse: &[bool]) -> Expr {
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyGroups,
input_wildcard_expansion: true,
auto_explode: false,
fmt_str: "argsort_by",
cast_to_supertypes: false,
..Default::default()
},
}
}
Expand All @@ -245,7 +244,7 @@ pub fn concat_str<E: AsRef<[Expr]>>(s: E, sep: &str) -> Expr {
input_wildcard_expansion: true,
auto_explode: true,
fmt_str: "concat_by",
cast_to_supertypes: false,
..Default::default()
},
}
}
Expand All @@ -262,9 +261,8 @@ pub fn concat_lst<E: AsRef<[IE]>, IE: Into<Expr> + Clone>(s: E) -> Expr {
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyFlat,
input_wildcard_expansion: true,
auto_explode: false,
fmt_str: "concat_list",
cast_to_supertypes: false,
..Default::default()
},
}
}
Expand Down Expand Up @@ -450,9 +448,8 @@ pub fn datetime(args: DatetimeArgs) -> Expr {
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyFlat,
input_wildcard_expansion: true,
auto_explode: false,
fmt_str: "datetime",
cast_to_supertypes: false,
..Default::default()
},
}
.alias("datetime")
Expand Down Expand Up @@ -528,9 +525,8 @@ pub fn duration(args: DurationArgs) -> Expr {
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyFlat,
input_wildcard_expansion: true,
auto_explode: false,
fmt_str: "duration",
cast_to_supertypes: false,
..Default::default()
},
}
.alias("duration")
Expand Down Expand Up @@ -731,7 +727,7 @@ where
input_wildcard_expansion: true,
auto_explode: true,
fmt_str: "",
cast_to_supertypes: false,
..Default::default()
},
}
} else {
Expand Down Expand Up @@ -904,10 +900,8 @@ pub fn arg_where<E: Into<Expr>>(condition: E) -> Expr {
function: FunctionExpr::ArgWhere,
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyGroups,
input_wildcard_expansion: false,
auto_explode: false,
fmt_str: "arg_where",
cast_to_supertypes: false,
..Default::default()
},
}
}
2 changes: 1 addition & 1 deletion polars/polars-lazy/src/dsl/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ impl ListNameSpace {
input_wildcard_expansion: true,
auto_explode: true,
fmt_str: "arr.contains",
cast_to_supertypes: false,
..Default::default()
},
}
}
Expand Down

0 comments on commit 4d7640d

Please sign in to comment.