Skip to content

Commit

Permalink
Python: nested datatypes; Allow casting list inner types. (#2882)
Browse files Browse the repository at this point in the history
* Python: nest datatypes

* python: add null type
  • Loading branch information
ritchie46 committed Mar 12, 2022
1 parent bc6632c commit 85999b9
Show file tree
Hide file tree
Showing 17 changed files with 182 additions and 129 deletions.
3 changes: 3 additions & 0 deletions polars/polars-core/src/series/from.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ impl Series {
Struct(_) => Series::try_from_arrow_unchecked(name, chunks, &dtype.to_arrow()).unwrap(),
#[cfg(feature = "object")]
Object(_) => todo!(),
Null => panic!("null type not supported"),
Unknown => panic!("uh oh, somehow we don't know the dtype?"),
#[allow(unreachable_patterns)]
_ => unreachable!(),
}
}
Expand Down
1 change: 1 addition & 0 deletions py-polars/polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def version() -> str:
Int32,
Int64,
List,
Null,
Object,
Struct,
Time,
Expand Down
58 changes: 55 additions & 3 deletions py-polars/polars/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,39 @@ class Utf8(DataType):
pass


class List(DataType):
class Null(DataType):
pass


class List(DataType):
def __init__(self, inner: Type[DataType]):
self.inner = py_type_to_dtype(inner)

def __eq__(self, other: Type[DataType]) -> bool: # type: ignore
# The comparison allows comparing objects to classes
# and specific inner types to none specific.
# if one of the arguments is not specific about its inner type
# we infer it as being equal.
# List[i64] == List[i64] == True
# List[i64] == List == True
# List[i64] == List[None] == True
# List[i64] == List[f32] == False

# allow comparing object instances to class
if type(other) is type and issubclass(other, List):
return True
if isinstance(other, List):
if self.inner is None or other.inner is None:
return True
else:
return self.inner == other.inner
else:
return False

def __hash__(self) -> int:
return hash(List)


class Date(DataType):
pass

Expand All @@ -94,7 +123,26 @@ class Categorical(DataType):

class Struct(DataType):
def __init__(self, inner_types: Sequence[Type[DataType]]):
self.inner_types = inner_types
self.inner_types = [py_type_to_dtype(dt) for dt in inner_types]

def __eq__(self, other: Type[DataType]) -> bool: # type: ignore
# The comparison allows comparing objects to classes
# and specific inner types to none specific.
# if one of the arguments is not specific about its inner type
# we infer it as being equal.
# See the list type for more info
if type(other) is type and issubclass(other, Struct):
return True
if isinstance(other, Struct):
if self.inner_types is None or other.inner_types is None:
return True
else:
return self.inner_types == other.inner_types
else:
return False

def __hash__(self) -> int:
return hash(Struct)


_DTYPE_TO_FFINAME: Dict[Type[DataType], str] = {
Expand Down Expand Up @@ -193,7 +241,11 @@ def dtype_to_py_type(dtype: Type[DataType]) -> Type:

def py_type_to_dtype(data_type: Type[Any]) -> Type[DataType]:
# when the passed in is already a Polars datatype, return that
if issubclass(data_type, DataType):
if (
type(data_type) is type
and issubclass(data_type, DataType)
or isinstance(data_type, DataType)
):
return data_type

try:
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/internals/construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def sequence_to_pyseries(
pyseries = constructor(name, values, strict)

if dtype in (Date, Datetime, Duration, Time, Categorical):
pyseries = pyseries.cast(str(dtype), True)
pyseries = pyseries.cast(dtype, True)

return pyseries

Expand Down
4 changes: 2 additions & 2 deletions py-polars/polars/internals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,7 +909,7 @@ def mode(self) -> "Expr":
"""
return wrap_expr(self._pyexpr.mode())

def cast(self, dtype: Type[Any], strict: bool = True) -> "Expr":
def cast(self, dtype: Union[Type[Any], DataType], strict: bool = True) -> "Expr":
"""
Cast between data types.
Expand Down Expand Up @@ -944,7 +944,7 @@ def cast(self, dtype: Type[Any], strict: bool = True) -> "Expr":
└─────┴─────┘
"""
dtype = py_type_to_dtype(dtype)
dtype = py_type_to_dtype(dtype) # type: ignore
return wrap_expr(self._pyexpr.cast(dtype, strict))

def sort(self, reverse: bool = False, nulls_last: bool = False) -> "Expr":
Expand Down
8 changes: 5 additions & 3 deletions py-polars/polars/internals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1786,7 +1786,9 @@ def __len__(self) -> int:

def cast(
self,
dtype: Union[Type[DataType], Type[int], Type[float], Type[str], Type[bool]],
dtype: Union[
Type[DataType], Type[int], Type[float], Type[str], Type[bool], DataType
],
strict: bool = True,
) -> "Series":
"""
Expand Down Expand Up @@ -1821,8 +1823,8 @@ def cast(
]
"""
pl_dtype = py_type_to_dtype(dtype)
return wrap_s(self._s.cast(str(pl_dtype), strict))
pl_dtype = py_type_to_dtype(dtype) # type: ignore
return wrap_s(self._s.cast(pl_dtype, strict))

def to_physical(self) -> "Series":
"""
Expand Down
88 changes: 63 additions & 25 deletions py-polars/src/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,14 +259,26 @@ impl ToPyObject for Wrap<DataType> {
DataType::Float64 => pl.getattr("Float64").unwrap().into(),
DataType::Boolean => pl.getattr("Boolean").unwrap().into(),
DataType::Utf8 => pl.getattr("Utf8").unwrap().into(),
DataType::List(_) => pl.getattr("List").unwrap().into(),
DataType::List(inner) => {
let inner = Wrap(*inner.clone()).to_object(py);
let list_class = pl.getattr("List").unwrap();
list_class.call1((inner,)).unwrap().into()
}
DataType::Date => pl.getattr("Date").unwrap().into(),
DataType::Datetime(_, _) => pl.getattr("Datetime").unwrap().into(),
DataType::Duration(_) => pl.getattr("Duration").unwrap().into(),
DataType::Object(_) => pl.getattr("Object").unwrap().into(),
DataType::Categorical(_) => pl.getattr("Categorical").unwrap().into(),
DataType::Time => pl.getattr("Time").unwrap().into(),
DataType::Struct(_) => pl.getattr("Struct").unwrap().into(),
DataType::Struct(inners) => {
let iter = inners
.iter()
.map(|fld| Wrap(fld.data_type().clone()).to_object(py));
let inners = PyList::new(py, iter);
let struct_class = pl.getattr("Struct").unwrap();
struct_class.call1((inners,)).unwrap().into()
}
DataType::Null => pl.getattr("Null").unwrap().into(),
dt => panic!("{} not supported", dt),
}
}
Expand Down Expand Up @@ -299,32 +311,58 @@ impl FromPyObject<'_> for Wrap<QuantileInterpolOptions> {
}
}

static PREFIX_LEN: usize = "<class 'polars.datatypes.".len();

impl FromPyObject<'_> for Wrap<DataType> {
fn extract(ob: &PyAny) -> PyResult<Self> {
let dtype = match ob.repr().unwrap().to_str().unwrap() {
"<class 'polars.datatypes.UInt8'>" => DataType::UInt8,
"<class 'polars.datatypes.UInt16'>" => DataType::UInt16,
"<class 'polars.datatypes.UInt32'>" => DataType::UInt32,
"<class 'polars.datatypes.UInt64'>" => DataType::UInt64,
"<class 'polars.datatypes.Int8'>" => DataType::Int8,
"<class 'polars.datatypes.Int16'>" => DataType::Int16,
"<class 'polars.datatypes.Int32'>" => DataType::Int32,
"<class 'polars.datatypes.Int64'>" => DataType::Int64,
"<class 'polars.datatypes.Utf8'>" => DataType::Utf8,
"<class 'polars.datatypes.List'>" => DataType::List(Box::new(DataType::Boolean)),
"<class 'polars.datatypes.Boolean'>" => DataType::Boolean,
"<class 'polars.datatypes.Categorical'>" => DataType::Categorical(None),
"<class 'polars.datatypes.Date'>" => DataType::Date,
"<class 'polars.datatypes.Datetime'>" => {
DataType::Datetime(TimeUnit::Milliseconds, None)
let str_rep = ob.repr().unwrap().to_str().unwrap();

// slice off unneeded parts
let dtype = match &str_rep[PREFIX_LEN..str_rep.len() - 2] {
"UInt8" => DataType::UInt8,
"UInt16" => DataType::UInt16,
"UInt32" => DataType::UInt32,
"UInt64" => DataType::UInt64,
"Int8" => DataType::Int8,
"Int16" => DataType::Int16,
"Int32" => DataType::Int32,
"Int64" => DataType::Int64,
"Utf8" => DataType::Utf8,
"Boolean" => DataType::Boolean,
"Categorical" => DataType::Categorical(None),
"Date" => DataType::Date,
"Datetime" => DataType::Datetime(TimeUnit::Microseconds, None),
"Time" => DataType::Time,
"Duration" => DataType::Duration(TimeUnit::Microseconds),
"Float32" => DataType::Float32,
"Float64" => DataType::Float64,
"Object" => DataType::Object("unknown"),
// just the class, not an object
"List" => DataType::List(Box::new(DataType::Boolean)),
"Null" => DataType::Null,
dt => {
let out: PyResult<_> = Python::with_gil(|py| {
let builtins = PyModule::import(py, "builtins")?;
let polars = PyModule::import(py, "polars")?;
let list_class = polars.getattr("List").unwrap();
if builtins
.getattr("isinstance")
.unwrap()
.call1((ob, list_class))?
.extract::<bool>()?
{
let inner = ob.getattr("inner")?;
let inner = inner.extract::<Wrap<DataType>>()?;
Ok(DataType::List(Box::new(inner.0)))
} else {
panic!(
"{} not expected in python dtype to rust dtype conversion",
dt
)
}
});
out?
}
"<class 'polars.datatypes.Float32'>" => DataType::Float32,
"<class 'polars.datatypes.Float64'>" => DataType::Float64,
"<class 'polars.datatypes.Object'>" => DataType::Object("unknown"),
dt => panic!(
"{} not expected in python dtype to rust dtype conversion",
dt
),
};
Ok(Wrap(dtype))
}
Expand Down
23 changes: 6 additions & 17 deletions py-polars/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ use crate::conversion::{ObjectValue, Wrap};
use crate::file::get_mmap_bytes_reader;
use crate::lazy::dataframe::PyLazyFrame;
use crate::prelude::{dicts_to_rows, str_to_null_strategy};
use crate::utils::str_to_polarstype;
use crate::{
arrow_interop,
error::PyPolarsErr,
Expand Down Expand Up @@ -96,8 +95,8 @@ impl PyDataFrame {
encoding: &str,
n_threads: Option<usize>,
path: Option<String>,
overwrite_dtype: Option<Vec<(&str, &PyAny)>>,
overwrite_dtype_slice: Option<Vec<&PyAny>>,
overwrite_dtype: Option<Vec<(&str, Wrap<DataType>)>>,
overwrite_dtype_slice: Option<Vec<Wrap<DataType>>>,
low_memory: bool,
comment_char: Option<&str>,
quote_char: Option<&str>,
Expand Down Expand Up @@ -132,8 +131,7 @@ impl PyDataFrame {

let overwrite_dtype = overwrite_dtype.map(|overwrite_dtype| {
let fields = overwrite_dtype.iter().map(|(name, dtype)| {
let str_repr = dtype.str().unwrap().to_str().unwrap();
let dtype = str_to_polarstype(str_repr);
let dtype = dtype.0.clone();
Field::new(name, dtype)
});
Schema::from(fields)
Expand All @@ -142,10 +140,7 @@ impl PyDataFrame {
let overwrite_dtype_slice = overwrite_dtype_slice.map(|overwrite_dtype| {
overwrite_dtype
.iter()
.map(|dt| {
let str_repr = dt.str().unwrap().to_str().unwrap();
str_to_polarstype(str_repr)
})
.map(|dt| dt.0.clone())
.collect::<Vec<_>>()
});

Expand Down Expand Up @@ -1149,20 +1144,14 @@ impl PyDataFrame {
pub fn apply(
&self,
lambda: &PyAny,
output_type: &PyAny,
output_type: Option<Wrap<DataType>>,
inference_size: usize,
) -> PyResult<(PyObject, bool)> {
let gil = Python::acquire_gil();
let py = gil.python();
let df = &self.df;

let output_type = match output_type.is_none() {
true => None,
false => {
let str_repr = output_type.str().unwrap().to_str().unwrap();
Some(str_to_polarstype(str_repr))
}
};
let output_type = output_type.map(|dt| dt.0);
let out = match output_type {
Some(DataType::Int32) => {
apply_lambda_with_primitive_out_type::<Int32Type>(df, py, lambda, 0, None)
Expand Down
6 changes: 3 additions & 3 deletions py-polars/src/datatypes.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::utils::str_to_polarstype;
use crate::Wrap;
use polars::prelude::*;
use pyo3::{FromPyObject, PyAny, PyResult};

Expand Down Expand Up @@ -94,8 +94,8 @@ impl Into<DataType> for PyDataType {

impl FromPyObject<'_> for PyDataType {
fn extract(ob: &PyAny) -> PyResult<Self> {
let str_repr = ob.str().unwrap().to_str().unwrap();
Ok(str_to_polarstype(str_repr).into())
let dt = ob.extract::<Wrap<DataType>>()?;
Ok(dt.0.into())
}
}

Expand Down
13 changes: 4 additions & 9 deletions py-polars/src/lazy/apply.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use crate::lazy::dsl::PyExpr;
use crate::prelude::PyDataType;
use crate::series::PySeries;
use crate::utils::str_to_polarstype;
use polars::prelude::*;
use pyo3::prelude::*;
use pyo3::types::PyList;
Expand Down Expand Up @@ -114,18 +113,14 @@ pub fn binary_function(
input_a: PyExpr,
input_b: PyExpr,
lambda: PyObject,
output_type: &PyAny,
output_type: Option<DataType>,
) -> PyExpr {
let input_a = input_a.inner;
let input_b = input_b.inner;

let output_field = match output_type.is_none() {
true => Field::new("binary_function", DataType::Null),
false => {
let str_repr = output_type.str().unwrap().to_str().unwrap();
let data_type = str_to_polarstype(str_repr);
Field::new("binary_function", data_type)
}
let output_field = match output_type {
None => Field::new("binary_function", DataType::Null),
Some(dt) => Field::new("binary_function", dt),
};

let func = move |a: Series, b: Series| binary_lambda(&lambda, a, b);
Expand Down

0 comments on commit 85999b9

Please sign in to comment.