Skip to content

Commit

Permalink
make sure to always rechunk before converting to arrow chunks (#2464)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jan 25, 2022
1 parent ebcc035 commit c15ecdd
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 28 deletions.
8 changes: 8 additions & 0 deletions polars/polars-core/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2075,6 +2075,14 @@ impl DataFrame {
}

/// Iterator over the rows in this `DataFrame` as Arrow RecordBatches.
///
/// # Panics
///
/// Panics if the `DataFrame` that is passed is not rechunked.
///
/// This responsibility is left to the caller as we don't want to take mutable references here,
/// but we also don't want to rechunk here, as this operation is costly and would benefit the caller
/// as well.
pub fn iter_chunks(&self) -> impl Iterator<Item = ArrowChunk> + '_ {
RecordBatchIter {
columns: &self.columns,
Expand Down
11 changes: 6 additions & 5 deletions polars/polars-io/src/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,11 @@ where
}
}

fn finish(self, df: &DataFrame) -> Result<()> {
fn finish(self, df: &mut DataFrame) -> Result<()> {
df.rechunk();
let mut writer = self.writer_builder.from_writer(self.buffer);
let iter = df.iter_chunks();
let names = df.get_column_names();
let iter = df.iter_chunks();
if self.header {
write::write_header(&mut writer, &names)?;
}
Expand Down Expand Up @@ -635,19 +636,19 @@ mod test {
#[test]
fn write_csv() {
let mut buf: Vec<u8> = Vec::new();
let df = create_df();
let mut df = create_df();

CsvWriter::new(&mut buf)
.has_header(true)
.finish(&df)
.finish(&mut df)
.expect("csv written");
let csv = std::str::from_utf8(&buf).unwrap();
assert_eq!("days,temp\n0,22.1\n1,19.9\n2,7.0\n3,2.0\n4,3.0\n", csv);

let mut buf: Vec<u8> = Vec::new();
CsvWriter::new(&mut buf)
.has_header(false)
.finish(&df)
.finish(&mut df)
.expect("csv written");
let csv = std::str::from_utf8(&buf).unwrap();
assert_eq!("0,22.1\n1,19.9\n2,7.0\n3,2.0\n4,3.0\n", csv);
Expand Down
26 changes: 16 additions & 10 deletions polars/polars-io/src/ipc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ where
}
}

fn finish(mut self, df: &DataFrame) -> Result<()> {
fn finish(mut self, df: &mut DataFrame) -> Result<()> {
let mut ipc_writer = write::FileWriter::try_new(
&mut self.writer,
&df.schema().to_arrow(),
Expand All @@ -266,7 +266,7 @@ where
compression: self.compression,
},
)?;

df.rechunk();
let iter = df.iter_chunks();

for batch in iter {
Expand All @@ -290,9 +290,11 @@ mod test {
// Vec<T> : Write + Read
// Cursor<Vec<_>>: Seek
let mut buf: Cursor<Vec<u8>> = Cursor::new(Vec::new());
let df = create_df();
let mut df = create_df();

IpcWriter::new(&mut buf).finish(&df).expect("ipc writer");
IpcWriter::new(&mut buf)
.finish(&mut df)
.expect("ipc writer");

buf.set_position(0);

Expand All @@ -303,9 +305,11 @@ mod test {
#[test]
fn test_read_ipc_with_projection() {
let mut buf: Cursor<Vec<u8>> = Cursor::new(Vec::new());
let df = df!("a" => [1, 2, 3], "b" => [2, 3, 4], "c" => [3, 4, 5]).unwrap();
let mut df = df!("a" => [1, 2, 3], "b" => [2, 3, 4], "c" => [3, 4, 5]).unwrap();

IpcWriter::new(&mut buf).finish(&df).expect("ipc writer");
IpcWriter::new(&mut buf)
.finish(&mut df)
.expect("ipc writer");
buf.set_position(0);

let expected = df!("b" => [2, 3, 4], "c" => [3, 4, 5]).unwrap();
Expand All @@ -320,9 +324,11 @@ mod test {
#[test]
fn test_read_ipc_with_columns() {
let mut buf: Cursor<Vec<u8>> = Cursor::new(Vec::new());
let df = df!("a" => [1, 2, 3], "b" => [2, 3, 4], "c" => [3, 4, 5]).unwrap();
let mut df = df!("a" => [1, 2, 3], "b" => [2, 3, 4], "c" => [3, 4, 5]).unwrap();

IpcWriter::new(&mut buf).finish(&df).expect("ipc writer");
IpcWriter::new(&mut buf)
.finish(&mut df)
.expect("ipc writer");
buf.set_position(0);

let expected = df!("b" => [2, 3, 4], "c" => [3, 4, 5]).unwrap();
Expand All @@ -336,7 +342,7 @@ mod test {

#[test]
fn test_write_with_compression() {
let df = create_df();
let mut df = create_df();

let compressions = vec![
None,
Expand All @@ -348,7 +354,7 @@ mod test {
let mut buf: Cursor<Vec<u8>> = Cursor::new(Vec::new());
IpcWriter::new(&mut buf)
.with_compression(compression)
.finish(&df)
.finish(&mut df)
.expect("ipc writer");
buf.set_position(0);

Expand Down
3 changes: 2 additions & 1 deletion polars/polars-io/src/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ where
}
}

fn finish(mut self, df: &DataFrame) -> Result<()> {
fn finish(mut self, df: &mut DataFrame) -> Result<()> {
df.rechunk();
let batches = df.iter_chunks().map(Ok);
let names = df.get_column_names_owned();

Expand Down
2 changes: 1 addition & 1 deletion polars/polars-io/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ where
W: Write,
{
fn new(writer: W) -> Self;
fn finish(self, df: &DataFrame) -> Result<()>;
fn finish(self, df: &mut DataFrame) -> Result<()>;
}

pub trait ArrowReader {
Expand Down
4 changes: 2 additions & 2 deletions polars/polars-lazy/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ fn init_files() {

for out_path in [out_path1, out_path2] {
if std::fs::metadata(&out_path).is_err() {
let df = CsvReader::from_path(path).unwrap().finish().unwrap();
let mut df = CsvReader::from_path(path).unwrap().finish().unwrap();

if out_path.ends_with("parquet") {
let f = std::fs::File::create(&out_path).unwrap();
Expand All @@ -59,7 +59,7 @@ fn init_files() {
.unwrap();
} else {
let f = std::fs::File::create(&out_path).unwrap();
IpcWriter::new(f).finish(&df).unwrap();
IpcWriter::new(f).finish(&mut df).unwrap();
}
}
}
Expand Down
20 changes: 11 additions & 9 deletions py-polars/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ impl PyDataFrame {

#[cfg(feature = "json")]
pub fn to_json(
&self,
&mut self,
py_f: PyObject,
pretty: bool,
row_oriented: bool,
Expand All @@ -250,10 +250,10 @@ impl PyDataFrame {
(_, true, true) => panic!("{}", "only one of {row_oriented, json_lines} should be set"),
(_, _, true) => JsonWriter::new(file)
.with_json_format(JsonFormat::JsonLines)
.finish(&self.df),
.finish(&mut self.df),
(_, true, false) => JsonWriter::new(file)
.with_json_format(JsonFormat::Json)
.finish(&self.df),
.finish(&mut self.df),
(true, _, _) => serde_json::to_writer_pretty(file, &self.df)
.map_err(|e| PolarsError::ComputeError(format!("{:?}", e).into())),
(false, _, _) => serde_json::to_writer(file, &self.df)
Expand Down Expand Up @@ -286,18 +286,18 @@ impl PyDataFrame {
Ok(pydf)
}

pub fn to_csv(&self, py_f: PyObject, has_header: bool, sep: u8) -> PyResult<()> {
pub fn to_csv(&mut self, py_f: PyObject, has_header: bool, sep: u8) -> PyResult<()> {
let mut buf = get_file_like(py_f, true)?;
CsvWriter::new(&mut buf)
.has_header(has_header)
.with_delimiter(sep)
.finish(&self.df)
.finish(&mut self.df)
.map_err(PyPolarsEr::from)?;
Ok(())
}

#[cfg(feature = "ipc")]
pub fn to_ipc(&self, py_f: PyObject, compression: &str) -> PyResult<()> {
pub fn to_ipc(&mut self, py_f: PyObject, compression: &str) -> PyResult<()> {
let compression = match compression {
"uncompressed" => None,
"lz4" => Some(IpcCompression::LZ4),
Expand All @@ -308,7 +308,7 @@ impl PyDataFrame {

IpcWriter::new(&mut buf)
.with_compression(compression)
.finish(&self.df)
.finish(&mut self.df)
.map_err(PyPolarsEr::from)?;
Ok(())
}
Expand Down Expand Up @@ -421,7 +421,8 @@ impl PyDataFrame {
Ok(())
}

pub fn to_arrow(&self) -> PyResult<Vec<PyObject>> {
pub fn to_arrow(&mut self) -> PyResult<Vec<PyObject>> {
self.df.rechunk();
let gil = Python::acquire_gil();
let py = gil.python();
let pyarrow = py.import("pyarrow")?;
Expand All @@ -435,7 +436,8 @@ impl PyDataFrame {
Ok(rbs)
}

pub fn to_pandas(&self) -> PyResult<Vec<PyObject>> {
pub fn to_pandas(&mut self) -> PyResult<Vec<PyObject>> {
self.df.rechunk();
let gil = Python::acquire_gil();
let py = gil.python();
let pyarrow = py.import("pyarrow")?;
Expand Down
15 changes: 15 additions & 0 deletions py-polars/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,3 +566,18 @@ def test_csv_schema_offset() -> None:
""".encode()
df = pl.read_csv(csv, offset_schema_inference=4, skip_rows=4)
assert df.columns == ["a", "b", "c"]


def test_from_different_chunks() -> None:
s0 = pl.Series("a", [1, 2, 3, 4, None])
s1 = pl.Series("b", [1, 2])
s11 = pl.Series("b", [1, 2, 3])
s1.append(s11)

# check we don't panic
df = pl.DataFrame([s0, s1])
df.to_arrow()
df = pl.DataFrame([s0, s1])
out = df.to_pandas()
assert list(out.columns) == ["a", "b"]
assert out.shape == (5, 2)

0 comments on commit c15ecdd

Please sign in to comment.