Skip to content

Commit

Permalink
Added StorageView to extract fields from storage on the fly
Browse files Browse the repository at this point in the history
  • Loading branch information
david-zwicker committed Feb 20, 2024
1 parent ff0ce7d commit 3b11f74
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 3 deletions.
58 changes: 55 additions & 3 deletions pde/storage/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from numpy.typing import DTypeLike

from ..fields import FieldCollection, ScalarField, Tensor2Field, VectorField
from ..fields.base import FieldBase
from ..fields.base import DataFieldBase, FieldBase
from ..grids.base import GridBase
from ..tools.docstrings import fill_in_docstring
from ..tools.output import display_progress
Expand All @@ -29,7 +29,8 @@

WriteModeType = Literal[
"append",
"readonly" "truncate",
"readonly",
"truncate",
"truncate_once",
]

Expand All @@ -42,7 +43,7 @@ class StorageBase(metaclass=ABCMeta):
will return the fields in order and individual time points can also be accessed.
"""

times: Sequence[float] # :class:`~numpy.ndarray`): stored time points
times: Sequence[float] # stored time points
data: Any # actual data for all the stored times
write_mode: WriteModeType # mode determining how the storage behaves

Expand Down Expand Up @@ -602,3 +603,54 @@ def finalize(self, info: InfoDict | None = None) -> None:
"""
super().finalize(info)
self.storage.end_writing()


class StorageView:
"""represents a view into a storage that extracts a particular field"""

has_collection: bool = False

def __init__(self, storage: StorageBase, *, field: int | str):
"""
Args:
storage (:class:`~pde.storage.base.StorageBase`):
The storage providing the basic data
field (int or str):
The index into the field collection determining which field of the
collection is returned. Instead of a numerical index, the field label
can also be supplied. If there are multiple fields with the same label,
only the first field is returned.
"""
self.storage = storage
if not self.storage.has_collection:
raise RuntimeError("Can only create view into Storage of field collection")

Check warning on line 626 in pde/storage/base.py

View check run for this annotation

Codecov / codecov/patch

pde/storage/base.py#L626

Added line #L626 was not covered by tests

if isinstance(field, str):
self.field_index = self.storage._field.labels.index(field) # type: ignore
else:
self.field_index = field

@property
def times(self) -> Sequence[float]:
return self.storage.times

@property
def grid(self) -> GridBase | None:
return self.storage.grid

Check warning on line 639 in pde/storage/base.py

View check run for this annotation

Codecov / codecov/patch

pde/storage/base.py#L639

Added line #L639 was not covered by tests

def __len__(self):
return len(self.storage)

def __getitem__(self, key: int) -> DataFieldBase:
"""return field at given index or a list of fields for a slice"""
return self.storage[key][self.field_index] # type: ignore

def __iter__(self) -> Iterator[DataFieldBase]:
"""iterate over all stored fields"""
for fields in self.storage:
yield fields[self.field_index] # type: ignore

def items(self) -> Iterator[tuple[float, DataFieldBase]]:
"""iterate over all times and stored fields, returning pairs"""
for k, v in self.storage.items():
yield k, v[self.field_index] # type: ignore
33 changes: 33 additions & 0 deletions tests/storage/test_generic_storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pde.fields import FieldCollection, ScalarField, Tensor2Field, VectorField
from pde.tools import mpi
from pde.tools.misc import module_available
from pde.storage.base import StorageView

STORAGE_CLASSES = [MemoryStorage, FileStorage]

Expand Down Expand Up @@ -300,3 +301,35 @@ def test_storing_transformation_scalar(storage_factory, rng):
assert not storage.has_collection
for sol in storage:
np.testing.assert_allclose(sol.data, field.data**2)


@pytest.mark.parametrize("storage_class", STORAGE_CLASSES)
def test_storage_view(storage_factory, rng):
"""test StorageView"""
grid = UnitGrid([2, 2])
f1 = ScalarField.random_uniform(grid, 0.1, 0.4, label="a", rng=rng)
f2 = VectorField.random_uniform(grid, 0.1, 0.4, label="b", rng=rng)
fc = FieldCollection([f1, f2])

# store some data
storage = storage_factory()
storage.start_writing(fc)
storage.append(fc, 0)
storage.append(fc, 1)
storage.append(fc, 2)
storage.end_writing()

view = StorageView(storage, field=0)
assert not view.has_collection
np.testing.assert_allclose(view.times, range(3))
assert len(view) == 3
assert view[0] == f1
for field in view:
assert field == f1
for i, (j, field) in enumerate(view.items()):
assert i == j
assert field == f1

assert StorageView(storage, field="a")[0] == f1
assert StorageView(storage, field="b")[0] == f2
assert StorageView(storage, field=1)[0] == f2

0 comments on commit 3b11f74

Please sign in to comment.