Skip to content

Commit

Permalink
fix(rust, python): fix serde for small integer dtypes (#9495)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jun 22, 2023
1 parent 8cd1e07 commit 9e501f9
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 90 deletions.
14 changes: 13 additions & 1 deletion polars/polars-core/src/datatypes/_serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ pub enum SerializableDataType {
Struct(Vec<Field>),
// some logical types we cannot know statically, e.g. Datetime
Unknown,
#[cfg(feature = "dtype-categorical")]
Categorical,
#[cfg(feature = "object")]
Object(String),
}

impl From<&DataType> for SerializableDataType {
Expand Down Expand Up @@ -87,7 +91,11 @@ impl From<&DataType> for SerializableDataType {
Unknown => Self::Unknown,
#[cfg(feature = "dtype-struct")]
Struct(flds) => Self::Struct(flds.clone()),
_ => todo!(),
#[cfg(feature = "dtype-categorical")]
Categorical(_) => Self::Categorical,
#[cfg(feature = "object")]
Object(name) => Self::Object(name.to_string()),
dt => panic!("{dt:?} not supported"),
}
}
}
Expand Down Expand Up @@ -117,6 +125,10 @@ impl From<SerializableDataType> for DataType {
Unknown => Self::Unknown,
#[cfg(feature = "dtype-struct")]
Struct(flds) => Self::Struct(flds),
#[cfg(feature = "dtype-categorical")]
Categorical => Self::Categorical(None),
#[cfg(feature = "object")]
Object(_) => Self::Object("unknown"),
}
}
}
13 changes: 4 additions & 9 deletions polars/polars-core/src/serde/chunked_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use std::cell::RefCell;
use serde::ser::SerializeMap;
use serde::{Serialize, Serializer};

use super::DeDataType;
use crate::prelude::*;

pub struct IterSer<I>
Expand Down Expand Up @@ -56,8 +55,7 @@ where
{
let mut state = serializer.serialize_map(Some(3))?;
state.serialize_entry("name", name)?;
let dtype: DeDataType = dtype.into();
state.serialize_entry("datatype", &dtype)?;
state.serialize_entry("datatype", dtype)?;
state.serialize_entry("values", &IterSer::new(ca.into_iter()))?;
state.end()
}
Expand Down Expand Up @@ -107,8 +105,7 @@ macro_rules! impl_serialize {
{
let mut state = serializer.serialize_map(Some(3))?;
state.serialize_entry("name", self.name())?;
let dtype: DeDataType = self.dtype().into();
state.serialize_entry("datatype", &dtype)?;
state.serialize_entry("datatype", self.dtype())?;
state.serialize_entry("values", &IterSer::new(self.into_iter()))?;
state.end()
}
Expand All @@ -133,8 +130,7 @@ impl Serialize for CategoricalChunked {
{
let mut state = serializer.serialize_map(Some(3))?;
state.serialize_entry("name", self.name())?;
let dtype: DeDataType = self.dtype().into();
state.serialize_entry("datatype", &dtype)?;
state.serialize_entry("datatype", self.dtype())?;
state.serialize_entry("values", &IterSer::new(self.iter_str()))?;
state.end()
}
Expand All @@ -153,8 +149,7 @@ impl Serialize for StructChunked {
{
let mut state = serializer.serialize_map(Some(3))?;
state.serialize_entry("name", self.name())?;
let dtype: DeDataType = self.dtype().into();
state.serialize_entry("datatype", &dtype)?;
state.serialize_entry("datatype", self.dtype())?;
state.serialize_entry("values", self.fields())?;
state.end()
}
Expand Down
62 changes: 1 addition & 61 deletions polars/polars-core/src/serde/mod.rs
Original file line number Diff line number Diff line change
@@ -1,71 +1,11 @@
use serde::{Deserialize, Serialize};

use crate::prelude::*;

pub mod chunked_array;
mod df;
pub mod series;

/// Intermediate enum. Needed because [crate::datatypes::DataType] has
/// a &static str and thus requires Deserialize<&static>
#[derive(Serialize, Deserialize, Debug)]
enum DeDataType<'a> {
Boolean,
UInt8,
UInt16,
UInt32,
UInt64,
Int8,
Int16,
Int32,
Int64,
Float32,
Float64,
Utf8,
Binary,
Date,
Datetime(TimeUnit, Option<TimeZone>),
Duration(TimeUnit),
Time,
List,
Object(&'a str),
Null,
Categorical,
Struct,
}

impl From<&DataType> for DeDataType<'_> {
fn from(dt: &DataType) -> Self {
match dt {
DataType::Int32 => DeDataType::Int32,
DataType::UInt32 => DeDataType::UInt32,
DataType::Int64 => DeDataType::Int64,
DataType::UInt64 => DeDataType::UInt64,
DataType::Date => DeDataType::Date,
DataType::Datetime(tu, tz) => DeDataType::Datetime(*tu, tz.clone()),
DataType::Duration(tu) => DeDataType::Duration(*tu),
DataType::Time => DeDataType::Time,
DataType::Float32 => DeDataType::Float32,
DataType::Float64 => DeDataType::Float64,
DataType::Utf8 => DeDataType::Utf8,
DataType::Boolean => DeDataType::Boolean,
DataType::Null => DeDataType::Null,
DataType::List(_) => DeDataType::List,
DataType::Binary => DeDataType::Binary,
#[cfg(feature = "object")]
DataType::Object(s) => DeDataType::Object(s),
#[cfg(feature = "dtype-struct")]
DataType::Struct(_) => DeDataType::Struct,
#[cfg(feature = "dtype-categorical")]
DataType::Categorical(_) => DeDataType::Categorical,
_ => unimplemented!(),
}
}
}

#[cfg(test)]
mod test {
use super::*;
use crate::prelude::*;

#[test]
fn test_serde() -> PolarsResult<()> {
Expand Down
47 changes: 28 additions & 19 deletions polars/polars-core/src/serde/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use serde::de::{MapAccess, Visitor};
use serde::{de, Deserialize, Deserializer, Serialize, Serializer};

use crate::prelude::*;
use crate::serde::DeDataType;

impl Serialize for Series {
fn serialize<S>(
Expand Down Expand Up @@ -130,89 +129,99 @@ impl<'de> Deserialize<'de> for Series {

match dtype {
#[cfg(feature = "dtype-i8")]
DeDataType::Int8 => {
DataType::Int8 => {
let values: Vec<Option<i8>> = map.next_value()?;
Ok(Series::new(&name, values))
}
#[cfg(feature = "dtype-u8")]
DeDataType::UInt8 => {
DataType::UInt8 => {
let values: Vec<Option<u8>> = map.next_value()?;
Ok(Series::new(&name, values))
}
DeDataType::Int32 => {
#[cfg(feature = "dtype-i16")]
DataType::Int16 => {
let values: Vec<Option<i16>> = map.next_value()?;
Ok(Series::new(&name, values))
}
#[cfg(feature = "dtype-u16")]
DataType::UInt16 => {
let values: Vec<Option<u16>> = map.next_value()?;
Ok(Series::new(&name, values))
}
DataType::Int32 => {
let values: Vec<Option<i32>> = map.next_value()?;
Ok(Series::new(&name, values))
}
DeDataType::UInt32 => {
DataType::UInt32 => {
let values: Vec<Option<u32>> = map.next_value()?;
Ok(Series::new(&name, values))
}
DeDataType::Int64 => {
DataType::Int64 => {
let values: Vec<Option<i64>> = map.next_value()?;
Ok(Series::new(&name, values))
}
DeDataType::UInt64 => {
DataType::UInt64 => {
let values: Vec<Option<u64>> = map.next_value()?;
Ok(Series::new(&name, values))
}
#[cfg(feature = "dtype-date")]
DeDataType::Date => {
DataType::Date => {
let values: Vec<Option<i32>> = map.next_value()?;
Ok(Series::new(&name, values).cast(&DataType::Date).unwrap())
}
#[cfg(feature = "dtype-datetime")]
DeDataType::Datetime(tu, tz) => {
DataType::Datetime(tu, tz) => {
let values: Vec<Option<i64>> = map.next_value()?;
Ok(Series::new(&name, values)
.cast(&DataType::Datetime(tu, tz))
.unwrap())
}
#[cfg(feature = "dtype-duration")]
DeDataType::Duration(tu) => {
DataType::Duration(tu) => {
let values: Vec<Option<i64>> = map.next_value()?;
Ok(Series::new(&name, values)
.cast(&DataType::Duration(tu))
.unwrap())
}
#[cfg(feature = "dtype-time")]
DeDataType::Time => {
DataType::Time => {
let values: Vec<Option<i64>> = map.next_value()?;
Ok(Series::new(&name, values).cast(&DataType::Time).unwrap())
}
DeDataType::Boolean => {
DataType::Boolean => {
let values: Vec<Option<bool>> = map.next_value()?;
Ok(Series::new(&name, values))
}
DeDataType::Float32 => {
DataType::Float32 => {
let values: Vec<Option<f32>> = map.next_value()?;
Ok(Series::new(&name, values))
}
DeDataType::Float64 => {
DataType::Float64 => {
let values: Vec<Option<f64>> = map.next_value()?;
Ok(Series::new(&name, values))
}
DeDataType::Utf8 => {
DataType::Utf8 => {
let values: Vec<Option<Cow<str>>> = map.next_value()?;
Ok(Series::new(&name, values))
}
DeDataType::List => {
DataType::List(_) => {
let values: Vec<Option<Series>> = map.next_value()?;
Ok(Series::new(&name, values))
}
DeDataType::Binary => {
DataType::Binary => {
let values: Vec<Option<Cow<[u8]>>> = map.next_value()?;
Ok(Series::new(&name, values))
}
#[cfg(feature = "dtype-struct")]
DeDataType::Struct => {
DataType::Struct(_) => {
let values: Vec<Series> = map.next_value()?;
let ca = StructChunked::new(&name, &values).unwrap();
let mut s = ca.into_series();
s.rename(&name);
Ok(s)
}
#[cfg(feature = "dtype-categorical")]
DeDataType::Categorical => {
DataType::Categorical(_) => {
let values: Vec<Option<Cow<str>>> = map.next_value()?;
Ok(Series::new(&name, values)
.cast(&DataType::Categorical(None))
Expand Down
13 changes: 13 additions & 0 deletions py-polars/tests/unit/test_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,16 @@ def test_pickle_udf_expression() -> None:
match=r"expected output type 'Utf8', got 'Int64'; set `return_dtype` to the proper datatype",
):
df.select(e)


def test_pickle_small_integers() -> None:
df = pl.DataFrame(
[
pl.Series([1, 2], dtype=pl.Int16),
pl.Series([3, 2], dtype=pl.Int8),
pl.Series([32, 2], dtype=pl.UInt8),
pl.Series([3, 3], dtype=pl.UInt16),
]
)
b = pickle.dumps(df)
assert_frame_equal(pickle.loads(b), df)

0 comments on commit 9e501f9

Please sign in to comment.