Skip to content

Commit

Permalink
bugfix "velocity" which is not stored in ASE
Browse files Browse the repository at this point in the history
  • Loading branch information
PythonFZ committed Jun 26, 2024
1 parent 78b84b2 commit d24d571
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 3 deletions.
17 changes: 15 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

@pytest.fixture
def s22() -> list[ase.Atoms]:
return ase.collections.s22
return list(ase.collections.s22)


@pytest.fixture
Expand Down Expand Up @@ -62,9 +62,10 @@ def s22_all_properties() -> list[ase.Atoms]:


@pytest.fixture
def s22_info_arrays_calc():
def s22_info_arrays_calc() -> list[ase.Atoms]:
images = []
for atoms in ase.collections.s22:
atoms: ase.Atoms
atoms.info.update(
{
"mlip_energy": np.random.rand(),
Expand All @@ -74,6 +75,7 @@ def s22_info_arrays_calc():
)
atoms.new_array("mlip_forces", np.random.rand(len(atoms), 3))
atoms.new_array("mlip_forces_2", np.random.rand(len(atoms), 3))
atoms.set_velocities(np.random.rand(len(atoms), 3))
calc = SinglePointCalculator(
atoms, energy=np.random.rand(), forces=np.random.rand(len(atoms), 3)
)
Expand All @@ -92,6 +94,17 @@ def s22_mixed_pbc_cell() -> list[ase.Atoms]:
return images


@pytest.fixture
def s22_illegal_calc_results() -> list[ase.Atoms]:
images = []
for atoms in ase.collections.s22:
atoms.calc = SinglePointCalculator(atoms)
atoms.calc.results["mlip_energy"] = np.random.rand()

images.append(atoms)
return images


@pytest.fixture
def water() -> list[ase.Atoms]:
"""Get a dataset without positions."""
Expand Down
7 changes: 7 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"s22_all_properties",
"s22_info_arrays_calc",
"s22_mixed_pbc_cell",
# "s22_illegal_calc_results",
"water",
],
)
Expand Down Expand Up @@ -64,6 +65,8 @@ def test_datasets_h5py(tmp_path, dataset, request):
assert "particles/atoms/species/value" in f
assert "particles/atoms/force/value" in f
assert "observables/atoms/force/value" not in f
assert "particles/atoms/velocity/value" in f
assert "observables/atoms/velocity/value" not in f

assert "particles/atoms/energy/value" not in f
assert "observables/atoms/energy/value" in f
Expand All @@ -75,6 +78,10 @@ def test_datasets_h5py(tmp_path, dataset, request):
assert "observables/atoms/mlip_energy_2/value" in f
assert "observables/atoms/mlip_stress/value" in f

assert f["particles/atoms/velocity/value"].attrs["unit"] == "Angstrom/fs"
assert f["particles/atoms/force/value"].attrs["unit"] == "eV/Angstrom"
# assert f["observables/atoms/energy/value"].attrs["unit"] == "eV"

npt.assert_array_equal(
f["particles/atoms/box"].attrs["boundary"], ["none", "none", "none"]
)
Expand Down
1 change: 1 addition & 0 deletions znh5md/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def extract_atoms_data(atoms: ase.Atoms) -> ASEData:
)


# TODO highlight that an additional dimension is added to ASEData here
def combine_asedata(data: List[ASEData]) -> ASEData:
"""Combine multiple ASEData objects into one."""
atomic_numbers = concatenate_varying_shape_arrays([x.atomic_numbers for x in data])
Expand Down
2 changes: 1 addition & 1 deletion znh5md/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def extend(self, images: List[ase.Atoms]):
else:
self._extend_existing_data(f, combined_data)

def _create_particle_group(self, f, data):
def _create_particle_group(self, f, data: fmt.ASEData):
g_particle_grp = f["particles"].create_group(self.particle_group)
self._create_group(g_particle_grp, "species", data.atomic_numbers)
self._create_group(g_particle_grp, "position", data.positions, "Angstrom")
Expand Down
3 changes: 3 additions & 0 deletions znh5md/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ def build_structures(
info_data,
) -> list[ase.Atoms]:
structures = []
# ASE does not store "velocity" but only "momenta"
if "velocity" in arrays_data:
del arrays_data["velocity"]
if atomic_numbers is not None:
for idx in range(len(atomic_numbers)):
# ruff thinks, this is less complex than doing it in place ... ??
Expand Down

0 comments on commit d24d571

Please sign in to comment.