From 4025ad3349224fa93698b15b84a4025246d776ed Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Wed, 8 May 2024 07:18:53 +0200 Subject: [PATCH] The Fix --- .../polars/testing/parametric/strategies.py | 81 ++++++++----------- 1 file changed, 35 insertions(+), 46 deletions(-) diff --git a/py-polars/polars/testing/parametric/strategies.py b/py-polars/polars/testing/parametric/strategies.py index 31890968d968..9c9ae56469ea 100644 --- a/py-polars/polars/testing/parametric/strategies.py +++ b/py-polars/polars/testing/parametric/strategies.py @@ -159,6 +159,7 @@ def _parse_allowed_dtypes( return allowed_dtypes_flat, allowed_dtypes_nested + @st.composite def _flat_dtypes( draw: DrawFn, @@ -213,6 +214,8 @@ def _nested_dtypes( """Create a strategy for generating nested Polars :class:`DataType` objects.""" if allowed_dtypes is None: allowed_dtypes = _NESTED_DTYPES + if excluded_dtypes is None: + excluded_dtypes = [] dtype = draw(st.sampled_from(allowed_dtypes)) return draw( @@ -259,7 +262,6 @@ def instantiate_inner(dtype: PolarsDataType) -> DataType: raise InvalidArgument(msg) - @st.composite def _instantiate_dtype( draw: DrawFn, @@ -270,56 +272,32 @@ def _instantiate_dtype( nesting_level: int = 3, ) -> DataType: """Take a data type and instantiate it.""" - allowed_dtypes = [dt for dt in allowed_dtypes if dt == dtype] - inner = if not dtype.is_nested(): - return _flat_dtypes(allowed_dtypes=allowed_dtypes, excluded_dtypes=excluded_dtypes) - else: - inner = _instantiate_nested_dtype(dtype, ) - - if dtype.is_nested(): - if not flat_dtypes: - return _nested_dtypes( - inner=st.just(Null()), - allowed_dtypes=nested_dtypes, - excluded_dtypes=excluded_dtypes, - ) - return st.recursive( - base=_flat_dtypes( - allowed_dtypes=flat_dtypes, excluded_dtypes=excluded_dtypes - ), - extend=lambda s: _nested_dtypes( - s, allowed_dtypes=nested_dtypes, excluded_dtypes=excluded_dtypes - ), - max_leaves=nesting_level, - ) - else: - return _flat_dtypes(allowed_dtypes=[]flat_dtypes, excluded_dtypes=excluded_dtypes) - - if not dtype.is_nested(): - if isinstance(dtype, DataType): - return dtype + if allowed_dtypes is None: + allowed_dtypes = [dtype] else: allowed_dtypes = [dt for dt in allowed_dtypes if dt == dtype] - return draw( - dtypes(allowed_dtypes=allowed_dtypes, excluded_dtypes=excluded_dtypes) - ) + return draw( + _flat_dtypes(allowed_dtypes=allowed_dtypes, excluded_dtypes=excluded_dtypes) + ) def draw_inner(dtype: PolarsDataType) -> DataType: if isinstance(dtype, DataType): return draw( _instantiate_dtype( - dtype.inner, + dtype.inner, # type: ignore[attr-defined] allowed_dtypes=allowed_dtypes, excluded_dtypes=excluded_dtypes, nesting_level=nesting_level - 1, ) ) else: - return dtypes( - allowed_dtypes=allowed_dtypes, - excluded_dtypes=excluded_dtypes, - nesting_level=nesting_level - 1, + return draw( + dtypes( + allowed_dtypes=allowed_dtypes, + excluded_dtypes=excluded_dtypes, + nesting_level=nesting_level - 1, + ) ) if dtype == List: @@ -329,14 +307,22 @@ def draw_inner(dtype: PolarsDataType) -> DataType: inner = draw_inner(dtype) width = draw(st.integers(min_value=1, max_value=_DEFAULT_ARRAY_WIDTH_LIMIT)) return Array(inner, width) + elif dtype == Struct: + if isinstance(dtype, DataType): + return dtype - # inner_strategy = - + n_fields = draw( + st.integers(min_value=1, max_value=_DEFAULT_STRUCT_FIELDS_LIMIT) + ) + inner_strategy = dtypes( + allowed_dtypes=allowed_dtypes, + excluded_dtypes=excluded_dtypes, + nesting_level=nesting_level - 1, + ) + return Struct({f"f{i}": draw(inner_strategy) for i in range(n_fields)}) else: - inner = _flat_dtypes() - return _instantiate_nested_dtype(dtype, inner) - - + msg = f"unsupported data type: {dtype}" + raise InvalidArgument(msg) _INTEGER_STRATEGIES: dict[bool, dict[int, SearchStrategy[int]]] = { @@ -547,7 +533,6 @@ def data(dtype: PolarsDataType, **kwargs: Any) -> SearchStrategy[Any]: raise InvalidArgument(msg) - @deprecate_renamed_function("lists", version="0.20.25") def create_list_strategy( inner_dtype: PolarsDataType | None = None, @@ -616,11 +601,15 @@ def create_list_strategy( if size is not None: min_size = max_size = size - strategies = list_strategies( + if inner_dtype is None: + inner_dtype = dtypes().example() + else: + inner_dtype = _instantiate_dtype(inner_dtype).example() + + return lists( inner_dtype, select_from=select_from, min_size=min_size, max_size=max_size, unique=unique, ) - return strategies.example()