Skip to content

Commit

Permalink
Improving Python DataType support for Struct and repr (#3471)
Browse files Browse the repository at this point in the history
  • Loading branch information
cjermain committed May 23, 2022
1 parent 12405a3 commit c6d11d9
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 60 deletions.
2 changes: 1 addition & 1 deletion polars/polars-core/src/datatypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ impl Display for DataType {
}
DataType::Duration(tu) => return write!(f, "duration[{}]", tu),
DataType::Time => "time",
DataType::List(tp) => return write!(f, "list [{}]", tp),
DataType::List(tp) => return write!(f, "list[{}]", tp),
#[cfg(feature = "object")]
DataType::Object(s) => s,
#[cfg(feature = "dtype-categorical")]
Expand Down
2 changes: 2 additions & 0 deletions py-polars/polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def version() -> str:
Date,
Datetime,
Duration,
Field,
Float32,
Float64,
Int8,
Expand Down Expand Up @@ -153,6 +154,7 @@ def version() -> str:
"Time",
"Object",
"Categorical",
"Field",
"Struct",
# polars.io
"read_csv",
Expand Down
29 changes: 24 additions & 5 deletions py-polars/polars/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@

class DataType:
@classmethod
def string_repr(self) -> str:
def string_repr(cls) -> str:
return dtype_str_repr(cls)

def __repr__(self) -> str:
return dtype_str_repr(self)


Expand Down Expand Up @@ -130,9 +133,25 @@ class Categorical(DataType):
pass


class Field:
def __init__(self, name: str, dtype: Type[DataType]):
self.name = name
self.dtype = py_type_to_dtype(dtype)

def __eq__(self, other: "Field") -> bool: # type: ignore
return (self.name == other.name) & (self.dtype == other.dtype)

def __repr__(self) -> str:
if isinstance(self.dtype, type):
dtype_str = self.dtype.string_repr()
else:
dtype_str = repr(self.dtype)
return f'Field("{self.name}": {dtype_str})'


class Struct(DataType):
def __init__(self, inner_types: Sequence[Type[DataType]]):
self.inner_types = [py_type_to_dtype(dt) for dt in inner_types]
def __init__(self, fields: Sequence[Field]):
self.fields = fields

def __eq__(self, other: Type[DataType]) -> bool: # type: ignore
# The comparison allows comparing objects to classes
Expand All @@ -143,10 +162,10 @@ def __eq__(self, other: Type[DataType]) -> bool: # type: ignore
if type(other) is type and issubclass(other, Struct):
return True
if isinstance(other, Struct):
if self.inner_types is None or other.inner_types is None:
if self.fields is None or other.fields is None:
return True
else:
return self.inner_types == other.inner_types
return self.fields == other.fields
else:
return False

Expand Down
117 changes: 64 additions & 53 deletions py-polars/src/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,13 +267,16 @@ impl ToPyObject for Wrap<DataType> {
DataType::Object(_) => pl.getattr("Object").unwrap().into(),
DataType::Categorical(_) => pl.getattr("Categorical").unwrap().into(),
DataType::Time => pl.getattr("Time").unwrap().into(),
DataType::Struct(inners) => {
let iter = inners
.iter()
.map(|fld| Wrap(fld.data_type().clone()).to_object(py));
let inners = PyList::new(py, iter);
DataType::Struct(fields) => {
let field_class = pl.getattr("Field").unwrap();
let iter = fields.iter().map(|fld| {
let name = fld.name().clone();
let dtype = Wrap(fld.data_type().clone()).to_object(py);
field_class.call1((name, dtype)).unwrap()
});
let fields = PyList::new(py, iter);
let struct_class = pl.getattr("Struct").unwrap();
struct_class.call1((inners,)).unwrap().into()
struct_class.call1((fields,)).unwrap().into()
}
DataType::Null => pl.getattr("Null").unwrap().into(),
dt => panic!("{} not supported", dt),
Expand Down Expand Up @@ -308,57 +311,65 @@ impl FromPyObject<'_> for Wrap<QuantileInterpolOptions> {
}
}

static PREFIX_LEN: usize = "<class 'polars.datatypes.".len();
impl FromPyObject<'_> for Wrap<Field> {
fn extract(ob: &PyAny) -> PyResult<Self> {
let name = ob.getattr("name")?.str()?.to_str()?;
let dtype = ob.getattr("dtype")?.extract::<Wrap<DataType>>()?;
Ok(Wrap(Field::new(name, dtype.0)))
}
}

impl FromPyObject<'_> for Wrap<DataType> {
fn extract(ob: &PyAny) -> PyResult<Self> {
let str_rep = ob.repr().unwrap().to_str().unwrap();

// slice off unneeded parts
let dtype = match &str_rep[PREFIX_LEN..str_rep.len() - 2] {
"UInt8" => DataType::UInt8,
"UInt16" => DataType::UInt16,
"UInt32" => DataType::UInt32,
"UInt64" => DataType::UInt64,
"Int8" => DataType::Int8,
"Int16" => DataType::Int16,
"Int32" => DataType::Int32,
"Int64" => DataType::Int64,
"Utf8" => DataType::Utf8,
"Boolean" => DataType::Boolean,
"Categorical" => DataType::Categorical(None),
"Date" => DataType::Date,
"Datetime" => DataType::Datetime(TimeUnit::Microseconds, None),
"Time" => DataType::Time,
"Duration" => DataType::Duration(TimeUnit::Microseconds),
"Float32" => DataType::Float32,
"Float64" => DataType::Float64,
"Object" => DataType::Object("unknown"),
// just the class, not an object
"List" => DataType::List(Box::new(DataType::Boolean)),
"Null" => DataType::Null,
let type_name = ob.get_type().name()?;

let dtype = match type_name {
"type" => {
// just the class, not an object
let name = ob.getattr("__name__")?.str()?.to_str()?;
match name {
"UInt8" => DataType::UInt8,
"UInt16" => DataType::UInt16,
"UInt32" => DataType::UInt32,
"UInt64" => DataType::UInt64,
"Int8" => DataType::Int8,
"Int16" => DataType::Int16,
"Int32" => DataType::Int32,
"Int64" => DataType::Int64,
"Utf8" => DataType::Utf8,
"Boolean" => DataType::Boolean,
"Categorical" => DataType::Categorical(None),
"Date" => DataType::Date,
"Datetime" => DataType::Datetime(TimeUnit::Microseconds, None),
"Time" => DataType::Time,
"Duration" => DataType::Duration(TimeUnit::Microseconds),
"Float32" => DataType::Float32,
"Float64" => DataType::Float64,
"Object" => DataType::Object("unknown"),
"List" => DataType::List(Box::new(DataType::Boolean)),
"Null" => DataType::Null,
dt => panic!("{} not expected as Python type for dtype conversion", dt),
}
}
"List" => {
let inner = ob.getattr("inner")?;
let inner = inner.extract::<Wrap<DataType>>()?;
DataType::List(Box::new(inner.0))
}
"Struct" => {
let fields = ob.getattr("fields")?;
let fields = fields
.extract::<Vec<Wrap<Field>>>()?
.into_iter()
.map(|f| f.0)
.collect::<Vec<Field>>();
DataType::Struct(fields)
}
dt => {
let out: PyResult<_> = Python::with_gil(|py| {
let builtins = PyModule::import(py, "builtins")?;
let polars = PyModule::import(py, "polars")?;
let list_class = polars.getattr("List").unwrap();
if builtins
.getattr("isinstance")
.unwrap()
.call1((ob, list_class))?
.extract::<bool>()?
{
let inner = ob.getattr("inner")?;
let inner = inner.extract::<Wrap<DataType>>()?;
Ok(DataType::List(Box::new(inner.0)))
} else {
panic!(
"{} not expected in python dtype to rust dtype conversion",
dt
)
}
});
out?
panic!(
"{} not expected in Python dtype to Rust dtype conversion",
dt
)
}
};
Ok(Wrap(dtype))
Expand Down
7 changes: 7 additions & 0 deletions py-polars/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ fn dtype_cols(dtypes: &PyAny) -> PyResult<dsl::PyExpr> {
Ok(dsl::dtype_cols(dtypes))
}

#[pyfunction]
fn dtype_str_repr(dtype: Wrap<DataType>) -> PyResult<String> {
let dtype = dtype.0;
Ok(dtype.to_string())
}

#[pyfunction]
fn lit(value: &PyAny) -> dsl::PyExpr {
dsl::lit(value)
Expand Down Expand Up @@ -454,6 +460,7 @@ fn polars(py: Python, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(last)).unwrap();
m.add_wrapped(wrap_pyfunction!(cols)).unwrap();
m.add_wrapped(wrap_pyfunction!(dtype_cols)).unwrap();
m.add_wrapped(wrap_pyfunction!(dtype_str_repr)).unwrap();
m.add_wrapped(wrap_pyfunction!(lit)).unwrap();
m.add_wrapped(wrap_pyfunction!(fold)).unwrap();
m.add_wrapped(wrap_pyfunction!(binary_expr)).unwrap();
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def test_nested_struct() -> None:
nest_l2 = nest_l1.to_struct("a").to_frame()

assert isinstance(nest_l2.dtypes[0], pl.datatypes.Struct)
assert nest_l2.dtypes[0].inner_types == nest_l1.dtypes
assert [f.dtype for f in nest_l2.dtypes[0].fields] == nest_l1.dtypes
assert isinstance(nest_l1.dtypes[0], pl.datatypes.Struct)


Expand Down

0 comments on commit c6d11d9

Please sign in to comment.