Skip to content

Commit

Permalink
fix(python): fix for categorical inserts from row-oriented data (#5462)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie committed Nov 10, 2022
1 parent 55caf1f commit 9f5049f
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 13 deletions.
39 changes: 28 additions & 11 deletions py-polars/polars/internals/construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
PolarsDataType,
Time,
Unknown,
Utf8,
dtype_to_arrow_type,
dtype_to_py_type,
is_polars_dtype,
Expand Down Expand Up @@ -507,18 +508,22 @@ def _handle_columns_arg(
raise ValueError("Dimensions of columns arg must match data dimensions.")


def _post_apply_columns(pydf: PyDataFrame, columns: ColumnsType) -> PyDataFrame:
def _post_apply_columns(
pydf: PyDataFrame, columns: ColumnsType, categoricals: set[str] | None = None
) -> PyDataFrame:
"""Apply 'columns' param _after_ PyDataFrame creation (if no alternative)."""
pydf_columns, pydf_dtypes = pydf.columns(), pydf.dtypes()
columns, dtypes = _unpack_columns(columns or pydf_columns)
if columns != pydf_columns:
pydf.set_column_names(columns)

column_casts = [
pli.col(col).cast(dtypes[col])._pyexpr
for i, col in enumerate(columns)
if col in dtypes and dtypes[col] != pydf_dtypes[i]
]
column_casts = []
for i, col in enumerate(columns):
if categoricals and col in categoricals:
column_casts.append(pli.col(col).cast(Categorical)._pyexpr)
elif col in dtypes and dtypes[col] != pydf_dtypes[i]:
column_casts.append(pli.col(col).cast(dtypes[col])._pyexpr)

if column_casts:
pydf = pydf.lazy().with_columns(column_casts).collect()
return pydf
Expand Down Expand Up @@ -655,13 +660,22 @@ def sequence_to_pydf(

if orient == "row":
column_names, dtypes = _unpack_columns(columns)
schema_override = include_unknowns(dtypes, column_names) if dtypes else None
schema_override = include_unknowns(dtypes, column_names) if dtypes else {}
if column_names and data and len(data[0]) != len(column_names):
raise ShapeError("The row data does not match the number of columns")
categoricals = {
col for col, tp in schema_override.items() if tp == Categorical
}
for col in categoricals:
schema_override[col] = Utf8

pydf = PyDataFrame.read_rows(data, infer_schema_length, schema_override)
pydf = PyDataFrame.read_rows(
data,
infer_schema_length,
schema_override or None,
)
if column_names:
pydf = _post_apply_columns(pydf, column_names)
pydf = _post_apply_columns(pydf, column_names, categoricals)
return pydf

elif orient == "col" or orient is None:
Expand All @@ -685,12 +699,15 @@ def sequence_to_pydf(
col: (py_type_to_dtype(tp, raise_unmatched=False) or Unknown)
for col, tp in dataclass_type_hints(data[0].__class__).items()
}
categoricals = {col for col, tp in schema_override.items() if tp == Categorical}
for col in categoricals:
schema_override[col] = Utf8

pydf = PyDataFrame.read_rows(
[astuple(dc) for dc in data], infer_schema_length, schema_override
[astuple(dc) for dc in data], infer_schema_length, schema_override or None
)
if columns:
pydf = _post_apply_columns(pydf, columns)
pydf = _post_apply_columns(pydf, columns, categoricals)
return pydf

elif _PANDAS_TYPE(data[0]) and isinstance(data[0], (pd.Series, pd.DatetimeIndex)):
Expand Down
13 changes: 11 additions & 2 deletions py-polars/tests/unit/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,14 +264,14 @@ class TradeNT(NamedTuple):
data=trades,
columns=[ # type: ignore[arg-type]
("ts", pl.Datetime("ms")),
("tk", pl.Utf8),
("tk", pl.Categorical),
("pc", pl.Float32),
("sz", pl.UInt16),
],
)
assert df.schema == {
"ts": pl.Datetime("ms"),
"tk": pl.Utf8,
"tk": pl.Categorical,
"pc": pl.Float32,
"sz": pl.UInt16,
}
Expand Down Expand Up @@ -1086,6 +1086,7 @@ def test_string_cache_eager_lazy() -> None:
df1 = pl.DataFrame(
{"region_ids": ["reg1", "reg2", "reg3", "reg4", "reg5"]}
).select([pl.col("region_ids").cast(pl.Categorical)])

df2 = pl.DataFrame(
{"seq_name": ["reg4", "reg2", "reg1"], "score": [3.0, 1.0, 2.0]}
).select([pl.col("seq_name").cast(pl.Categorical), pl.col("score")])
Expand All @@ -1101,6 +1102,14 @@ def test_string_cache_eager_lazy() -> None:
df2, left_on="region_ids", right_on="seq_name", how="left"
).frame_equal(expected, null_equal=True)

# also check row-wise categorical insert.
# (column-wise is preferred, but this shouldn't fail)
df3 = pl.DataFrame(
data=[["reg1"], ["reg2"], ["reg3"], ["reg4"], ["reg5"]],
columns=[("region_ids", pl.Categorical)],
)
assert_frame_equal(df1, df3)


def test_assign() -> None:
# check if can assign in case of a single column
Expand Down

0 comments on commit 9f5049f

Please sign in to comment.