Skip to content

Commit

Permalink
python: don't lock gil in arr.contains (#4210)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Aug 2, 2022
1 parent 95074d9 commit 53edb19
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 3 deletions.
12 changes: 12 additions & 0 deletions polars/polars-lazy/src/dsl/function_expr/list.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
use super::*;

#[cfg(feature = "is_in")]
pub(super) fn contains(args: &mut [Series]) -> Result<Series> {
let list = &args[0];
let is_in = &args[1];

is_in.is_in(list).map(|mut ca| {
ca.rename(list.name());
ca.into_series()
})
}
9 changes: 9 additions & 0 deletions polars/polars-lazy/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod arg_where;
mod fill_null;
#[cfg(feature = "is_in")]
mod is_in;
mod list;
mod pow;
#[cfg(feature = "row_hash")]
mod row_hash;
Expand Down Expand Up @@ -49,6 +50,8 @@ pub enum FunctionExpr {
FillNull {
super_type: DataType,
},
#[cfg(feature = "is_in")]
ListContains,
}

#[cfg(feature = "trigonometry")]
Expand Down Expand Up @@ -112,6 +115,8 @@ impl FunctionExpr {
#[cfg(feature = "sign")]
Sign => with_dtype(DataType::Int64),
FillNull { super_type, .. } => with_dtype(super_type.clone()),
#[cfg(feature = "is_in")]
ListContains => with_dtype(DataType::Boolean),
}
}
}
Expand Down Expand Up @@ -224,6 +229,10 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
FillNull { super_type } => {
map_as_slice!(fill_null::fill_null, &super_type)
}
#[cfg(feature = "is_in")]
ListContains => {
wrap!(list::contains)
}
}
}
}
21 changes: 21 additions & 0 deletions polars/polars-lazy/src/dsl/list.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::dsl::eval::prepare_eval_expr;
use crate::dsl::function_expr::FunctionExpr;
use crate::physical_plan::state::ExecutionState;
use crate::prelude::*;
use parking_lot::Mutex;
Expand Down Expand Up @@ -212,6 +213,7 @@ impl ListNameSpace {

/// Run any [`Expr`] on these lists elements
#[cfg(feature = "list_eval")]
#[cfg_attr(docsrs, doc(cfg(feature = "list_eval")))]
pub fn eval(self, expr: Expr, parallel: bool) -> Expr {
let expr = prepare_eval_expr(expr);

Expand Down Expand Up @@ -312,6 +314,7 @@ impl ListNameSpace {
}

#[cfg(feature = "list_to_struct")]
#[cfg_attr(docsrs, doc(cfg(feature = "list_to_struct")))]
#[allow(clippy::wrong_self_convention)]
/// Convert this `List` to a `Series` of type `Struct`. The width will be determined according to
/// `ListToStructWidthStrategy` and the names of the fields determined by the given `name_generator`.
Expand All @@ -332,4 +335,22 @@ impl ListNameSpace {
)
.with_fmt("arr.to_struct")
}

#[cfg(feature = "is_in")]
#[cfg_attr(docsrs, doc(cfg(feature = "is_in")))]
/// Check if the list array contain an element
pub fn contains<E: Into<Expr>>(self, other: E) -> Expr {
let other = other.into();

Expr::Function {
input: vec![self.0, other],
function: FunctionExpr::ListContains,
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyFlat,
input_wildcard_expansion: true,
auto_explode: true,
fmt_str: "arr.contains",
},
}
}
}
7 changes: 4 additions & 3 deletions py-polars/polars/internals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from polars import internals as pli
from polars.datatypes import (
DTYPE_TEMPORAL_UNITS,
Boolean,
DataType,
Date,
Datetime,
Expand Down Expand Up @@ -5353,7 +5352,7 @@ def last(self) -> Expr:
"""
return self.get(-1)

def contains(self, item: float | str | bool | int | date | datetime) -> Expr:
def contains(self, item: float | str | bool | int | date | datetime | Expr) -> Expr:
"""
Check if sublists contain the given item.
Expand Down Expand Up @@ -5384,7 +5383,7 @@ def contains(self, item: float | str | bool | int | date | datetime) -> Expr:
└───────┘
"""
return wrap_expr(self._pyexpr).map(lambda s: s.arr.contains(item), Boolean)
return wrap_expr(self._pyexpr.arr_contains(expr_to_lit_or_expr(item)._pyexpr))

def join(self, separator: str) -> Expr:
"""
Expand Down Expand Up @@ -7114,6 +7113,8 @@ def expr_to_lit_or_expr(
| str
| pli.Series
| None
| date
| datetime
| Sequence[(int | float | str | None)]
),
str_to_lit: bool = True,
Expand Down
4 changes: 4 additions & 0 deletions py-polars/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,10 @@ impl PyExpr {
self.inner.clone().arr().lengths().into()
}

pub fn arr_contains(&self, other: PyExpr) -> PyExpr {
self.inner.clone().arr().contains(other.inner).into()
}

pub fn year(&self) -> PyExpr {
self.clone().inner.dt().year().into()
}
Expand Down

0 comments on commit 53edb19

Please sign in to comment.