-
-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add python rust compilation example (#2826)
* start with examples * add example showing custom compiled python functions
- Loading branch information
Showing
22 changed files
with
355 additions
and
57 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
[package] | ||
name = "python_rust_compiled_function" | ||
version = "0.1.0" | ||
edition = "2021" | ||
|
||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | ||
[lib] | ||
name = "my_polars_functions" | ||
crate-type = ["cdylib"] | ||
|
||
[dependencies] | ||
polars = { path = "../../polars" } | ||
polars-arrow = { path = "../../polars/polars-arrow" } | ||
pyo3 = "0.16" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# Compile Custom Rust functions and use in python polars | ||
|
||
## Compile a development binary in your current environment | ||
`$ pip install -U maturin && maturin develop` | ||
|
||
## Run | ||
`$ python example.py` | ||
|
||
|
||
## Compile a **release** build | ||
`$ maturin develop --release` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import polars as pl | ||
from my_polars_functions import hamming_distance | ||
|
||
a = pl.Series("a", ["foo", "bar"]) | ||
b = pl.Series("b", ["fooy", "ham"]) | ||
|
||
dist = hamming_distance(a, b) | ||
expected = pl.Series("", [None, 2], dtype=pl.UInt32) | ||
|
||
print(hamming_distance(a, b)) | ||
assert dist.series_equal(expected, null_equal=True) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
[build-system] | ||
requires = ["maturin>=0.12,<0.13"] | ||
build-backend = "maturin" | ||
|
||
[project] | ||
name = "my_polars_functions" | ||
version = "0.1.0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
use arrow::{array::ArrayRef, ffi}; | ||
use polars::prelude::*; | ||
use polars_arrow::export::arrow; | ||
use pyo3::exceptions::PyValueError; | ||
use pyo3::prelude::*; | ||
use pyo3::{ffi::Py_uintptr_t, PyAny, PyObject, PyResult}; | ||
|
||
/// Take an arrow array from python and convert it to a rust arrow array. | ||
/// This operation does not copy data. | ||
fn array_to_rust(arrow_array: &PyAny) -> PyResult<ArrayRef> { | ||
// prepare a pointer to receive the Array struct | ||
let array = Box::new(ffi::ArrowArray::empty()); | ||
let schema = Box::new(ffi::ArrowSchema::empty()); | ||
|
||
let array_ptr = &*array as *const ffi::ArrowArray; | ||
let schema_ptr = &*schema as *const ffi::ArrowSchema; | ||
|
||
// make the conversion through PyArrow's private API | ||
// this changes the pointer's memory and is thus unsafe. In particular, `_export_to_c` can go out of bounds | ||
arrow_array.call_method1( | ||
"_export_to_c", | ||
(array_ptr as Py_uintptr_t, schema_ptr as Py_uintptr_t), | ||
)?; | ||
|
||
unsafe { | ||
let field = ffi::import_field_from_c(schema.as_ref()).unwrap(); | ||
let array = ffi::import_array_from_c(array, field.data_type).unwrap(); | ||
Ok(array.into()) | ||
} | ||
} | ||
|
||
/// Arrow array to Python. | ||
pub(crate) fn to_py_array(py: Python, pyarrow: &PyModule, array: ArrayRef) -> PyResult<PyObject> { | ||
let array_ptr = Box::new(ffi::ArrowArray::empty()); | ||
let schema_ptr = Box::new(ffi::ArrowSchema::empty()); | ||
|
||
let array_ptr = Box::into_raw(array_ptr); | ||
let schema_ptr = Box::into_raw(schema_ptr); | ||
|
||
unsafe { | ||
ffi::export_field_to_c( | ||
&ArrowField::new("", array.data_type().clone(), true), | ||
schema_ptr, | ||
); | ||
ffi::export_array_to_c(array, array_ptr); | ||
}; | ||
|
||
let array = pyarrow.getattr("Array")?.call_method1( | ||
"_import_from_c", | ||
(array_ptr as Py_uintptr_t, schema_ptr as Py_uintptr_t), | ||
)?; | ||
|
||
unsafe { | ||
Box::from_raw(array_ptr); | ||
Box::from_raw(schema_ptr); | ||
}; | ||
|
||
Ok(array.to_object(py)) | ||
} | ||
|
||
pub fn py_series_to_rust_series(series: &PyAny) -> PyResult<Series> { | ||
// rechunk series so that they have a single arrow array | ||
let series = series.call_method0("rechunk")?; | ||
|
||
let name = series.getattr("name")?.extract::<String>()?; | ||
|
||
// retrieve pyarrow array | ||
let array = series.call_method0("to_arrow")?; | ||
|
||
// retrieve rust arrow array | ||
let array = array_to_rust(array)?; | ||
|
||
Series::try_from((name.as_str(), array)).map_err(|e| PyValueError::new_err(format!("{}", e))) | ||
} | ||
|
||
pub fn rust_series_to_py_series(series: &Series) -> PyResult<PyObject> { | ||
// ensure we have a single chunk | ||
let series = series.rechunk(); | ||
let array = series.to_arrow(0); | ||
|
||
// acquire the gil | ||
let gil = Python::acquire_gil(); | ||
let py = gil.python(); | ||
// import pyarrow | ||
let pyarrow = py.import("pyarrow")?; | ||
|
||
// pyarrow array | ||
let pyarrow_array = to_py_array(py, pyarrow, array)?; | ||
|
||
// import polars | ||
let polars = py.import("polars")?; | ||
let out = polars.call_method1("from_arrow", (pyarrow_array,))?; | ||
Ok(out.to_object(py)) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
mod ffi; | ||
|
||
use polars::prelude::*; | ||
use pyo3::exceptions::PyValueError; | ||
use pyo3::prelude::*; | ||
|
||
#[pyfunction] | ||
fn hamming_distance(series_a: &PyAny, series_b: &PyAny) -> PyResult<PyObject> { | ||
let series_a = ffi::py_series_to_rust_series(series_a)?; | ||
let series_b = ffi::py_series_to_rust_series(series_b)?; | ||
|
||
let out = hamming_distance_impl(&series_a, &series_b) | ||
.map_err(|e| PyValueError::new_err(format!("Something went wrong: {:?}", e)))?; | ||
ffi::rust_series_to_py_series(&out.into_series()) | ||
} | ||
|
||
/// This function iterates over 2 `Utf8Chunked` arrays and computes the hamming distance between the values . | ||
fn hamming_distance_impl(a: &Series, b: &Series) -> Result<UInt32Chunked> { | ||
Ok(a.utf8()? | ||
.into_iter() | ||
.zip(b.utf8()?.into_iter()) | ||
.map(|(lhs, rhs)| hamming_distance_strs(lhs, rhs)) | ||
.collect()) | ||
} | ||
|
||
/// Compute the hamming distance between 2 string values. | ||
fn hamming_distance_strs(a: Option<&str>, b: Option<&str>) -> Option<u32> { | ||
match (a, b) { | ||
(None, _) => None, | ||
(_, None) => None, | ||
(Some(a), Some(b)) => { | ||
if a.len() != b.len() { | ||
None | ||
} else { | ||
Some( | ||
a.chars() | ||
.zip(b.chars()) | ||
.map(|(a_char, b_char)| (a_char != b_char) as u32) | ||
.sum::<u32>(), | ||
) | ||
} | ||
} | ||
} | ||
} | ||
|
||
#[pymodule] | ||
fn my_polars_functions(_py: Python, m: &PyModule) -> PyResult<()> { | ||
m.add_wrapped(wrap_pyfunction!(hamming_distance)).unwrap(); | ||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
[package] | ||
name = "read_csv" | ||
version = "0.1.0" | ||
edition = "2021" | ||
|
||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | ||
|
||
[features] | ||
write_output = ["polars/ipc", "polars/parquet"] | ||
default = ["write_output"] | ||
|
||
[dependencies] | ||
polars = { path = "../../polars", features = ["lazy", "csv-file", "pretty_fmt"] } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
use polars::prelude::*; | ||
|
||
fn main() -> Result<()> { | ||
let mut df = LazyCsvReader::new("../datasets/foods1.csv".into()) | ||
.finish()? | ||
.select([ | ||
// select all columns | ||
all(), | ||
// and do some aggregations | ||
cols(["fats_g", "sugars_g"]).sum().suffix("_summed"), | ||
]) | ||
.collect()?; | ||
|
||
dbg!(&df); | ||
|
||
write_other_formats(&mut df)?; | ||
Ok(()) | ||
} | ||
|
||
fn write_other_formats(df: &mut DataFrame) -> Result<()> { | ||
let parquet_out = "../datasets/foods1.parquet"; | ||
if std::fs::metadata(&parquet_out).is_err() { | ||
let f = std::fs::File::create(&parquet_out).unwrap(); | ||
ParquetWriter::new(f).with_statistics(true).finish(df)?; | ||
} | ||
let ipc_out = "../datasets/foods1.ipc"; | ||
if std::fs::metadata(&ipc_out).is_err() { | ||
let f = std::fs::File::create(&ipc_out).unwrap(); | ||
IpcWriter::new(f).finish(df)? | ||
} | ||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
[package] | ||
name = "read_parquet" | ||
version = "0.1.0" | ||
edition = "2021" | ||
|
||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | ||
|
||
[dependencies] | ||
polars = { path = "../../polars", features = ["lazy", "parquet", "pretty_fmt"] } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
use polars::prelude::*; | ||
|
||
fn main() -> Result<()> { | ||
let df = LazyFrame::scan_parquet( | ||
"../datasets/foods1.parquet".into(), | ||
ScanArgsParquet::default(), | ||
)? | ||
.select([ | ||
// select all columns | ||
all(), | ||
// and do some aggregations | ||
cols(["fats_g", "sugars_g"]).sum().suffix("_summed"), | ||
]) | ||
.collect()?; | ||
|
||
dbg!(df); | ||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.