Skip to content

Commit

Permalink
feat[python]: Expose format parameters to DataFrame.write_csv (#4403)
Browse files Browse the repository at this point in the history
  • Loading branch information
matteosantama committed Aug 14, 2022
1 parent 4f48bb5 commit 5ac1b7b
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 12 deletions.
16 changes: 11 additions & 5 deletions polars/polars-io/src/csv/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,19 +65,25 @@ where

/// Set the CSV file's date format
pub fn with_date_format(mut self, format: Option<String>) -> Self {
self.options.date_format = format;
if format.is_some() {
self.options.date_format = format;
}
self
}

/// Set the CSV file's time format
pub fn with_time_format(mut self, format: Option<String>) -> Self {
self.options.time_format = format;
if format.is_some() {
self.options.time_format = format;
}
self
}

/// Set the CSV file's timestamp format array in
pub fn with_datetime(mut self, format: Option<String>) -> Self {
self.options.datetime_format = format;
/// Set the CSV file's timestamp format array
pub fn with_datetime_format(mut self, format: Option<String>) -> Self {
if format.is_some() {
self.options.datetime_format = format;
}
self
}

Expand Down
41 changes: 37 additions & 4 deletions py-polars/polars/internals/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,9 @@ def write_csv(
sep: str = ",",
quote: str = '"',
batch_size: int = 1024,
datetime_format: str | None = None,
date_format: str | None = None,
time_format: str | None = None,
) -> str | None:
"""
Write to comma-separated values (CSV) file.
Expand All @@ -1090,9 +1093,21 @@ def write_csv(
sep
Separate CSV fields with this symbol.
quote
byte to use as quoting character
Byte to use as quoting character.
batch_size
rows that will be processed per thread
Number of rows that will be processed per thread.
datetime_format
A format string, with the specifiers defined by the
`chrono <https://docs.rs/chrono/latest/chrono/format/strftime/index.html>`_
Rust crate.
date_format
A format string, with the specifiers defined by the
`chrono <https://docs.rs/chrono/latest/chrono/format/strftime/index.html>`_
Rust crate.
time_format
A format string, with the specifiers defined by the
`chrono <https://docs.rs/chrono/latest/chrono/format/strftime/index.html>`_
Rust crate.
Examples
--------
Expand All @@ -1115,13 +1130,31 @@ def write_csv(
raise ValueError("only single byte quote char is allowed")
if file is None:
buffer = BytesIO()
self._df.write_csv(buffer, has_header, ord(sep), ord(quote), batch_size)
self._df.write_csv(
buffer,
has_header,
ord(sep),
ord(quote),
batch_size,
datetime_format,
date_format,
time_format,
)
return str(buffer.getvalue(), encoding="utf-8")

if isinstance(file, (str, Path)):
file = format_path(file)

self._df.write_csv(file, has_header, ord(sep), ord(quote), batch_size)
self._df.write_csv(
file,
has_header,
ord(sep),
ord(quote),
batch_size,
datetime_format,
date_format,
time_format,
)
return None

def write_avro(
Expand Down
9 changes: 9 additions & 0 deletions py-polars/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,9 @@ impl PyDataFrame {
sep: u8,
quote: u8,
batch_size: usize,
datetime_format: Option<String>,
date_format: Option<String>,
time_format: Option<String>,
) -> PyResult<()> {
if let Ok(s) = py_f.extract::<&str>(py) {
let f = std::fs::File::create(s).unwrap();
Expand All @@ -435,6 +438,9 @@ impl PyDataFrame {
.with_delimiter(sep)
.with_quoting_char(quote)
.with_batch_size(batch_size)
.with_datetime_format(datetime_format)
.with_date_format(date_format)
.with_time_format(time_format)
.finish(&mut self.df)
.map_err(PyPolarsErr::from)?;
} else {
Expand All @@ -444,6 +450,9 @@ impl PyDataFrame {
.with_delimiter(sep)
.with_quoting_char(quote)
.with_batch_size(batch_size)
.with_datetime_format(datetime_format)
.with_date_format(date_format)
.with_time_format(time_format)
.finish(&mut self.df)
.map_err(PyPolarsErr::from)?;
}
Expand Down
50 changes: 47 additions & 3 deletions py-polars/tests/io/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import textwrap
import zlib
from datetime import date
from datetime import date, datetime, time
from pathlib import Path

import pytest
Expand Down Expand Up @@ -482,7 +482,7 @@ def test_escaped_null_values() -> None:
assert df[0, "c"] is None


def quoting_round_trip() -> None:
def test_quoting_round_trip() -> None:
f = io.BytesIO()
df = pl.DataFrame(
{
Expand All @@ -500,7 +500,7 @@ def quoting_round_trip() -> None:
assert read_df.frame_equal(df)


def fallback_chrono_parser() -> None:
def test_fallback_chrono_parser() -> None:
data = textwrap.dedent(
"""\
date_1,date_2
Expand Down Expand Up @@ -616,3 +616,47 @@ def test_csv_dtype_overwrite_bool() -> None:
dtypes={"a": pl.Boolean, "b": pl.Boolean},
)
assert df.dtypes == [pl.Boolean, pl.Boolean]


@pytest.mark.parametrize(
"fmt,expected",
[
(None, "dt\n2022-01-02T00:00:00.000000000\n"),
("%Y", "dt\n2022\n"),
("%m", "dt\n01\n"),
("%m$%d", "dt\n01$02\n"),
("%R", "dt\n00:00\n"),
],
)
def test_datetime_format(fmt: str, expected: str) -> None:
df = pl.DataFrame({"dt": [datetime(2022, 1, 2)]})
csv = df.write_csv(datetime_format=fmt)
assert csv == expected


@pytest.mark.parametrize(
"fmt,expected",
[
(None, "dt\n2022-01-02\n"),
("%Y", "dt\n2022\n"),
("%m", "dt\n01\n"),
("%m$%d", "dt\n01$02\n"),
],
)
def test_date_format(fmt: str, expected: str) -> None:
df = pl.DataFrame({"dt": [date(2022, 1, 2)]})
csv = df.write_csv(date_format=fmt)
assert csv == expected


@pytest.mark.parametrize(
"fmt,expected",
[
(None, "dt\n16:15:30.000000000\n"),
("%R", "dt\n16:15\n"),
],
)
def test_time_format(fmt: str, expected: str) -> None:
df = pl.DataFrame({"dt": [time(16, 15, 30)]})
csv = df.write_csv(time_format=fmt)
assert csv == expected

0 comments on commit 5ac1b7b

Please sign in to comment.