Skip to content

Commit

Permalink
Backport PR #52036 on branch 2.0.x (BUG: Remove unnecessary validatio…
Browse files Browse the repository at this point in the history
…n to non-string columns/index in df.to_parquet) (#52044)

BUG: Remove unnecessary validation to non-string columns/index in df.to_parquet (#52036)

Co-authored-by: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com>
  • Loading branch information
phofl and mroeschke committed Mar 17, 2023
1 parent 1cde5ec commit fea45ba
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 60 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1293,6 +1293,7 @@ I/O
- Bug in :func:`read_csv` when ``engine="pyarrow"`` where ``encoding`` parameter was not handled correctly (:issue:`51302`)
- Bug in :func:`read_xml` ignored repeated elements when iterparse is used (:issue:`51183`)
- Bug in :class:`ExcelWriter` leaving file handles open if an exception occurred during instantiation (:issue:`51443`)
- Bug in :meth:`DataFrame.to_parquet` where non-string index or columns were raising a ``ValueError`` when ``engine="pyarrow"`` (:issue:`52036`)

Period
^^^^^^
Expand Down
23 changes: 0 additions & 23 deletions pandas/io/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import pandas as pd
from pandas import (
DataFrame,
MultiIndex,
get_option,
)
from pandas.core.shared_docs import _shared_docs
Expand Down Expand Up @@ -122,28 +121,6 @@ def validate_dataframe(df: DataFrame) -> None:
if not isinstance(df, DataFrame):
raise ValueError("to_parquet only supports IO with DataFrames")

# must have value column names for all index levels (strings only)
if isinstance(df.columns, MultiIndex):
if not all(
x.inferred_type in {"string", "empty"} for x in df.columns.levels
):
raise ValueError(
"""
parquet must have string column names for all values in
each level of the MultiIndex
"""
)
else:
if df.columns.inferred_type not in {"string", "empty"}:
raise ValueError("parquet must have string column names")

# index level names must be strings
valid_names = all(
isinstance(name, str) for name in df.index.names if name is not None
)
if not valid_names:
raise ValueError("Index level names must be strings")

def write(self, df: DataFrame, path, compression, **kwargs):
raise AbstractMethodError(self)

Expand Down
115 changes: 78 additions & 37 deletions pandas/tests/io/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,25 +404,6 @@ def test_columns_dtypes(self, engine):
df.columns = ["foo", "bar"]
check_round_trip(df, engine)

def test_columns_dtypes_invalid(self, engine):
df = pd.DataFrame({"string": list("abc"), "int": list(range(1, 4))})

msg = "parquet must have string column names"
# numeric
df.columns = [0, 1]
self.check_error_on_write(df, engine, ValueError, msg)

# bytes
df.columns = [b"foo", b"bar"]
self.check_error_on_write(df, engine, ValueError, msg)

# python object
df.columns = [
datetime.datetime(2011, 1, 1, 0, 0),
datetime.datetime(2011, 1, 1, 1, 1),
]
self.check_error_on_write(df, engine, ValueError, msg)

@pytest.mark.parametrize("compression", [None, "gzip", "snappy", "brotli"])
def test_compression(self, engine, compression):
if compression == "snappy":
Expand Down Expand Up @@ -528,16 +509,16 @@ def test_write_column_multiindex(self, engine):
# Not able to write column multi-indexes with non-string column names.
mi_columns = pd.MultiIndex.from_tuples([("a", 1), ("a", 2), ("b", 1)])
df = pd.DataFrame(np.random.randn(4, 3), columns=mi_columns)
msg = (
r"\s*parquet must have string column names for all values in\s*"
"each level of the MultiIndex"
)
self.check_error_on_write(df, engine, ValueError, msg)

def test_write_column_multiindex_nonstring(self, pa):
if engine == "fastparquet":
self.check_error_on_write(
df, engine, TypeError, "Column name must be a string"
)
elif engine == "pyarrow":
check_round_trip(df, engine)

def test_write_column_multiindex_nonstring(self, engine):
# GH #34777
# Not supported in fastparquet as of 0.1.3
engine = pa

# Not able to write column multi-indexes with non-string column names
arrays = [
Expand All @@ -546,11 +527,14 @@ def test_write_column_multiindex_nonstring(self, pa):
]
df = pd.DataFrame(np.random.randn(8, 8), columns=arrays)
df.columns.names = ["Level1", "Level2"]
msg = (
r"\s*parquet must have string column names for all values in\s*"
"each level of the MultiIndex"
)
self.check_error_on_write(df, engine, ValueError, msg)
if engine == "fastparquet":
if Version(fastparquet.__version__) < Version("0.7.0"):
err = TypeError
else:
err = ValueError
self.check_error_on_write(df, engine, err, "Column name")
elif engine == "pyarrow":
check_round_trip(df, engine)

def test_write_column_multiindex_string(self, pa):
# GH #34777
Expand Down Expand Up @@ -579,17 +563,19 @@ def test_write_column_index_string(self, pa):

check_round_trip(df, engine)

def test_write_column_index_nonstring(self, pa):
def test_write_column_index_nonstring(self, engine):
# GH #34777
# Not supported in fastparquet as of 0.1.3
engine = pa

# Write column indexes with string column names
arrays = [1, 2, 3, 4]
df = pd.DataFrame(np.random.randn(8, 4), columns=arrays)
df.columns.name = "NonStringCol"
msg = r"parquet must have string column names"
self.check_error_on_write(df, engine, ValueError, msg)
if engine == "fastparquet":
self.check_error_on_write(
df, engine, TypeError, "Column name must be a string"
)
else:
check_round_trip(df, engine)

@pytest.mark.skipif(pa_version_under7p0, reason="minimum pyarrow not installed")
def test_dtype_backend(self, engine, request):
Expand Down Expand Up @@ -1041,6 +1027,31 @@ def test_read_dtype_backend_pyarrow_config_index(self, pa):
expected=expected,
)

def test_columns_dtypes_not_invalid(self, pa):
df = pd.DataFrame({"string": list("abc"), "int": list(range(1, 4))})

# numeric
df.columns = [0, 1]
check_round_trip(df, pa)

# bytes
df.columns = [b"foo", b"bar"]
with pytest.raises(NotImplementedError, match="|S3"):
# Bytes fails on read_parquet
check_round_trip(df, pa)

# python object
df.columns = [
datetime.datetime(2011, 1, 1, 0, 0),
datetime.datetime(2011, 1, 1, 1, 1),
]
check_round_trip(df, pa)

def test_empty_columns(self, pa):
# GH 52034
df = pd.DataFrame(index=pd.Index(["a", "b", "c"], name="custom name"))
check_round_trip(df, pa)


class TestParquetFastParquet(Base):
def test_basic(self, fp, df_full):
Expand All @@ -1052,6 +1063,27 @@ def test_basic(self, fp, df_full):
df["timedelta"] = pd.timedelta_range("1 day", periods=3)
check_round_trip(df, fp)

def test_columns_dtypes_invalid(self, fp):
df = pd.DataFrame({"string": list("abc"), "int": list(range(1, 4))})

err = TypeError
msg = "Column name must be a string"

# numeric
df.columns = [0, 1]
self.check_error_on_write(df, fp, err, msg)

# bytes
df.columns = [b"foo", b"bar"]
self.check_error_on_write(df, fp, err, msg)

# python object
df.columns = [
datetime.datetime(2011, 1, 1, 0, 0),
datetime.datetime(2011, 1, 1, 1, 1),
]
self.check_error_on_write(df, fp, err, msg)

def test_duplicate_columns(self, fp):
# not currently able to handle duplicate columns
df = pd.DataFrame(np.arange(12).reshape(4, 3), columns=list("aaa")).copy()
Expand Down Expand Up @@ -1221,3 +1253,12 @@ def test_invalid_dtype_backend(self, engine):
df.to_parquet(path)
with pytest.raises(ValueError, match=msg):
read_parquet(path, dtype_backend="numpy")

def test_empty_columns(self, fp):
# GH 52034
df = pd.DataFrame(index=pd.Index(["a", "b", "c"], name="custom name"))
expected = pd.DataFrame(
columns=pd.Index([], dtype=object),
index=pd.Index(["a", "b", "c"], name="custom name"),
)
check_round_trip(df, fp, expected=expected)

0 comments on commit fea45ba

Please sign in to comment.