Skip to content

Commit

Permalink
Fix from_records/from_numpy orient type (#3961)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Jul 17, 2022
1 parent 6695cce commit 19cc512
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 8 deletions.
10 changes: 8 additions & 2 deletions py-polars/polars/convert.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import sys
import warnings
from typing import Any, Mapping, Sequence, overload

Expand All @@ -26,6 +27,11 @@
except ImportError: # pragma: no cover
_PANDAS_AVAILABLE = False

if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal # pragma: no cover


def from_dict(
data: Mapping[str, Sequence | Mapping],
Expand Down Expand Up @@ -110,7 +116,7 @@ def from_dicts(
def from_records(
data: Sequence[Sequence[Any]],
columns: Sequence[str] | None = None,
orient: str | None = None,
orient: Literal["col", "row"] | None = None,
) -> DataFrame:
"""
Construct a DataFrame from a numpy ndarray or sequence of sequences.
Expand Down Expand Up @@ -166,7 +172,7 @@ def from_records(
def from_numpy(
data: np.ndarray,
columns: Sequence[str] | None = None,
orient: str | None = None,
orient: Literal["col", "row"] | None = None,
) -> DataFrame:
"""
Construct a DataFrame from a numpy ndarray.
Expand Down
10 changes: 8 additions & 2 deletions py-polars/polars/internals/construction.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import sys
import warnings
from datetime import date, datetime, time, timedelta
from itertools import zip_longest
Expand Down Expand Up @@ -53,6 +54,11 @@
except ImportError: # pragma: no cover
_PYARROW_AVAILABLE = False

if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal # pragma: no cover


################################
# Series constructor interface #
Expand Down Expand Up @@ -449,7 +455,7 @@ def dict_to_pydf(
def sequence_to_pydf(
data: Sequence[Any],
columns: ColumnsType | None = None,
orient: str | None = None,
orient: Literal["col", "row"] | None = None,
) -> PyDataFrame:
"""
Construct a PyDataFrame from a sequence.
Expand Down Expand Up @@ -507,7 +513,7 @@ def sequence_to_pydf(
def numpy_to_pydf(
data: np.ndarray,
columns: ColumnsType | None = None,
orient: str | None = None,
orient: Literal["col", "row"] | None = None,
) -> PyDataFrame:
"""
Construct a PyDataFrame from a numpy ndarray.
Expand Down
6 changes: 3 additions & 3 deletions py-polars/polars/internals/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def __init__(
| pli.Series
) = None,
columns: ColumnsType | None = None,
orient: str | None = None,
orient: Literal["col", "row"] | None = None,
):
if data is None:
self._df = dict_to_pydf({}, columns=columns)
Expand Down Expand Up @@ -392,7 +392,7 @@ def _from_records(
cls: type[DF],
data: Sequence[Sequence[Any]],
columns: Sequence[str] | None = None,
orient: str | None = None,
orient: Literal["col", "row"] | None = None,
) -> DF:
"""
Construct a DataFrame from a sequence of sequences.
Expand Down Expand Up @@ -420,7 +420,7 @@ def _from_numpy(
cls: type[DF],
data: np.ndarray,
columns: Sequence[str] | None = None,
orient: str | None = None,
orient: Literal["col", "row"] | None = None,
) -> DF:
"""
Construct a DataFrame from a numpy ndarray.
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def test_init_ndarray() -> None:
assert df.frame_equal(truth)

# 2D array - default to column orientation
df = pl.DataFrame(np.array([[1, 2], [3, 4]]), orient="column")
df = pl.DataFrame(np.array([[1, 2], [3, 4]]), orient="col")
truth = pl.DataFrame({"column_0": [1, 2], "column_1": [3, 4]})
assert df.frame_equal(truth)

Expand Down

0 comments on commit 19cc512

Please sign in to comment.