Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 21 additions & 18 deletions simvue/api/objects/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,19 @@
__all__ = ["Grid"]


def check_ordered_array(axis_ticks: list[float]) -> bool:
def check_ordered_array(
axis_ticks: list[list[float]] | numpy.ndarray,
) -> list[list[float]]:
"""Returns if array is ordered or reverse ordered."""
if not isinstance(axis_ticks[0], float):
raise ValueError("Ordering can only be checked on a 1D array")
_array = numpy.array(axis_ticks)
return numpy.all(numpy.sort(_array) == _array) or numpy.all(
numpy.reversed(numpy.sort(_array)) == _array
)
if isinstance(axis_ticks, numpy.ndarray):
axis_ticks = axis_ticks.tolist()
for i, _array in enumerate(axis_ticks):
_array = numpy.array(_array)
if not numpy.all(numpy.sort(_array) == _array) or numpy.all(
reversed(numpy.sort(_array)) == _array
):
raise ValueError(f"Axis {i} has unordered values.")
return axis_ticks


class Grid(SimvueObject):
Expand Down Expand Up @@ -104,7 +109,13 @@ def new(
cls,
*,
name: str,
grid: list[list[float]],
grid: typing.Annotated[
list[list[float]],
pydantic.conlist(
pydantic.conlist(float, min_length=1), min_length=1, max_length=2
),
pydantic.AfterValidator(check_ordered_array),
],
labels: list[str],
offline: bool = False,
**kwargs,
Expand All @@ -116,7 +127,8 @@ def new(
name : str
name for this grid.
grid : list[list[float]]
define a grid as a list of axes containing tick values.
define a grid as a list of axes containing tick values
number of axes must be 1 or 2
labels : list[str]
label each of the axes defined.
offline: bool, optional
Expand All @@ -127,22 +139,13 @@ def new(
Metrics
metrics object
"""
if len(grid) < 1:
raise ValueError("Invalid argument for 'grid'")

if len(labels) != len(set(labels)):
raise ValueError("Labels must be unique.")

if len(labels) != len(grid):
raise AssertionError(
"Length of argument 'labels' must match first "
f"grid dimension {len(grid)}."
)

for i, axis in enumerate(grid):
if not check_ordered_array(axis):
raise ValueError(f"Axis {i} has unordered values.")

return Grid(
grid=grid,
labels=labels,
Expand Down
16 changes: 6 additions & 10 deletions tests/unit/test_grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,12 @@ def test_grid_creation_online() -> None:
_run.commit()
_grid_def=numpy.vstack([
numpy.linspace(0, 10, 10),
numpy.linspace(0, 20, 10),
numpy.linspace(50, 60, 10),
])
_grid_list = _grid_def.tolist()
_grid = Grid.new(
name=f"test_grid_creation_online_{_uuid}",
labels=["x", "y", "z"],
labels=["x", "y"],
grid=_grid_list
)
_grid.commit()
Expand All @@ -57,14 +56,13 @@ def test_grid_creation_offline() -> None:
_run.commit()
_grid_def=numpy.vstack([
numpy.linspace(0, 10, 10),
numpy.linspace(0, 20, 10),
numpy.linspace(50, 60, 10),
])
_grid_list = _grid_def.tolist()
_grid = Grid.new(
name=f"test_grid_creation_online_{_uuid}",
grid=_grid_list,
labels=["x", "y", "z"],
labels=["x", "y"],
offline=True
)
_grid.commit()
Expand Down Expand Up @@ -105,13 +103,12 @@ def test_grid_metrics_creation_online() -> None:
_run.commit()
_grid_def=numpy.vstack([
numpy.linspace(0, 10, 10),
numpy.linspace(0, 20, 10),
numpy.linspace(50, 60, 10),
])
_grid_list = _grid_def.tolist()
_grid = Grid.new(
name=f"test_grid_creation_online_{_uuid}",
labels=["x", "y", "z"],
labels=["x", "y"],
grid=_grid_list
)
_grid.commit()
Expand All @@ -126,7 +123,7 @@ def test_grid_metrics_creation_online() -> None:
),
"time": _time,
"step": _step,
"array": numpy.ones((10, 10, 10)),
"array": numpy.ones((10, 10)),
"grid": _grid.id,
"metric": "A"
}
Expand Down Expand Up @@ -156,13 +153,12 @@ def test_grid_metrics_creation_offline() -> None:
_run.commit()
_grid_def=numpy.vstack([
numpy.linspace(0, 10, 10),
numpy.linspace(0, 20, 10),
numpy.linspace(50, 60, 10),
])
_grid_list = _grid_def.tolist()
_grid = Grid.new(
name=f"test_grid_creation_offline_{_uuid}",
labels=["x", "y", "z"],
labels=["x", "y"],
grid=_grid_list,
offline=True
)
Expand All @@ -178,7 +174,7 @@ def test_grid_metrics_creation_offline() -> None:
),
"time": _time,
"step": _step,
"array": numpy.ones((10, 10, 10)),
"array": numpy.ones((10, 10)),
"grid": _grid.id,
"metric": "A"
}
Expand Down