diff --git a/pylammpsmpi/utils/lammps.py b/pylammpsmpi/utils/lammps.py index bc7ebca2..9a838c83 100644 --- a/pylammpsmpi/utils/lammps.py +++ b/pylammpsmpi/utils/lammps.py @@ -18,6 +18,28 @@ __date__ = "Feb 28, 2020" +def _initialize_socket(interface, cmdargs, cwd, cores, oversubscribe=False): + port_selected = interface.bind_to_random_port() + executable = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "../mpi", "lmpmpi.py" + ) + cmds = ["mpiexec"] + if oversubscribe: + cmds += ["--oversubscribe"] + cmds += [ + "-n", + str(cores), + "python", + executable, + "--zmqport", + str(port_selected), + ] + if cmdargs is not None: + cmds.extend(cmdargs) + interface.bootup(command_lst=cmds, cwd=cwd) + return interface + + class LammpsBase: def __init__( self, cores=8, oversubscribe=False, working_directory=".", cmdargs=None @@ -30,24 +52,13 @@ def __init__( self._cmdargs = cmdargs def start_process(self): - port_selected = self._interface.bind_to_random_port() - executable = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "../mpi", "lmpmpi.py" + self._interface = _initialize_socket( + interface=self._interface, + cmdargs=self._cmdargs, + cwd=self.working_directory, + cores=self.cores, + oversubscribe=self._oversubscribe, ) - cmds = ["mpiexec"] - if self._oversubscribe: - cmds += ["--oversubscribe"] - cmds += [ - "-n", - str(self.cores), - "python", - executable, - "--zmqport", - str(port_selected), - ] - if self._cmdargs is not None: - cmds.extend(self._cmdargs) - self._interface.bootup(command_lst=cmds, cwd=self.working_directory) def _send_and_receive_dict(self, command, data=None): return self._interface.send_and_receive_dict(