Skip to content

Commit

Permalink
chore[python]: Standardize conversion of Python string literals to Ru…
Browse files Browse the repository at this point in the history
…st enums [Part 2] (#4394)
  • Loading branch information
stinodego committed Aug 13, 2022
1 parent b14e6fb commit 0f912ba
Show file tree
Hide file tree
Showing 4 changed files with 246 additions and 197 deletions.
208 changes: 191 additions & 17 deletions py-polars/src/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ use pyo3::{PyAny, PyResult};
use std::fmt::{Display, Formatter};
use std::hash::{Hash, Hasher};

#[cfg(feature = "avro")]
use polars::io::avro::AvroCompression;
#[cfg(feature = "ipc")]
use polars::io::ipc::IpcCompression;

pub(crate) fn slice_to_wrapped<T>(slice: &[T]) -> &[Wrap<T>] {
// Safety:
// Wrap is transparent.
Expand Down Expand Up @@ -724,23 +729,55 @@ pub(crate) fn dicts_to_rows(records: &PyAny) -> PyResult<(Vec<Row>, Vec<String>)
Ok((rows, keys_first))
}

pub(crate) fn parse_strategy(strat: &str, limit: FillNullLimit) -> PyResult<FillNullStrategy> {
let strat = match strat {
"forward" => FillNullStrategy::Forward(limit),
"backward" => FillNullStrategy::Backward(limit),
"min" => FillNullStrategy::Min,
"max" => FillNullStrategy::Max,
"mean" => FillNullStrategy::Mean,
"zero" => FillNullStrategy::Zero,
"one" => FillNullStrategy::One,
e => {
return Err(PyValueError::new_err(format!(
"strategy must be one of {{'forward', 'backward', 'min', 'max', 'mean', 'zero', 'one'}}, got {}",
e,
)))
}
};
Ok(strat)
#[cfg(feature = "asof_join")]
impl FromPyObject<'_> for Wrap<AsofStrategy> {
fn extract(ob: &PyAny) -> PyResult<Self> {
let parsed = match ob.extract::<&str>()? {
"backward" => AsofStrategy::Backward,
"forward" => AsofStrategy::Forward,
v => {
return Err(PyValueError::new_err(format!(
"strategy must be one of {{'backward', 'forward'}}, got {}",
v
)))
}
};
Ok(Wrap(parsed))
}
}

#[cfg(feature = "avro")]
impl FromPyObject<'_> for Wrap<Option<AvroCompression>> {
fn extract(ob: &PyAny) -> PyResult<Self> {
let parsed = match ob.extract::<&str>()? {
"uncompressed" => None,
"snappy" => Some(AvroCompression::Snappy),
"deflate" => Some(AvroCompression::Deflate),
v => {
return Err(PyValueError::new_err(format!(
"compression must be one of {{'uncompressed', 'snappy', 'deflate'}}, got {}",
v
)))
}
};
Ok(Wrap(parsed))
}
}

impl FromPyObject<'_> for Wrap<CategoricalOrdering> {
fn extract(ob: &PyAny) -> PyResult<Self> {
let parsed = match ob.extract::<&str>()? {
"physical" => CategoricalOrdering::Physical,
"lexical" => CategoricalOrdering::Lexical,
v => {
return Err(PyValueError::new_err(format!(
"ordering must be one of {{'physical', 'lexical'}}, got {}",
v
)))
}
};
Ok(Wrap(parsed))
}
}

impl FromPyObject<'_> for Wrap<ClosedWindow> {
Expand All @@ -761,6 +798,77 @@ impl FromPyObject<'_> for Wrap<ClosedWindow> {
}
}

impl FromPyObject<'_> for Wrap<CsvEncoding> {
fn extract(ob: &PyAny) -> PyResult<Self> {
let parsed = match ob.extract::<&str>()? {
"utf8" => CsvEncoding::Utf8,
"utf8-lossy" => CsvEncoding::LossyUtf8,
v => {
return Err(PyValueError::new_err(format!(
"encoding must be one of {{'utf8', 'utf8-lossy'}}, got {}",
v
)))
}
};
Ok(Wrap(parsed))
}
}

#[cfg(feature = "ipc")]
impl FromPyObject<'_> for Wrap<Option<IpcCompression>> {
fn extract(ob: &PyAny) -> PyResult<Self> {
let parsed = match ob.extract::<&str>()? {
"uncompressed" => None,
"lz4" => Some(IpcCompression::LZ4),
"zstd" => Some(IpcCompression::ZSTD),
v => {
return Err(PyValueError::new_err(format!(
"compression must be one of {{'uncompressed', 'lz4', 'zstd'}}, got {}",
v
)))
}
};
Ok(Wrap(parsed))
}
}

impl FromPyObject<'_> for Wrap<JoinType> {
fn extract(ob: &PyAny) -> PyResult<Self> {
let parsed = match ob.extract::<&str>()? {
"inner" => JoinType::Inner,
"left" => JoinType::Left,
"outer" => JoinType::Outer,
"semi" => JoinType::Semi,
"anti" => JoinType::Anti,
#[cfg(feature = "cross_join")]
"cross" => JoinType::Cross,
v => {
return Err(PyValueError::new_err(format!(
"how must be one of {{'inner', 'left', 'outer', 'semi', 'anti', 'cross'}}, got {}",
v
)))
}
};
Ok(Wrap(parsed))
}
}

impl FromPyObject<'_> for Wrap<ListToStructWidthStrategy> {
fn extract(ob: &PyAny) -> PyResult<Self> {
let parsed = match ob.extract::<&str>()? {
"first_non_null" => ListToStructWidthStrategy::FirstNonNull,
"max_width" => ListToStructWidthStrategy::MaxWidth,
v => {
return Err(PyValueError::new_err(format!(
"n_field_strategy must be one of {{'first_non_null', 'max_width'}}, got {}",
v
)))
}
};
Ok(Wrap(parsed))
}
}

impl FromPyObject<'_> for Wrap<NullBehavior> {
fn extract(ob: &PyAny) -> PyResult<Self> {
let parsed = match ob.extract::<&str>()? {
Expand Down Expand Up @@ -905,3 +1013,69 @@ impl FromPyObject<'_> for Wrap<UniqueKeepStrategy> {
Ok(Wrap(parsed))
}
}

pub(crate) fn parse_fill_null_strategy(
strategy: &str,
limit: FillNullLimit,
) -> PyResult<FillNullStrategy> {
let parsed = match strategy {
"forward" => FillNullStrategy::Forward(limit),
"backward" => FillNullStrategy::Backward(limit),
"min" => FillNullStrategy::Min,
"max" => FillNullStrategy::Max,
"mean" => FillNullStrategy::Mean,
"zero" => FillNullStrategy::Zero,
"one" => FillNullStrategy::One,
e => {
return Err(PyValueError::new_err(format!(
"strategy must be one of {{'forward', 'backward', 'min', 'max', 'mean', 'zero', 'one'}}, got {}",
e,
)))
}
};
Ok(parsed)
}

#[cfg(feature = "parquet")]
pub(crate) fn parse_parquet_compression(
compression: &str,
compression_level: Option<i32>,
) -> PyResult<ParquetCompression> {
let parsed = match compression {
"uncompressed" => ParquetCompression::Uncompressed,
"snappy" => ParquetCompression::Snappy,
"gzip" => ParquetCompression::Gzip(
compression_level
.map(|lvl| {
GzipLevel::try_new(lvl as u8)
.map_err(|e| PyValueError::new_err(format!("{:?}", e)))
})
.transpose()?,
),
"lzo" => ParquetCompression::Lzo,
"brotli" => ParquetCompression::Brotli(
compression_level
.map(|lvl| {
BrotliLevel::try_new(lvl as u32)
.map_err(|e| PyValueError::new_err(format!("{:?}", e)))
})
.transpose()?,
),
"lz4" => ParquetCompression::Lz4Raw,
"zstd" => ParquetCompression::Zstd(
compression_level
.map(|lvl| {
ZstdLevel::try_new(lvl as i32)
.map_err(|e| PyValueError::new_err(format!("{:?}", e)))
})
.transpose()?,
),
e => {
return Err(PyValueError::new_err(format!(
"compression must be one of {{'uncompressed', 'snappy', 'gzip', 'lzo', 'brotli', 'lz4', 'zstd'}}, got {}",
e
)))
}
};
Ok(parsed)
}

0 comments on commit 0f912ba

Please sign in to comment.