Skip to content

Commit

Permalink
check external file handle
Browse files Browse the repository at this point in the history
  • Loading branch information
PythonFZ committed Jun 20, 2024
1 parent 6c1ff2e commit 0fee63f
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 10 deletions.
19 changes: 19 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,22 @@ def test_two_datasets(tmp_path, s22_all_properties, s22_mixed_pbc_cell):

for a, b in zip(s22_mixed_pbc_cell, io_b[:]):
npt.assert_array_equal(a.get_positions(), b.get_positions())


def test_two_datasets_external(tmp_path, s22_all_properties, s22_mixed_pbc_cell):
with h5py.File(tmp_path / "test.h5", "w") as f:
io_a = znh5md.IO(file_handle=f, particle_group="a")
io_b = znh5md.IO(file_handle=f, particle_group="b")

io_a.extend(s22_all_properties)
io_b.extend(s22_mixed_pbc_cell)

assert len(io_a) == len(s22_all_properties)
assert len(io_b) == len(s22_mixed_pbc_cell)

with h5py.File(tmp_path / "test.h5", "r") as f:
io_a = znh5md.IO(file_handle=f, particle_group="a")
io_b = znh5md.IO(file_handle=f, particle_group="b")

assert len(io_a) == len(s22_all_properties)
assert len(io_b) == len(s22_mixed_pbc_cell)
53 changes: 43 additions & 10 deletions znh5md/io.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import contextlib
import dataclasses
import os
import pathlib
import typing as t
from collections.abc import MutableSequence
from typing import List, Optional, Union

Expand All @@ -18,11 +20,23 @@
# TODO: allow external file handles instead of providing filename


@contextlib.contextmanager
def _open_file(
filename: str | os.PathLike | None, file_handle: h5py.File | None, **kwargs
) -> t.Generator[h5py.File, None, None]:
if file_handle is not None:
yield file_handle
else:
with h5py.File(filename, **kwargs) as f:
yield f


@dataclasses.dataclass
class IO(MutableSequence):
"""A class for handling H5MD files for ASE Atoms objects."""

filename: Union[str, os.PathLike]
filename: Optional[str | os.PathLike] = None
file_handle: Optional[h5py.File] = None
pbc_group: bool = True # Specify PBC per step (Not H5MD conform)
save_units: bool = True # Export ASE units into the H5MD file
author: str = "N/A"
Expand All @@ -32,18 +46,31 @@ class IO(MutableSequence):
particle_group: Optional[str] = None

def __post_init__(self):
self.filename = pathlib.Path(self.filename)
if self.filename is None and self.file_handle is None:
raise ValueError("Either filename or file_handle must be provided")
if self.filename is not None and self.file_handle is not None:
raise ValueError("Only one of filename or file_handle can be provided")
if self.filename is not None:
self.filename = pathlib.Path(self.filename)
self._set_particle_group()

def _set_particle_group(self):
if self.particle_group and self.filename.exists():
with h5py.File(self.filename, "r") as f:
if self.particle_group is not None:
pass
elif self.filename is not None and self.filename.exists():
with _open_file(self.filename, self.file_handle, mode="r") as f:
self.particle_group = next(iter(f["particles"].keys()))
elif (
self.file_handle is not None
and pathlib.Path(self.file_handle.filename).exists()
):
with _open_file(self.filename, self.file_handle, mode="r") as f:
self.particle_group = next(iter(f["particles"].keys()))
elif not self.particle_group:
else:
self.particle_group = "atoms"

def create_file(self):
with h5py.File(self.filename, "w") as f:
with _open_file(self.filename, self.file_handle, mode="w") as f:
g_h5md = f.create_group("h5md")
g_h5md.attrs["version"] = np.array([1, 1])
g_author = g_h5md.create_group("author")
Expand All @@ -55,7 +82,7 @@ def create_file(self):
f.create_group("particles")

def __len__(self) -> int:
with h5py.File(self.filename, "r") as f:
with _open_file(self.filename, self.file_handle, mode="r") as f:
return len(f["particles"][self.particle_group]["species"]["value"])

def __getitem__(
Expand All @@ -68,7 +95,7 @@ def __getitem__(
calc_data = {}
info_data = {}

with h5py.File(self.filename, "r") as f:
with _open_file(self.filename, self.file_handle, mode="r") as f:
atomic_numbers = fmt.get_atomic_numbers(
f["particles"], self.particle_group, index
)
Expand Down Expand Up @@ -153,13 +180,19 @@ def _build_structures( # noqa: C901
return structures

def extend(self, images: List[ase.Atoms]):
if not self.filename.exists():
if self.filename is not None and not self.filename.exists():
self.create_file()
if self.file_handle is not None:
needs_creation = False
with _open_file(self.filename, self.file_handle, mode="r") as f:
needs_creation = "h5md" not in f
if needs_creation:
self.create_file()

data = [fmt.extract_atoms_data(atoms) for atoms in images]
combined_data = fmt.combine_asedata(data)

with h5py.File(self.filename, "a") as f:
with _open_file(self.filename, self.file_handle, mode="a") as f:
if self.particle_group not in f["particles"]:
self._create_particle_group(f, combined_data)
else:
Expand Down

0 comments on commit 0fee63f

Please sign in to comment.