Skip to content

Commit

Permalink
add callback to lazy csv scanner so that column names can be modified
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Nov 26, 2021
1 parent 4f921e2 commit 2eb476f
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 3 deletions.
4 changes: 4 additions & 0 deletions polars/polars-core/src/datatypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,10 @@ impl Field {
self.data_type = dtype;
}

pub fn set_name(&mut self, name: String) {
self.name = name;
}

pub fn to_arrow(&self) -> ArrowField {
ArrowField::new(&self.name, self.data_type.to_arrow(), true)
}
Expand Down
28 changes: 28 additions & 0 deletions polars/polars-lazy/src/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ use crate::prelude::simplify_expr::SimplifyBooleanRule;
use crate::utils::{combine_predicates_expr, expr_to_root_column_names};
use crate::{logical_plan::FETCH_ROWS, prelude::*};
use polars_io::csv::NullValues;
#[cfg(feature = "csv-file")]
use polars_io::csv_core::utils::get_reader_bytes;
#[cfg(feature = "csv-file")]
use polars_io::csv_core::utils::infer_file_schema;

#[derive(Clone)]
#[cfg(feature = "csv-file")]
Expand Down Expand Up @@ -149,6 +153,30 @@ impl<'a> LazyCsvReader<'a> {
self
}

/// Modify a schema before we run the lazy scanning.
///
/// Important! Run this function latest in the builder!
pub fn with_schema_modify<F>(mut self, f: F) -> Result<Self>
where
F: Fn(Schema) -> Result<Schema>,
{
let mut file = std::fs::File::open(&self.path)?;
let reader_bytes = get_reader_bytes(&mut file).expect("could not mmap file");

let (schema, _) = infer_file_schema(
&reader_bytes,
self.delimiter,
self.infer_schema_length,
self.has_header,
self.schema_overwrite,
&mut self.skip_rows,
self.comment_char,
self.quote_char,
)?;
let schema = f(schema)?;
Ok(self.with_schema(Arc::new(schema)))
}

pub fn finish(self) -> Result<LazyFrame> {
let mut lf: LazyFrame = LogicalPlanBuilder::scan_csv(
self.path,
Expand Down
2 changes: 2 additions & 0 deletions py-polars/polars/internals/lazy_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def scan_csv(
comment_char: Optional[str] = None,
quote_char: Optional[str] = r'"',
null_values: Optional[Union[str, tp.List[str], Dict[str, str]]] = None,
with_column_names: Optional[Callable[[tp.List[str]], tp.List[str]]] = None,
) -> "LazyFrame":
"""
See Also: `pl.scan_csv`
Expand All @@ -79,6 +80,7 @@ def scan_csv(
quote_char,
processed_null_values,
infer_schema_length,
with_column_names,
)
return self

Expand Down
41 changes: 41 additions & 0 deletions py-polars/polars/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import (
Any,
BinaryIO,
Callable,
ContextManager,
Dict,
Iterator,
Expand Down Expand Up @@ -395,6 +396,7 @@ def scan_csv(
comment_char: Optional[str] = None,
quote_char: Optional[str] = r'"',
null_values: Optional[Union[str, List[str], Dict[str, str]]] = None,
with_column_names: Optional[Callable[[List[str]], List[str]]] = None,
) -> LazyFrame:
"""
Lazily read from a csv file.
Expand Down Expand Up @@ -436,6 +438,44 @@ def scan_csv(
- str -> all values encountered equal to this string will be null
- List[str] -> A null value per column.
- Dict[str, str] -> A dictionary that maps column name to a null value string.
with_column_names
Apply a function over the column names. This can be used to update a schema just in time, thus before scanning.
Examples
--------
>>> (pl.scan_csv("my_long_file.csv") # lazy, doesn't do a thing
>>> .select(["a", "c"]) # select only 2 columns (other columns will not be read)
>>> .filter(pl.col("a") > 10) # the filter is pushed down the the scan, so less data read in memory
>>> .fetch(100) # pushed a limit of 100 rows to the scan level
>>> )
>>> # we can use `with_column_names` to modify the header before scanning
>>> df = pl.DataFrame({
>>> "BrEeZaH": [1, 2, 3, 4],
>>> "LaNgUaGe": ["is", "terrible", "to", "read"]
>>> })
>>> df.to_csv("mydf.csv")
>>> (pl.scan_csv("mydf.csv",
>>> with_column_names=lambda cols: [col.lower() for col in cols])
>>> .fetch()
>>> )
shape: (4, 2)
┌─────────┬──────────┐
│ breezah ┆ language │
│ --- ┆ --- │
│ i64 ┆ str │
╞═════════╪══════════╡
│ 1 ┆ is │
├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┤
│ 2 ┆ terrible │
├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┤
│ 3 ┆ to │
├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌┤
│ 4 ┆ read │
└─────────┴──────────┘
"""
if isinstance(file, Path):
file = str(file)
Expand All @@ -453,6 +493,7 @@ def scan_csv(
quote_char=quote_char,
null_values=null_values,
infer_schema_length=infer_schema_length,
with_column_names=with_column_names,
)


Expand Down
63 changes: 60 additions & 3 deletions py-polars/src/lazy/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use polars::lazy::frame::{AllowedOptimizations, LazyCsvReader, LazyFrame, LazyGr
use polars::lazy::prelude::col;
use polars::prelude::{DataFrame, Field, JoinType, Schema};
use pyo3::prelude::*;
use pyo3::types::PyList;

#[pyclass]
#[repr(transparent)]
Expand Down Expand Up @@ -81,6 +82,38 @@ impl From<LazyFrame> for PyLazyFrame {
}
}

// pub fn apply(&mut self, lambda: PyObject) -> PyLazyFrame {
// let lgb = self.lgb.take().unwrap();
//
// let function = move |df: DataFrame| {
// let gil = Python::acquire_gil();
// let py = gil.python();
// // get the pypolars module
// let pypolars = PyModule::import(py, "polars").unwrap();
//
// // create a PyDataFrame struct/object for Python
// let pydf = PyDataFrame::new(df);
//
// // Wrap this PySeries object in the python side DataFrame wrapper
// let python_df_wrapper = pypolars.getattr("wrap_df").unwrap().call1((pydf,)).unwrap();
//
// // call the lambda and get a python side DataFrame wrapper
// let result_df_wrapper = match lambda.call1(py, (python_df_wrapper,)) {
// Ok(pyobj) => pyobj,
// Err(e) => panic!("UDF failed: {}", e.pvalue(py).to_string()),
// };
// // unpack the wrapper in a PyDataFrame
// let py_pydf = result_df_wrapper.getattr(py, "_df").expect(
// "Could net get DataFrame attribute '_df'. Make sure that you return a DataFrame object.",
// );
// // Downcast to Rust
// let pydf = py_pydf.extract::<PyDataFrame>(py).unwrap();
// // Finally get the actual DataFrame
// Ok(pydf.df)
// };
// lgb.apply(function).into()
// }

#[pymethods]
#[allow(clippy::should_implement_trait)]
impl PyLazyFrame {
Expand All @@ -100,6 +133,7 @@ impl PyLazyFrame {
quote_char: Option<&str>,
null_values: Option<Wrap<NullValues>>,
infer_schema_length: Option<usize>,
with_schema_modify: Option<PyObject>
) -> PyResult<Self> {
let null_values = null_values.map(|w| w.0);
let comment_char = comment_char.map(|s| s.as_bytes()[0]);
Expand All @@ -117,8 +151,7 @@ impl PyLazyFrame {
.collect();
Schema::new(fields)
});

Ok(LazyCsvReader::new(path)
let mut r = LazyCsvReader::new(path)
.with_infer_schema_length(infer_schema_length)
.with_delimiter(delimiter)
.has_header(has_header)
Expand All @@ -130,7 +163,31 @@ impl PyLazyFrame {
.low_memory(low_memory)
.with_comment_char(comment_char)
.with_quote_char(quote_char)
.with_null_values(null_values)
.with_null_values(null_values);

if let Some(lambda) = with_schema_modify {
let f = | mut schema: Schema| {
let gil = Python::acquire_gil();
let py = gil.python();

let iter = schema.fields().iter().map(|fld| fld.name().as_str());
let names = PyList::new(py, iter);

let out = lambda.call1(py, (names,)).expect("python function failed");
let new_names = out.extract::<Vec<String>>(py).expect("python function should return List[str]");
assert_eq!(new_names.len(), schema.fields().len(), "The length of the new names list should be equal to the original column length");

schema.fields_mut().iter_mut().zip(new_names).for_each(|(fld, new_name)| {
fld.set_name(new_name)
});

Ok(schema)
};
r = r.with_schema_modify(f).map_err(PyPolarsEr::from)?
}

Ok(
r
.finish()
.map_err(PyPolarsEr::from)?
.into())
Expand Down

0 comments on commit 2eb476f

Please sign in to comment.