Skip to content

Commit

Permalink
The Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed May 8, 2024
1 parent 8d21c45 commit 4025ad3
Showing 1 changed file with 35 additions and 46 deletions.
81 changes: 35 additions & 46 deletions py-polars/polars/testing/parametric/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def _parse_allowed_dtypes(

return allowed_dtypes_flat, allowed_dtypes_nested


@st.composite
def _flat_dtypes(
draw: DrawFn,
Expand Down Expand Up @@ -213,6 +214,8 @@ def _nested_dtypes(
"""Create a strategy for generating nested Polars :class:`DataType` objects."""
if allowed_dtypes is None:
allowed_dtypes = _NESTED_DTYPES
if excluded_dtypes is None:
excluded_dtypes = []

dtype = draw(st.sampled_from(allowed_dtypes))
return draw(
Expand Down Expand Up @@ -259,7 +262,6 @@ def instantiate_inner(dtype: PolarsDataType) -> DataType:
raise InvalidArgument(msg)



@st.composite
def _instantiate_dtype(
draw: DrawFn,
Expand All @@ -270,56 +272,32 @@ def _instantiate_dtype(
nesting_level: int = 3,
) -> DataType:
"""Take a data type and instantiate it."""
allowed_dtypes = [dt for dt in allowed_dtypes if dt == dtype]
inner =
if not dtype.is_nested():
return _flat_dtypes(allowed_dtypes=allowed_dtypes, excluded_dtypes=excluded_dtypes)
else:
inner = _instantiate_nested_dtype(dtype, )

if dtype.is_nested():
if not flat_dtypes:
return _nested_dtypes(
inner=st.just(Null()),
allowed_dtypes=nested_dtypes,
excluded_dtypes=excluded_dtypes,
)
return st.recursive(
base=_flat_dtypes(
allowed_dtypes=flat_dtypes, excluded_dtypes=excluded_dtypes
),
extend=lambda s: _nested_dtypes(
s, allowed_dtypes=nested_dtypes, excluded_dtypes=excluded_dtypes
),
max_leaves=nesting_level,
)
else:
return _flat_dtypes(allowed_dtypes=[]flat_dtypes, excluded_dtypes=excluded_dtypes)

if not dtype.is_nested():
if isinstance(dtype, DataType):
return dtype
if allowed_dtypes is None:
allowed_dtypes = [dtype]
else:
allowed_dtypes = [dt for dt in allowed_dtypes if dt == dtype]
return draw(
dtypes(allowed_dtypes=allowed_dtypes, excluded_dtypes=excluded_dtypes)
)
return draw(
_flat_dtypes(allowed_dtypes=allowed_dtypes, excluded_dtypes=excluded_dtypes)
)

def draw_inner(dtype: PolarsDataType) -> DataType:
if isinstance(dtype, DataType):
return draw(
_instantiate_dtype(
dtype.inner,
dtype.inner, # type: ignore[attr-defined]
allowed_dtypes=allowed_dtypes,
excluded_dtypes=excluded_dtypes,
nesting_level=nesting_level - 1,
)
)
else:
return dtypes(
allowed_dtypes=allowed_dtypes,
excluded_dtypes=excluded_dtypes,
nesting_level=nesting_level - 1,
return draw(
dtypes(
allowed_dtypes=allowed_dtypes,
excluded_dtypes=excluded_dtypes,
nesting_level=nesting_level - 1,
)
)

if dtype == List:
Expand All @@ -329,14 +307,22 @@ def draw_inner(dtype: PolarsDataType) -> DataType:
inner = draw_inner(dtype)
width = draw(st.integers(min_value=1, max_value=_DEFAULT_ARRAY_WIDTH_LIMIT))
return Array(inner, width)
elif dtype == Struct:
if isinstance(dtype, DataType):
return dtype

# inner_strategy =

n_fields = draw(
st.integers(min_value=1, max_value=_DEFAULT_STRUCT_FIELDS_LIMIT)
)
inner_strategy = dtypes(
allowed_dtypes=allowed_dtypes,
excluded_dtypes=excluded_dtypes,
nesting_level=nesting_level - 1,
)
return Struct({f"f{i}": draw(inner_strategy) for i in range(n_fields)})
else:
inner = _flat_dtypes()
return _instantiate_nested_dtype(dtype, inner)


msg = f"unsupported data type: {dtype}"
raise InvalidArgument(msg)


_INTEGER_STRATEGIES: dict[bool, dict[int, SearchStrategy[int]]] = {
Expand Down Expand Up @@ -547,7 +533,6 @@ def data(dtype: PolarsDataType, **kwargs: Any) -> SearchStrategy[Any]:
raise InvalidArgument(msg)



@deprecate_renamed_function("lists", version="0.20.25")
def create_list_strategy(
inner_dtype: PolarsDataType | None = None,
Expand Down Expand Up @@ -616,11 +601,15 @@ def create_list_strategy(
if size is not None:
min_size = max_size = size

strategies = list_strategies(
if inner_dtype is None:
inner_dtype = dtypes().example()
else:
inner_dtype = _instantiate_dtype(inner_dtype).example()

return lists(
inner_dtype,
select_from=select_from,
min_size=min_size,
max_size=max_size,
unique=unique,
)
return strategies.example()

0 comments on commit 4025ad3

Please sign in to comment.