Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(python): Consistent expansion of nested struct data during DataFrame init from dict #15217

Merged
merged 1 commit into from Mar 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 16 additions & 5 deletions py-polars/polars/_utils/construction/dataframe.py
Expand Up @@ -134,7 +134,7 @@ def dict_to_pydf(
else:
data_series = [
s._s
for s in _expand_dict_scalars(
for s in _expand_dict_values(
data,
schema_overrides=schema_overrides,
strict=strict,
Expand Down Expand Up @@ -307,7 +307,7 @@ def _post_apply_columns(
return pydf


def _expand_dict_scalars(
def _expand_dict_values(
data: Mapping[str, Sequence[object] | Mapping[str, Sequence[object]] | Series],
*,
schema_overrides: SchemaDict | None = None,
Expand All @@ -334,9 +334,20 @@ def _expand_dict_scalars(
for name, val in data.items():
dtype = dtypes.get(name)
if isinstance(val, dict) and dtype != Struct:
updated_data[name] = pl.DataFrame(val, strict=strict).to_struct(
name
)
vdf = pl.DataFrame(val, strict=strict)
if (
len(vdf) == 1
and array_len > 1
and all(not d.is_nested() for d in vdf.schema.values())
):
s_vals = {
nm: vdf[nm].extend_constant(v, n=(array_len - 1))
for nm, v in val.items()
}
st = pl.DataFrame(s_vals).to_struct(name)
else:
st = vdf.to_struct(name)
updated_data[name] = st

elif isinstance(val, pl.Series):
s = val.rename(name) if name != val.name else val
Expand Down
4 changes: 2 additions & 2 deletions py-polars/polars/_utils/various.py
Expand Up @@ -207,9 +207,9 @@ def _in_notebook() -> bool:


def arrlen(obj: Any) -> int | None:
"""Return length of (non-string) sequence object; returns None for non-sequences."""
"""Return length of (non-string/dict) sequence; returns None for non-sequences."""
try:
return None if isinstance(obj, str) else len(obj)
return None if isinstance(obj, (str, dict)) else len(obj)
except TypeError:
return None

Expand Down
41 changes: 36 additions & 5 deletions py-polars/tests/unit/dataframe/test_from_dict.py
Expand Up @@ -146,7 +146,7 @@ def test_from_dict_with_scalars() -> None:


@pytest.mark.slow()
def test_from_dict_with_scalars_mixed() -> None:
def test_from_dict_with_values_mixed() -> None:
# a bit of everything
mixed_dtype_data: dict[str, Any] = {
"a": 0,
Expand All @@ -164,11 +164,10 @@ def test_from_dict_with_scalars_mixed() -> None:
# note: deliberately set this value large; if all dtypes are
# on the fast-path it'll only take ~0.03secs. if it becomes
# even remotely noticeable that will indicate a regression.
# TODO: This is now slow (~0.15 seconds). Needs to be looked into.
n_range = 1_000_000
index_and_data: dict[str, Any] = {"idx": range(n_range)}
index_and_data.update(mixed_dtype_data.items())
df8 = pl.DataFrame(
df = pl.DataFrame(
data=index_and_data,
schema={
"idx": pl.Int32,
Expand All @@ -185,14 +184,46 @@ def test_from_dict_with_scalars_mixed() -> None:
"k": pl.String,
},
)
dfx = df8.select(pl.exclude("idx"))
dfx = df.select(pl.exclude("idx"))

assert len(df8) == n_range
assert len(df) == n_range
assert dfx[:5].rows() == dfx[5:10].rows()
assert dfx[-10:-5].rows() == dfx[-5:].rows()
assert dfx.row(n_range // 2, named=True) == mixed_dtype_data


def test_from_dict_expand_nested_struct() -> None:
# confirm consistent init of nested struct from dict data
dt = date(2077, 10, 10)
expected = pl.DataFrame(
[
pl.Series("x", [dt]),
pl.Series("nested", [{"y": -1, "z": 1}]),
]
)
for df in (
pl.DataFrame({"x": dt, "nested": {"y": -1, "z": 1}}),
pl.DataFrame({"x": dt, "nested": [{"y": -1, "z": 1}]}),
pl.DataFrame({"x": [dt], "nested": {"y": -1, "z": 1}}),
pl.DataFrame({"x": [dt], "nested": [{"y": -1, "z": 1}]}),
):
assert_frame_equal(expected, df)

# confirm expansion to 'n' nested values
nested_values = [{"y": -1, "z": 1}, {"y": -1, "z": 1}, {"y": -1, "z": 1}]
expected = pl.DataFrame(
[
pl.Series("x", [0, 1, 2]),
pl.Series("nested", nested_values),
]
)
for df in (
pl.DataFrame({"x": range(3), "nested": {"y": -1, "z": 1}}),
pl.DataFrame({"x": [0, 1, 2], "nested": {"y": -1, "z": 1}}),
):
assert_frame_equal(expected, df)


def test_from_dict_duration_subseconds() -> None:
d = {"duration": [timedelta(seconds=1, microseconds=1000)]}
result = pl.from_dict(d)
Expand Down