Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 35 additions & 25 deletions properties/test_index_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,9 @@

import xarray.testing.strategies as xrst


@st.composite
def unique(draw, strategy):
# https://stackoverflow.com/questions/73737073/create-hypothesis-strategy-that-returns-unique-values
seen = draw(st.shared(st.builds(set), key="key-for-unique-elems"))
return draw(
strategy.filter(lambda x: x not in seen).map(lambda x: seen.add(x) or x)
)


# Share to ensure we get unique names on each draw,
# so we don't try to add two variables with the same name
# or stack to a dimension with a name that already exists in the Dataset.
UNIQUE_NAME = unique(strategy=xrst.names())
DIM_NAME = xrst.dimension_names(name_strategy=UNIQUE_NAME, min_dims=1, max_dims=1)
# Strategy for generating names - uniqueness is enforced by the state machine
NAME_STRATEGY = xrst.names()
DIM_NAME = xrst.dimension_names(name_strategy=NAME_STRATEGY, min_dims=1, max_dims=1)
index_variables = st.builds(
xr.Variable,
data=npst.arrays(
Expand Down Expand Up @@ -75,25 +63,45 @@ def __init__(self):
self.indexed_dims = []
self.multi_indexed_dims = []

@initialize(var=index_variables)
def init_ds(self, var):
# Track all used names to ensure uniqueness (avoids flaky Hypothesis tests)
self.used_names: set[str] = set()

def _draw_unique_name(self, data) -> str:
"""Draw a name that hasn't been used yet in this test case."""
name = data.draw(NAME_STRATEGY.filter(lambda x: x not in self.used_names))
self.used_names.add(name)
return name

def _draw_unique_var(self, data) -> xr.Variable:
"""Draw an index variable with a unique dimension name."""
var = data.draw(index_variables)
# Replace with a guaranteed unique name
new_name = self._draw_unique_name(data)
return xr.Variable(dims=(new_name,), data=var.data, attrs=var.attrs)

@initialize(data=st.data())
def init_ds(self, data):
"""Initialize the Dataset so that at least one rule will always fire."""
var = self._draw_unique_var(data)
(name,) = var.dims
note(f"initializing with dimension coordinate {name}")
add_dim_coord_and_data_var(self.dataset, var)

self.indexed_dims.append(name)

# TODO: stacking with a timedelta64 index and unstacking converts it to object
@rule(var=index_variables)
def add_dim_coord(self, var):
@rule(data=st.data())
def add_dim_coord(self, data):
var = self._draw_unique_var(data)
(name,) = var.dims
note(f"adding dimension coordinate {name}")
add_dim_coord_and_data_var(self.dataset, var)

self.indexed_dims.append(name)

@rule(var=index_variables)
def assign_coords(self, var):
@rule(data=st.data())
def assign_coords(self, data):
var = self._draw_unique_var(data)
(name,) = var.dims
note(f"assign_coords: {name}")
self.dataset = self.dataset.assign_coords({name: var})
Expand All @@ -117,9 +125,10 @@ def reset_index(self, data):
elif dim in self.multi_indexed_dims:
del self.multi_indexed_dims[self.multi_indexed_dims.index(dim)]

@rule(newname=UNIQUE_NAME, data=st.data(), create_index=st.booleans())
@rule(data=st.data(), create_index=st.booleans())
@precondition(lambda self: bool(self.indexed_dims))
def stack(self, newname, data, create_index):
def stack(self, data, create_index):
newname = self._draw_unique_name(data)
oldnames = data.draw(
st.lists(
st.sampled_from(self.indexed_dims),
Expand Down Expand Up @@ -158,9 +167,10 @@ def unstack(self, data):
# TODO: fix this
pass

@rule(newname=UNIQUE_NAME, data=st.data())
@rule(data=st.data())
@precondition(lambda self: bool(self.dataset.variables))
def rename_vars(self, newname, data):
def rename_vars(self, data):
newname = self._draw_unique_name(data)
dim = data.draw(st.sampled_from(sorted(self.dataset.variables)))
# benbovy: "skip the default indexes invariant test when the name of an
# existing dimension coordinate is passed as input kwarg or dict key
Expand Down
Loading