diff --git a/simvue/api/objects/grids.py b/simvue/api/objects/grids.py index 1641d219..42b50610 100644 --- a/simvue/api/objects/grids.py +++ b/simvue/api/objects/grids.py @@ -33,14 +33,19 @@ __all__ = ["Grid"] -def check_ordered_array(axis_ticks: list[float]) -> bool: +def check_ordered_array( + axis_ticks: list[list[float]] | numpy.ndarray, +) -> list[list[float]]: """Returns if array is ordered or reverse ordered.""" - if not isinstance(axis_ticks[0], float): - raise ValueError("Ordering can only be checked on a 1D array") - _array = numpy.array(axis_ticks) - return numpy.all(numpy.sort(_array) == _array) or numpy.all( - numpy.reversed(numpy.sort(_array)) == _array - ) + if isinstance(axis_ticks, numpy.ndarray): + axis_ticks = axis_ticks.tolist() + for i, _array in enumerate(axis_ticks): + _array = numpy.array(_array) + if not numpy.all(numpy.sort(_array) == _array) or numpy.all( + reversed(numpy.sort(_array)) == _array + ): + raise ValueError(f"Axis {i} has unordered values.") + return axis_ticks class Grid(SimvueObject): @@ -104,7 +109,13 @@ def new( cls, *, name: str, - grid: list[list[float]], + grid: typing.Annotated[ + list[list[float]], + pydantic.conlist( + pydantic.conlist(float, min_length=1), min_length=1, max_length=2 + ), + pydantic.AfterValidator(check_ordered_array), + ], labels: list[str], offline: bool = False, **kwargs, @@ -116,7 +127,8 @@ def new( name : str name for this grid. grid : list[list[float]] - define a grid as a list of axes containing tick values. + define a grid as a list of axes containing tick values + number of axes must be 1 or 2 labels : list[str] label each of the axes defined. offline: bool, optional @@ -127,11 +139,6 @@ def new( Metrics metrics object """ - if len(grid) < 1: - raise ValueError("Invalid argument for 'grid'") - - if len(labels) != len(set(labels)): - raise ValueError("Labels must be unique.") if len(labels) != len(grid): raise AssertionError( @@ -139,10 +146,6 @@ def new( f"grid dimension {len(grid)}." ) - for i, axis in enumerate(grid): - if not check_ordered_array(axis): - raise ValueError(f"Axis {i} has unordered values.") - return Grid( grid=grid, labels=labels, diff --git a/tests/unit/test_grids.py b/tests/unit/test_grids.py index 74587d25..a3764d96 100644 --- a/tests/unit/test_grids.py +++ b/tests/unit/test_grids.py @@ -26,13 +26,12 @@ def test_grid_creation_online() -> None: _run.commit() _grid_def=numpy.vstack([ numpy.linspace(0, 10, 10), - numpy.linspace(0, 20, 10), numpy.linspace(50, 60, 10), ]) _grid_list = _grid_def.tolist() _grid = Grid.new( name=f"test_grid_creation_online_{_uuid}", - labels=["x", "y", "z"], + labels=["x", "y"], grid=_grid_list ) _grid.commit() @@ -57,14 +56,13 @@ def test_grid_creation_offline() -> None: _run.commit() _grid_def=numpy.vstack([ numpy.linspace(0, 10, 10), - numpy.linspace(0, 20, 10), numpy.linspace(50, 60, 10), ]) _grid_list = _grid_def.tolist() _grid = Grid.new( name=f"test_grid_creation_online_{_uuid}", grid=_grid_list, - labels=["x", "y", "z"], + labels=["x", "y"], offline=True ) _grid.commit() @@ -105,13 +103,12 @@ def test_grid_metrics_creation_online() -> None: _run.commit() _grid_def=numpy.vstack([ numpy.linspace(0, 10, 10), - numpy.linspace(0, 20, 10), numpy.linspace(50, 60, 10), ]) _grid_list = _grid_def.tolist() _grid = Grid.new( name=f"test_grid_creation_online_{_uuid}", - labels=["x", "y", "z"], + labels=["x", "y"], grid=_grid_list ) _grid.commit() @@ -126,7 +123,7 @@ def test_grid_metrics_creation_online() -> None: ), "time": _time, "step": _step, - "array": numpy.ones((10, 10, 10)), + "array": numpy.ones((10, 10)), "grid": _grid.id, "metric": "A" } @@ -156,13 +153,12 @@ def test_grid_metrics_creation_offline() -> None: _run.commit() _grid_def=numpy.vstack([ numpy.linspace(0, 10, 10), - numpy.linspace(0, 20, 10), numpy.linspace(50, 60, 10), ]) _grid_list = _grid_def.tolist() _grid = Grid.new( name=f"test_grid_creation_offline_{_uuid}", - labels=["x", "y", "z"], + labels=["x", "y"], grid=_grid_list, offline=True ) @@ -178,7 +174,7 @@ def test_grid_metrics_creation_offline() -> None: ), "time": _time, "step": _step, - "array": numpy.ones((10, 10, 10)), + "array": numpy.ones((10, 10)), "grid": _grid.id, "metric": "A" }