Skip to content

Commit

Permalink
feat[rust, python]: add float_precision parameter to `DataFrame.wri…
Browse files Browse the repository at this point in the history
…te_csv` (#4504)
  • Loading branch information
matteosantama committed Aug 19, 2022
1 parent 9376b9b commit cc9667c
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 27 deletions.
10 changes: 9 additions & 1 deletion polars/polars-io/src/csv/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,22 @@ where
self
}

/// Set the CSV file's timestamp format array
/// Set the CSV file's datetime format
pub fn with_datetime_format(mut self, format: Option<String>) -> Self {
if format.is_some() {
self.options.datetime_format = format;
}
self
}

/// Set the CSV file's float precision
pub fn with_float_precision(mut self, precision: Option<usize>) -> Self {
if precision.is_some() {
self.options.float_precision = precision;
}
self
}

/// Set the single byte character used for quoting
pub fn with_quoting_char(mut self, char: u8) -> Self {
self.options.quote = char;
Expand Down
50 changes: 24 additions & 26 deletions polars/polars-io/src/csv/write_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,17 @@ fn fmt_and_escape_str(f: &mut Vec<u8>, v: &str, options: &SerializeOptions) -> s
}
}

fn fast_float_write<N: ToLexical>(f: &mut Vec<u8>, n: N, write_size: usize) -> std::io::Result<()> {
let len = f.len();
f.reserve(write_size);
unsafe {
let buffer = std::slice::from_raw_parts_mut(f.as_mut_ptr().add(len), write_size);
let written_n = n.to_lexical(buffer).len();
f.set_len(len + written_n);
}
Ok(())
}

fn write_anyvalue(f: &mut Vec<u8>, value: AnyValue, options: &SerializeOptions) {
match value {
AnyValue::Null => write!(f, ""),
Expand All @@ -45,30 +56,14 @@ fn write_anyvalue(f: &mut Vec<u8>, value: AnyValue, options: &SerializeOptions)
AnyValue::UInt16(v) => write!(f, "{}", v),
AnyValue::UInt32(v) => write!(f, "{}", v),
AnyValue::UInt64(v) => write!(f, "{}", v),
AnyValue::Float32(v) => {
let len = f.len();
let write_size = f32::FORMATTED_SIZE_DECIMAL;
f.reserve(write_size);
unsafe {
let buf = std::slice::from_raw_parts_mut(f.as_mut_ptr().add(len), write_size);

let written_n = v.to_lexical(buf).len();
f.set_len(len + written_n);
}
Ok(())
}
AnyValue::Float64(v) => {
let len = f.len();
let write_size = f64::FORMATTED_SIZE_DECIMAL;
f.reserve(write_size);
unsafe {
let buf = std::slice::from_raw_parts_mut(f.as_mut_ptr().add(len), write_size);

let written_n = v.to_lexical(buf).len();
f.set_len(len + written_n);
}
Ok(())
}
AnyValue::Float32(v) => match &options.float_precision {
None => fast_float_write(f, v, f32::FORMATTED_SIZE_DECIMAL),
Some(precision) => write!(f, "{v:.precision$}", v = v, precision = precision),
},
AnyValue::Float64(v) => match &options.float_precision {
None => fast_float_write(f, v, f64::FORMATTED_SIZE_DECIMAL),
Some(precision) => write!(f, "{v:.precision$}", v = v, precision = precision),
},
AnyValue::Boolean(v) => write!(f, "{}", v),
AnyValue::Utf8(v) => fmt_and_escape_str(f, v, options),
#[cfg(feature = "dtype-categorical")]
Expand Down Expand Up @@ -126,10 +121,12 @@ fn write_anyvalue(f: &mut Vec<u8>, value: AnyValue, options: &SerializeOptions)
pub struct SerializeOptions {
/// used for [`DataType::Date`]
pub date_format: Option<String>,
/// used for [`DataType::Time64`]
/// used for [`DataType::Time`]
pub time_format: Option<String>,
/// used for [`DataType::Timestamp`]
/// used for [`DataType::Datetime]
pub datetime_format: Option<String>,
/// used for [`DataType::Float64`] and [`DataType::Float32`]
pub float_precision: Option<usize>,
/// used as separator/delimiter
pub delimiter: u8,
/// quoting character
Expand All @@ -142,6 +139,7 @@ impl Default for SerializeOptions {
date_format: None,
time_format: None,
datetime_format: None,
float_precision: None,
delimiter: b',',
quote: b'"',
}
Expand Down
6 changes: 6 additions & 0 deletions py-polars/polars/internals/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,6 +1089,7 @@ def write_csv(
datetime_format: str | None = None,
date_format: str | None = None,
time_format: str | None = None,
float_precision: int | None = None,
) -> str | None:
"""
Write to comma-separated values (CSV) file.
Expand Down Expand Up @@ -1117,6 +1118,9 @@ def write_csv(
A format string, with the specifiers defined by the
`chrono <https://docs.rs/chrono/latest/chrono/format/strftime/index.html>`_
Rust crate.
float_precision
Number of decimal places to write, applied to both ``Float32`` and
``Float64`` datatypes.
Examples
--------
Expand Down Expand Up @@ -1148,6 +1152,7 @@ def write_csv(
datetime_format,
date_format,
time_format,
float_precision,
)
return str(buffer.getvalue(), encoding="utf-8")

Expand All @@ -1163,6 +1168,7 @@ def write_csv(
datetime_format,
date_format,
time_format,
float_precision,
)
return None

Expand Down
3 changes: 3 additions & 0 deletions py-polars/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,7 @@ impl PyDataFrame {
datetime_format: Option<String>,
date_format: Option<String>,
time_format: Option<String>,
float_precision: Option<usize>,
) -> PyResult<()> {
if let Ok(s) = py_f.extract::<&str>(py) {
let f = std::fs::File::create(s).unwrap();
Expand All @@ -442,6 +443,7 @@ impl PyDataFrame {
.with_datetime_format(datetime_format)
.with_date_format(date_format)
.with_time_format(time_format)
.with_float_precision(float_precision)
.finish(&mut self.df)
.map_err(PyPolarsErr::from)?;
} else {
Expand All @@ -454,6 +456,7 @@ impl PyDataFrame {
.with_datetime_format(datetime_format)
.with_date_format(date_format)
.with_time_format(time_format)
.with_float_precision(float_precision)
.finish(&mut self.df)
.map_err(PyPolarsErr::from)?;
}
Expand Down
11 changes: 11 additions & 0 deletions py-polars/tests/io/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,3 +730,14 @@ def test_time_format(fmt: str, expected: str) -> None:
df = pl.DataFrame({"dt": [time(16, 15, 30)]})
csv = df.write_csv(time_format=fmt)
assert csv == expected


@pytest.mark.parametrize("dtype", [pl.Float32, pl.Float64])
def test_float_precision(dtype: pl.Float32 | pl.Float64) -> None:
df = pl.Series("col", [1.0, 2.2, 3.33], dtype=dtype).to_frame()

assert df.write_csv(float_precision=None) == "col\n1.0\n2.2\n3.33\n"
assert df.write_csv(float_precision=0) == "col\n1\n2\n3\n"
assert df.write_csv(float_precision=1) == "col\n1.0\n2.2\n3.3\n"
assert df.write_csv(float_precision=2) == "col\n1.00\n2.20\n3.33\n"
assert df.write_csv(float_precision=3) == "col\n1.000\n2.200\n3.330\n"

0 comments on commit cc9667c

Please sign in to comment.