Skip to content

Commit

Permalink
Fix panics on series deserialization (#1920)
Browse files Browse the repository at this point in the history
  • Loading branch information
nmandery committed Dec 1, 2021
1 parent 151f34a commit 2f626ab
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 14 deletions.
3 changes: 3 additions & 0 deletions polars/polars-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ features = [
"compute_take",
]

[dev-dependencies]
bincode = "1"

[package.metadata.docs.rs]
# not all because arrow 4.3 does not compile with simd
# all-features = true
Expand Down
42 changes: 30 additions & 12 deletions polars/polars-core/src/serde/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,31 +106,49 @@ mod test {
assert!(ca.into_series().series_equal_missing(&out));
}

#[test]
fn test_serde_df() {
let s = Series::new("foo", &[1, 2, 3]);
let s1 = Series::new("bar", &[Some(true), None, Some(false)]);
let s_list = Series::new("list", &[s.clone(), s.clone(), s.clone()]);
fn sample_dataframe() -> DataFrame {
let s1 = Series::new("foo", &[1, 2, 3]);
let s2 = Series::new("bar", &[Some(true), None, Some(false)]);
let s3 = Series::new("utf8", &["mouse", "elephant", "dog"]);
let s_list = Series::new("list", &[s1.clone(), s1.clone(), s1.clone()]);

DataFrame::new(vec![s1, s2, s3, s_list]).unwrap()
}

let df = DataFrame::new(vec![s, s_list, s1]).unwrap();
#[test]
fn test_serde_df_json() {
let df = sample_dataframe();
let json = serde_json::to_string(&df).unwrap();
dbg!(&json);
let out = serde_json::from_str::<DataFrame>(&json).unwrap(); // uses `Deserialize<'de>`
assert!(df.frame_equal_missing(&out));
}

/// test using the `DeserializedOwned` trait
#[test]
fn test_serde_df_owned() {
let s = Series::new("foo", &[1, 2, 3]);
let s1 = Series::new("bar", &[Some(true), None, Some(false)]);
let s_list = Series::new("list", &[s.clone(), s.clone(), s.clone()]);
fn test_serde_df_bincode() {
let df = sample_dataframe();
let bytes = bincode::serialize(&df).unwrap();
let out = bincode::deserialize::<DataFrame>(&bytes).unwrap(); // uses `Deserialize<'de>`
assert!(df.frame_equal_missing(&out));
}

let df = DataFrame::new(vec![s, s_list, s1]).unwrap();
/// test using the `DeserializedOwned` trait
#[test]
fn test_serde_df_owned_json() {
let df = sample_dataframe();
let json = serde_json::to_string(&df).unwrap();
dbg!(&json);

let out = serde_json::from_reader::<_, DataFrame>(json.as_bytes()).unwrap(); // uses `DeserializeOwned`
assert!(df.frame_equal_missing(&out));
}

/// test using the `DeserializedOwned` trait
#[test]
fn test_serde_df_owned_bincode() {
let df = sample_dataframe();
let bytes = bincode::serialize(&df).unwrap();
let out = bincode::deserialize_from::<_, DataFrame>(bytes.as_slice()).unwrap(); // uses `DeserializeOwned`
assert!(df.frame_equal_missing(&out));
}
}
4 changes: 2 additions & 2 deletions polars/polars-core/src/serde/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,11 @@ impl<'de> Deserialize<'de> for Series {
Ok(Series::new(&name, values))
}
DeDataType::Utf8 => {
let values: Vec<Option<&str>> = map.next_value()?;
let values: Vec<Option<Cow<str>>> = map.next_value()?;
Ok(Series::new(&name, values))
}
DeDataType::List => {
let values: Vec<Series> = map.next_value()?;
let values: Vec<Option<Series>> = map.next_value()?;
Ok(Series::new(&name, values))
}
dt => {
Expand Down
41 changes: 41 additions & 0 deletions polars/polars-core/src/series/from.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::chunked_array::object::extension::polars_extension::PolarsExtension;
use crate::prelude::*;
use arrow::compute::cast::utf8_to_large_utf8;
use polars_arrow::compute::cast::cast;
use std::borrow::Cow;
use std::convert::TryFrom;

pub trait NamedFrom<T, Phantom: ?Sized> {
Expand Down Expand Up @@ -33,6 +34,24 @@ impl<'a, T: AsRef<[Option<&'a str>]>> NamedFrom<T, [Option<&'a str>]> for Series
}
}

impl<'a, T: AsRef<[Cow<'a, str>]>> NamedFrom<T, [Cow<'a, str>]> for Series {
fn new(name: &str, v: T) -> Self {
Utf8Chunked::new_from_iter(name, v.as_ref().iter().map(|value| value.as_ref()))
.into_series()
}
}
impl<'a, T: AsRef<[Option<Cow<'a, str>>]>> NamedFrom<T, [Option<Cow<'a, str>>]> for Series {
fn new(name: &str, v: T) -> Self {
Utf8Chunked::new_from_opt_iter(
name,
v.as_ref()
.iter()
.map(|opt| opt.as_ref().map(|value| value.as_ref())),
)
.into_series()
}
}

impl_named_from!([String], Utf8Type, new_from_slice);
impl_named_from!([bool], BooleanType, new_from_slice);
#[cfg(feature = "dtype-u8")]
Expand Down Expand Up @@ -80,6 +99,28 @@ impl<T: AsRef<[Series]>> NamedFrom<T, ListType> for Series {
}
}

impl<T: AsRef<[Option<Series>]>> NamedFrom<T, [Option<Series>]> for Series {
fn new(name: &str, s: T) -> Self {
let series_slice = s.as_ref();
let values_cap = series_slice.iter().fold(0, |acc, opt_s| {
acc + opt_s.as_ref().map(|s| s.len()).unwrap_or(0)
});

let dt = series_slice
.iter()
.filter_map(|opt| opt.as_ref())
.next()
.expect("cannot create List Series from a slice of nulls")
.dtype();

let mut builder = get_list_builder(dt, values_cap, series_slice.len(), name);
for series in series_slice {
builder.append_opt_series(series.as_ref())
}
builder.finish().into_series()
}
}

fn convert_list_inner(arr: &ArrayRef, fld: &ArrowField) -> ArrayRef {
// if inner type is Utf8, we need to convert that to large utf8
match fld.data_type() {
Expand Down

0 comments on commit 2f626ab

Please sign in to comment.