diff --git a/tests/test_ase_interface.py b/tests/test_ase_interface.py index dcd77d1..29b07b1 100644 --- a/tests/test_ase_interface.py +++ b/tests/test_ase_interface.py @@ -6,7 +6,13 @@ import numpy as np from pylammpsmpi import LammpsASELibrary, LammpsLibrary -from pylammpsmpi.wrapper.ase import get_species_symbols, get_structure_indices, cell_is_skewed, set_selective_dynamics +from pylammpsmpi.wrapper.ase import ( + cell_is_skewed, + get_species_symbols, + get_structure_indices, + get_lammps_indicies_from_ase_structure, + set_selective_dynamics, +) class TestLammpsASELibrary(unittest.TestCase): @@ -48,6 +54,10 @@ def test_static(self): self.assertEqual(lmp.interactive_temperatures_getter(), 0) self.assertTrue(np.isclose(np.sum(lmp.interactive_pressures_getter()), -0.015661731917941832)) self.assertEqual(np.sum(lmp.interactive_velocities_getter()), 0.0) + self.assertTrue(np.isclose(np.sum(lmp.interactive_positions_getter()), 291.6)) + lmp.interactive_cells_setter(cell=1.01 * structure.cell.array) + lmp.interactive_lib_command("run 0") + self.assertTrue(np.all(np.isclose(lmp.interactive_cells_getter(), 1.01 * structure.cell.array))) lmp.close() def test_static_with_statement(self): @@ -87,6 +97,7 @@ def test_static_with_statement(self): self.assertEqual(lmp.interactive_temperatures_getter(), 0) self.assertTrue(np.isclose(np.sum(lmp.interactive_pressures_getter()), -0.00937227406237915)) self.assertEqual(np.sum(lmp.interactive_velocities_getter()), 0.0) + self.assertTrue(np.isclose(np.sum(lmp.interactive_positions_getter()), 73.60669397577456)) class TestASEHelperFunctions(unittest.TestCase): @@ -108,6 +119,15 @@ def test_cell_is_skewed(self): self.assertTrue(cell_is_skewed(cell=self.structure_skewed.cell)) self.assertFalse(cell_is_skewed(cell=self.structure_cubic.cell)) + def test_get_lammps_indicies_from_ase_structure(self): + indicies = get_lammps_indicies_from_ase_structure( + structure=self.structure_cubic, + el_eam_lst=["Al", "H"] + ) + self.assertEqual(len(indicies), len(self.structure_cubic)) + self.assertEqual(len(set(indicies)), 1) + self.assertEqual(set(indicies), {1}) + class TestConstraints(unittest.TestCase): @classmethod