Skip to content

Commit

Permalink
Lazy; select columns by dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Oct 25, 2021
1 parent a911dbe commit e1a60a9
Show file tree
Hide file tree
Showing 9 changed files with 115 additions and 22 deletions.
13 changes: 13 additions & 0 deletions polars/polars-lazy/src/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ pub enum Expr {
Alias(Box<Expr>, Arc<String>),
Column(Arc<String>),
Columns(Vec<String>),
DtypeColumn(Vec<DataType>),
Literal(LiteralValue),
BinaryExpr {
left: Box<Expr>,
Expand Down Expand Up @@ -442,6 +443,7 @@ impl fmt::Debug for Expr {
KeepName(e) => write!(f, "KEEP NAME {:?}", e),
SufPreFix { expr, .. } => write!(f, "SUF-PREFIX {:?}", expr),
Columns(names) => write!(f, "COLUMNS({:?})", names),
DtypeColumn(dt) => write!(f, "COLUMN OF DTYPE: {:?}", dt),
}
}
}
Expand Down Expand Up @@ -1675,6 +1677,17 @@ pub fn cols(names: Vec<String>) -> Expr {
Expr::Columns(names)
}

/// Select multiple columns by dtype.
pub fn dtype_col(dtype: &DataType) -> Expr {
Expr::DtypeColumn(vec![dtype.clone()])
}

/// Select multiple columns by dtype.
pub fn dtype_cols<DT: AsRef<[DataType]>>(dtype: DT) -> Expr {
let dtypes = dtype.as_ref().to_vec();
Expr::DtypeColumn(dtypes)
}

/// Count the number of values in this Expression.
pub fn count(name: &str) -> Expr {
match name {
Expand Down
1 change: 1 addition & 0 deletions polars/polars-lazy/src/logical_plan/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ pub(crate) fn to_aexpr(expr: Expr, arena: &mut Arena<AExpr>) -> Node {
Expr::Exclude(_, _) => panic!("no exclude expected at this point"),
Expr::SufPreFix { .. } => panic!("no `suffix/prefix` expected at this point"),
Expr::Columns { .. } => panic!("no `columns` expected at this point"),
Expr::DtypeColumn { .. } => panic!("no `dtype-columns` expected at this point"),
};
arena.add(v)
}
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-lazy/src/logical_plan/iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ macro_rules! push_expr {
($current_expr:expr, $push:ident, $iter:ident) => {{
use Expr::*;
match $current_expr {
Column(_) | Literal(_) | Wildcard | Columns(_) => {}
Column(_) | Literal(_) | Wildcard | Columns(_) | DtypeColumn(_) => {}
Alias(e, _) => $push(e),
Not(e) => $push(e),
BinaryExpr { left, op: _, right } => {
Expand Down
57 changes: 39 additions & 18 deletions polars/polars-lazy/src/logical_plan/projection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ fn replace_regex(expr: &Expr, result: &mut Vec<Expr>, schema: &Schema) {
}
}

/// replace columns(["A", "B"]).. with col("A").., col("B")..
/// replace `columns(["A", "B"])..` with `col("A")..`, `col("B")..`
fn expand_columns(expr: &Expr, result: &mut Vec<Expr>, names: &[String]) {
for name in names {
let mut new_expr = expr.clone();
Expand All @@ -132,7 +132,28 @@ fn expand_columns(expr: &Expr, result: &mut Vec<Expr>, names: &[String]) {
}
}

fn prepare_exluded(expr: &Expr, schema: &Schema) -> Vec<Arc<String>> {
/// replace `DtypeColumn` with `col("foo")..col("bar")`
fn expand_dtypes(expr: &Expr, result: &mut Vec<Expr>, schema: &Schema, dtypes: &[DataType]) {
for dtype in dtypes {
for field in schema.fields().iter().filter(|f| f.data_type() == dtype) {
let name = field.name();

let mut new_expr = expr.clone();
new_expr.mutate().apply(|e| {
if let Expr::DtypeColumn(_) = &e {
*e = Expr::Column(Arc::new(name.clone()));
}
// always keep iterating all inputs
true
});

let new_expr = rewrite_keep_name_and_sufprefix(new_expr);
result.push(new_expr)
}
}
}

fn prepare_excluded(expr: &Expr, schema: &Schema) -> Vec<Arc<String>> {
let mut exclude = vec![];
expr.into_iter().for_each(|e| {
if let Expr::Exclude(_, names) = e {
Expand Down Expand Up @@ -168,20 +189,22 @@ pub(crate) fn rewrite_projections(exprs: Vec<Expr>, schema: &Schema) -> Vec<Expr
let mut result = Vec::with_capacity(exprs.len() + schema.fields().len());

for mut expr in exprs {
// in case of multiple cols, we still want to check wildcard for function input,
// but in case of no wildcard, we don't want this expr pushed to results.
let mut push_current = true;
// has multiple column names
if let Some(e) = expr.into_iter().find(|e| matches!(e, Expr::Columns(_))) {
if let Some(e) = expr
.into_iter()
.find(|e| matches!(e, Expr::Columns(_) | Expr::DtypeColumn(_)))
{
if let Expr::Columns(names) = e {
expand_columns(&expr, &mut result, names)
} else if let Expr::DtypeColumn(dtypes) = e {
expand_dtypes(&expr, &mut result, schema, dtypes)
}
push_current = false;
continue;
}

if has_wildcard(&expr) {
// keep track of column excluded from the wildcard
let exclude = prepare_exluded(&expr, schema);
let exclude = prepare_excluded(&expr, schema);

// if count wildcard. count one column
if has_expr(&expr, |e| matches!(e, Expr::Agg(AggExpr::Count(_)))) {
Expand Down Expand Up @@ -233,16 +256,14 @@ pub(crate) fn rewrite_projections(exprs: Vec<Expr>, schema: &Schema) -> Vec<Expr
replace_wilcard(&expr, &mut result, &exclude, schema);
} else {
#[allow(clippy::collapsible_else_if)]
if push_current {
#[cfg(feature = "regex")]
{
replace_regex(&expr, &mut result, schema)
}
#[cfg(not(feature = "regex"))]
{
let expr = rewrite_keep_name_and_sufprefix(expr);
result.push(expr)
}
#[cfg(feature = "regex")]
{
replace_regex(&expr, &mut result, schema)
}
#[cfg(not(feature = "regex"))]
{
let expr = rewrite_keep_name_and_sufprefix(expr);
result.push(expr)
}
};
}
Expand Down
17 changes: 17 additions & 0 deletions polars/polars-lazy/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1982,3 +1982,20 @@ fn test_apply_multiple_columns() -> Result<()> {
assert_eq!(Vec::from(out), &[Some(16.0)]);
Ok(())
}

#[test]
pub fn test_select_by_dtypes() -> Result<()> {
let df = df![
"bools" => [true, false, true],
"ints" => [1, 2, 3],
"strings" => ["a", "b", "c"],
"floats" => [1.0, 2.0, 3.0f32]
]?;
let out = df
.lazy()
.select([dtype_cols([DataType::Float32, DataType::Utf8])])
.collect()?;
assert_eq!(out.dtypes(), &[DataType::Float32, DataType::Utf8]);

Ok(())
}
17 changes: 15 additions & 2 deletions py-polars/polars/lazy/functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import typing as tp
from datetime import datetime, timezone
from inspect import isclass
from typing import Any, Callable, Optional, Type, Union

import numpy as np
Expand All @@ -16,6 +17,7 @@
from polars.polars import concat_lst as _concat_lst
from polars.polars import concat_str as _concat_str
from polars.polars import cov as pycov
from polars.polars import dtype_cols as _dtype_cols
from polars.polars import fold as pyfold
from polars.polars import lit as pylit
from polars.polars import map_mul as _map_mul
Expand Down Expand Up @@ -66,7 +68,9 @@
]


def col(name: Union[str, tp.List[str]]) -> "pl.Expr":
def col(
name: Union[str, tp.List[str], tp.List[Type[DataType]], Type[DataType]]
) -> "pl.Expr":
"""
A column in a DataFrame.
Can be used to select:
Expand Down Expand Up @@ -153,8 +157,17 @@ def col(name: Union[str, tp.List[str]]) -> "pl.Expr":
╰───────────┴─────╯
"""

if isclass(name) and issubclass(name, DataType): # type: ignore
name = [name] # type: ignore

if isinstance(name, list):
return pl.lazy.expr.wrap_expr(pycols(name))
if len(name) == 0 or isinstance(name[0], str):
return pl.lazy.expr.wrap_expr(pycols(name))
elif isclass(name[0]) and issubclass(name[0], DataType):
return pl.lazy.expr.wrap_expr(_dtype_cols(name))
else:
raise ValueError("did expect argument of List[str] or List[DataType]")
return pl.lazy.expr.wrap_expr(pycol(name))


Expand Down
4 changes: 4 additions & 0 deletions py-polars/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1002,6 +1002,10 @@ pub fn cols(names: Vec<String>) -> PyExpr {
dsl::cols(names).into()
}

pub fn dtype_cols(dtypes: Vec<DataType>) -> PyExpr {
dsl::dtype_cols(dtypes).into()
}

pub fn binary_expr(l: PyExpr, op: u8, r: PyExpr) -> PyExpr {
let left = l.inner;
let right = r.inner;
Expand Down
19 changes: 18 additions & 1 deletion py-polars/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ pub mod utils;
use crate::conversion::{get_df, get_pyseq, get_series, Wrap};
use crate::error::PyPolarsEr;
use crate::file::get_either_file;
use crate::prelude::DataType;
use crate::prelude::{DataType, PyDataType};
use mimalloc::MiMalloc;
use polars_core::export::arrow::io::ipc::read::read_file_metadata;
use pyo3::types::PyDict;
Expand All @@ -51,6 +51,22 @@ fn cols(names: Vec<String>) -> dsl::PyExpr {
dsl::cols(names)
}

#[pyfunction]
fn dtype_cols(dtypes: &PyAny) -> PyResult<dsl::PyExpr> {
let (seq, len) = get_pyseq(dtypes)?;
let iter = seq.iter()?;

let mut dtypes = Vec::with_capacity(len);

for res in iter {
let item = res?;
let pydt = item.extract::<PyDataType>()?;
let dt: DataType = pydt.into();
dtypes.push(dt)
}
Ok(dsl::dtype_cols(dtypes))
}

#[pyfunction]
fn lit(value: &PyAny) -> dsl::PyExpr {
dsl::lit(value)
Expand Down Expand Up @@ -221,6 +237,7 @@ fn polars(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<dsl::PyExpr>().unwrap();
m.add_wrapped(wrap_pyfunction!(col)).unwrap();
m.add_wrapped(wrap_pyfunction!(cols)).unwrap();
m.add_wrapped(wrap_pyfunction!(dtype_cols)).unwrap();
m.add_wrapped(wrap_pyfunction!(lit)).unwrap();
m.add_wrapped(wrap_pyfunction!(fold)).unwrap();
m.add_wrapped(wrap_pyfunction!(binary_expr)).unwrap();
Expand Down
7 changes: 7 additions & 0 deletions py-polars/tests/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,3 +1152,10 @@ def test_groupby_agg_n_unique_floats():
[pl.col("b").cast(dtype).n_unique()]
)
out["b_n_unique"].to_list() == [2, 1]


def test_select_by_dtype(df):
out = df.select(pl.col(pl.Utf8))
assert out.columns == ["strings", "strings_nulls"]
out = df.select(pl.col([pl.Utf8, pl.Boolean]))
assert out.columns == ["strings", "strings_nulls", "bools", "bools_nulls"]

0 comments on commit e1a60a9

Please sign in to comment.