diff --git a/.ci_support/environment-mpich.yml b/.ci_support/environment-mpich.yml index 1ad09f91..dc2efda8 100644 --- a/.ci_support/environment-mpich.yml +++ b/.ci_support/environment-mpich.yml @@ -5,9 +5,11 @@ dependencies: - coverage - codacy-coverage - lammps >=2022.06.23 - - mpi4py =3.1.4 - mpich - numpy - distributed - dask-jobqueue + - cloudpickle =2.2.1 + - mpi4py =3.1.4 + - pympipool =0.4.2 - pyzmq =25.1.0 \ No newline at end of file diff --git a/.ci_support/environment-openmpi.yml b/.ci_support/environment-openmpi.yml index 45ba34d2..1fc2a515 100644 --- a/.ci_support/environment-openmpi.yml +++ b/.ci_support/environment-openmpi.yml @@ -5,9 +5,11 @@ dependencies: - coverage - codacy-coverage - lammps >=2022.06.23 - - mpi4py =3.1.4 - openmpi - numpy - distributed - dask-jobqueue + - cloudpickle =2.2.1 + - mpi4py =3.1.4 + - pympipool =0.4.2 - pyzmq =25.1.0 \ No newline at end of file diff --git a/binder/environment.yml b/binder/environment.yml index 1880c649..8ee0e32e 100644 --- a/binder/environment.yml +++ b/binder/environment.yml @@ -8,3 +8,7 @@ dependencies: - pytest - distributed - dask-jobqueue +- cloudpickle =2.2.1 +- mpi4py =3.1.4 +- pympipool =0.4.2 +- pyzmq =25.1.0 \ No newline at end of file diff --git a/pylammpsmpi/mpi/lmpmpi.py b/pylammpsmpi/mpi/lmpmpi.py index 65e6aa30..16c2c486 100644 --- a/pylammpsmpi/mpi/lmpmpi.py +++ b/pylammpsmpi/mpi/lmpmpi.py @@ -5,7 +5,7 @@ from ctypes import c_double, c_int from mpi4py import MPI import numpy as np -import pickle +import cloudpickle import sys import zmq from lammps import lammps @@ -311,6 +311,7 @@ def installed_packages(funct_args): def set_fix_external_callback(funct_args): job.set_fix_external_callback(*funct_args) + return 1 def get_neighlist(funct_args): @@ -472,7 +473,7 @@ def _gather_data_from_all_processors(data): socket.connect("tcp://localhost:" + port_selected) while True: if MPI.COMM_WORLD.rank == 0: - input_dict = pickle.loads(socket.recv()) + input_dict = cloudpickle.loads(socket.recv()) # with open('process.txt', 'a') as file: # print('Input:', input_dict, file=file) else: @@ -488,4 +489,4 @@ def _gather_data_from_all_processors(data): if MPI.COMM_WORLD.rank == 0 and output is not None: # with open('process.txt', 'a') as file: # print('Output:', output, file=file) - socket.send(pickle.dumps(output)) + socket.send(cloudpickle.dumps({"r": output})) diff --git a/pylammpsmpi/utils/lammps.py b/pylammpsmpi/utils/lammps.py index f6471dad..f3112748 100644 --- a/pylammpsmpi/utils/lammps.py +++ b/pylammpsmpi/utils/lammps.py @@ -3,9 +3,7 @@ # Distributed under the terms of "New BSD License", see the LICENSE file. import os -import pickle -import subprocess -import zmq +from pympipool import SocketInterface __author__ = "Sarath Menon, Jan Janssen" @@ -26,19 +24,16 @@ def __init__( ): self.cores = cores self.working_directory = working_directory + self._interface = SocketInterface() self._process = None self._oversubscribe = oversubscribe self._cmdargs = cmdargs - self._socket = None - self._context = None def start_process(self): + port_selected = self._interface.bind_to_random_port() executable = os.path.join( os.path.dirname(os.path.abspath(__file__)), "../mpi", "lmpmpi.py" ) - self._context = zmq.Context() - self._socket = self._context.socket(zmq.PAIR) - port_selected = self._socket.bind_to_random_port("tcp://*") cmds = ["mpiexec"] if self._oversubscribe: cmds += ["--oversubscribe"] @@ -52,12 +47,11 @@ def start_process(self): ] if self._cmdargs is not None: cmds.extend(self._cmdargs) - self._process = subprocess.Popen( - cmds, - stdout=subprocess.PIPE, - stderr=None, - stdin=subprocess.PIPE, - cwd=self.working_directory, + self._interface.bootup(command_lst=cmds, cwd=self.working_directory) + + def _send_and_receive_dict(self, command, data=None): + return self._interface.send_and_receive_dict( + input_dict={"c": command, "d": data} ) def _send(self, command, data=None): @@ -76,7 +70,7 @@ def _send(self, command, data=None): ------- None """ - self._socket.send(pickle.dumps({"c": command, "d": data})) + self._interface.send_dict({"c": command, "d": data}) def _receive(self): """ @@ -91,8 +85,7 @@ def _receive(self): data : string data from the command """ - output = pickle.loads(self._socket.recv()) - return output + return self._interface.receive_dict() @property def version(self): @@ -108,8 +101,7 @@ def version(self): version: string version string of lammps """ - self._send(command="get_version", data=[]) - return self._receive() + return self._send_and_receive_dict(command="get_version", data=[]) def file(self, inputfile): """ @@ -126,13 +118,11 @@ def file(self, inputfile): """ if not os.path.exists(inputfile): raise FileNotFoundError("Input file does not exist") - self._send(command="get_file", data=[inputfile]) - _ = self._receive() + _ = self._send_and_receive_dict(command="get_file", data=[inputfile]) # TODO def extract_setting(self, *args): - self._send(command="extract_setting", data=list(args)) - return self._receive() + return self._send_and_receive_dict(command="extract_setting", data=list(args)) def extract_global(self, name): """ @@ -160,8 +150,7 @@ def extract_global(self, name): "mvh2r", "angstrom", "femtosecond", "qelectron" """ - self._send(command="extract_global", data=[name]) - return self._receive() + return self._send_and_receive_dict(command="extract_global", data=[name]) def extract_box(self): """ @@ -179,8 +168,7 @@ def extract_box(self): `xy, yz, xz` are the box tilts, `periodicity` is an array which shows if the box is periodic in three dimensions. """ - self._send(command="extract_box", data=[]) - return self._receive() + return self._send_and_receive_dict(command="extract_box", data=[]) def extract_atom(self, name): """ @@ -207,8 +195,7 @@ def extract_atom(self, name): -------- scatter_atoms """ - self._send(command="extract_atom", data=list([name])) - return self._receive() + return self._send_and_receive_dict(command="extract_atom", data=list([name])) def extract_fix(self, *args): """ @@ -240,9 +227,7 @@ def extract_fix(self, *args): value Fix data corresponding to the requested dimensions """ - - self._send(command="extract_fix", data=list(args)) - return self._receive() + return self._send_and_receive_dict(command="extract_fix", data=list(args)) def extract_variable(self, *args): """ @@ -270,8 +255,7 @@ def extract_variable(self, *args): Currently only returns the information provided on a single processor """ - self._send(command="extract_variable", data=list(args)) - return self._receive() + return self._send_and_receive_dict(command="extract_variable", data=list(args)) @property def natoms(self): @@ -290,9 +274,7 @@ def get_natoms(self): natoms : int number of atoms """ - - self._send(command="get_natoms", data=[]) - return self._receive() + return self._send_and_receive_dict(command="get_natoms", data=[]) def set_variable(self, *args): """ @@ -311,8 +293,7 @@ def set_variable(self, *args): flag : int 0 if successfull, -1 otherwise """ - self._send(command="set_variable", data=list(args)) - return self._receive() + return self._send_and_receive_dict(command="set_variable", data=list(args)) def reset_box(self, *args): """ @@ -329,8 +310,7 @@ def reset_box(self, *args): xy, yz, xz : floats box tilts """ - self._send(command="reset_box", data=list(args)) - _ = self._receive() + _ = self._send_and_receive_dict(command="reset_box", data=list(args)) def generate_atoms( self, ids=None, type=None, x=None, v=None, image=None, shrinkexceed=False @@ -404,44 +384,39 @@ def create_atoms(self, n, id, type, x, v=None, image=None, shrinkexceed=False): if x is not None: funct_args = [n, id, type, x, v, image, shrinkexceed] - self._send(command="create_atoms", data=funct_args) - _ = self._receive() + _ = self._send_and_receive_dict(command="create_atoms", data=funct_args) else: raise TypeError("Value of x cannot be None") @property def has_exceptions(self): """Return whether the LAMMPS shared library was compiled with C++ exceptions handling enabled""" - self._send(command="has_exceptions", data=[]) - return self._receive() + return self._send_and_receive_dict(command="has_exceptions", data=[]) @property def has_gzip_support(self): - self._send(command="has_gzip_support", data=[]) - return self._receive() + return self._send_and_receive_dict(command="has_gzip_support", data=[]) @property def has_png_support(self): - self._send(command="has_png_support", data=[]) - return self._receive() + return self._send_and_receive_dict(command="has_png_support", data=[]) @property def has_jpeg_support(self): - self._send(command="has_jpeg_support", data=[]) - return self._receive() + return self._send_and_receive_dict(command="has_jpeg_support", data=[]) @property def has_ffmpeg_support(self): - self._send(command="has_ffmpeg_support", data=[]) - return self._receive() + return self._send_and_receive_dict(command="has_ffmpeg_support", data=[]) @property def installed_packages(self): - self._send(command="get_installed_packages", data=[]) - return self._receive() + return self._send_and_receive_dict(command="get_installed_packages", data=[]) def set_fix_external_callback(self, *args): - self._send(command="set_fix_external_callback", data=list(args)) + _ = self._send_and_receive_dict( + command="set_fix_external_callback", data=list(args) + ) def get_neighlist(self, *args): """Returns an instance of :class:`NeighList` which wraps access to the neighbor list with the given index @@ -450,8 +425,7 @@ def get_neighlist(self, *args): :return: an instance of :class:`NeighList` wrapping access to neighbor list data :rtype: NeighList """ - self._send(command="get_neighlist", data=list(args)) - return self._receive() + return self._send_and_receive_dict(command="get_neighlist", data=list(args)) def find_pair_neighlist(self, *args): """Find neighbor list index of pair style neighbor list @@ -473,8 +447,9 @@ def find_pair_neighlist(self, *args): :return: neighbor list index if found, otherwise -1 :rtype: int """ - self._send(command="find_pair_neighlist", data=list(args)) - return self._receive() + return self._send_and_receive_dict( + command="find_pair_neighlist", data=list(args) + ) def find_fix_neighlist(self, *args): """Find neighbor list index of fix neighbor list @@ -485,8 +460,9 @@ def find_fix_neighlist(self, *args): :return: neighbor list index if found, otherwise -1 :rtype: int """ - self._send(command="find_fix_neighlist", data=list(args)) - return self._receive() + return self._send_and_receive_dict( + command="find_fix_neighlist", data=list(args) + ) def find_compute_neighlist(self, *args): """Find neighbor list index of compute neighbor list @@ -497,8 +473,9 @@ def find_compute_neighlist(self, *args): :return: neighbor list index if found, otherwise -1 :rtype: int """ - self._send(command="find_compute_neighlist", data=list(args)) - return self._receive() + return self._send_and_receive_dict( + command="find_compute_neighlist", data=list(args) + ) def get_neighlist_size(self, *args): """Return the number of elements in neighbor list with the given index @@ -507,12 +484,14 @@ def get_neighlist_size(self, *args): :return: number of elements in neighbor list with index idx :rtype: int """ - self._send(command="get_neighlist_size", data=list(args)) - return self._receive() + return self._send_and_receive_dict( + command="get_neighlist_size", data=list(args) + ) def get_neighlist_element_neighbors(self, *args): - self._send(command="get_neighlist_element_neighbors", data=list(args)) - return self._receive() + return self._send_and_receive_dict( + command="get_neighlist_element_neighbors", data=list(args) + ) def command(self, cmd): """ @@ -529,15 +508,12 @@ def command(self, cmd): """ if isinstance(cmd, list): for c in cmd: - self._send(command="command", data=c) - _ = self._receive() + _ = self._send_and_receive_dict(command="command", data=c) elif len(cmd.split("\n")) > 1: for c in cmd.split("\n"): - self._send(command="command", data=c) - _ = self._receive() + _ = self._send_and_receive_dict(command="command", data=c) else: - self._send(command="command", data=cmd) - _ = self._receive() + _ = self._send_and_receive_dict(command="command", data=cmd) def gather_atoms(self, *args, concat=False, ids=None): """ @@ -572,16 +548,17 @@ def gather_atoms(self, *args, concat=False, ids=None): extract_atoms """ if concat: - self._send(command="gather_atoms_concat", data=list(args)) + return self._send_and_receive_dict( + command="gather_atoms_concat", data=list(args) + ) elif ids is not None: lenids = len(ids) args = list(args) args.append(len(ids)) args.append(ids) - self._send(command="gather_atoms_subset", data=args) + return self._send_and_receive_dict(command="gather_atoms_subset", data=args) else: - self._send(command="gather_atoms", data=list(args)) - return self._receive() + return self._send_and_receive_dict(command="gather_atoms", data=list(args)) def scatter_atoms(self, *args, ids=None): """ @@ -595,11 +572,9 @@ def scatter_atoms(self, *args, ids=None): args = list(args) args.append(len(ids)) args.append(ids) - self._send(command="scatter_atoms_subset", data=args) - _ = self._receive() + _ = self._send_and_receive_dict(command="scatter_atoms_subset", data=args) else: - self._send(command="scatter_atoms", data=list(args)) - _ = self._receive() + _ = self._send_and_receive_dict(command="scatter_atoms", data=list(args)) def get_thermo(self, *args): """ @@ -616,8 +591,7 @@ def get_thermo(self, *args): value of the thermo keyword """ - self._send(command="get_thermo", data=list(args)) - return self._receive() + return self._send_and_receive_dict(command="get_thermo", data=list(args)) # TODO def extract_compute(self, id, style, type, length=0, width=0): @@ -653,8 +627,7 @@ def extract_compute(self, id, style, type, length=0, width=0): """ args = [id, style, type, length, width] - self._send(command="extract_compute", data=args) - return self._receive() + return self._send_and_receive_dict(command="extract_compute", data=args) def close(self): """ @@ -668,19 +641,7 @@ def close(self): ------- None """ - self._send(command="close") - try: - self._process.kill() - self._process.stdout.close() - self._process.stdin.close() - self._process.wait() - self._socket.close() - self._context.term() - except AttributeError: - pass - self._process = None - self._socket = None - self._context = None + self._interface.shutdown(wait=True) # TODO def __del__(self): diff --git a/setup.py b/setup.py index eb60d8af..95779c87 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,7 @@ keywords='lammps, mpi4py', packages=find_packages(exclude=["*tests*"]), install_requires=[ - 'mpi4py==3.1.4', "pyzmq==25.1.0" + "cloudpickle==2.2.1", "mpi4py==3.1.4", "pympipool==0.4.2", "pyzmq==25.1.0", ], cmdclass=versioneer.get_cmdclass(), )