Skip to content

Commit

Permalink
update arrow -> remove recorbatch logic (#2263)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jan 4, 2022
1 parent 4380682 commit 7af76c2
Show file tree
Hide file tree
Showing 18 changed files with 150 additions and 229 deletions.
2 changes: 1 addition & 1 deletion polars/polars-arrow/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ description = "Arrow interfaces for Polars DataFrame library"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
arrow = { package = "arrow2", git = "https://github.com/jorgecarleitao/arrow2", rev = "b617331354fd8c64c2126b6f4fc6f9935f7971ab", default-features = false }
arrow = { package = "arrow2", git = "https://github.com/jorgecarleitao/arrow2", rev = "7add9d31bff7a65076efbf1c4f7732be702f0e2b", default-features = false }
hashbrown = "0.11"
# arrow = { package = "arrow2", git = "https://github.com/ritchie46/arrow2", default-features = false, features = ["compute"], branch = "offset_pub" }
# arrow = { package = "arrow2", version = "0.8", default-features = false }
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ unsafe_unwrap = "^0.1.0"
package = "arrow2"
git = "https://github.com/jorgecarleitao/arrow2"
# git = "https://github.com/ritchie46/arrow2"
rev = "b617331354fd8c64c2126b6f4fc6f9935f7971ab"
rev = "7add9d31bff7a65076efbf1c4f7732be702f0e2b"
# branch = "offset_pub"
# version = "0.8"
default-features = false
Expand Down
21 changes: 21 additions & 0 deletions polars/polars-core/src/frame/chunks.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
use crate::prelude::*;
use arrow::array::ArrayRef;
use arrow::chunk::Chunk;

pub type ArrowChunk = Chunk<ArrayRef>;

impl std::convert::TryFrom<(ArrowChunk, &[ArrowField])> for DataFrame {
type Error = PolarsError;

fn try_from(arg: (ArrowChunk, &[ArrowField])) -> Result<DataFrame> {
let columns: Result<Vec<Series>> = arg
.0
.columns()
.iter()
.zip(arg.1)
.map(|(arr, field)| Series::try_from((field.name().as_ref(), arr.clone())))
.collect();

DataFrame::new(columns?)
}
}
142 changes: 15 additions & 127 deletions polars/polars-core/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,21 @@ use std::collections::HashSet;
use std::iter::{FromIterator, Iterator};
use std::mem;
use std::ops;
use std::sync::Arc;

use ahash::{AHashSet, RandomState};
use arrow::record_batch::RecordBatch;
use polars_arrow::prelude::QuantileInterpolOptions;
use rayon::prelude::*;

use crate::chunked_array::ops::unique::is_unique_helper;
use crate::frame::select::Selection;
use crate::prelude::*;
use crate::utils::{
accumulate_dataframes_horizontal, accumulate_dataframes_vertical, split_ca, split_df, NoNull,
};
use crate::utils::{accumulate_dataframes_horizontal, split_ca, split_df, NoNull};

#[cfg(feature = "dataframe_arithmetic")]
mod arithmetic;
#[cfg(feature = "asof_join")]
pub(crate) mod asof_join;
mod chunks;
#[cfg(feature = "cross_join")]
pub(crate) mod cross_join;
pub mod explode;
Expand All @@ -43,6 +40,8 @@ use hashbrown::HashMap;
use serde::{Deserialize, Serialize};
use std::hash::{BuildHasher, Hash, Hasher};

pub use chunks::*;

#[derive(Copy, Clone, Debug)]
pub enum NullStrategy {
Ignore,
Expand Down Expand Up @@ -464,6 +463,11 @@ impl DataFrame {
self.columns.iter().map(|s| s.name()).collect()
}

/// Get the `Vec<String>` representing the column names.
pub fn get_column_names_owned(&self) -> Vec<String> {
self.columns.iter().map(|s| s.name().to_string()).collect()
}

/// Set the column names.
/// # Example
///
Expand Down Expand Up @@ -2050,17 +2054,10 @@ impl DataFrame {
DataFrame::new_no_checks(col)
}

/// Transform the underlying chunks in the `DataFrame` to Arrow RecordBatches.
pub fn as_record_batches(&self) -> Result<Vec<RecordBatch>> {
self.n_chunks()?;
Ok(self.iter_record_batches().collect())
}

/// Iterator over the rows in this `DataFrame` as Arrow RecordBatches.
pub fn iter_record_batches(&self) -> impl Iterator<Item = RecordBatch> + '_ {
pub fn iter_chunks(&self) -> impl Iterator<Item = ArrowChunk> + '_ {
RecordBatchIter {
columns: &self.columns,
schema: Arc::new(self.schema().to_arrow()),
idx: 0,
n_chunks: self.n_chunks().unwrap_or(0),
}
Expand Down Expand Up @@ -2664,13 +2661,12 @@ impl DataFrame {

pub struct RecordBatchIter<'a> {
columns: &'a Vec<Series>,
schema: Arc<ArrowSchema>,
idx: usize,
n_chunks: usize,
}

impl<'a> Iterator for RecordBatchIter<'a> {
type Item = RecordBatch;
type Item = ArrowChunk;

fn next(&mut self) -> Option<Self::Item> {
if self.idx >= self.n_chunks {
Expand All @@ -2680,7 +2676,7 @@ impl<'a> Iterator for RecordBatchIter<'a> {
let batch_cols = self.columns.iter().map(|s| s.to_arrow(self.idx)).collect();
self.idx += 1;

Some(RecordBatch::try_new(self.schema.clone(), batch_cols).unwrap())
Some(ArrowChunk::new(batch_cols))
}
}
}
Expand All @@ -2697,137 +2693,29 @@ impl From<DataFrame> for Vec<Series> {
}
}

/// Conversion from Vec<RecordBatch> into DataFrame
///
///
impl std::convert::TryFrom<RecordBatch> for DataFrame {
type Error = PolarsError;

fn try_from(batch: RecordBatch) -> Result<DataFrame> {
let columns: Result<Vec<Series>> = batch
.columns()
.iter()
.zip(batch.schema().fields())
.map(|(arr, field)| Series::try_from((field.name().as_ref(), arr.clone())))
.collect();

DataFrame::new(columns?)
}
}

/// Conversion from Vec<RecordBatch> into DataFrame
///
/// If batch-size is small it might be advisable to call rechunk
/// to ensure predictable performance
impl std::convert::TryFrom<Vec<RecordBatch>> for DataFrame {
type Error = PolarsError;

fn try_from(batches: Vec<RecordBatch>) -> Result<DataFrame> {
let mut batch_iter = batches.iter();

// Non empty array
let first_batch = batch_iter.next().ok_or_else(|| {
PolarsError::NoData("At least one record batch is needed to create a dataframe".into())
})?;

// Validate all record batches have the same schema
let schema = first_batch.schema();
for batch in batch_iter {
if batch.schema() != schema {
return Err(PolarsError::SchemaMisMatch(
"All record batches must have the same schema".into(),
));
}
}

let dfs: Result<Vec<DataFrame>> = batches
.iter()
.map(|batch| DataFrame::try_from(batch.clone()))
.collect();

accumulate_dataframes_vertical(dfs?)
}
}

#[cfg(test)]
mod test {
use std::convert::TryFrom;

use arrow::array::{Float64Array, Int64Array};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;

use super::*;
use crate::frame::NullStrategy;
use crate::prelude::*;

fn create_frame() -> DataFrame {
let s0 = Series::new("days", [0, 1, 2].as_ref());
let s1 = Series::new("temp", [22.1, 19.9, 7.].as_ref());
DataFrame::new(vec![s0, s1]).unwrap()
}

fn create_record_batches() -> Vec<RecordBatch> {
// Creates a dataframe using 2 record-batches
//
// | foo | bar |
// -------------------
// | 1.0 | 1 |
// | 2.0 | 2 |
// | 3.0 | 3 |
// | 4.0 | 4 |
// | 5.0 | 5 |
// -------------------
let schema = Arc::new(Schema::new(vec![
Field::new("foo", DataType::Float64, false),
Field::new("bar", DataType::Int64, false),
]));

let batch0 = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Float64Array::from_slice(&[1.0, 2.0, 3.0])),
Arc::new(Int64Array::from_slice(&[1, 2, 3])),
],
)
.unwrap();

let batch1 = RecordBatch::try_new(
schema,
vec![
Arc::new(Float64Array::from_slice(&[4.0, 5.0])),
Arc::new(Int64Array::from_slice(&[4, 5])),
],
)
.unwrap();

return vec![batch0, batch1];
}

#[test]
#[cfg_attr(miri, ignore)]
fn test_recordbatch_iterator() {
let df = df!(
"foo" => &[1, 2, 3, 4, 5]
)
.unwrap();
let mut iter = df.iter_record_batches();
assert_eq!(5, iter.next().unwrap().num_rows());
let mut iter = df.iter_chunks();
assert_eq!(5, iter.next().unwrap().len());
assert!(iter.next().is_none());
}

#[test]
#[cfg_attr(miri, ignore)]
fn test_frame_from_recordbatch() {
let record_batches: Vec<RecordBatch> = create_record_batches();

let df = DataFrame::try_from(record_batches).expect("frame can be initialized");

assert_eq!(
Vec::from(df.column("bar").unwrap().i64().unwrap()),
&[Some(1), Some(2), Some(3), Some(4), Some(5)]
);
}

#[test]
#[cfg_attr(miri, ignore)]
fn test_select() {
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-io/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ private = []
[dependencies]
ahash = "0.7"
anyhow = "1.0"
arrow = { package = "arrow2", git = "https://github.com/jorgecarleitao/arrow2", rev = "b617331354fd8c64c2126b6f4fc6f9935f7971ab", default-features = false }
arrow = { package = "arrow2", git = "https://github.com/jorgecarleitao/arrow2", rev = "7add9d31bff7a65076efbf1c4f7732be702f0e2b", default-features = false }
# arrow = { package = "arrow2", git = "https://github.com/ritchie46/arrow2", default-features = false, features = ["compute"], branch = "offset_pub" }
# arrow = { package = "arrow2", version = "0.8", default-features = false }
csv-core = { version = "0.1.10", optional = true }
Expand Down
7 changes: 4 additions & 3 deletions polars/polars-io/src/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,13 @@ where

fn finish(self, df: &DataFrame) -> Result<()> {
let mut writer = self.writer_builder.from_writer(self.buffer);
let iter = df.iter_record_batches();
let iter = df.iter_chunks();
let names = df.get_column_names();
if self.header {
write::write_header(&mut writer, &df.schema().to_arrow())?;
write::write_header(&mut writer, &names)?;
}
for batch in iter {
write::write_batch(&mut writer, &batch, &self.options)?;
write::write_chunk(&mut writer, &batch, &self.options)?;
}
Ok(())
}
Expand Down

0 comments on commit 7af76c2

Please sign in to comment.