Skip to content

Commit

Permalink
Improve handling of literal arguments: compression (#4296)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Aug 8, 2022
1 parent 189df82 commit ab92d15
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 44 deletions.
57 changes: 20 additions & 37 deletions py-polars/polars/internals/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,7 +1182,7 @@ def write_csv(
batch_size: int = 1024,
) -> str | None:
"""
Write Dataframe to comma-separated values file (csv).
Write to comma-separated values (CSV) file.
Parameters
----------
Expand Down Expand Up @@ -1254,13 +1254,12 @@ def write_avro(
----------
file
File path to which the file should be written.
compression
Compression method. Choose one of:
- "uncompressed"
- "snappy"
- "deflate"
compression : {'uncompressed', 'snappy', 'deflate'}
Compression method. Defaults to "uncompressed".
"""
if compression is None:
compression = "uncompressed"
if isinstance(file, (str, Path)):
file = format_path(file)

Expand All @@ -1283,20 +1282,17 @@ def to_avro(
def write_ipc(
self,
file: BinaryIO | BytesIO | str | Path,
compression: Literal["uncompressed", "lz4", "zstd"] | None = "uncompressed",
compression: Literal["uncompressed", "lz4", "zstd"] = "uncompressed",
) -> None:
"""
Write to Arrow IPC binary stream, or a feather file.
Write to Arrow IPC binary stream or Feather file.
Parameters
----------
file
File path to which the file should be written.
compression
Compression method. Choose one of:
- "uncompressed"
- "lz4"
- "zstd"
compression : {'uncompressed', 'lz4', 'zstd'}
Compression method. Defaults to "uncompressed".
"""
if compression is None:
Expand All @@ -1309,7 +1305,7 @@ def write_ipc(
def to_ipc(
self,
file: BinaryIO | BytesIO | str | Path,
compression: Literal["uncompressed", "lz4", "zstd"] | None = "uncompressed",
compression: Literal["uncompressed", "lz4", "zstd"] = "uncompressed",
) -> None: # pragma: no cover
"""
.. deprecated:: 0.13.12
Expand Down Expand Up @@ -1454,38 +1450,27 @@ def write_parquet(
self,
file: str | Path | BytesIO,
*,
compression: (
Literal["uncompressed", "snappy", "gzip", "lzo", "brotli", "lz4", "zstd"]
| str
| None
) = "lz4",
compression: Literal[
"lz4", "uncompressed", "snappy", "gzip", "lzo", "brotli", "zstd"
] = "lz4",
compression_level: int | None = None,
statistics: bool = False,
row_group_size: int | None = None,
use_pyarrow: bool = False,
**kwargs: Any,
) -> None:
"""
Write the DataFrame to disk in parquet format.
Write to Apache Parquet file.
Parameters
----------
file
File path to which the file should be written.
compression
Compression method. Choose one of:
- "uncompressed" (not supported by pyarrow)
- "snappy"
- "gzip"
- "lzo"
- "brotli"
- "lz4"
- "zstd"
The default compression "lz4" (actually lz4raw) has very good performance,
but may not yet been supported by older readers. If you want more
compatability guarantees, consider using "snappy".
compression : {'lz4', 'uncompressed', 'snappy', 'gzip', 'lzo', 'brotli', 'zstd'}
Compression method. The default compression "lz4" (actually lz4raw) has very
good performance, but may not yet been supported by older readers. If you
want more compatability guarantees, consider using "snappy".
Method "uncompressed" is not supported by pyarrow.
compression_level
The level of compression to use. Higher compression means smaller files on
disk.
Expand Down Expand Up @@ -1548,9 +1533,7 @@ def to_parquet(
self,
file: str | Path | BytesIO,
compression: (
Literal["uncompressed", "snappy", "gzip", "lzo", "brotli", "lz4", "zstd"]
| str
| None
Literal["lz4", "uncompressed", "snappy", "gzip", "lzo", "brotli", "zstd"]
) = "snappy",
statistics: bool = False,
use_pyarrow: bool = False,
Expand Down
21 changes: 18 additions & 3 deletions py-polars/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,12 @@ impl PyDataFrame {
"uncompressed" => None,
"snappy" => Some(AvroCompression::Snappy),
"deflate" => Some(AvroCompression::Deflate),
s => return Err(PyPolarsErr::Other(format!("compression {} not supported", s)).into()),
e => {
return Err(PyValueError::new_err(format!(
"compression must be one of {{'uncompressed', 'snappy', 'deflate'}}, got {}",
e
)))
}
};

if let Ok(s) = py_f.extract::<&str>(py) {
Expand Down Expand Up @@ -462,7 +467,12 @@ impl PyDataFrame {
"uncompressed" => None,
"lz4" => Some(IpcCompression::LZ4),
"zstd" => Some(IpcCompression::ZSTD),
s => return Err(PyPolarsErr::Other(format!("compression {} not supported", s)).into()),
e => {
return Err(PyValueError::new_err(format!(
"compression must be one of {{'uncompressed', 'lz4', 'zstd'}}, got {}",
e
)))
}
};

if let Ok(s) = py_f.extract::<&str>(py) {
Expand Down Expand Up @@ -613,7 +623,12 @@ impl PyDataFrame {
})
.transpose()?,
),
s => return Err(PyPolarsErr::Other(format!("compression {} not supported", s)).into()),
e => {
return Err(PyValueError::new_err(format!(
"compression must be one of {{'uncompressed', 'snappy', 'gzip', 'lzo', 'brotli', 'lz4', 'zstd'}}, got {}",
e
)))
}
};

if let Ok(s) = py_f.extract::<&str>(py) {
Expand Down
21 changes: 17 additions & 4 deletions py-polars/tests/io/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import io
import os
import sys

import numpy as np
import pandas as pd
Expand All @@ -10,13 +11,25 @@
import polars as pl
from polars.testing import assert_frame_equal_local_categoricals

if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal


CompressionMethod = Literal[
"lz4", "uncompressed", "snappy", "gzip", "lzo", "brotli", "zstd"
]


@pytest.fixture
def compressions() -> list[str]:
return ["uncompressed", "snappy", "gzip", "lzo", "brotli", "lz4", "zstd"]
def compressions() -> list[CompressionMethod]:
return ["lz4", "uncompressed", "snappy", "gzip", "lzo", "brotli", "zstd"]


def test_to_from_buffer(df: pl.DataFrame, compressions: list[str]) -> None:
def test_to_from_buffer(
df: pl.DataFrame, compressions: list[CompressionMethod]
) -> None:
for compression in compressions:
if compression == "lzo":
# lzo compression is not supported now
Expand Down Expand Up @@ -47,7 +60,7 @@ def test_to_from_buffer(df: pl.DataFrame, compressions: list[str]) -> None:


def test_to_from_file(
io_test_dir: str, df: pl.DataFrame, compressions: list[str]
io_test_dir: str, df: pl.DataFrame, compressions: list[CompressionMethod]
) -> None:
f = os.path.join(io_test_dir, "small.parquet")
for compression in compressions:
Expand Down

0 comments on commit ab92d15

Please sign in to comment.