diff --git a/pylammpsmpi/lammps_wrapper.py b/pylammpsmpi/lammps_wrapper.py index 57d48d70..771c1189 100644 --- a/pylammpsmpi/lammps_wrapper.py +++ b/pylammpsmpi/lammps_wrapper.py @@ -22,7 +22,9 @@ class LammpsLibrary: Top level class which manages the lammps library provided by LammpsBase """ - def __init__(self, cores=1, working_directory=".", client=None, mode="local"): + def __init__( + self, cores=1, working_directory=".", client=None, mode="local", cmdargs=None + ): self.cores = cores self.working_directory = working_directory self.client = client @@ -33,6 +35,7 @@ def __init__(self, cores=1, working_directory=".", client=None, mode="local"): LammpsBase, cores=self.cores, working_directory=self.working_directory, + cmdargs=cmdargs, actor=True, ) self.lmp = fut.result() @@ -42,7 +45,9 @@ def __init__(self, cores=1, working_directory=".", client=None, mode="local"): elif self.mode == "local": self.lmp = LammpsBase( - cores=self.cores, working_directory=self.working_directory + cores=self.cores, + working_directory=self.working_directory, + cmdargs=cmdargs, ) self.lmp.start_process() diff --git a/pylammpsmpi/mpi/lmpmpi.py b/pylammpsmpi/mpi/lmpmpi.py index 4db14678..b2ed04aa 100644 --- a/pylammpsmpi/mpi/lmpmpi.py +++ b/pylammpsmpi/mpi/lmpmpi.py @@ -42,7 +42,10 @@ } # Lammps executable -job = lammps(cmdargs=["-screen", "none"]) +args = ["-screen", "none"] +if len(sys.argv) > 1: + args.extend(sys.argv[1:]) +job = lammps(cmdargs=args) def extract_compute(funct_args): diff --git a/pylammpsmpi/utils/lammps.py b/pylammpsmpi/utils/lammps.py index 285965ef..1b61bba3 100644 --- a/pylammpsmpi/utils/lammps.py +++ b/pylammpsmpi/utils/lammps.py @@ -20,17 +20,28 @@ class LammpsBase: - def __init__(self, cores=8, working_directory="."): + def __init__(self, cores=8, working_directory=".", cmdargs=None): self.cores = cores self.working_directory = working_directory self._process = None + self._cmdargs = cmdargs def start_process(self): executable = os.path.join( os.path.dirname(os.path.abspath(__file__)), "../mpi", "lmpmpi.py" ) + cmds = [ + "mpiexec", + "--oversubscribe", + "-n", + str(self.cores), + "python", + executable, + ] + if self._cmdargs is not None: + cmds.extend(self._cmdargs) self._process = subprocess.Popen( - ["mpiexec", "--oversubscribe", "-n", str(self.cores), "python", executable], + cmds, stdout=subprocess.PIPE, stderr=None, stdin=subprocess.PIPE, diff --git a/tests/test_pylammpsmpi_cluster.py b/tests/test_pylammpsmpi_cluster.py index d4a80eda..d1a7c015 100644 --- a/tests/test_pylammpsmpi_cluster.py +++ b/tests/test_pylammpsmpi_cluster.py @@ -13,10 +13,20 @@ class TestLocalLammpsLibrary(unittest.TestCase): @classmethod def setUpClass(cls): cls.execution_path = os.path.dirname(os.path.abspath(__file__)) - cluster = LocalCluster(n_workers=1, threads_per_worker=2) + cls.citation_file = os.path.join(cls.execution_path, "citations.txt") + cls.lammps_file = os.path.join(cls.execution_path, "in.simple") + cluster = LocalCluster( + n_workers=1, + threads_per_worker=2 + ) client = Client(cluster) - cls.lmp = LammpsLibrary(cores=2, mode='dask', client=client) - cls.lmp.file(os.path.join(cls.execution_path, "in.simple")) + cls.lmp = LammpsLibrary( + cores=2, + mode='dask', + cmdargs=["-cite", cls.citation_file], + client=client + ) + cls.lmp.file(cls.lammps_file) @classmethod def tearDownClass(cls): @@ -60,11 +70,11 @@ def test_scatter_atoms(self): f1 = self.lmp.gather_atoms("f") self.assertEqual(f1[1][0], val) - f = self.lmp.gather_atoms("f", ids=[1,2]) + f = self.lmp.gather_atoms("f", ids=[1, 2]) val = np.random.randint(0, 100) f[1][1] = val - self.lmp.scatter_atoms("f", f, ids=[1,2]) - f1 = self.lmp.gather_atoms("f", ids=[1,2]) + self.lmp.scatter_atoms("f", f, ids=[1, 2]) + f1 = self.lmp.gather_atoms("f", ids=[1, 2]) self.assertEqual(f1[1][1], val) def test_extract_box(self): @@ -75,12 +85,15 @@ def test_extract_box(self): self.assertEqual(np.round(box[1][0], 2), 6.72) self.lmp.delete_atoms("group", "all") - self.lmp.reset_box([0.0,0.0,0.0], [8.0,8.0,8.0], 0.0,0.0,0.0) + self.lmp.reset_box([0.0, 0.0, 0.0], [8.0, 8.0, 8.0], 0.0, 0.0, 0.0) box = self.lmp.extract_box() self.assertEqual(box[0][0], 0.0) self.assertEqual(np.round(box[1][0], 2), 8.0) self.lmp.clear() - self.lmp.file(os.path.join(self.execution_path, "in.simple")) + self.lmp.file(self.lammps_file) + + def test_cmdarg_options(self): + self.assertTrue(os.path.isfile(self.citation_file)) if __name__ == "__main__": diff --git a/tests/test_pylammpsmpi_local.py b/tests/test_pylammpsmpi_local.py index 1fc98914..f4ef06cc 100644 --- a/tests/test_pylammpsmpi_local.py +++ b/tests/test_pylammpsmpi_local.py @@ -12,8 +12,10 @@ class TestLocalLammpsLibrary(unittest.TestCase): @classmethod def setUpClass(cls): cls.execution_path = os.path.dirname(os.path.abspath(__file__)) - cls.lmp = LammpsLibrary(cores=2, mode='local') - cls.lmp.file(os.path.join(cls.execution_path, "in.simple")) + cls.citation_file = os.path.join(cls.execution_path, "citations.txt") + cls.lammps_file = os.path.join(cls.execution_path, "in.simple") + cls.lmp = LammpsLibrary(cores=2, mode='local', cmdargs=["-cite", cls.citation_file]) + cls.lmp.file(cls.lammps_file) @classmethod def tearDownClass(cls): @@ -78,7 +80,10 @@ def test_extract_box(self): self.assertEqual(box[0][0], 0.0) self.assertEqual(np.round(box[1][0], 2), 8.0) self.lmp.clear() - self.lmp.file(os.path.join(self.execution_path, "in.simple")) + self.lmp.file(self.lammps_file) + + def test_cmdarg_options(self): + self.assertTrue(os.path.isfile(self.citation_file)) if __name__ == "__main__":