Skip to content

Commit

Permalink
feat[rust, python]: if datetime format not specified for "write_csv",…
Browse files Browse the repository at this point in the history
… infer the required precision (#4724)
  • Loading branch information
alexander-beedie committed Sep 4, 2022
1 parent a56ea7c commit 3620e06
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 5 deletions.
3 changes: 1 addition & 2 deletions polars/polars-io/src/csv/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ where
// 9f: all nanoseconds
let options = write_impl::SerializeOptions {
time_format: Some("%T%.9f".to_string()),
datetime_format: Some("%FT%H:%M:%S.%9f".to_string()),
..Default::default()
};

Expand All @@ -38,7 +37,7 @@ where
if self.header {
write_impl::write_header(&mut self.buffer, &names, &self.options)?;
}
write_impl::write(&mut self.buffer, df, self.batch_size, &self.options)
write_impl::write(&mut self.buffer, df, self.batch_size, &mut self.options)
}
}

Expand Down
24 changes: 23 additions & 1 deletion polars/polars-io/src/csv/write_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,35 @@ pub(crate) fn write<W: Write>(
writer: &mut W,
df: &DataFrame,
chunk_size: usize,
options: &SerializeOptions,
options: &mut SerializeOptions,
) -> Result<()> {
// check that the double quote is valid utf8
std::str::from_utf8(&[options.quote, options.quote])
.map_err(|_| PolarsError::ComputeError("quote char leads invalid utf8".into()))?;
let delimiter = char::from(options.delimiter);

// if datetime format not specified, infer the maximum required precision
if options.datetime_format.is_none() {
for col in df.get_columns() {
match col.dtype() {
DataType::Datetime(TimeUnit::Microseconds, _)
if options.datetime_format.is_none() =>
{
options.datetime_format = Some("%FT%H:%M:%S.%6f".to_string());
}
DataType::Datetime(TimeUnit::Nanoseconds, _) => {
options.datetime_format = Some("%FT%H:%M:%S.%9f".to_string());
break; // highest precision; no need to check further
}
_ => {}
}
}
// if still not set, no cols require higher precision than "ms" (or no datetime cols)
if options.datetime_format.is_none() {
options.datetime_format = Some("%FT%H:%M:%S.%3f".to_string());
}
}

let len = df.height();
let n_threads = POOL.current_num_threads();
let total_rows_per_pool_iter = n_threads * chunk_size;
Expand Down
4 changes: 3 additions & 1 deletion py-polars/polars/internals/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1814,7 +1814,9 @@ def write_csv(
datetime_format
A format string, with the specifiers defined by the
`chrono <https://docs.rs/chrono/latest/chrono/format/strftime/index.html>`_
Rust crate.
Rust crate. If no format specified, the default fractional-second
precision is inferred from the maximum timeunit found in the frame's
Datetime cols (if any).
date_format
A format string, with the specifiers defined by the
`chrono <https://docs.rs/chrono/latest/chrono/format/strftime/index.html>`_
Expand Down
1 change: 1 addition & 0 deletions py-polars/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,7 @@ impl PyDataFrame {
null_value: Option<String>,
) -> PyResult<()> {
let null = null_value.unwrap_or(String::new());

if let Ok(s) = py_f.extract::<&str>(py) {
let f = std::fs::File::create(s).unwrap();
// no need for a buffered writer, because the csv writer does internal buffering
Expand Down
44 changes: 43 additions & 1 deletion py-polars/tests/unit/io/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import polars as pl
from polars import DataType
from polars.internals.type_aliases import TimeUnit
from polars.testing import assert_frame_equal_local_categoricals


Expand Down Expand Up @@ -692,7 +693,8 @@ def test_csv_dtype_overwrite_bool() -> None:
@pytest.mark.parametrize(
"fmt,expected",
[
(None, "dt\n2022-01-02T00:00:00.000000000\n"),
(None, "dt\n2022-01-02T00:00:00.000000\n"),
("%F %T%.3f", "dt\n2022-01-02 00:00:00.000\n"),
("%Y", "dt\n2022\n"),
("%m", "dt\n01\n"),
("%m$%d", "dt\n01$02\n"),
Expand All @@ -705,6 +707,46 @@ def test_datetime_format(fmt: str, expected: str) -> None:
assert csv == expected


@pytest.mark.parametrize(
"tu1,tu2,expected",
[
(
"ns",
"ns",
"x,y\n2022-09-04T10:30:45.123000000,2022-09-04T10:30:45.123000000\n",
),
(
"ns",
"us",
"x,y\n2022-09-04T10:30:45.123000000,2022-09-04T10:30:45.123000000\n",
),
(
"ns",
"ms",
"x,y\n2022-09-04T10:30:45.123000000,2022-09-04T10:30:45.123000000\n",
),
("us", "us", "x,y\n2022-09-04T10:30:45.123000,2022-09-04T10:30:45.123000\n"),
("us", "ms", "x,y\n2022-09-04T10:30:45.123000,2022-09-04T10:30:45.123000\n"),
("ms", "us", "x,y\n2022-09-04T10:30:45.123000,2022-09-04T10:30:45.123000\n"),
("ms", "ms", "x,y\n2022-09-04T10:30:45.123,2022-09-04T10:30:45.123\n"),
],
)
def test_datetime_format_inferred_precision(
tu1: TimeUnit, tu2: TimeUnit, expected: str
) -> None:
df = pl.DataFrame(
data={
"x": [datetime(2022, 9, 4, 10, 30, 45, 123000)],
"y": [datetime(2022, 9, 4, 10, 30, 45, 123000)],
},
columns=[
("x", pl.Datetime(tu1)),
("y", pl.Datetime(tu2)),
],
)
assert expected == df.write_csv()


@pytest.mark.parametrize(
"fmt,expected",
[
Expand Down

0 comments on commit 3620e06

Please sign in to comment.