diff --git a/pylammpsmpi/lammps_wrapper.py b/pylammpsmpi/lammps_wrapper.py index 57d48d70..64515457 100644 --- a/pylammpsmpi/lammps_wrapper.py +++ b/pylammpsmpi/lammps_wrapper.py @@ -22,7 +22,9 @@ class LammpsLibrary: Top level class which manages the lammps library provided by LammpsBase """ - def __init__(self, cores=1, working_directory=".", client=None, mode="local"): + def __init__( + self, cores=1, working_directory=".", client=None, mode="local", cmdargs=None + ): self.cores = cores self.working_directory = working_directory self.client = client @@ -42,7 +44,9 @@ def __init__(self, cores=1, working_directory=".", client=None, mode="local"): elif self.mode == "local": self.lmp = LammpsBase( - cores=self.cores, working_directory=self.working_directory + cores=self.cores, + working_directory=self.working_directory, + cmdargs=cmdargs, ) self.lmp.start_process() diff --git a/pylammpsmpi/mpi/lmpmpi.py b/pylammpsmpi/mpi/lmpmpi.py index 4db14678..b2ed04aa 100644 --- a/pylammpsmpi/mpi/lmpmpi.py +++ b/pylammpsmpi/mpi/lmpmpi.py @@ -42,7 +42,10 @@ } # Lammps executable -job = lammps(cmdargs=["-screen", "none"]) +args = ["-screen", "none"] +if len(sys.argv) > 1: + args.extend(sys.argv[1:]) +job = lammps(cmdargs=args) def extract_compute(funct_args): diff --git a/pylammpsmpi/utils/lammps.py b/pylammpsmpi/utils/lammps.py index 285965ef..1b61bba3 100644 --- a/pylammpsmpi/utils/lammps.py +++ b/pylammpsmpi/utils/lammps.py @@ -20,17 +20,28 @@ class LammpsBase: - def __init__(self, cores=8, working_directory="."): + def __init__(self, cores=8, working_directory=".", cmdargs=None): self.cores = cores self.working_directory = working_directory self._process = None + self._cmdargs = cmdargs def start_process(self): executable = os.path.join( os.path.dirname(os.path.abspath(__file__)), "../mpi", "lmpmpi.py" ) + cmds = [ + "mpiexec", + "--oversubscribe", + "-n", + str(self.cores), + "python", + executable, + ] + if self._cmdargs is not None: + cmds.extend(self._cmdargs) self._process = subprocess.Popen( - ["mpiexec", "--oversubscribe", "-n", str(self.cores), "python", executable], + cmds, stdout=subprocess.PIPE, stderr=None, stdin=subprocess.PIPE, diff --git a/tests/test_pylammpsmpi_local.py b/tests/test_pylammpsmpi_local.py index 1fc98914..c894bc20 100644 --- a/tests/test_pylammpsmpi_local.py +++ b/tests/test_pylammpsmpi_local.py @@ -81,5 +81,10 @@ def test_extract_box(self): self.lmp.file(os.path.join(self.execution_path, "in.simple")) + def test_cmdarg_options(self): + self.lmp2 = LammpsLibrary(cores=2, mode='local', cmdargs=["-cite", "citations.txt"]) + self.lmp2.file(os.path.join(self.execution_path, "in.simple")) + assert os.path.isfile(os.path.join(self.execution_path, "citations.txt")) + if __name__ == "__main__": unittest.main()