Skip to content

Commit

Permalink
fix[python]: default numpy array keep dimensions (#4507)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Aug 20, 2022
1 parent cc9667c commit 6b56b7d
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 6 deletions.
21 changes: 16 additions & 5 deletions py-polars/polars/internals/construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,13 +563,24 @@ def numpy_to_pydf(
n_columns = 1

elif len(shape) == 2:
# Infer orientation
if orient is None and columns is not None:
orient = "col" if len(columns) == shape[0] else "row"
# default convention
# first axis is rows, second axis is columns
if orient is None and columns is None:
n_columns = shape[1]
orient = "row"

if orient == "row":
# Infer orientation if columns argument is given
elif orient is None and columns is not None:
if len(columns) == shape[0]:
orient = "col"
n_columns = shape[0]
else:
orient = "row"
n_columns = shape[1]

elif orient == "row":
n_columns = shape[1]
elif orient == "col" or orient is None:
elif orient == "col":
n_columns = shape[0]
else:
raise ValueError(
Expand Down
6 changes: 5 additions & 1 deletion py-polars/tests/test_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,13 @@ def test_init_ndarray(monkeypatch: Any) -> None:

# 2D array - default to column orientation
df = pl.DataFrame(np.array([[1, 2], [3, 4]]))
truth = pl.DataFrame({"column_0": [1, 2], "column_1": [3, 4]})
truth = pl.DataFrame({"column_0": [1, 3], "column_1": [2, 4]})
assert df.frame_equal(truth)

# no orientation is numpy convention
df = pl.DataFrame(np.ones((3, 1)))
assert df.shape == (3, 1)

# 2D array - row orientation inferred
df = pl.DataFrame(np.array([[1, 2, 3], [4, 5, 6]]), columns=["a", "b", "c"])
truth = pl.DataFrame({"a": [1, 4], "b": [2, 5], "c": [3, 6]})
Expand Down

0 comments on commit 6b56b7d

Please sign in to comment.