Skip to content

Commit

Permalink
improve dtype selection (#3664)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jun 11, 2022
1 parent 0c30f86 commit c48823a
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 24 deletions.
7 changes: 1 addition & 6 deletions polars/polars-lazy/src/logical_plan/projection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,7 @@ fn expand_columns(expr: &Expr, result: &mut Vec<Expr>, names: &[String]) {
/// replace `DtypeColumn` with `col("foo")..col("bar")`
fn expand_dtypes(expr: &Expr, result: &mut Vec<Expr>, schema: &Schema, dtypes: &[DataType]) {
for dtype in dtypes {
// we compare by variant not by exact datatype as units/ refmaps etc may differ.
let variant = std::mem::discriminant(dtype);
for field in schema
.iter_fields()
.filter(|f| std::mem::discriminant(f.data_type()) == variant)
{
for field in schema.iter_fields().filter(|f| f.data_type() == dtype) {
let name = field.name();

let mut new_expr = expr.clone();
Expand Down
3 changes: 1 addition & 2 deletions polars/tests/it/lazy/expressions/expand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@ fn test_expand_datetimes_3042() -> Result<()> {
"dt2" => date_range,
]?
.lazy()
// this tests if we expand datetimes even though the units differ
.with_column(
dtype_col(&DataType::Datetime(TimeUnit::Microseconds, None))
dtype_col(&DataType::Datetime(TimeUnit::Milliseconds, None))
.dt()
.strftime("%m/%d/%Y"),
)
Expand Down
1 change: 1 addition & 0 deletions py-polars/polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def version() -> str:
List,
Null,
Object,
PolarsDataType,
Struct,
Time,
UInt8,
Expand Down
26 changes: 23 additions & 3 deletions py-polars/polars/internals/lazy_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,14 @@
import numpy as np

from polars import internals as pli
from polars.datatypes import DataType, Date, Datetime, Duration, py_type_to_dtype
from polars.datatypes import (
DataType,
Date,
Datetime,
Duration,
PolarsDataType,
py_type_to_dtype,
)
from polars.utils import (
_datetime_to_pl_timestamp,
_timedelta_to_pl_timedelta,
Expand Down Expand Up @@ -50,7 +57,13 @@


def col(
name: Union[str, List[str], List[Type[DataType]], "pli.Series", Type[DataType]]
name: Union[
str,
List[str],
Sequence[PolarsDataType],
"pli.Series",
PolarsDataType,
]
) -> "pli.Expr":
"""
A column in a DataFrame.
Expand Down Expand Up @@ -151,10 +164,17 @@ def col(
if isclass(name) and issubclass(cast(type, name), DataType):
name = [cast(type, name)]

if isinstance(name, DataType):
return pli.wrap_expr(_dtype_cols([name]))

if isinstance(name, list):
if len(name) == 0 or isinstance(name[0], str):
return pli.wrap_expr(pycols(name))
elif isclass(name[0]) and issubclass(name[0], DataType):
elif (
isclass(name[0])
and issubclass(name[0], DataType)
or isinstance(name[0], DataType)
):
return pli.wrap_expr(_dtype_cols(name))
else:
raise ValueError("did expect argument of List[str] or List[DataType]")
Expand Down
6 changes: 6 additions & 0 deletions py-polars/src/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ pub(crate) fn slice_extract_wrapped<T>(slice: &[Wrap<T>]) -> &[T] {
unsafe { std::mem::transmute(slice) }
}

pub(crate) fn vec_extract_wrapped<T>(buf: Vec<Wrap<T>>) -> Vec<T> {
// Safety:
// Wrap is transparent.
unsafe { std::mem::transmute(buf) }
}

#[repr(transparent)]
pub struct Wrap<T>(pub T);

Expand Down
18 changes: 5 additions & 13 deletions py-polars/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ use crate::error::{
SchemaError,
};
use crate::file::get_either_file;
use crate::prelude::{ClosedWindow, DataType, DatetimeArgs, Duration, DurationArgs, PyDataType};
use crate::prelude::{
vec_extract_wrapped, ClosedWindow, DataType, DatetimeArgs, Duration, DurationArgs,
};
use dsl::ToExprs;
#[cfg(target_os = "linux")]
use jemallocator::Jemalloc;
Expand Down Expand Up @@ -88,18 +90,8 @@ fn cols(names: Vec<String>) -> dsl::PyExpr {
}

#[pyfunction]
fn dtype_cols(dtypes: &PyAny) -> PyResult<dsl::PyExpr> {
let (seq, len) = get_pyseq(dtypes)?;
let iter = seq.iter()?;

let mut dtypes = Vec::with_capacity(len);

for res in iter {
let item = res?;
let pydt = item.extract::<PyDataType>()?;
let dt: DataType = pydt.into();
dtypes.push(dt)
}
fn dtype_cols(dtypes: Vec<Wrap<DataType>>) -> PyResult<dsl::PyExpr> {
let dtypes = vec_extract_wrapped(dtypes);
Ok(dsl::dtype_cols(dtypes))
}

Expand Down
18 changes: 18 additions & 0 deletions py-polars/tests/test_datelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,3 +963,21 @@ def test_datetime_units() -> None:
len(set(df.select([pl.all().exclude(pl.Datetime(unit))]).columns) - subset)
== 0
)


def test_datetime_instance_selection() -> None:
df = pl.DataFrame(
data={
"ns": [datetime(2022, 12, 31, 1, 2, 3)],
"us": [datetime(2022, 12, 31, 4, 5, 6)],
"ms": [datetime(2022, 12, 31, 7, 8, 9)],
},
columns=[
("ns", pl.Datetime("ns")),
("us", pl.Datetime("us")),
("ms", pl.Datetime("ms")),
],
)

for tu in ["ns", "us", "ms"]:
assert df.select(pl.col([pl.Datetime(tu)])).dtypes == [pl.Datetime(tu)]

0 comments on commit c48823a

Please sign in to comment.