Skip to content

Commit

Permalink
feat(python): Implement support for Struct types in parametric tests (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed May 13, 2024
1 parent 9bfa30c commit 81cc802
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 9 deletions.
33 changes: 30 additions & 3 deletions py-polars/polars/testing/parametric/strategies/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import decimal
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Literal, Sequence
from typing import TYPE_CHECKING, Any, Literal, Mapping, Sequence

import hypothesis.strategies as st
from hypothesis.errors import InvalidArgument
Expand Down Expand Up @@ -34,6 +34,7 @@
Decimal,
Duration,
Enum,
Field,
Float32,
Float64,
Int8,
Expand All @@ -43,6 +44,7 @@
List,
Null,
String,
Struct,
Time,
UInt8,
UInt16,
Expand All @@ -58,10 +60,10 @@
if TYPE_CHECKING:
from datetime import date, datetime, time

from hypothesis.strategies import SearchStrategy
from hypothesis.strategies import DrawFn, SearchStrategy

from polars.datatypes import DataType, DataTypeClass
from polars.type_aliases import PolarsDataType, TimeUnit
from polars.type_aliases import PolarsDataType, SchemaDict, TimeUnit

_DEFAULT_LIST_LEN_LIMIT = 3
_DEFAULT_N_CATEGORIES = 10
Expand Down Expand Up @@ -278,6 +280,28 @@ def lists(
)


@st.composite
def structs( # noqa: D417
draw: DrawFn, /, fields: Sequence[Field] | SchemaDict, **kwargs: Any
) -> dict[str, Any]:
"""
Create a strategy for generating structs with the given fields.
Parameters
----------
fields
The fields that make up the struct. Can be either a sequence of Field
objects or a mapping of column names to data types.
**kwargs
Additional arguments that are passed to nested data generation strategies.
"""
if isinstance(fields, Mapping):
fields = [Field(name, dtype) for name, dtype in fields.items()]

strats = {f.name: data(f.dtype, **kwargs) for f in fields}
return {col: draw(strat) for col, strat in strats.items()}


def nulls() -> SearchStrategy[None]:
"""Create a strategy for generating null values."""
return st.none()
Expand Down Expand Up @@ -360,6 +384,9 @@ def data(
allow_null=allow_null,
**kwargs,
)
elif dtype == Struct:
fields = getattr(dtype, "fields", None) or [Field("f0", Null())]
strategy = structs(fields, **kwargs)
else:
msg = f"unsupported data type: {dtype}"
raise InvalidArgument(msg)
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/testing/parametric/strategies/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
# TODO: Enable nested types by default when various issues are solved.
# List,
# Array,
# Struct,
Struct,
]
# Supported data type classes that do not contain other data types
_FLAT_DTYPES = _SIMPLE_DTYPES + _COMPLEX_DTYPES
Expand Down
6 changes: 5 additions & 1 deletion py-polars/tests/unit/dataframe/test_null_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
min_size=1,
min_cols=1,
allow_null=True,
excluded_dtypes=[pl.String, pl.List],
excluded_dtypes=[
pl.String,
pl.List,
pl.Struct, # See: https://github.com/pola-rs/polars/issues/3462
],
)
)
@example(df=pl.DataFrame(schema=["x", "y", "z"]))
Expand Down
8 changes: 7 additions & 1 deletion py-polars/tests/unit/dataframe/test_to_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@
from polars.testing.parametric import dataframes


@given(df=dataframes())
@given(
df=dataframes(
excluded_dtypes=[
pl.Categorical, # Bug: https://github.com/pola-rs/polars/issues/16196
]
)
)
def test_to_dict(df: pl.DataFrame) -> None:
d = df.to_dict(as_series=False)
result = pl.from_dict(d, schema=df.schema)
Expand Down
14 changes: 12 additions & 2 deletions py-polars/tests/unit/operations/test_clear.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,24 @@
from polars.testing.parametric import series


@given(s=series(), n=st.integers(min_value=0, max_value=10))
def test_clear_series_parametric(s: pl.Series, n: int) -> None:
@given(s=series())
def test_clear_series_parametric(s: pl.Series) -> None:
result = s.clear()

assert result.dtype == s.dtype
assert result.name == s.name
assert result.is_empty()


@given(
s=series(
excluded_dtypes=[
pl.Struct, # See: https://github.com/pola-rs/polars/issues/3462
]
),
n=st.integers(min_value=0, max_value=10),
)
def test_clear_series_n_parametric(s: pl.Series, n: int) -> None:
result = s.clear(n)

assert result.dtype == s.dtype
Expand Down
9 changes: 8 additions & 1 deletion py-polars/tests/unit/operations/test_drop_nulls.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@
from polars.testing.parametric import series


@given(s=series(allow_null=True))
@given(
s=series(
allow_null=True,
excluded_dtypes=[
pl.Struct, # See: https://github.com/pola-rs/polars/issues/3462
],
)
)
def test_drop_nulls_parametric(s: pl.Series) -> None:
result = s.drop_nulls()
assert result.len() == s.len() - s.null_count()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,9 @@ def test_data_enum(cat: str) -> None:
@given(cat=data(pl.Enum(["hello", "world"])))
def test_data_enum_instantiated(cat: str) -> None:
assert cat in ("hello", "world")


@given(struct=data(pl.Struct({"a": pl.Int8, "b": pl.String})))
def test_data_struct(struct: dict[str, int | str]) -> None:
assert isinstance(struct["a"], int)
assert isinstance(struct["b"], str)

0 comments on commit 81cc802

Please sign in to comment.