Skip to content

Commit

Permalink
Handle wrong input for orient argument (#4065)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Jul 18, 2022
1 parent 5bfc886 commit 7cfa390
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
14 changes: 11 additions & 3 deletions py-polars/polars/internals/construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal # pragma: no cover
from typing_extensions import Literal


################################
Expand Down Expand Up @@ -495,12 +495,16 @@ def sequence_to_pydf(
if columns:
pydf = _post_apply_columns(pydf, columns)
return pydf
else:
elif orient == "col" or orient is None:
columns, dtypes = _unpack_columns(columns, n_expected=len(data))
data_series = [
pli.Series(columns[i], data[i], dtypes.get(columns[i])).inner()
for i in range(len(data))
]
else:
raise ValueError(
f"orient must be one of {{'col', 'row', None}}, got {orient} instead."
)

else:
columns, dtypes = _unpack_columns(columns, n_expected=1)
Expand Down Expand Up @@ -557,11 +561,15 @@ def numpy_to_pydf(
pli.Series(columns[i], data[:, i], dtypes.get(columns[i])).inner()
for i in range(n_columns)
]
else:
elif orient == "col":
data_series = [
pli.Series(columns[i], data[i], dtypes.get(columns[i])).inner()
for i in range(n_columns)
]
else:
raise ValueError(
f"orient must be one of {{'col', 'row', None}}, got {orient} instead."
)
else:
raise ValueError("A numpy array should not have more than two dimensions.")

Expand Down
11 changes: 11 additions & 0 deletions py-polars/tests/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,13 @@ def test_init_ndarray() -> None:
with pytest.raises(ValueError):
_ = pl.DataFrame(np.random.randn(2, 2, 2))

# Wrong orient value
with pytest.raises(ValueError):
df = pl.DataFrame(
np.array([[1, 2, 3], [4, 5, 6]]),
orient="wrong", # type: ignore[arg-type]
)

# numpy not available
with patch("polars.internals.frame._NUMPY_AVAILABLE", False):
with pytest.raises(ValueError):
Expand Down Expand Up @@ -344,6 +351,10 @@ def test_init_seq_of_seq() -> None:
assert df.schema == {"a": pl.Float32, "b": pl.Float32}
assert df.rows() == [(1.0, 2.0), (3.0, 4.0)]

# Wrong orient value
with pytest.raises(ValueError):
df = pl.DataFrame(((1, 2), (3, 4)), orient="wrong") # type: ignore[arg-type]


def test_init_1d_sequence() -> None:
# Empty list
Expand Down

0 comments on commit 7cfa390

Please sign in to comment.