Skip to content

Commit

Permalink
fix[rust]: don't divide by zero on parquet write row_groups > df.size (
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Aug 14, 2022
1 parent 98e46ba commit 8684945
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 3 deletions.
3 changes: 3 additions & 0 deletions polars/polars-core/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ fn flatten_df(df: &DataFrame) -> impl Iterator<Item = DataFrame> + '_ {
#[cfg(feature = "private")]
#[doc(hidden)]
pub fn split_df(df: &DataFrame, n: usize) -> Result<Vec<DataFrame>> {
if n == 0 {
return Ok(vec![df.clone()]);
}
let total_len = df.height();
let chunk_size = total_len / n;

Expand Down
7 changes: 5 additions & 2 deletions polars/polars-io/src/parquet/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ where
self
}

/// Set the row group size during writing. This can reduce memory pressure and improve
/// Set the row group size (in number of rows) during writing. This can reduce memory pressure and improve
/// writing performance.
pub fn with_row_group_size(mut self, size: Option<usize>) -> Self {
self.row_group_size = size;
Expand All @@ -68,7 +68,10 @@ where
df.rechunk();

if let Some(n) = self.row_group_size {
*df = accumulate_dataframes_vertical_unchecked(split_df(df, df.height() / n)?);
let n_splits = df.height() / n;
if n_splits > 0 {
*df = accumulate_dataframes_vertical_unchecked(split_df(df, n_splits)?);
}
};

let fields = df.schema().to_arrow().fields;
Expand Down
3 changes: 2 additions & 1 deletion py-polars/polars/internals/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1333,7 +1333,8 @@ def write_parquet(
statistics
Write statistics to the parquet headers. This requires extra compute.
row_group_size
Size of the row groups. If None (default), the chunks of the `DataFrame` are
Size of the row groups in number of rows.
If None (default), the chunks of the `DataFrame` are
used. Writing in smaller chunks may reduce memory pressure and improve
writing speeds. This argument has no effect if 'pyarrow' is used.
use_pyarrow
Expand Down
10 changes: 10 additions & 0 deletions py-polars/tests/io/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,13 @@ def test_nested_dictionary() -> None:

read_df = pl.read_parquet(f)
assert df.frame_equal(read_df)


def test_row_group_size_saturation() -> None:
df = pl.DataFrame({"a": [1, 2, 3]})
f = io.BytesIO()

# request larger chunk than rows in df
df.write_parquet(f, row_group_size=1024)
f.seek(0)
assert pl.read_parquet(f).frame_equal(df)

0 comments on commit 8684945

Please sign in to comment.